diff options
Diffstat (limited to 'core/src/main/scala/org/apache')
266 files changed, 34927 insertions, 0 deletions
diff --git a/core/src/main/scala/org/apache/hadoop/mapred/SparkHadoopMapRedUtil.scala b/core/src/main/scala/org/apache/hadoop/mapred/SparkHadoopMapRedUtil.scala new file mode 100644 index 0000000000..f87460039b --- /dev/null +++ b/core/src/main/scala/org/apache/hadoop/mapred/SparkHadoopMapRedUtil.scala @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.mapred + +trait SparkHadoopMapRedUtil { + def newJobContext(conf: JobConf, jobId: JobID): JobContext = { + val klass = firstAvailableClass("org.apache.hadoop.mapred.JobContextImpl", "org.apache.hadoop.mapred.JobContext"); + val ctor = klass.getDeclaredConstructor(classOf[JobConf], classOf[org.apache.hadoop.mapreduce.JobID]) + ctor.newInstance(conf, jobId).asInstanceOf[JobContext] + } + + def newTaskAttemptContext(conf: JobConf, attemptId: TaskAttemptID): TaskAttemptContext = { + val klass = firstAvailableClass("org.apache.hadoop.mapred.TaskAttemptContextImpl", "org.apache.hadoop.mapred.TaskAttemptContext") + val ctor = klass.getDeclaredConstructor(classOf[JobConf], classOf[TaskAttemptID]) + ctor.newInstance(conf, attemptId).asInstanceOf[TaskAttemptContext] + } + + def newTaskAttemptID(jtIdentifier: String, jobId: Int, isMap: Boolean, taskId: Int, attemptId: Int) = { + new TaskAttemptID(jtIdentifier, jobId, isMap, taskId, attemptId) + } + + private def firstAvailableClass(first: String, second: String): Class[_] = { + try { + Class.forName(first) + } catch { + case e: ClassNotFoundException => + Class.forName(second) + } + } +} diff --git a/core/src/main/scala/org/apache/hadoop/mapreduce/SparkHadoopMapReduceUtil.scala b/core/src/main/scala/org/apache/hadoop/mapreduce/SparkHadoopMapReduceUtil.scala new file mode 100644 index 0000000000..93180307fa --- /dev/null +++ b/core/src/main/scala/org/apache/hadoop/mapreduce/SparkHadoopMapReduceUtil.scala @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.mapreduce + +import org.apache.hadoop.conf.Configuration +import java.lang.{Integer => JInteger, Boolean => JBoolean} + +trait SparkHadoopMapReduceUtil { + def newJobContext(conf: Configuration, jobId: JobID): JobContext = { + val klass = firstAvailableClass( + "org.apache.hadoop.mapreduce.task.JobContextImpl", // hadoop2, hadoop2-yarn + "org.apache.hadoop.mapreduce.JobContext") // hadoop1 + val ctor = klass.getDeclaredConstructor(classOf[Configuration], classOf[JobID]) + ctor.newInstance(conf, jobId).asInstanceOf[JobContext] + } + + def newTaskAttemptContext(conf: Configuration, attemptId: TaskAttemptID): TaskAttemptContext = { + val klass = firstAvailableClass( + "org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl", // hadoop2, hadoop2-yarn + "org.apache.hadoop.mapreduce.TaskAttemptContext") // hadoop1 + val ctor = klass.getDeclaredConstructor(classOf[Configuration], classOf[TaskAttemptID]) + ctor.newInstance(conf, attemptId).asInstanceOf[TaskAttemptContext] + } + + def newTaskAttemptID(jtIdentifier: String, jobId: Int, isMap: Boolean, taskId: Int, attemptId: Int) = { + val klass = Class.forName("org.apache.hadoop.mapreduce.TaskAttemptID"); + try { + // first, attempt to use the old-style constructor that takes a boolean isMap (not available in YARN) + val ctor = klass.getDeclaredConstructor(classOf[String], classOf[Int], classOf[Boolean], + classOf[Int], classOf[Int]) + ctor.newInstance(jtIdentifier, new JInteger(jobId), new JBoolean(isMap), new JInteger(taskId), new + JInteger(attemptId)).asInstanceOf[TaskAttemptID] + } catch { + case exc: NoSuchMethodException => { + // failed, look for the new ctor that takes a TaskType (not available in 1.x) + val taskTypeClass = Class.forName("org.apache.hadoop.mapreduce.TaskType").asInstanceOf[Class[Enum[_]]] + val taskType = taskTypeClass.getMethod("valueOf", classOf[String]).invoke(taskTypeClass, if(isMap) "MAP" else "REDUCE") + val ctor = klass.getDeclaredConstructor(classOf[String], classOf[Int], taskTypeClass, + classOf[Int], classOf[Int]) + ctor.newInstance(jtIdentifier, new JInteger(jobId), taskType, new JInteger(taskId), new + JInteger(attemptId)).asInstanceOf[TaskAttemptID] + } + } + } + + private def firstAvailableClass(first: String, second: String): Class[_] = { + try { + Class.forName(first) + } catch { + case e: ClassNotFoundException => + Class.forName(second) + } + } +} diff --git a/core/src/main/scala/org/apache/spark/Accumulators.scala b/core/src/main/scala/org/apache/spark/Accumulators.scala new file mode 100644 index 0000000000..6e922a612a --- /dev/null +++ b/core/src/main/scala/org/apache/spark/Accumulators.scala @@ -0,0 +1,257 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import java.io._ + +import scala.collection.mutable.Map +import scala.collection.generic.Growable +import org.apache.spark.serializer.JavaSerializer + +/** + * A datatype that can be accumulated, i.e. has an commutative and associative "add" operation, + * but where the result type, `R`, may be different from the element type being added, `T`. + * + * You must define how to add data, and how to merge two of these together. For some datatypes, + * such as a counter, these might be the same operation. In that case, you can use the simpler + * [[org.apache.spark.Accumulator]]. They won't always be the same, though -- e.g., imagine you are + * accumulating a set. You will add items to the set, and you will union two sets together. + * + * @param initialValue initial value of accumulator + * @param param helper object defining how to add elements of type `R` and `T` + * @tparam R the full accumulated data (result type) + * @tparam T partial data that can be added in + */ +class Accumulable[R, T] ( + @transient initialValue: R, + param: AccumulableParam[R, T]) + extends Serializable { + + val id = Accumulators.newId + @transient private var value_ = initialValue // Current value on master + val zero = param.zero(initialValue) // Zero value to be passed to workers + var deserialized = false + + Accumulators.register(this, true) + + /** + * Add more data to this accumulator / accumulable + * @param term the data to add + */ + def += (term: T) { value_ = param.addAccumulator(value_, term) } + + /** + * Add more data to this accumulator / accumulable + * @param term the data to add + */ + def add(term: T) { value_ = param.addAccumulator(value_, term) } + + /** + * Merge two accumulable objects together + * + * Normally, a user will not want to use this version, but will instead call `+=`. + * @param term the other `R` that will get merged with this + */ + def ++= (term: R) { value_ = param.addInPlace(value_, term)} + + /** + * Merge two accumulable objects together + * + * Normally, a user will not want to use this version, but will instead call `add`. + * @param term the other `R` that will get merged with this + */ + def merge(term: R) { value_ = param.addInPlace(value_, term)} + + /** + * Access the accumulator's current value; only allowed on master. + */ + def value: R = { + if (!deserialized) { + value_ + } else { + throw new UnsupportedOperationException("Can't read accumulator value in task") + } + } + + /** + * Get the current value of this accumulator from within a task. + * + * This is NOT the global value of the accumulator. To get the global value after a + * completed operation on the dataset, call `value`. + * + * The typical use of this method is to directly mutate the local value, eg., to add + * an element to a Set. + */ + def localValue = value_ + + /** + * Set the accumulator's value; only allowed on master. + */ + def value_= (newValue: R) { + if (!deserialized) value_ = newValue + else throw new UnsupportedOperationException("Can't assign accumulator value in task") + } + + /** + * Set the accumulator's value; only allowed on master + */ + def setValue(newValue: R) { + this.value = newValue + } + + // Called by Java when deserializing an object + private def readObject(in: ObjectInputStream) { + in.defaultReadObject() + value_ = zero + deserialized = true + Accumulators.register(this, false) + } + + override def toString = value_.toString +} + +/** + * Helper object defining how to accumulate values of a particular type. An implicit + * AccumulableParam needs to be available when you create Accumulables of a specific type. + * + * @tparam R the full accumulated data (result type) + * @tparam T partial data that can be added in + */ +trait AccumulableParam[R, T] extends Serializable { + /** + * Add additional data to the accumulator value. Is allowed to modify and return `r` + * for efficiency (to avoid allocating objects). + * + * @param r the current value of the accumulator + * @param t the data to be added to the accumulator + * @return the new value of the accumulator + */ + def addAccumulator(r: R, t: T): R + + /** + * Merge two accumulated values together. Is allowed to modify and return the first value + * for efficiency (to avoid allocating objects). + * + * @param r1 one set of accumulated data + * @param r2 another set of accumulated data + * @return both data sets merged together + */ + def addInPlace(r1: R, r2: R): R + + /** + * Return the "zero" (identity) value for an accumulator type, given its initial value. For + * example, if R was a vector of N dimensions, this would return a vector of N zeroes. + */ + def zero(initialValue: R): R +} + +private[spark] +class GrowableAccumulableParam[R <% Growable[T] with TraversableOnce[T] with Serializable, T] + extends AccumulableParam[R,T] { + + def addAccumulator(growable: R, elem: T): R = { + growable += elem + growable + } + + def addInPlace(t1: R, t2: R): R = { + t1 ++= t2 + t1 + } + + def zero(initialValue: R): R = { + // We need to clone initialValue, but it's hard to specify that R should also be Cloneable. + // Instead we'll serialize it to a buffer and load it back. + val ser = new JavaSerializer().newInstance() + val copy = ser.deserialize[R](ser.serialize(initialValue)) + copy.clear() // In case it contained stuff + copy + } +} + +/** + * A simpler value of [[org.apache.spark.Accumulable]] where the result type being accumulated is the same + * as the types of elements being merged. + * + * @param initialValue initial value of accumulator + * @param param helper object defining how to add elements of type `T` + * @tparam T result type + */ +class Accumulator[T](@transient initialValue: T, param: AccumulatorParam[T]) + extends Accumulable[T,T](initialValue, param) + +/** + * A simpler version of [[org.apache.spark.AccumulableParam]] where the only datatype you can add in is the same type + * as the accumulated value. An implicit AccumulatorParam object needs to be available when you create + * Accumulators of a specific type. + * + * @tparam T type of value to accumulate + */ +trait AccumulatorParam[T] extends AccumulableParam[T, T] { + def addAccumulator(t1: T, t2: T): T = { + addInPlace(t1, t2) + } +} + +// TODO: The multi-thread support in accumulators is kind of lame; check +// if there's a more intuitive way of doing it right +private object Accumulators { + // TODO: Use soft references? => need to make readObject work properly then + val originals = Map[Long, Accumulable[_, _]]() + val localAccums = Map[Thread, Map[Long, Accumulable[_, _]]]() + var lastId: Long = 0 + + def newId: Long = synchronized { + lastId += 1 + return lastId + } + + def register(a: Accumulable[_, _], original: Boolean): Unit = synchronized { + if (original) { + originals(a.id) = a + } else { + val accums = localAccums.getOrElseUpdate(Thread.currentThread, Map()) + accums(a.id) = a + } + } + + // Clear the local (non-original) accumulators for the current thread + def clear() { + synchronized { + localAccums.remove(Thread.currentThread) + } + } + + // Get the values of the local accumulators for the current thread (by ID) + def values: Map[Long, Any] = synchronized { + val ret = Map[Long, Any]() + for ((id, accum) <- localAccums.getOrElse(Thread.currentThread, Map())) { + ret(id) = accum.localValue + } + return ret + } + + // Add values to the original accumulators with some given IDs + def add(values: Map[Long, Any]): Unit = synchronized { + for ((id, value) <- values) { + if (originals.contains(id)) { + originals(id).asInstanceOf[Accumulable[Any, Any]] ++= value + } + } + } +} diff --git a/core/src/main/scala/org/apache/spark/Aggregator.scala b/core/src/main/scala/org/apache/spark/Aggregator.scala new file mode 100644 index 0000000000..3ef402926e --- /dev/null +++ b/core/src/main/scala/org/apache/spark/Aggregator.scala @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import java.util.{HashMap => JHashMap} + +import scala.collection.JavaConversions._ + +/** A set of functions used to aggregate data. + * + * @param createCombiner function to create the initial value of the aggregation. + * @param mergeValue function to merge a new value into the aggregation result. + * @param mergeCombiners function to merge outputs from multiple mergeValue function. + */ +case class Aggregator[K, V, C] ( + createCombiner: V => C, + mergeValue: (C, V) => C, + mergeCombiners: (C, C) => C) { + + def combineValuesByKey(iter: Iterator[_ <: Product2[K, V]]) : Iterator[(K, C)] = { + val combiners = new JHashMap[K, C] + for (kv <- iter) { + val oldC = combiners.get(kv._1) + if (oldC == null) { + combiners.put(kv._1, createCombiner(kv._2)) + } else { + combiners.put(kv._1, mergeValue(oldC, kv._2)) + } + } + combiners.iterator + } + + def combineCombinersByKey(iter: Iterator[(K, C)]) : Iterator[(K, C)] = { + val combiners = new JHashMap[K, C] + iter.foreach { case(k, c) => + val oldC = combiners.get(k) + if (oldC == null) { + combiners.put(k, c) + } else { + combiners.put(k, mergeCombiners(oldC, c)) + } + } + combiners.iterator + } +} + diff --git a/core/src/main/scala/org/apache/spark/BlockStoreShuffleFetcher.scala b/core/src/main/scala/org/apache/spark/BlockStoreShuffleFetcher.scala new file mode 100644 index 0000000000..908ff56a6b --- /dev/null +++ b/core/src/main/scala/org/apache/spark/BlockStoreShuffleFetcher.scala @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.HashMap + +import org.apache.spark.executor.{ShuffleReadMetrics, TaskMetrics} +import org.apache.spark.serializer.Serializer +import org.apache.spark.storage.BlockManagerId +import org.apache.spark.util.CompletionIterator + + +private[spark] class BlockStoreShuffleFetcher extends ShuffleFetcher with Logging { + + override def fetch[T](shuffleId: Int, reduceId: Int, metrics: TaskMetrics, serializer: Serializer) + : Iterator[T] = + { + + logDebug("Fetching outputs for shuffle %d, reduce %d".format(shuffleId, reduceId)) + val blockManager = SparkEnv.get.blockManager + + val startTime = System.currentTimeMillis + val statuses = SparkEnv.get.mapOutputTracker.getServerStatuses(shuffleId, reduceId) + logDebug("Fetching map output location for shuffle %d, reduce %d took %d ms".format( + shuffleId, reduceId, System.currentTimeMillis - startTime)) + + val splitsByAddress = new HashMap[BlockManagerId, ArrayBuffer[(Int, Long)]] + for (((address, size), index) <- statuses.zipWithIndex) { + splitsByAddress.getOrElseUpdate(address, ArrayBuffer()) += ((index, size)) + } + + val blocksByAddress: Seq[(BlockManagerId, Seq[(String, Long)])] = splitsByAddress.toSeq.map { + case (address, splits) => + (address, splits.map(s => ("shuffle_%d_%d_%d".format(shuffleId, s._1, reduceId), s._2))) + } + + def unpackBlock(blockPair: (String, Option[Iterator[Any]])) : Iterator[T] = { + val blockId = blockPair._1 + val blockOption = blockPair._2 + blockOption match { + case Some(block) => { + block.asInstanceOf[Iterator[T]] + } + case None => { + val regex = "shuffle_([0-9]*)_([0-9]*)_([0-9]*)".r + blockId match { + case regex(shufId, mapId, _) => + val address = statuses(mapId.toInt)._1 + throw new FetchFailedException(address, shufId.toInt, mapId.toInt, reduceId, null) + case _ => + throw new SparkException( + "Failed to get block " + blockId + ", which is not a shuffle block") + } + } + } + } + + val blockFetcherItr = blockManager.getMultiple(blocksByAddress, serializer) + val itr = blockFetcherItr.flatMap(unpackBlock) + + CompletionIterator[T, Iterator[T]](itr, { + val shuffleMetrics = new ShuffleReadMetrics + shuffleMetrics.shuffleFinishTime = System.currentTimeMillis + shuffleMetrics.remoteFetchTime = blockFetcherItr.remoteFetchTime + shuffleMetrics.fetchWaitTime = blockFetcherItr.fetchWaitTime + shuffleMetrics.remoteBytesRead = blockFetcherItr.remoteBytesRead + shuffleMetrics.totalBlocksFetched = blockFetcherItr.totalBlocks + shuffleMetrics.localBlocksFetched = blockFetcherItr.numLocalBlocks + shuffleMetrics.remoteBlocksFetched = blockFetcherItr.numRemoteBlocks + metrics.shuffleReadMetrics = Some(shuffleMetrics) + }) + } +} diff --git a/core/src/main/scala/org/apache/spark/CacheManager.scala b/core/src/main/scala/org/apache/spark/CacheManager.scala new file mode 100644 index 0000000000..e299a106ee --- /dev/null +++ b/core/src/main/scala/org/apache/spark/CacheManager.scala @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import scala.collection.mutable.{ArrayBuffer, HashSet} +import org.apache.spark.storage.{BlockManager, StorageLevel} +import org.apache.spark.rdd.RDD + + +/** Spark class responsible for passing RDDs split contents to the BlockManager and making + sure a node doesn't load two copies of an RDD at once. + */ +private[spark] class CacheManager(blockManager: BlockManager) extends Logging { + private val loading = new HashSet[String] + + /** Gets or computes an RDD split. Used by RDD.iterator() when an RDD is cached. */ + def getOrCompute[T](rdd: RDD[T], split: Partition, context: TaskContext, storageLevel: StorageLevel) + : Iterator[T] = { + val key = "rdd_%d_%d".format(rdd.id, split.index) + logInfo("Cache key is " + key) + blockManager.get(key) match { + case Some(cachedValues) => + // Partition is in cache, so just return its values + logInfo("Found partition in cache!") + return cachedValues.asInstanceOf[Iterator[T]] + + case None => + // Mark the split as loading (unless someone else marks it first) + loading.synchronized { + if (loading.contains(key)) { + logInfo("Loading contains " + key + ", waiting...") + while (loading.contains(key)) { + try {loading.wait()} catch {case _ : Throwable =>} + } + logInfo("Loading no longer contains " + key + ", so returning cached result") + // See whether someone else has successfully loaded it. The main way this would fail + // is for the RDD-level cache eviction policy if someone else has loaded the same RDD + // partition but we didn't want to make space for it. However, that case is unlikely + // because it's unlikely that two threads would work on the same RDD partition. One + // downside of the current code is that threads wait serially if this does happen. + blockManager.get(key) match { + case Some(values) => + return values.asInstanceOf[Iterator[T]] + case None => + logInfo("Whoever was loading " + key + " failed; we'll try it ourselves") + loading.add(key) + } + } else { + loading.add(key) + } + } + try { + // If we got here, we have to load the split + val elements = new ArrayBuffer[Any] + logInfo("Computing partition " + split) + elements ++= rdd.computeOrReadCheckpoint(split, context) + // Try to put this block in the blockManager + blockManager.put(key, elements, storageLevel, true) + return elements.iterator.asInstanceOf[Iterator[T]] + } finally { + loading.synchronized { + loading.remove(key) + loading.notifyAll() + } + } + } + } +} diff --git a/core/src/main/scala/org/apache/spark/Dependency.scala b/core/src/main/scala/org/apache/spark/Dependency.scala new file mode 100644 index 0000000000..cc30105940 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/Dependency.scala @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import org.apache.spark.rdd.RDD + +/** + * Base class for dependencies. + */ +abstract class Dependency[T](val rdd: RDD[T]) extends Serializable + + +/** + * Base class for dependencies where each partition of the parent RDD is used by at most one + * partition of the child RDD. Narrow dependencies allow for pipelined execution. + */ +abstract class NarrowDependency[T](rdd: RDD[T]) extends Dependency(rdd) { + /** + * Get the parent partitions for a child partition. + * @param partitionId a partition of the child RDD + * @return the partitions of the parent RDD that the child partition depends upon + */ + def getParents(partitionId: Int): Seq[Int] +} + + +/** + * Represents a dependency on the output of a shuffle stage. + * @param rdd the parent RDD + * @param partitioner partitioner used to partition the shuffle output + * @param serializerClass class name of the serializer to use + */ +class ShuffleDependency[K, V]( + @transient rdd: RDD[_ <: Product2[K, V]], + val partitioner: Partitioner, + val serializerClass: String = null) + extends Dependency(rdd.asInstanceOf[RDD[Product2[K, V]]]) { + + val shuffleId: Int = rdd.context.newShuffleId() +} + + +/** + * Represents a one-to-one dependency between partitions of the parent and child RDDs. + */ +class OneToOneDependency[T](rdd: RDD[T]) extends NarrowDependency[T](rdd) { + override def getParents(partitionId: Int) = List(partitionId) +} + + +/** + * Represents a one-to-one dependency between ranges of partitions in the parent and child RDDs. + * @param rdd the parent RDD + * @param inStart the start of the range in the parent RDD + * @param outStart the start of the range in the child RDD + * @param length the length of the range + */ +class RangeDependency[T](rdd: RDD[T], inStart: Int, outStart: Int, length: Int) + extends NarrowDependency[T](rdd) { + + override def getParents(partitionId: Int) = { + if (partitionId >= outStart && partitionId < outStart + length) { + List(partitionId - outStart + inStart) + } else { + Nil + } + } +} diff --git a/core/src/main/scala/org/apache/spark/FetchFailedException.scala b/core/src/main/scala/org/apache/spark/FetchFailedException.scala new file mode 100644 index 0000000000..d242047502 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/FetchFailedException.scala @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import org.apache.spark.storage.BlockManagerId + +private[spark] class FetchFailedException( + taskEndReason: TaskEndReason, + message: String, + cause: Throwable) + extends Exception { + + def this (bmAddress: BlockManagerId, shuffleId: Int, mapId: Int, reduceId: Int, cause: Throwable) = + this(FetchFailed(bmAddress, shuffleId, mapId, reduceId), + "Fetch failed: %s %d %d %d".format(bmAddress, shuffleId, mapId, reduceId), + cause) + + def this (shuffleId: Int, reduceId: Int, cause: Throwable) = + this(FetchFailed(null, shuffleId, -1, reduceId), + "Unable to fetch locations from master: %d %d".format(shuffleId, reduceId), cause) + + override def getMessage(): String = message + + + override def getCause(): Throwable = cause + + def toTaskEndReason: TaskEndReason = taskEndReason + +} diff --git a/core/src/main/scala/org/apache/spark/HttpFileServer.scala b/core/src/main/scala/org/apache/spark/HttpFileServer.scala new file mode 100644 index 0000000000..ad1ee20045 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/HttpFileServer.scala @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import java.io.{File} +import com.google.common.io.Files +import org.apache.spark.util.Utils + +private[spark] class HttpFileServer extends Logging { + + var baseDir : File = null + var fileDir : File = null + var jarDir : File = null + var httpServer : HttpServer = null + var serverUri : String = null + + def initialize() { + baseDir = Utils.createTempDir() + fileDir = new File(baseDir, "files") + jarDir = new File(baseDir, "jars") + fileDir.mkdir() + jarDir.mkdir() + logInfo("HTTP File server directory is " + baseDir) + httpServer = new HttpServer(baseDir) + httpServer.start() + serverUri = httpServer.uri + } + + def stop() { + httpServer.stop() + } + + def addFile(file: File) : String = { + addFileToDir(file, fileDir) + return serverUri + "/files/" + file.getName + } + + def addJar(file: File) : String = { + addFileToDir(file, jarDir) + return serverUri + "/jars/" + file.getName + } + + def addFileToDir(file: File, dir: File) : String = { + Files.copy(file, new File(dir, file.getName)) + return dir + "/" + file.getName + } + +} diff --git a/core/src/main/scala/org/apache/spark/HttpServer.scala b/core/src/main/scala/org/apache/spark/HttpServer.scala new file mode 100644 index 0000000000..cdfc9dd54e --- /dev/null +++ b/core/src/main/scala/org/apache/spark/HttpServer.scala @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import java.io.File +import java.net.InetAddress + +import org.eclipse.jetty.server.Server +import org.eclipse.jetty.server.bio.SocketConnector +import org.eclipse.jetty.server.handler.DefaultHandler +import org.eclipse.jetty.server.handler.HandlerList +import org.eclipse.jetty.server.handler.ResourceHandler +import org.eclipse.jetty.util.thread.QueuedThreadPool +import org.apache.spark.util.Utils + +/** + * Exception type thrown by HttpServer when it is in the wrong state for an operation. + */ +private[spark] class ServerStateException(message: String) extends Exception(message) + +/** + * An HTTP server for static content used to allow worker nodes to access JARs added to SparkContext + * as well as classes created by the interpreter when the user types in code. This is just a wrapper + * around a Jetty server. + */ +private[spark] class HttpServer(resourceBase: File) extends Logging { + private var server: Server = null + private var port: Int = -1 + + def start() { + if (server != null) { + throw new ServerStateException("Server is already started") + } else { + server = new Server() + val connector = new SocketConnector + connector.setMaxIdleTime(60*1000) + connector.setSoLingerTime(-1) + connector.setPort(0) + server.addConnector(connector) + + val threadPool = new QueuedThreadPool + threadPool.setDaemon(true) + server.setThreadPool(threadPool) + val resHandler = new ResourceHandler + resHandler.setResourceBase(resourceBase.getAbsolutePath) + val handlerList = new HandlerList + handlerList.setHandlers(Array(resHandler, new DefaultHandler)) + server.setHandler(handlerList) + server.start() + port = server.getConnectors()(0).getLocalPort() + } + } + + def stop() { + if (server == null) { + throw new ServerStateException("Server is already stopped") + } else { + server.stop() + port = -1 + server = null + } + } + + /** + * Get the URI of this HTTP server (http://host:port) + */ + def uri: String = { + if (server == null) { + throw new ServerStateException("Server is not started") + } else { + return "http://" + Utils.localIpAddress + ":" + port + } + } +} diff --git a/core/src/main/scala/org/apache/spark/Logging.scala b/core/src/main/scala/org/apache/spark/Logging.scala new file mode 100644 index 0000000000..6a973ea495 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/Logging.scala @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import org.slf4j.Logger +import org.slf4j.LoggerFactory + +/** + * Utility trait for classes that want to log data. Creates a SLF4J logger for the class and allows + * logging messages at different levels using methods that only evaluate parameters lazily if the + * log level is enabled. + */ +trait Logging { + // Make the log field transient so that objects with Logging can + // be serialized and used on another machine + @transient private var log_ : Logger = null + + // Method to get or create the logger for this object + protected def log: Logger = { + if (log_ == null) { + var className = this.getClass.getName + // Ignore trailing $'s in the class names for Scala objects + if (className.endsWith("$")) { + className = className.substring(0, className.length - 1) + } + log_ = LoggerFactory.getLogger(className) + } + return log_ + } + + // Log methods that take only a String + protected def logInfo(msg: => String) { + if (log.isInfoEnabled) log.info(msg) + } + + protected def logDebug(msg: => String) { + if (log.isDebugEnabled) log.debug(msg) + } + + protected def logTrace(msg: => String) { + if (log.isTraceEnabled) log.trace(msg) + } + + protected def logWarning(msg: => String) { + if (log.isWarnEnabled) log.warn(msg) + } + + protected def logError(msg: => String) { + if (log.isErrorEnabled) log.error(msg) + } + + // Log methods that take Throwables (Exceptions/Errors) too + protected def logInfo(msg: => String, throwable: Throwable) { + if (log.isInfoEnabled) log.info(msg, throwable) + } + + protected def logDebug(msg: => String, throwable: Throwable) { + if (log.isDebugEnabled) log.debug(msg, throwable) + } + + protected def logTrace(msg: => String, throwable: Throwable) { + if (log.isTraceEnabled) log.trace(msg, throwable) + } + + protected def logWarning(msg: => String, throwable: Throwable) { + if (log.isWarnEnabled) log.warn(msg, throwable) + } + + protected def logError(msg: => String, throwable: Throwable) { + if (log.isErrorEnabled) log.error(msg, throwable) + } + + protected def isTraceEnabled(): Boolean = { + log.isTraceEnabled + } + + // Method for ensuring that logging is initialized, to avoid having multiple + // threads do it concurrently (as SLF4J initialization is not thread safe). + protected def initLogging() { log } +} diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala new file mode 100644 index 0000000000..1afb1870f1 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -0,0 +1,340 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import java.io._ +import java.util.zip.{GZIPInputStream, GZIPOutputStream} + +import scala.collection.mutable.HashMap +import scala.collection.mutable.HashSet + +import akka.actor._ +import scala.concurrent.Await +import akka.pattern.ask +import akka.remote._ + +import scala.concurrent.duration.Duration +import akka.util.Timeout +import scala.concurrent.duration._ + +import org.apache.spark.scheduler.MapStatus +import org.apache.spark.storage.BlockManagerId +import org.apache.spark.util.{Utils, MetadataCleaner, TimeStampedHashMap} + + +private[spark] sealed trait MapOutputTrackerMessage +private[spark] case class GetMapOutputStatuses(shuffleId: Int, requester: String) + extends MapOutputTrackerMessage +private[spark] case object StopMapOutputTracker extends MapOutputTrackerMessage + +private[spark] class MapOutputTrackerActor(tracker: MapOutputTracker) extends Actor with Logging { + def receive = { + case GetMapOutputStatuses(shuffleId: Int, requester: String) => + logInfo("Asked to send map output locations for shuffle " + shuffleId + " to " + requester) + sender ! tracker.getSerializedLocations(shuffleId) + + case StopMapOutputTracker => + logInfo("MapOutputTrackerActor stopped!") + sender ! true + context.stop(self) + } +} + +private[spark] class MapOutputTracker extends Logging { + + private val timeout = Duration.create(System.getProperty("spark.akka.askTimeout", "10").toLong, "seconds") + + // Set to the MapOutputTrackerActor living on the driver + var trackerActor: ActorRef = _ + + private var mapStatuses = new TimeStampedHashMap[Int, Array[MapStatus]] + + // Incremented every time a fetch fails so that client nodes know to clear + // their cache of map output locations if this happens. + private var epoch: Long = 0 + private val epochLock = new java.lang.Object + + // Cache a serialized version of the output statuses for each shuffle to send them out faster + var cacheEpoch = epoch + private val cachedSerializedStatuses = new TimeStampedHashMap[Int, Array[Byte]] + + val metadataCleaner = new MetadataCleaner("MapOutputTracker", this.cleanup) + + // Send a message to the trackerActor and get its result within a default timeout, or + // throw a SparkException if this fails. + def askTracker(message: Any): Any = { + try { + val future = trackerActor.ask(message)(timeout) + return Await.result(future, timeout) + } catch { + case e: Exception => + throw new SparkException("Error communicating with MapOutputTracker", e) + } + } + + // Send a one-way message to the trackerActor, to which we expect it to reply with true. + def communicate(message: Any) { + if (askTracker(message) != true) { + throw new SparkException("Error reply received from MapOutputTracker") + } + } + + def registerShuffle(shuffleId: Int, numMaps: Int) { + if (mapStatuses.putIfAbsent(shuffleId, new Array[MapStatus](numMaps)).isDefined) { + throw new IllegalArgumentException("Shuffle ID " + shuffleId + " registered twice") + } + } + + def registerMapOutput(shuffleId: Int, mapId: Int, status: MapStatus) { + var array = mapStatuses(shuffleId) + array.synchronized { + array(mapId) = status + } + } + + def registerMapOutputs( + shuffleId: Int, + statuses: Array[MapStatus], + changeEpoch: Boolean = false) { + mapStatuses.put(shuffleId, Array[MapStatus]() ++ statuses) + if (changeEpoch) { + incrementEpoch() + } + } + + def unregisterMapOutput(shuffleId: Int, mapId: Int, bmAddress: BlockManagerId) { + var arrayOpt = mapStatuses.get(shuffleId) + if (arrayOpt.isDefined && arrayOpt.get != null) { + var array = arrayOpt.get + array.synchronized { + if (array(mapId) != null && array(mapId).location == bmAddress) { + array(mapId) = null + } + } + incrementEpoch() + } else { + throw new SparkException("unregisterMapOutput called for nonexistent shuffle ID") + } + } + + // Remembers which map output locations are currently being fetched on a worker + private val fetching = new HashSet[Int] + + // Called on possibly remote nodes to get the server URIs and output sizes for a given shuffle + def getServerStatuses(shuffleId: Int, reduceId: Int): Array[(BlockManagerId, Long)] = { + val statuses = mapStatuses.get(shuffleId).orNull + if (statuses == null) { + logInfo("Don't have map outputs for shuffle " + shuffleId + ", fetching them") + var fetchedStatuses: Array[MapStatus] = null + fetching.synchronized { + if (fetching.contains(shuffleId)) { + // Someone else is fetching it; wait for them to be done + while (fetching.contains(shuffleId)) { + try { + fetching.wait() + } catch { + case e: InterruptedException => + } + } + } + + // Either while we waited the fetch happened successfully, or + // someone fetched it in between the get and the fetching.synchronized. + fetchedStatuses = mapStatuses.get(shuffleId).orNull + if (fetchedStatuses == null) { + // We have to do the fetch, get others to wait for us. + fetching += shuffleId + } + } + + if (fetchedStatuses == null) { + // We won the race to fetch the output locs; do so + logInfo("Doing the fetch; tracker actor = " + trackerActor) + val hostPort = Utils.localHostPort() + // This try-finally prevents hangs due to timeouts: + try { + val fetchedBytes = + askTracker(GetMapOutputStatuses(shuffleId, hostPort)).asInstanceOf[Array[Byte]] + fetchedStatuses = deserializeStatuses(fetchedBytes) + logInfo("Got the output locations") + mapStatuses.put(shuffleId, fetchedStatuses) + } finally { + fetching.synchronized { + fetching -= shuffleId + fetching.notifyAll() + } + } + } + if (fetchedStatuses != null) { + fetchedStatuses.synchronized { + return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, fetchedStatuses) + } + } + else{ + throw new FetchFailedException(null, shuffleId, -1, reduceId, + new Exception("Missing all output locations for shuffle " + shuffleId)) + } + } else { + statuses.synchronized { + return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, statuses) + } + } + } + + private def cleanup(cleanupTime: Long) { + mapStatuses.clearOldValues(cleanupTime) + cachedSerializedStatuses.clearOldValues(cleanupTime) + } + + def stop() { + communicate(StopMapOutputTracker) + mapStatuses.clear() + metadataCleaner.cancel() + trackerActor = null + } + + // Called on master to increment the epoch number + def incrementEpoch() { + epochLock.synchronized { + epoch += 1 + logDebug("Increasing epoch to " + epoch) + } + } + + // Called on master or workers to get current epoch number + def getEpoch: Long = { + epochLock.synchronized { + return epoch + } + } + + // Called on workers to update the epoch number, potentially clearing old outputs + // because of a fetch failure. (Each worker task calls this with the latest epoch + // number on the master at the time it was created.) + def updateEpoch(newEpoch: Long) { + epochLock.synchronized { + if (newEpoch > epoch) { + logInfo("Updating epoch to " + newEpoch + " and clearing cache") + // mapStatuses = new TimeStampedHashMap[Int, Array[MapStatus]] + mapStatuses.clear() + epoch = newEpoch + } + } + } + + def getSerializedLocations(shuffleId: Int): Array[Byte] = { + var statuses: Array[MapStatus] = null + var epochGotten: Long = -1 + epochLock.synchronized { + if (epoch > cacheEpoch) { + cachedSerializedStatuses.clear() + cacheEpoch = epoch + } + cachedSerializedStatuses.get(shuffleId) match { + case Some(bytes) => + return bytes + case None => + statuses = mapStatuses(shuffleId) + epochGotten = epoch + } + } + // If we got here, we failed to find the serialized locations in the cache, so we pulled + // out a snapshot of the locations as "locs"; let's serialize and return that + val bytes = serializeStatuses(statuses) + logInfo("Size of output statuses for shuffle %d is %d bytes".format(shuffleId, bytes.length)) + // Add them into the table only if the epoch hasn't changed while we were working + epochLock.synchronized { + if (epoch == epochGotten) { + cachedSerializedStatuses(shuffleId) = bytes + } + } + return bytes + } + + // Serialize an array of map output locations into an efficient byte format so that we can send + // it to reduce tasks. We do this by compressing the serialized bytes using GZIP. They will + // generally be pretty compressible because many map outputs will be on the same hostname. + private def serializeStatuses(statuses: Array[MapStatus]): Array[Byte] = { + val out = new ByteArrayOutputStream + val objOut = new ObjectOutputStream(new GZIPOutputStream(out)) + // Since statuses can be modified in parallel, sync on it + statuses.synchronized { + objOut.writeObject(statuses) + } + objOut.close() + out.toByteArray + } + + // Opposite of serializeStatuses. + def deserializeStatuses(bytes: Array[Byte]): Array[MapStatus] = { + val objIn = new ObjectInputStream(new GZIPInputStream(new ByteArrayInputStream(bytes))) + objIn.readObject(). + // // drop all null's from status - not sure why they are occuring though. Causes NPE downstream in slave if present + // comment this out - nulls could be due to missing location ? + asInstanceOf[Array[MapStatus]] // .filter( _ != null ) + } +} + +private[spark] object MapOutputTracker { + private val LOG_BASE = 1.1 + + // Convert an array of MapStatuses to locations and sizes for a given reduce ID. If + // any of the statuses is null (indicating a missing location due to a failed mapper), + // throw a FetchFailedException. + private def convertMapStatuses( + shuffleId: Int, + reduceId: Int, + statuses: Array[MapStatus]): Array[(BlockManagerId, Long)] = { + assert (statuses != null) + statuses.map { + status => + if (status == null) { + throw new FetchFailedException(null, shuffleId, -1, reduceId, + new Exception("Missing an output location for shuffle " + shuffleId)) + } else { + (status.location, decompressSize(status.compressedSizes(reduceId))) + } + } + } + + /** + * Compress a size in bytes to 8 bits for efficient reporting of map output sizes. + * We do this by encoding the log base 1.1 of the size as an integer, which can support + * sizes up to 35 GB with at most 10% error. + */ + def compressSize(size: Long): Byte = { + if (size == 0) { + 0 + } else if (size <= 1L) { + 1 + } else { + math.min(255, math.ceil(math.log(size) / math.log(LOG_BASE)).toInt).toByte + } + } + + /** + * Decompress an 8-bit encoded block size, using the reverse operation of compressSize. + */ + def decompressSize(compressedSize: Byte): Long = { + if (compressedSize == 0) { + 0 + } else { + math.pow(LOG_BASE, (compressedSize & 0xFF)).toLong + } + } +} diff --git a/core/src/main/scala/org/apache/spark/Partition.scala b/core/src/main/scala/org/apache/spark/Partition.scala new file mode 100644 index 0000000000..87914a061f --- /dev/null +++ b/core/src/main/scala/org/apache/spark/Partition.scala @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +/** + * A partition of an RDD. + */ +trait Partition extends Serializable { + /** + * Get the split's index within its parent RDD + */ + def index: Int + + // A better default implementation of HashCode + override def hashCode(): Int = index +} diff --git a/core/src/main/scala/org/apache/spark/Partitioner.scala b/core/src/main/scala/org/apache/spark/Partitioner.scala new file mode 100644 index 0000000000..62b608c088 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/Partitioner.scala @@ -0,0 +1,140 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import org.apache.spark.util.Utils +import org.apache.spark.rdd.RDD + +import scala.reflect.ClassTag + +/** + * An object that defines how the elements in a key-value pair RDD are partitioned by key. + * Maps each key to a partition ID, from 0 to `numPartitions - 1`. + */ +abstract class Partitioner extends Serializable { + def numPartitions: Int + def getPartition(key: Any): Int +} + +object Partitioner { + /** + * Choose a partitioner to use for a cogroup-like operation between a number of RDDs. + * + * If any of the RDDs already has a partitioner, choose that one. + * + * Otherwise, we use a default HashPartitioner. For the number of partitions, if + * spark.default.parallelism is set, then we'll use the value from SparkContext + * defaultParallelism, otherwise we'll use the max number of upstream partitions. + * + * Unless spark.default.parallelism is set, He number of partitions will be the + * same as the number of partitions in the largest upstream RDD, as this should + * be least likely to cause out-of-memory errors. + * + * We use two method parameters (rdd, others) to enforce callers passing at least 1 RDD. + */ + def defaultPartitioner(rdd: RDD[_], others: RDD[_]*): Partitioner = { + val bySize = (Seq(rdd) ++ others).sortBy(_.partitions.size).reverse + for (r <- bySize if r.partitioner != None) { + return r.partitioner.get + } + if (System.getProperty("spark.default.parallelism") != null) { + return new HashPartitioner(rdd.context.defaultParallelism) + } else { + return new HashPartitioner(bySize.head.partitions.size) + } + } +} + +/** + * A [[org.apache.spark.Partitioner]] that implements hash-based partitioning using Java's `Object.hashCode`. + * + * Java arrays have hashCodes that are based on the arrays' identities rather than their contents, + * so attempting to partition an RDD[Array[_]] or RDD[(Array[_], _)] using a HashPartitioner will + * produce an unexpected or incorrect result. + */ +class HashPartitioner(partitions: Int) extends Partitioner { + def numPartitions = partitions + + def getPartition(key: Any): Int = key match { + case null => 0 + case _ => Utils.nonNegativeMod(key.hashCode, numPartitions) + } + + override def equals(other: Any): Boolean = other match { + case h: HashPartitioner => + h.numPartitions == numPartitions + case _ => + false + } +} + +/** + * A [[org.apache.spark.Partitioner]] that partitions sortable records by range into roughly equal ranges. + * Determines the ranges by sampling the RDD passed in. + */ +class RangePartitioner[K <% Ordered[K]: ClassTag, V]( + partitions: Int, + @transient rdd: RDD[_ <: Product2[K,V]], + private val ascending: Boolean = true) + extends Partitioner { + + // An array of upper bounds for the first (partitions - 1) partitions + private val rangeBounds: Array[K] = { + if (partitions == 1) { + Array() + } else { + val rddSize = rdd.count() + val maxSampleSize = partitions * 20.0 + val frac = math.min(maxSampleSize / math.max(rddSize, 1), 1.0) + val rddSample = rdd.sample(false, frac, 1).map(_._1).collect().sortWith(_ < _) + if (rddSample.length == 0) { + Array() + } else { + val bounds = new Array[K](partitions - 1) + for (i <- 0 until partitions - 1) { + val index = (rddSample.length - 1) * (i + 1) / partitions + bounds(i) = rddSample(index) + } + bounds + } + } + } + + def numPartitions = partitions + + def getPartition(key: Any): Int = { + // TODO: Use a binary search here if number of partitions is large + val k = key.asInstanceOf[K] + var partition = 0 + while (partition < rangeBounds.length && k > rangeBounds(partition)) { + partition += 1 + } + if (ascending) { + partition + } else { + rangeBounds.length - partition + } + } + + override def equals(other: Any): Boolean = other match { + case r: RangePartitioner[_,_] => + r.rangeBounds.sameElements(rangeBounds) && r.ascending == ascending + case _ => + false + } +} diff --git a/core/src/main/scala/org/apache/spark/SerializableWritable.scala b/core/src/main/scala/org/apache/spark/SerializableWritable.scala new file mode 100644 index 0000000000..fdd4c24e23 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/SerializableWritable.scala @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import java.io._ + +import org.apache.hadoop.io.ObjectWritable +import org.apache.hadoop.io.Writable +import org.apache.hadoop.conf.Configuration + +class SerializableWritable[T <: Writable](@transient var t: T) extends Serializable { + def value = t + override def toString = t.toString + + private def writeObject(out: ObjectOutputStream) { + out.defaultWriteObject() + new ObjectWritable(t).write(out) + } + + private def readObject(in: ObjectInputStream) { + in.defaultReadObject() + val ow = new ObjectWritable() + ow.setConf(new Configuration()) + ow.readFields(in) + t = ow.get().asInstanceOf[T] + } +} diff --git a/core/src/main/scala/org/apache/spark/ShuffleFetcher.scala b/core/src/main/scala/org/apache/spark/ShuffleFetcher.scala new file mode 100644 index 0000000000..307c383a89 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/ShuffleFetcher.scala @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import org.apache.spark.executor.TaskMetrics +import org.apache.spark.serializer.Serializer + + +private[spark] abstract class ShuffleFetcher { + + /** + * Fetch the shuffle outputs for a given ShuffleDependency. + * @return An iterator over the elements of the fetched shuffle outputs. + */ + def fetch[T](shuffleId: Int, reduceId: Int, metrics: TaskMetrics, + serializer: Serializer = SparkEnv.get.serializerManager.default): Iterator[T] + + /** Stop the fetcher */ + def stop() {} +} diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala new file mode 100644 index 0000000000..04d172a989 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -0,0 +1,1000 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import java.io._ +import java.net.URI +import java.util.Properties +import java.util.concurrent.atomic.AtomicInteger + +import scala.collection.Map +import scala.collection.generic.Growable +import scala.collection.JavaConversions._ +import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.HashMap +import scala.reflect.{ ClassTag, classTag} +import scala.util.DynamicVariable + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.Path +import org.apache.hadoop.io.ArrayWritable +import org.apache.hadoop.io.BooleanWritable +import org.apache.hadoop.io.BytesWritable +import org.apache.hadoop.io.DoubleWritable +import org.apache.hadoop.io.FloatWritable +import org.apache.hadoop.io.IntWritable +import org.apache.hadoop.io.LongWritable +import org.apache.hadoop.io.NullWritable +import org.apache.hadoop.io.Text +import org.apache.hadoop.io.Writable +import org.apache.hadoop.mapred.FileInputFormat +import org.apache.hadoop.mapred.InputFormat +import org.apache.hadoop.mapred.JobConf +import org.apache.hadoop.mapred.SequenceFileInputFormat +import org.apache.hadoop.mapred.TextInputFormat +import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat} +import org.apache.hadoop.mapreduce.{Job => NewHadoopJob} +import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat => NewFileInputFormat} + +import org.apache.mesos.MesosNativeLibrary + +import org.apache.spark.deploy.LocalSparkCluster +import org.apache.spark.partial.{ApproximateEvaluator, PartialResult} +import org.apache.spark.rdd._ +import org.apache.spark.scheduler._ +import org.apache.spark.scheduler.cluster.{StandaloneSchedulerBackend, SparkDeploySchedulerBackend, + ClusterScheduler, Schedulable, SchedulingMode} +import org.apache.spark.scheduler.local.LocalScheduler +import org.apache.spark.scheduler.mesos.{CoarseMesosSchedulerBackend, MesosSchedulerBackend} +import org.apache.spark.storage.{StorageUtils, BlockManagerSource} +import org.apache.spark.ui.SparkUI +import org.apache.spark.util.{ClosureCleaner, Utils, MetadataCleaner, TimeStampedHashMap} +import org.apache.spark.scheduler.StageInfo +import org.apache.spark.storage.RDDInfo +import org.apache.spark.storage.StorageStatus + +/** + * Main entry point for Spark functionality. A SparkContext represents the connection to a Spark + * cluster, and can be used to create RDDs, accumulators and broadcast variables on that cluster. + * + * @param master Cluster URL to connect to (e.g. mesos://host:port, spark://host:port, local[4]). + * @param appName A name for your application, to display on the cluster web UI. + * @param sparkHome Location where Spark is installed on cluster nodes. + * @param jars Collection of JARs to send to the cluster. These can be paths on the local file + * system or HDFS, HTTP, HTTPS, or FTP URLs. + * @param environment Environment variables to set on worker nodes. + */ +class SparkContext( + val master: String, + val appName: String, + val sparkHome: String = null, + val jars: Seq[String] = Nil, + val environment: Map[String, String] = Map(), + // This is used only by yarn for now, but should be relevant to other cluster types (mesos, etc) too. + // This is typically generated from InputFormatInfo.computePreferredLocations .. host, set of data-local splits on host + val preferredNodeLocationData: scala.collection.Map[String, scala.collection.Set[SplitInfo]] = scala.collection.immutable.Map()) + extends Logging { + + // Ensure logging is initialized before we spawn any threads + initLogging() + + // Set Spark driver host and port system properties + if (System.getProperty("spark.driver.host") == null) { + System.setProperty("spark.driver.host", Utils.localHostName()) + } + if (System.getProperty("spark.driver.port") == null) { + System.setProperty("spark.driver.port", "0") + } + + val isLocal = (master == "local" || master.startsWith("local[")) + + // Create the Spark execution environment (cache, map output tracker, etc) + private[spark] val env = SparkEnv.createFromSystemProperties( + "<driver>", + System.getProperty("spark.driver.host"), + System.getProperty("spark.driver.port").toInt, + true, + isLocal) + SparkEnv.set(env) + + // Used to store a URL for each static file/jar together with the file's local timestamp + private[spark] val addedFiles = HashMap[String, Long]() + private[spark] val addedJars = HashMap[String, Long]() + + // Keeps track of all persisted RDDs + private[spark] val persistentRdds = new TimeStampedHashMap[Int, RDD[_]] + private[spark] val metadataCleaner = new MetadataCleaner("SparkContext", this.cleanup) + + // Initalize the Spark UI + private[spark] val ui = new SparkUI(this) + ui.bind() + + val startTime = System.currentTimeMillis() + + // Add each JAR given through the constructor + if (jars != null) { + jars.foreach { addJar(_) } + } + + // Environment variables to pass to our executors + private[spark] val executorEnvs = HashMap[String, String]() + // Note: SPARK_MEM is included for Mesos, but overwritten for standalone mode in ExecutorRunner + for (key <- Seq("SPARK_CLASSPATH", "SPARK_LIBRARY_PATH", "SPARK_JAVA_OPTS", "SPARK_TESTING")) { + val value = System.getenv(key) + if (value != null) { + executorEnvs(key) = value + } + } + // Since memory can be set with a system property too, use that + executorEnvs("SPARK_MEM") = SparkContext.executorMemoryRequested + "m" + if (environment != null) { + executorEnvs ++= environment + } + + // Create and start the scheduler + private var taskScheduler: TaskScheduler = { + // Regular expression used for local[N] master format + val LOCAL_N_REGEX = """local\[([0-9]+)\]""".r + // Regular expression for local[N, maxRetries], used in tests with failing tasks + val LOCAL_N_FAILURES_REGEX = """local\[([0-9]+)\s*,\s*([0-9]+)\]""".r + // Regular expression for simulating a Spark cluster of [N, cores, memory] locally + val LOCAL_CLUSTER_REGEX = """local-cluster\[\s*([0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*]""".r + // Regular expression for connecting to Spark deploy clusters + val SPARK_REGEX = """(spark://.*)""".r + //Regular expression for connection to Mesos cluster + val MESOS_REGEX = """(mesos://.*)""".r + + master match { + case "local" => + new LocalScheduler(1, 0, this) + + case LOCAL_N_REGEX(threads) => + new LocalScheduler(threads.toInt, 0, this) + + case LOCAL_N_FAILURES_REGEX(threads, maxFailures) => + new LocalScheduler(threads.toInt, maxFailures.toInt, this) + + case SPARK_REGEX(sparkUrl) => + val scheduler = new ClusterScheduler(this) + val backend = new SparkDeploySchedulerBackend(scheduler, this, sparkUrl, appName) + scheduler.initialize(backend) + scheduler + + case LOCAL_CLUSTER_REGEX(numSlaves, coresPerSlave, memoryPerSlave) => + // Check to make sure memory requested <= memoryPerSlave. Otherwise Spark will just hang. + val memoryPerSlaveInt = memoryPerSlave.toInt + if (SparkContext.executorMemoryRequested > memoryPerSlaveInt) { + throw new SparkException( + "Asked to launch cluster with %d MB RAM / worker but requested %d MB/worker".format( + memoryPerSlaveInt, SparkContext.executorMemoryRequested)) + } + + val scheduler = new ClusterScheduler(this) + val localCluster = new LocalSparkCluster( + numSlaves.toInt, coresPerSlave.toInt, memoryPerSlaveInt) + val sparkUrl = localCluster.start() + val backend = new SparkDeploySchedulerBackend(scheduler, this, sparkUrl, appName) + scheduler.initialize(backend) + backend.shutdownCallback = (backend: SparkDeploySchedulerBackend) => { + localCluster.stop() + } + scheduler + + case "yarn-standalone" => + val scheduler = try { + val clazz = Class.forName("org.apache.spark.scheduler.cluster.YarnClusterScheduler") + val cons = clazz.getConstructor(classOf[SparkContext]) + cons.newInstance(this).asInstanceOf[ClusterScheduler] + } catch { + // TODO: Enumerate the exact reasons why it can fail + // But irrespective of it, it means we cannot proceed ! + case th: Throwable => { + throw new SparkException("YARN mode not available ?", th) + } + } + val backend = new StandaloneSchedulerBackend(scheduler, this.env.actorSystem) + scheduler.initialize(backend) + scheduler + + case _ => + if (MESOS_REGEX.findFirstIn(master).isEmpty) { + logWarning("Master %s does not match expected format, parsing as Mesos URL".format(master)) + } + MesosNativeLibrary.load() + val scheduler = new ClusterScheduler(this) + val coarseGrained = System.getProperty("spark.mesos.coarse", "false").toBoolean + val masterWithoutProtocol = master.replaceFirst("^mesos://", "") // Strip initial mesos:// + val backend = if (coarseGrained) { + new CoarseMesosSchedulerBackend(scheduler, this, masterWithoutProtocol, appName) + } else { + new MesosSchedulerBackend(scheduler, this, masterWithoutProtocol, appName) + } + scheduler.initialize(backend) + scheduler + } + } + taskScheduler.start() + + @volatile private var dagScheduler = new DAGScheduler(taskScheduler) + dagScheduler.start() + + ui.start() + + /** A default Hadoop Configuration for the Hadoop code (e.g. file systems) that we reuse. */ + val hadoopConfiguration = { + val env = SparkEnv.get + val conf = env.hadoop.newConfiguration() + // Explicitly check for S3 environment variables + if (System.getenv("AWS_ACCESS_KEY_ID") != null && System.getenv("AWS_SECRET_ACCESS_KEY") != null) { + conf.set("fs.s3.awsAccessKeyId", System.getenv("AWS_ACCESS_KEY_ID")) + conf.set("fs.s3n.awsAccessKeyId", System.getenv("AWS_ACCESS_KEY_ID")) + conf.set("fs.s3.awsSecretAccessKey", System.getenv("AWS_SECRET_ACCESS_KEY")) + conf.set("fs.s3n.awsSecretAccessKey", System.getenv("AWS_SECRET_ACCESS_KEY")) + } + // Copy any "spark.hadoop.foo=bar" system properties into conf as "foo=bar" + for (key <- System.getProperties.toMap[String, String].keys if key.startsWith("spark.hadoop.")) { + conf.set(key.substring("spark.hadoop.".length), System.getProperty(key)) + } + val bufferSize = System.getProperty("spark.buffer.size", "65536") + conf.set("io.file.buffer.size", bufferSize) + conf + } + + private[spark] var checkpointDir: Option[String] = None + + // Thread Local variable that can be used by users to pass information down the stack + private val localProperties = new DynamicVariable[Properties](null) + + def initLocalProperties() { + localProperties.value = new Properties() + } + + def setLocalProperty(key: String, value: String) { + if (localProperties.value == null) { + localProperties.value = new Properties() + } + if (value == null) { + localProperties.value.remove(key) + } else { + localProperties.value.setProperty(key, value) + } + } + + /** Set a human readable description of the current job. */ + def setJobDescription(value: String) { + setLocalProperty(SparkContext.SPARK_JOB_DESCRIPTION, value) + } + + // Post init + taskScheduler.postStartHook() + + val dagSchedulerSource = new DAGSchedulerSource(this.dagScheduler) + val blockManagerSource = new BlockManagerSource(SparkEnv.get.blockManager) + + def initDriverMetrics() { + SparkEnv.get.metricsSystem.registerSource(dagSchedulerSource) + SparkEnv.get.metricsSystem.registerSource(blockManagerSource) + } + + initDriverMetrics() + + // Methods for creating RDDs + + /** Distribute a local Scala collection to form an RDD. */ + def parallelize[T: ClassTag](seq: Seq[T], numSlices: Int = defaultParallelism): RDD[T] = { + new ParallelCollectionRDD[T](this, seq, numSlices, Map[Int, Seq[String]]()) + } + + /** Distribute a local Scala collection to form an RDD. */ + def makeRDD[T: ClassTag](seq: Seq[T], numSlices: Int = defaultParallelism): RDD[T] = { + parallelize(seq, numSlices) + } + + /** Distribute a local Scala collection to form an RDD, with one or more + * location preferences (hostnames of Spark nodes) for each object. + * Create a new partition for each collection item. */ + def makeRDD[T: ClassTag](seq: Seq[(T, Seq[String])]): RDD[T] = { + val indexToPrefs = seq.zipWithIndex.map(t => (t._2, t._1._2)).toMap + new ParallelCollectionRDD[T](this, seq.map(_._1), seq.size, indexToPrefs) + } + + /** + * Read a text file from HDFS, a local file system (available on all nodes), or any + * Hadoop-supported file system URI, and return it as an RDD of Strings. + */ + def textFile(path: String, minSplits: Int = defaultMinSplits): RDD[String] = { + hadoopFile(path, classOf[TextInputFormat], classOf[LongWritable], classOf[Text], minSplits) + .map(pair => pair._2.toString) + } + + /** + * Get an RDD for a Hadoop-readable dataset from a Hadoop JobConf giving its InputFormat and any + * other necessary info (e.g. file name for a filesystem-based dataset, table name for HyperTable, + * etc). + */ + def hadoopRDD[K, V]( + conf: JobConf, + inputFormatClass: Class[_ <: InputFormat[K, V]], + keyClass: Class[K], + valueClass: Class[V], + minSplits: Int = defaultMinSplits + ): RDD[(K, V)] = { + new HadoopRDD(this, conf, inputFormatClass, keyClass, valueClass, minSplits) + } + + /** Get an RDD for a Hadoop file with an arbitrary InputFormat */ + def hadoopFile[K, V]( + path: String, + inputFormatClass: Class[_ <: InputFormat[K, V]], + keyClass: Class[K], + valueClass: Class[V], + minSplits: Int = defaultMinSplits + ) : RDD[(K, V)] = { + val conf = new JobConf(hadoopConfiguration) + FileInputFormat.setInputPaths(conf, path) + new HadoopRDD(this, conf, inputFormatClass, keyClass, valueClass, minSplits) + } + + /** + * Smarter version of hadoopFile() that uses class tags to figure out the classes of keys, + * values and the InputFormat so that users don't need to pass them directly. Instead, callers + * can just write, for example, + * {{{ + * val file = sparkContext.hadoopFile[LongWritable, Text, TextInputFormat](path, minSplits) + * }}} + */ + def hadoopFile[K, V, F <: InputFormat[K, V]](path: String, minSplits: Int) + (implicit km: ClassTag[K], vm: ClassTag[V], fm: ClassTag[F]) + : RDD[(K, V)] = { + hadoopFile(path, + fm.runtimeClass.asInstanceOf[Class[F]], + km.runtimeClass.asInstanceOf[Class[K]], + vm.runtimeClass.asInstanceOf[Class[V]], + minSplits) + } + + /** + * Smarter version of hadoopFile() that uses class tags to figure out the classes of keys, + * values and the InputFormat so that users don't need to pass them directly. Instead, callers + * can just write, for example, + * {{{ + * val file = sparkContext.hadoopFile[LongWritable, Text, TextInputFormat](path) + * }}} + */ + def hadoopFile[K, V, F <: InputFormat[K, V]](path: String) + (implicit km: ClassTag[K], vm: ClassTag[V], fm: ClassTag[F]): RDD[(K, V)] = + hadoopFile[K, V, F](path, defaultMinSplits) + + /** Get an RDD for a Hadoop file with an arbitrary new API InputFormat. */ + def newAPIHadoopFile[K, V, F <: NewInputFormat[K, V]](path: String) + (implicit km: ClassTag[K], vm: ClassTag[V], fm: ClassTag[F]): RDD[(K, V)] = { + newAPIHadoopFile( + path, + fm.runtimeClass.asInstanceOf[Class[F]], + km.runtimeClass.asInstanceOf[Class[K]], + vm.runtimeClass.asInstanceOf[Class[V]]) + } + + /** + * Get an RDD for a given Hadoop file with an arbitrary new API InputFormat + * and extra configuration options to pass to the input format. + */ + def newAPIHadoopFile[K, V, F <: NewInputFormat[K, V]]( + path: String, + fClass: Class[F], + kClass: Class[K], + vClass: Class[V], + conf: Configuration = hadoopConfiguration): RDD[(K, V)] = { + val job = new NewHadoopJob(conf) + NewFileInputFormat.addInputPath(job, new Path(path)) + val updatedConf = job.getConfiguration + new NewHadoopRDD(this, fClass, kClass, vClass, updatedConf) + } + + /** + * Get an RDD for a given Hadoop file with an arbitrary new API InputFormat + * and extra configuration options to pass to the input format. + */ + def newAPIHadoopRDD[K, V, F <: NewInputFormat[K, V]]( + conf: Configuration = hadoopConfiguration, + fClass: Class[F], + kClass: Class[K], + vClass: Class[V]): RDD[(K, V)] = { + new NewHadoopRDD(this, fClass, kClass, vClass, conf) + } + + /** Get an RDD for a Hadoop SequenceFile with given key and value types. */ + def sequenceFile[K, V](path: String, + keyClass: Class[K], + valueClass: Class[V], + minSplits: Int + ): RDD[(K, V)] = { + val inputFormatClass = classOf[SequenceFileInputFormat[K, V]] + hadoopFile(path, inputFormatClass, keyClass, valueClass, minSplits) + } + + /** Get an RDD for a Hadoop SequenceFile with given key and value types. */ + def sequenceFile[K, V](path: String, keyClass: Class[K], valueClass: Class[V]): RDD[(K, V)] = + sequenceFile(path, keyClass, valueClass, defaultMinSplits) + + /** + * Version of sequenceFile() for types implicitly convertible to Writables through a + * WritableConverter. For example, to access a SequenceFile where the keys are Text and the + * values are IntWritable, you could simply write + * {{{ + * sparkContext.sequenceFile[String, Int](path, ...) + * }}} + * + * WritableConverters are provided in a somewhat strange way (by an implicit function) to support + * both subclasses of Writable and types for which we define a converter (e.g. Int to + * IntWritable). The most natural thing would've been to have implicit objects for the + * converters, but then we couldn't have an object for every subclass of Writable (you can't + * have a parameterized singleton object). We use functions instead to create a new converter + * for the appropriate type. In addition, we pass the converter a ClassTag of its type to + * allow it to figure out the Writable class to use in the subclass case. + */ + def sequenceFile[K, V](path: String, minSplits: Int = defaultMinSplits) + (implicit km: ClassTag[K], vm: ClassTag[V], + kcf: () => WritableConverter[K], vcf: () => WritableConverter[V]) + : RDD[(K, V)] = { + val kc = kcf() + val vc = vcf() + val format = classOf[SequenceFileInputFormat[Writable, Writable]] + val writables = hadoopFile(path, format, + kc.writableClass(km).asInstanceOf[Class[Writable]], + vc.writableClass(vm).asInstanceOf[Class[Writable]], minSplits) + writables.map{case (k,v) => (kc.convert(k), vc.convert(v))} + } + + /** + * Load an RDD saved as a SequenceFile containing serialized objects, with NullWritable keys and + * BytesWritable values that contain a serialized partition. This is still an experimental storage + * format and may not be supported exactly as is in future Spark releases. It will also be pretty + * slow if you use the default serializer (Java serialization), though the nice thing about it is + * that there's very little effort required to save arbitrary objects. + */ + def objectFile[T: ClassTag]( + path: String, + minSplits: Int = defaultMinSplits + ): RDD[T] = { + sequenceFile(path, classOf[NullWritable], classOf[BytesWritable], minSplits) + .flatMap(x => Utils.deserialize[Array[T]](x._2.getBytes)) + } + + + protected[spark] def checkpointFile[T: ClassTag]( + path: String + ): RDD[T] = { + new CheckpointRDD[T](this, path) + } + + /** Build the union of a list of RDDs. */ + def union[T: ClassTag](rdds: Seq[RDD[T]]): RDD[T] = new UnionRDD(this, rdds) + + /** Build the union of a list of RDDs passed as variable-length arguments. */ + def union[T: ClassTag](first: RDD[T], rest: RDD[T]*): RDD[T] = + new UnionRDD(this, Seq(first) ++ rest) + + // Methods for creating shared variables + + /** + * Create an [[org.apache.spark.Accumulator]] variable of a given type, which tasks can "add" values + * to using the `+=` method. Only the driver can access the accumulator's `value`. + */ + def accumulator[T](initialValue: T)(implicit param: AccumulatorParam[T]) = + new Accumulator(initialValue, param) + + /** + * Create an [[org.apache.spark.Accumulable]] shared variable, to which tasks can add values with `+=`. + * Only the driver can access the accumuable's `value`. + * @tparam T accumulator type + * @tparam R type that can be added to the accumulator + */ + def accumulable[T, R](initialValue: T)(implicit param: AccumulableParam[T, R]) = + new Accumulable(initialValue, param) + + /** + * Create an accumulator from a "mutable collection" type. + * + * Growable and TraversableOnce are the standard APIs that guarantee += and ++=, implemented by + * standard mutable collections. So you can use this with mutable Map, Set, etc. + */ + def accumulableCollection[R <% Growable[T] with TraversableOnce[T] with Serializable, T](initialValue: R) = { + val param = new GrowableAccumulableParam[R,T] + new Accumulable(initialValue, param) + } + + /** + * Broadcast a read-only variable to the cluster, returning a [[org.apache.spark.broadcast.Broadcast]] object for + * reading it in distributed functions. The variable will be sent to each cluster only once. + */ + def broadcast[T](value: T) = env.broadcastManager.newBroadcast[T](value, isLocal) + + /** + * Add a file to be downloaded with this Spark job on every node. + * The `path` passed can be either a local file, a file in HDFS (or other Hadoop-supported + * filesystems), or an HTTP, HTTPS or FTP URI. To access the file in Spark jobs, + * use `SparkFiles.get(path)` to find its download location. + */ + def addFile(path: String) { + val uri = new URI(path) + val key = uri.getScheme match { + case null | "file" => env.httpFileServer.addFile(new File(uri.getPath)) + case _ => path + } + addedFiles(key) = System.currentTimeMillis + + // Fetch the file locally in case a job is executed locally. + // Jobs that run through LocalScheduler will already fetch the required dependencies, + // but jobs run in DAGScheduler.runLocally() will not so we must fetch the files here. + Utils.fetchFile(path, new File(SparkFiles.getRootDirectory)) + + logInfo("Added file " + path + " at " + key + " with timestamp " + addedFiles(key)) + } + + def addSparkListener(listener: SparkListener) { + dagScheduler.addSparkListener(listener) + } + + /** + * Return a map from the slave to the max memory available for caching and the remaining + * memory available for caching. + */ + def getExecutorMemoryStatus: Map[String, (Long, Long)] = { + env.blockManager.master.getMemoryStatus.map { case(blockManagerId, mem) => + (blockManagerId.host + ":" + blockManagerId.port, mem) + } + } + + /** + * Return information about what RDDs are cached, if they are in mem or on disk, how much space + * they take, etc. + */ + def getRDDStorageInfo: Array[RDDInfo] = { + StorageUtils.rddInfoFromStorageStatus(getExecutorStorageStatus, this) + } + + /** + * Returns an immutable map of RDDs that have marked themselves as persistent via cache() call. + * Note that this does not necessarily mean the caching or computation was successful. + */ + def getPersistentRDDs: Map[Int, RDD[_]] = persistentRdds.toMap + + def getStageInfo: Map[Stage,StageInfo] = { + dagScheduler.stageToInfos + } + + /** + * Return information about blocks stored in all of the slaves + */ + def getExecutorStorageStatus: Array[StorageStatus] = { + env.blockManager.master.getStorageStatus + } + + /** + * Return pools for fair scheduler + * TODO(xiajunluan): We should take nested pools into account + */ + def getAllPools: ArrayBuffer[Schedulable] = { + taskScheduler.rootPool.schedulableQueue + } + + /** + * Return the pool associated with the given name, if one exists + */ + def getPoolForName(pool: String): Option[Schedulable] = { + taskScheduler.rootPool.schedulableNameToSchedulable.get(pool) + } + + /** + * Return current scheduling mode + */ + def getSchedulingMode: SchedulingMode.SchedulingMode = { + taskScheduler.schedulingMode + } + + /** + * Clear the job's list of files added by `addFile` so that they do not get downloaded to + * any new nodes. + */ + def clearFiles() { + addedFiles.clear() + } + + /** + * Gets the locality information associated with the partition in a particular rdd + * @param rdd of interest + * @param partition to be looked up for locality + * @return list of preferred locations for the partition + */ + private [spark] def getPreferredLocs(rdd: RDD[_], partition: Int): Seq[TaskLocation] = { + dagScheduler.getPreferredLocs(rdd, partition) + } + + /** + * Adds a JAR dependency for all tasks to be executed on this SparkContext in the future. + * The `path` passed can be either a local file, a file in HDFS (or other Hadoop-supported + * filesystems), or an HTTP, HTTPS or FTP URI. + */ + def addJar(path: String) { + if (path == null) { + logWarning("null specified as parameter to addJar", + new SparkException("null specified as parameter to addJar")) + } else { + var key = "" + if (path.contains("\\")) { + // For local paths with backslashes on Windows, URI throws an exception + key = env.httpFileServer.addJar(new File(path)) + } else { + val uri = new URI(path) + key = uri.getScheme match { + case null | "file" => + if (env.hadoop.isYarnMode()) { + logWarning("local jar specified as parameter to addJar under Yarn mode") + return + } + env.httpFileServer.addJar(new File(uri.getPath)) + case _ => + path + } + } + addedJars(key) = System.currentTimeMillis + logInfo("Added JAR " + path + " at " + key + " with timestamp " + addedJars(key)) + } + } + + /** + * Clear the job's list of JARs added by `addJar` so that they do not get downloaded to + * any new nodes. + */ + def clearJars() { + addedJars.clear() + } + + /** Shut down the SparkContext. */ + def stop() { + ui.stop() + // Do this only if not stopped already - best case effort. + // prevent NPE if stopped more than once. + val dagSchedulerCopy = dagScheduler + dagScheduler = null + if (dagSchedulerCopy != null) { + metadataCleaner.cancel() + dagSchedulerCopy.stop() + taskScheduler = null + // TODO: Cache.stop()? + env.stop() + // Clean up locally linked files + clearFiles() + clearJars() + SparkEnv.set(null) + ShuffleMapTask.clearCache() + ResultTask.clearCache() + logInfo("Successfully stopped SparkContext") + } else { + logInfo("SparkContext already stopped") + } + } + + + /** + * Get Spark's home location from either a value set through the constructor, + * or the spark.home Java property, or the SPARK_HOME environment variable + * (in that order of preference). If neither of these is set, return None. + */ + private[spark] def getSparkHome(): Option[String] = { + if (sparkHome != null) { + Some(sparkHome) + } else if (System.getProperty("spark.home") != null) { + Some(System.getProperty("spark.home")) + } else if (System.getenv("SPARK_HOME") != null) { + Some(System.getenv("SPARK_HOME")) + } else { + None + } + } + + /** + * Run a function on a given set of partitions in an RDD and pass the results to the given + * handler function. This is the main entry point for all actions in Spark. The allowLocal + * flag specifies whether the scheduler can run the computation on the driver rather than + * shipping it out to the cluster, for short actions like first(). + */ + def runJob[T, U: ClassTag]( + rdd: RDD[T], + func: (TaskContext, Iterator[T]) => U, + partitions: Seq[Int], + allowLocal: Boolean, + resultHandler: (Int, U) => Unit) { + val callSite = Utils.formatSparkCallSite + logInfo("Starting job: " + callSite) + val start = System.nanoTime + val result = dagScheduler.runJob(rdd, func, partitions, callSite, allowLocal, resultHandler, localProperties.value) + logInfo("Job finished: " + callSite + ", took " + (System.nanoTime - start) / 1e9 + " s") + rdd.doCheckpoint() + result + } + + /** + * Run a function on a given set of partitions in an RDD and return the results as an array. The + * allowLocal flag specifies whether the scheduler can run the computation on the driver rather + * than shipping it out to the cluster, for short actions like first(). + */ + def runJob[T, U: ClassTag]( + rdd: RDD[T], + func: (TaskContext, Iterator[T]) => U, + partitions: Seq[Int], + allowLocal: Boolean + ): Array[U] = { + val results = new Array[U](partitions.size) + runJob[T, U](rdd, func, partitions, allowLocal, (index, res) => results(index) = res) + results + } + + /** + * Run a job on a given set of partitions of an RDD, but take a function of type + * `Iterator[T] => U` instead of `(TaskContext, Iterator[T]) => U`. + */ + def runJob[T, U: ClassTag]( + rdd: RDD[T], + func: Iterator[T] => U, + partitions: Seq[Int], + allowLocal: Boolean + ): Array[U] = { + runJob(rdd, (context: TaskContext, iter: Iterator[T]) => func(iter), partitions, allowLocal) + } + + /** + * Run a job on all partitions in an RDD and return the results in an array. + */ + def runJob[T, U: ClassTag](rdd: RDD[T], func: (TaskContext, Iterator[T]) => U): Array[U] = { + runJob(rdd, func, 0 until rdd.partitions.size, false) + } + + /** + * Run a job on all partitions in an RDD and return the results in an array. + */ + def runJob[T, U: ClassTag](rdd: RDD[T], func: Iterator[T] => U): Array[U] = { + runJob(rdd, func, 0 until rdd.partitions.size, false) + } + + /** + * Run a job on all partitions in an RDD and pass the results to a handler function. + */ + def runJob[T, U: ClassTag]( + rdd: RDD[T], + processPartition: (TaskContext, Iterator[T]) => U, + resultHandler: (Int, U) => Unit) + { + runJob[T, U](rdd, processPartition, 0 until rdd.partitions.size, false, resultHandler) + } + + /** + * Run a job on all partitions in an RDD and pass the results to a handler function. + */ + def runJob[T, U: ClassTag]( + rdd: RDD[T], + processPartition: Iterator[T] => U, + resultHandler: (Int, U) => Unit) + { + val processFunc = (context: TaskContext, iter: Iterator[T]) => processPartition(iter) + runJob[T, U](rdd, processFunc, 0 until rdd.partitions.size, false, resultHandler) + } + + /** + * Run a job that can return approximate results. + */ + def runApproximateJob[T, U, R]( + rdd: RDD[T], + func: (TaskContext, Iterator[T]) => U, + evaluator: ApproximateEvaluator[U, R], + timeout: Long): PartialResult[R] = { + val callSite = Utils.formatSparkCallSite + logInfo("Starting job: " + callSite) + val start = System.nanoTime + val result = dagScheduler.runApproximateJob(rdd, func, evaluator, callSite, timeout, localProperties.value) + logInfo("Job finished: " + callSite + ", took " + (System.nanoTime - start) / 1e9 + " s") + result + } + + /** + * Clean a closure to make it ready to serialized and send to tasks + * (removes unreferenced variables in $outer's, updates REPL variables) + */ + private[spark] def clean[F <: AnyRef](f: F): F = { + ClosureCleaner.clean(f) + return f + } + + /** + * Set the directory under which RDDs are going to be checkpointed. The directory must + * be a HDFS path if running on a cluster. If the directory does not exist, it will + * be created. If the directory exists and useExisting is set to true, then the + * exisiting directory will be used. Otherwise an exception will be thrown to + * prevent accidental overriding of checkpoint files in the existing directory. + */ + def setCheckpointDir(dir: String, useExisting: Boolean = false) { + val env = SparkEnv.get + val path = new Path(dir) + val fs = path.getFileSystem(env.hadoop.newConfiguration()) + if (!useExisting) { + if (fs.exists(path)) { + throw new Exception("Checkpoint directory '" + path + "' already exists.") + } else { + fs.mkdirs(path) + } + } + checkpointDir = Some(dir) + } + + /** Default level of parallelism to use when not given by user (e.g. parallelize and makeRDD). */ + def defaultParallelism: Int = taskScheduler.defaultParallelism + + /** Default min number of partitions for Hadoop RDDs when not given by user */ + def defaultMinSplits: Int = math.min(defaultParallelism, 2) + + private val nextShuffleId = new AtomicInteger(0) + + private[spark] def newShuffleId(): Int = nextShuffleId.getAndIncrement() + + private val nextRddId = new AtomicInteger(0) + + /** Register a new RDD, returning its RDD ID */ + private[spark] def newRddId(): Int = nextRddId.getAndIncrement() + + /** Called by MetadataCleaner to clean up the persistentRdds map periodically */ + private[spark] def cleanup(cleanupTime: Long) { + persistentRdds.clearOldValues(cleanupTime) + } +} + +/** + * The SparkContext object contains a number of implicit conversions and parameters for use with + * various Spark features. + */ +object SparkContext { + val SPARK_JOB_DESCRIPTION = "spark.job.description" + + implicit object DoubleAccumulatorParam extends AccumulatorParam[Double] { + def addInPlace(t1: Double, t2: Double): Double = t1 + t2 + def zero(initialValue: Double) = 0.0 + } + + implicit object IntAccumulatorParam extends AccumulatorParam[Int] { + def addInPlace(t1: Int, t2: Int): Int = t1 + t2 + def zero(initialValue: Int) = 0 + } + + implicit object LongAccumulatorParam extends AccumulatorParam[Long] { + def addInPlace(t1: Long, t2: Long) = t1 + t2 + def zero(initialValue: Long) = 0l + } + + implicit object FloatAccumulatorParam extends AccumulatorParam[Float] { + def addInPlace(t1: Float, t2: Float) = t1 + t2 + def zero(initialValue: Float) = 0f + } + + // TODO: Add AccumulatorParams for other types, e.g. lists and strings + + implicit def rddToPairRDDFunctions[K: ClassTag, V: ClassTag](rdd: RDD[(K, V)]) = + new PairRDDFunctions(rdd) + + implicit def rddToSequenceFileRDDFunctions[K <% Writable: ClassTag, V <% Writable: ClassTag]( + rdd: RDD[(K, V)]) = + new SequenceFileRDDFunctions(rdd) + + implicit def rddToOrderedRDDFunctions[K <% Ordered[K]: ClassTag, V: ClassTag]( + rdd: RDD[(K, V)]) = + new OrderedRDDFunctions[K, V, (K, V)](rdd) + + implicit def doubleRDDToDoubleRDDFunctions(rdd: RDD[Double]) = new DoubleRDDFunctions(rdd) + + implicit def numericRDDToDoubleRDDFunctions[T](rdd: RDD[T])(implicit num: Numeric[T]) = + new DoubleRDDFunctions(rdd.map(x => num.toDouble(x))) + + // Implicit conversions to common Writable types, for saveAsSequenceFile + + implicit def intToIntWritable(i: Int) = new IntWritable(i) + + implicit def longToLongWritable(l: Long) = new LongWritable(l) + + implicit def floatToFloatWritable(f: Float) = new FloatWritable(f) + + implicit def doubleToDoubleWritable(d: Double) = new DoubleWritable(d) + + implicit def boolToBoolWritable (b: Boolean) = new BooleanWritable(b) + + implicit def bytesToBytesWritable (aob: Array[Byte]) = new BytesWritable(aob) + + implicit def stringToText(s: String) = new Text(s) + + private implicit def arrayToArrayWritable[T <% Writable: ClassTag](arr: Traversable[T]): ArrayWritable = { + def anyToWritable[U <% Writable](u: U): Writable = u + + new ArrayWritable(classTag[T].runtimeClass.asInstanceOf[Class[Writable]], + arr.map(x => anyToWritable(x)).toArray) + } + + // Helper objects for converting common types to Writable + private def simpleWritableConverter[T, W <: Writable: ClassTag](convert: W => T) = { + val wClass = classTag[W].runtimeClass.asInstanceOf[Class[W]] + new WritableConverter[T](_ => wClass, x => convert(x.asInstanceOf[W])) + } + + implicit def intWritableConverter() = simpleWritableConverter[Int, IntWritable](_.get) + + implicit def longWritableConverter() = simpleWritableConverter[Long, LongWritable](_.get) + + implicit def doubleWritableConverter() = simpleWritableConverter[Double, DoubleWritable](_.get) + + implicit def floatWritableConverter() = simpleWritableConverter[Float, FloatWritable](_.get) + + implicit def booleanWritableConverter() = simpleWritableConverter[Boolean, BooleanWritable](_.get) + + implicit def bytesWritableConverter() = simpleWritableConverter[Array[Byte], BytesWritable](_.getBytes) + + implicit def stringWritableConverter() = simpleWritableConverter[String, Text](_.toString) + + implicit def writableWritableConverter[T <: Writable]() = + new WritableConverter[T](_.runtimeClass.asInstanceOf[Class[T]], _.asInstanceOf[T]) + + /** + * Find the JAR from which a given class was loaded, to make it easy for users to pass + * their JARs to SparkContext + */ + def jarOfClass(cls: Class[_]): Seq[String] = { + val uri = cls.getResource("/" + cls.getName.replace('.', '/') + ".class") + if (uri != null) { + val uriStr = uri.toString + if (uriStr.startsWith("jar:file:")) { + // URI will be of the form "jar:file:/path/foo.jar!/package/cls.class", so pull out the /path/foo.jar + List(uriStr.substring("jar:file:".length, uriStr.indexOf('!'))) + } else { + Nil + } + } else { + Nil + } + } + + /** Find the JAR that contains the class of a particular object */ + def jarOfObject(obj: AnyRef): Seq[String] = jarOfClass(obj.getClass) + + /** Get the amount of memory per executor requested through system properties or SPARK_MEM */ + private[spark] val executorMemoryRequested = { + // TODO: Might need to add some extra memory for the non-heap parts of the JVM + Option(System.getProperty("spark.executor.memory")) + .orElse(Option(System.getenv("SPARK_MEM"))) + .map(Utils.memoryStringToMb) + .getOrElse(512) + } +} + +/** + * A class encapsulating how to convert some type T to Writable. It stores both the Writable class + * corresponding to T (e.g. IntWritable for Int) and a function for doing the conversion. + * The getter for the writable class takes a ClassTag[T] in case this is a generic object + * that doesn't know the type of T when it is created. This sounds strange but is necessary to + * support converting subclasses of Writable to themselves (writableWritableConverter). + */ +private[spark] class WritableConverter[T]( + val writableClass: ClassTag[T] => Class[_ <: Writable], + val convert: Writable => T) + extends Serializable + diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala new file mode 100644 index 0000000000..1e63b54b7a --- /dev/null +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -0,0 +1,241 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import collection.mutable +import serializer.Serializer + +import akka.actor.{Actor, ActorRef, Props, ActorSystemImpl, ActorSystem} +import akka.remote.RemoteActorRefProvider + +import org.apache.spark.broadcast.BroadcastManager +import org.apache.spark.metrics.MetricsSystem +import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.storage.{BlockManagerMasterActor, BlockManager, BlockManagerMaster} +import org.apache.spark.network.ConnectionManager +import org.apache.spark.serializer.{Serializer, SerializerManager} +import org.apache.spark.util.{Utils, AkkaUtils} +import org.apache.spark.api.python.PythonWorkerFactory + + +/** + * Holds all the runtime environment objects for a running Spark instance (either master or worker), + * including the serializer, Akka actor system, block manager, map output tracker, etc. Currently + * Spark code finds the SparkEnv through a thread-local variable, so each thread that accesses these + * objects needs to have the right SparkEnv set. You can get the current environment with + * SparkEnv.get (e.g. after creating a SparkContext) and set it with SparkEnv.set. + */ +class SparkEnv ( + val executorId: String, + val actorSystem: ActorSystem, + val serializerManager: SerializerManager, + val serializer: Serializer, + val closureSerializer: Serializer, + val cacheManager: CacheManager, + val mapOutputTracker: MapOutputTracker, + val shuffleFetcher: ShuffleFetcher, + val broadcastManager: BroadcastManager, + val blockManager: BlockManager, + val connectionManager: ConnectionManager, + val httpFileServer: HttpFileServer, + val sparkFilesDir: String, + val metricsSystem: MetricsSystem) { + + private val pythonWorkers = mutable.HashMap[(String, Map[String, String]), PythonWorkerFactory]() + + val hadoop = { + val yarnMode = java.lang.Boolean.valueOf(System.getProperty("SPARK_YARN_MODE", System.getenv("SPARK_YARN_MODE"))) + if(yarnMode) { + try { + Class.forName("org.apache.spark.deploy.yarn.YarnSparkHadoopUtil").newInstance.asInstanceOf[SparkHadoopUtil] + } catch { + case th: Throwable => throw new SparkException("Unable to load YARN support", th) + } + } else { + new SparkHadoopUtil + } + } + + def stop() { + pythonWorkers.foreach { case(key, worker) => worker.stop() } + httpFileServer.stop() + mapOutputTracker.stop() + shuffleFetcher.stop() + broadcastManager.stop() + blockManager.stop() + blockManager.master.stop() + metricsSystem.stop() + actorSystem.shutdown() + // Unfortunately Akka's awaitTermination doesn't actually wait for the Netty server to shut + // down, but let's call it anyway in case it gets fixed in a later release + // UPDATE: In Akka 2.1.x, this hangs if there are remote actors, so we can't call it. + //actorSystem.awaitTermination() + } + + def createPythonWorker(pythonExec: String, envVars: Map[String, String]): java.net.Socket = { + synchronized { + val key = (pythonExec, envVars) + pythonWorkers.getOrElseUpdate(key, new PythonWorkerFactory(pythonExec, envVars)).create() + } + } +} + +object SparkEnv extends Logging { + private val env = new ThreadLocal[SparkEnv] + @volatile private var lastSetSparkEnv : SparkEnv = _ + + def set(e: SparkEnv) { + lastSetSparkEnv = e + env.set(e) + } + + /** + * Returns the ThreadLocal SparkEnv, if non-null. Else returns the SparkEnv + * previously set in any thread. + */ + def get: SparkEnv = { + Option(env.get()).getOrElse(lastSetSparkEnv) + } + + /** + * Returns the ThreadLocal SparkEnv. + */ + def getThreadLocal : SparkEnv = { + env.get() + } + + def createFromSystemProperties( + executorId: String, + hostname: String, + port: Int, + isDriver: Boolean, + isLocal: Boolean): SparkEnv = { + + val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, port) + + // Bit of a hack: If this is the driver and our port was 0 (meaning bind to any free port), + // figure out which port number Akka actually bound to and set spark.driver.port to it. + if (isDriver && port == 0) { + System.setProperty("spark.driver.port", boundPort.toString) + } + + // set only if unset until now. + if (System.getProperty("spark.hostPort", null) == null) { + if (!isDriver){ + // unexpected + Utils.logErrorWithStack("Unexpected NOT to have spark.hostPort set") + } + Utils.checkHost(hostname) + System.setProperty("spark.hostPort", hostname + ":" + boundPort) + } + + val classLoader = Thread.currentThread.getContextClassLoader + + // Create an instance of the class named by the given Java system property, or by + // defaultClassName if the property is not set, and return it as a T + def instantiateClass[T](propertyName: String, defaultClassName: String): T = { + val name = System.getProperty(propertyName, defaultClassName) + Class.forName(name, true, classLoader).newInstance().asInstanceOf[T] + } + + val serializerManager = new SerializerManager + + val serializer = serializerManager.setDefault( + System.getProperty("spark.serializer", "org.apache.spark.serializer.JavaSerializer")) + + val closureSerializer = serializerManager.get( + System.getProperty("spark.closure.serializer", "org.apache.spark.serializer.JavaSerializer")) + + def registerOrLookup(name: String, newActor: => Actor): ActorRef = { + if (isDriver) { + logInfo("Registering " + name) + actorSystem.actorOf(Props(newActor), name = name) + } else { + val driverHost: String = System.getProperty("spark.driver.host", "localhost") + val driverPort: Int = System.getProperty("spark.driver.port", "7077").toInt + Utils.checkHost(driverHost, "Expected hostname") + val url = "akka://spark@%s:%s/user/%s".format(driverHost, driverPort, name) + logInfo("Connecting to " + name + ": " + url) + actorSystem.actorFor(url) + } + } + + val blockManagerMaster = new BlockManagerMaster(registerOrLookup( + "BlockManagerMaster", + new BlockManagerMasterActor(isLocal))) + val blockManager = new BlockManager(executorId, actorSystem, blockManagerMaster, serializer) + + val connectionManager = blockManager.connectionManager + + val broadcastManager = new BroadcastManager(isDriver) + + val cacheManager = new CacheManager(blockManager) + + // Have to assign trackerActor after initialization as MapOutputTrackerActor + // requires the MapOutputTracker itself + val mapOutputTracker = new MapOutputTracker() + mapOutputTracker.trackerActor = registerOrLookup( + "MapOutputTracker", + new MapOutputTrackerActor(mapOutputTracker)) + + val shuffleFetcher = instantiateClass[ShuffleFetcher]( + "spark.shuffle.fetcher", "org.apache.spark.BlockStoreShuffleFetcher") + + val httpFileServer = new HttpFileServer() + httpFileServer.initialize() + System.setProperty("spark.fileserver.uri", httpFileServer.serverUri) + + val metricsSystem = if (isDriver) { + MetricsSystem.createMetricsSystem("driver") + } else { + MetricsSystem.createMetricsSystem("executor") + } + metricsSystem.start() + + // Set the sparkFiles directory, used when downloading dependencies. In local mode, + // this is a temporary directory; in distributed mode, this is the executor's current working + // directory. + val sparkFilesDir: String = if (isDriver) { + Utils.createTempDir().getAbsolutePath + } else { + "." + } + + // Warn about deprecated spark.cache.class property + if (System.getProperty("spark.cache.class") != null) { + logWarning("The spark.cache.class property is no longer being used! Specify storage " + + "levels using the RDD.persist() method instead.") + } + + new SparkEnv( + executorId, + actorSystem, + serializerManager, + serializer, + closureSerializer, + cacheManager, + mapOutputTracker, + shuffleFetcher, + broadcastManager, + blockManager, + connectionManager, + httpFileServer, + sparkFilesDir, + metricsSystem) + } +} diff --git a/core/src/main/scala/org/apache/spark/SparkException.scala b/core/src/main/scala/org/apache/spark/SparkException.scala new file mode 100644 index 0000000000..d34e47e8ca --- /dev/null +++ b/core/src/main/scala/org/apache/spark/SparkException.scala @@ -0,0 +1,24 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +class SparkException(message: String, cause: Throwable) + extends Exception(message, cause) { + + def this(message: String) = this(message, null) +} diff --git a/core/src/main/scala/org/apache/spark/SparkFiles.java b/core/src/main/scala/org/apache/spark/SparkFiles.java new file mode 100644 index 0000000000..af9cf85e37 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/SparkFiles.java @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark; + +import java.io.File; + +/** + * Resolves paths to files added through `SparkContext.addFile()`. + */ +public class SparkFiles { + + private SparkFiles() {} + + /** + * Get the absolute path of a file added through `SparkContext.addFile()`. + */ + public static String get(String filename) { + return new File(getRootDirectory(), filename).getAbsolutePath(); + } + + /** + * Get the root directory that contains files added through `SparkContext.addFile()`. + */ + public static String getRootDirectory() { + return SparkEnv.get().sparkFilesDir(); + } +} diff --git a/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala b/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala new file mode 100644 index 0000000000..2bab9d6e3d --- /dev/null +++ b/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala @@ -0,0 +1,201 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.mapred + +import org.apache.hadoop.fs.FileSystem +import org.apache.hadoop.fs.Path + +import java.text.SimpleDateFormat +import java.text.NumberFormat +import java.io.IOException +import java.util.Date + +import org.apache.spark.Logging +import org.apache.spark.SerializableWritable + +/** + * Internal helper class that saves an RDD using a Hadoop OutputFormat. This is only public + * because we need to access this class from the `spark` package to use some package-private Hadoop + * functions, but this class should not be used directly by users. + * + * Saves the RDD using a JobConf, which should contain an output key class, an output value class, + * a filename to write to, etc, exactly like in a Hadoop MapReduce job. + */ +class SparkHadoopWriter(@transient jobConf: JobConf) extends Logging with SparkHadoopMapRedUtil with Serializable { + + private val now = new Date() + private val conf = new SerializableWritable(jobConf) + + private var jobID = 0 + private var splitID = 0 + private var attemptID = 0 + private var jID: SerializableWritable[JobID] = null + private var taID: SerializableWritable[TaskAttemptID] = null + + @transient private var writer: RecordWriter[AnyRef,AnyRef] = null + @transient private var format: OutputFormat[AnyRef,AnyRef] = null + @transient private var committer: OutputCommitter = null + @transient private var jobContext: JobContext = null + @transient private var taskContext: TaskAttemptContext = null + + def preSetup() { + setIDs(0, 0, 0) + setConfParams() + + val jCtxt = getJobContext() + getOutputCommitter().setupJob(jCtxt) + } + + + def setup(jobid: Int, splitid: Int, attemptid: Int) { + setIDs(jobid, splitid, attemptid) + setConfParams() + } + + def open() { + val numfmt = NumberFormat.getInstance() + numfmt.setMinimumIntegerDigits(5) + numfmt.setGroupingUsed(false) + + val outputName = "part-" + numfmt.format(splitID) + val path = FileOutputFormat.getOutputPath(conf.value) + val fs: FileSystem = { + if (path != null) { + path.getFileSystem(conf.value) + } else { + FileSystem.get(conf.value) + } + } + + getOutputCommitter().setupTask(getTaskContext()) + writer = getOutputFormat().getRecordWriter( + fs, conf.value, outputName, Reporter.NULL) + } + + def write(key: AnyRef, value: AnyRef) { + if (writer!=null) { + //println (">>> Writing ("+key.toString+": " + key.getClass.toString + ", " + value.toString + ": " + value.getClass.toString + ")") + writer.write(key, value) + } else { + throw new IOException("Writer is null, open() has not been called") + } + } + + def close() { + writer.close(Reporter.NULL) + } + + def commit() { + val taCtxt = getTaskContext() + val cmtr = getOutputCommitter() + if (cmtr.needsTaskCommit(taCtxt)) { + try { + cmtr.commitTask(taCtxt) + logInfo (taID + ": Committed") + } catch { + case e: IOException => { + logError("Error committing the output of task: " + taID.value, e) + cmtr.abortTask(taCtxt) + throw e + } + } + } else { + logWarning ("No need to commit output of task: " + taID.value) + } + } + + def commitJob() { + // always ? Or if cmtr.needsTaskCommit ? + val cmtr = getOutputCommitter() + cmtr.commitJob(getJobContext()) + } + + def cleanup() { + getOutputCommitter().cleanupJob(getJobContext()) + } + + // ********* Private Functions ********* + + private def getOutputFormat(): OutputFormat[AnyRef,AnyRef] = { + if (format == null) { + format = conf.value.getOutputFormat() + .asInstanceOf[OutputFormat[AnyRef,AnyRef]] + } + return format + } + + private def getOutputCommitter(): OutputCommitter = { + if (committer == null) { + committer = conf.value.getOutputCommitter + } + return committer + } + + private def getJobContext(): JobContext = { + if (jobContext == null) { + jobContext = newJobContext(conf.value, jID.value) + } + return jobContext + } + + private def getTaskContext(): TaskAttemptContext = { + if (taskContext == null) { + taskContext = newTaskAttemptContext(conf.value, taID.value) + } + return taskContext + } + + private def setIDs(jobid: Int, splitid: Int, attemptid: Int) { + jobID = jobid + splitID = splitid + attemptID = attemptid + + jID = new SerializableWritable[JobID](SparkHadoopWriter.createJobID(now, jobid)) + taID = new SerializableWritable[TaskAttemptID]( + new TaskAttemptID(new TaskID(jID.value, true, splitID), attemptID)) + } + + private def setConfParams() { + conf.value.set("mapred.job.id", jID.value.toString) + conf.value.set("mapred.tip.id", taID.value.getTaskID.toString) + conf.value.set("mapred.task.id", taID.value.toString) + conf.value.setBoolean("mapred.task.is.map", true) + conf.value.setInt("mapred.task.partition", splitID) + } +} + +object SparkHadoopWriter { + def createJobID(time: Date, id: Int): JobID = { + val formatter = new SimpleDateFormat("yyyyMMddHHmm") + val jobtrackerID = formatter.format(new Date()) + return new JobID(jobtrackerID, id) + } + + def createPathFromString(path: String, conf: JobConf): Path = { + if (path == null) { + throw new IllegalArgumentException("Output path is null") + } + var outputPath = new Path(path) + val fs = outputPath.getFileSystem(conf) + if (outputPath == null || fs == null) { + throw new IllegalArgumentException("Incorrectly formatted output path") + } + outputPath = outputPath.makeQualified(fs) + return outputPath + } +} diff --git a/core/src/main/scala/org/apache/spark/TaskContext.scala b/core/src/main/scala/org/apache/spark/TaskContext.scala new file mode 100644 index 0000000000..b2dd668330 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/TaskContext.scala @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import executor.TaskMetrics +import scala.collection.mutable.ArrayBuffer + +class TaskContext( + val stageId: Int, + val splitId: Int, + val attemptId: Long, + val taskMetrics: TaskMetrics = TaskMetrics.empty() +) extends Serializable { + + @transient val onCompleteCallbacks = new ArrayBuffer[() => Unit] + + // Add a callback function to be executed on task completion. An example use + // is for HadoopRDD to register a callback to close the input stream. + def addOnCompleteCallback(f: () => Unit) { + onCompleteCallbacks += f + } + + def executeOnCompleteCallbacks() { + onCompleteCallbacks.foreach{_()} + } +} diff --git a/core/src/main/scala/org/apache/spark/TaskEndReason.scala b/core/src/main/scala/org/apache/spark/TaskEndReason.scala new file mode 100644 index 0000000000..03bf268863 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/TaskEndReason.scala @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import org.apache.spark.executor.TaskMetrics +import org.apache.spark.storage.BlockManagerId + +/** + * Various possible reasons why a task ended. The low-level TaskScheduler is supposed to retry + * tasks several times for "ephemeral" failures, and only report back failures that require some + * old stages to be resubmitted, such as shuffle map fetch failures. + */ +private[spark] sealed trait TaskEndReason + +private[spark] case object Success extends TaskEndReason + +private[spark] +case object Resubmitted extends TaskEndReason // Task was finished earlier but we've now lost it + +private[spark] case class FetchFailed( + bmAddress: BlockManagerId, + shuffleId: Int, + mapId: Int, + reduceId: Int) + extends TaskEndReason + +private[spark] case class ExceptionFailure( + className: String, + description: String, + stackTrace: Array[StackTraceElement], + metrics: Option[TaskMetrics]) + extends TaskEndReason + +private[spark] case class OtherFailure(message: String) extends TaskEndReason + +private[spark] case class TaskResultTooBigFailure() extends TaskEndReason diff --git a/core/src/main/scala/org/apache/spark/TaskState.scala b/core/src/main/scala/org/apache/spark/TaskState.scala new file mode 100644 index 0000000000..0bf1e4a5e2 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/TaskState.scala @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import org.apache.mesos.Protos.{TaskState => MesosTaskState} + +private[spark] object TaskState extends Enumeration { + + val LAUNCHING, RUNNING, FINISHED, FAILED, KILLED, LOST = Value + + val FINISHED_STATES = Set(FINISHED, FAILED, KILLED, LOST) + + type TaskState = Value + + def isFinished(state: TaskState) = FINISHED_STATES.contains(state) + + def toMesos(state: TaskState): MesosTaskState = state match { + case LAUNCHING => MesosTaskState.TASK_STARTING + case RUNNING => MesosTaskState.TASK_RUNNING + case FINISHED => MesosTaskState.TASK_FINISHED + case FAILED => MesosTaskState.TASK_FAILED + case KILLED => MesosTaskState.TASK_KILLED + case LOST => MesosTaskState.TASK_LOST + } + + def fromMesos(mesosState: MesosTaskState): TaskState = mesosState match { + case MesosTaskState.TASK_STAGING => LAUNCHING + case MesosTaskState.TASK_STARTING => LAUNCHING + case MesosTaskState.TASK_RUNNING => RUNNING + case MesosTaskState.TASK_FINISHED => FINISHED + case MesosTaskState.TASK_FAILED => FAILED + case MesosTaskState.TASK_KILLED => KILLED + case MesosTaskState.TASK_LOST => LOST + } +} diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala new file mode 100644 index 0000000000..f0a1960a1b --- /dev/null +++ b/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala @@ -0,0 +1,170 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.java + +import scala.reflect.ClassTag + +import org.apache.spark.rdd.RDD +import org.apache.spark.SparkContext.doubleRDDToDoubleRDDFunctions +import org.apache.spark.api.java.function.{Function => JFunction} +import org.apache.spark.util.StatCounter +import org.apache.spark.partial.{BoundedDouble, PartialResult} +import org.apache.spark.storage.StorageLevel + +import java.lang.Double +import org.apache.spark.Partitioner + +class JavaDoubleRDD(val srdd: RDD[scala.Double]) extends JavaRDDLike[Double, JavaDoubleRDD] { + + override val classTag: ClassTag[Double] = implicitly[ClassTag[Double]] + + override val rdd: RDD[Double] = srdd.map(x => Double.valueOf(x)) + + override def wrapRDD(rdd: RDD[Double]): JavaDoubleRDD = + new JavaDoubleRDD(rdd.map(_.doubleValue)) + + // Common RDD functions + + import JavaDoubleRDD.fromRDD + + /** Persist this RDD with the default storage level (`MEMORY_ONLY`). */ + def cache(): JavaDoubleRDD = fromRDD(srdd.cache()) + + /** + * Set this RDD's storage level to persist its values across operations after the first time + * it is computed. Can only be called once on each RDD. + */ + def persist(newLevel: StorageLevel): JavaDoubleRDD = fromRDD(srdd.persist(newLevel)) + + // first() has to be overriden here in order for its return type to be Double instead of Object. + override def first(): Double = srdd.first() + + // Transformations (return a new RDD) + + /** + * Return a new RDD containing the distinct elements in this RDD. + */ + def distinct(): JavaDoubleRDD = fromRDD(srdd.distinct()) + + /** + * Return a new RDD containing the distinct elements in this RDD. + */ + def distinct(numPartitions: Int): JavaDoubleRDD = fromRDD(srdd.distinct(numPartitions)) + + /** + * Return a new RDD containing only the elements that satisfy a predicate. + */ + def filter(f: JFunction[Double, java.lang.Boolean]): JavaDoubleRDD = + fromRDD(srdd.filter(x => f(x).booleanValue())) + + /** + * Return a new RDD that is reduced into `numPartitions` partitions. + */ + def coalesce(numPartitions: Int): JavaDoubleRDD = fromRDD(srdd.coalesce(numPartitions)) + + /** + * Return a new RDD that is reduced into `numPartitions` partitions. + */ + def coalesce(numPartitions: Int, shuffle: Boolean): JavaDoubleRDD = + fromRDD(srdd.coalesce(numPartitions, shuffle)) + + /** + * Return an RDD with the elements from `this` that are not in `other`. + * + * Uses `this` partitioner/partition size, because even if `other` is huge, the resulting + * RDD will be <= us. + */ + def subtract(other: JavaDoubleRDD): JavaDoubleRDD = + fromRDD(srdd.subtract(other)) + + /** + * Return an RDD with the elements from `this` that are not in `other`. + */ + def subtract(other: JavaDoubleRDD, numPartitions: Int): JavaDoubleRDD = + fromRDD(srdd.subtract(other, numPartitions)) + + /** + * Return an RDD with the elements from `this` that are not in `other`. + */ + def subtract(other: JavaDoubleRDD, p: Partitioner): JavaDoubleRDD = + fromRDD(srdd.subtract(other, p)) + + /** + * Return a sampled subset of this RDD. + */ + def sample(withReplacement: Boolean, fraction: Double, seed: Int): JavaDoubleRDD = + fromRDD(srdd.sample(withReplacement, fraction, seed)) + + /** + * Return the union of this RDD and another one. Any identical elements will appear multiple + * times (use `.distinct()` to eliminate them). + */ + def union(other: JavaDoubleRDD): JavaDoubleRDD = fromRDD(srdd.union(other.srdd)) + + // Double RDD functions + + /** Add up the elements in this RDD. */ + def sum(): Double = srdd.sum() + + /** + * Return a [[org.apache.spark.util.StatCounter]] object that captures the mean, variance and count + * of the RDD's elements in one operation. + */ + def stats(): StatCounter = srdd.stats() + + /** Compute the mean of this RDD's elements. */ + def mean(): Double = srdd.mean() + + /** Compute the variance of this RDD's elements. */ + def variance(): Double = srdd.variance() + + /** Compute the standard deviation of this RDD's elements. */ + def stdev(): Double = srdd.stdev() + + /** + * Compute the sample standard deviation of this RDD's elements (which corrects for bias in + * estimating the standard deviation by dividing by N-1 instead of N). + */ + def sampleStdev(): Double = srdd.sampleStdev() + + /** + * Compute the sample variance of this RDD's elements (which corrects for bias in + * estimating the standard variance by dividing by N-1 instead of N). + */ + def sampleVariance(): Double = srdd.sampleVariance() + + /** Return the approximate mean of the elements in this RDD. */ + def meanApprox(timeout: Long, confidence: Double): PartialResult[BoundedDouble] = + srdd.meanApprox(timeout, confidence) + + /** (Experimental) Approximate operation to return the mean within a timeout. */ + def meanApprox(timeout: Long): PartialResult[BoundedDouble] = srdd.meanApprox(timeout) + + /** (Experimental) Approximate operation to return the sum within a timeout. */ + def sumApprox(timeout: Long, confidence: Double): PartialResult[BoundedDouble] = + srdd.sumApprox(timeout, confidence) + + /** (Experimental) Approximate operation to return the sum within a timeout. */ + def sumApprox(timeout: Long): PartialResult[BoundedDouble] = srdd.sumApprox(timeout) +} + +object JavaDoubleRDD { + def fromRDD(rdd: RDD[scala.Double]): JavaDoubleRDD = new JavaDoubleRDD(rdd) + + implicit def toRDD(rdd: JavaDoubleRDD): RDD[scala.Double] = rdd.srdd +} diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala new file mode 100644 index 0000000000..899e17d4fa --- /dev/null +++ b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala @@ -0,0 +1,602 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.java + +import java.util.{List => JList} +import java.util.Comparator + +import scala.Tuple2 +import scala.collection.JavaConversions._ +import scala.reflect.ClassTag + +import com.google.common.base.Optional +import org.apache.hadoop.io.compress.CompressionCodec +import org.apache.hadoop.mapred.JobConf +import org.apache.hadoop.mapred.OutputFormat +import org.apache.hadoop.mapreduce.{OutputFormat => NewOutputFormat} +import org.apache.hadoop.conf.Configuration + +import org.apache.spark.HashPartitioner +import org.apache.spark.Partitioner +import org.apache.spark.Partitioner._ +import org.apache.spark.SparkContext.rddToPairRDDFunctions +import org.apache.spark.api.java.function.{Function2 => JFunction2} +import org.apache.spark.api.java.function.{Function => JFunction} +import org.apache.spark.partial.BoundedDouble +import org.apache.spark.partial.PartialResult +import org.apache.spark.rdd.RDD +import org.apache.spark.rdd.OrderedRDDFunctions +import org.apache.spark.storage.StorageLevel + + +class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kClassTag: ClassTag[K], + implicit val vClassTag: ClassTag[V]) extends JavaRDDLike[(K, V), JavaPairRDD[K, V]] { + + override def wrapRDD(rdd: RDD[(K, V)]): JavaPairRDD[K, V] = JavaPairRDD.fromRDD(rdd) + + override val classTag: ClassTag[(K, V)] = + implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[Tuple2[K, V]]] + + import JavaPairRDD._ + + // Common RDD functions + + /** Persist this RDD with the default storage level (`MEMORY_ONLY`). */ + def cache(): JavaPairRDD[K, V] = new JavaPairRDD[K, V](rdd.cache()) + + /** + * Set this RDD's storage level to persist its values across operations after the first time + * it is computed. Can only be called once on each RDD. + */ + def persist(newLevel: StorageLevel): JavaPairRDD[K, V] = + new JavaPairRDD[K, V](rdd.persist(newLevel)) + + // Transformations (return a new RDD) + + /** + * Return a new RDD containing the distinct elements in this RDD. + */ + def distinct(): JavaPairRDD[K, V] = new JavaPairRDD[K, V](rdd.distinct()) + + /** + * Return a new RDD containing the distinct elements in this RDD. + */ + def distinct(numPartitions: Int): JavaPairRDD[K, V] = new JavaPairRDD[K, V](rdd.distinct(numPartitions)) + + /** + * Return a new RDD containing only the elements that satisfy a predicate. + */ + def filter(f: JFunction[(K, V), java.lang.Boolean]): JavaPairRDD[K, V] = + new JavaPairRDD[K, V](rdd.filter(x => f(x).booleanValue())) + + /** + * Return a new RDD that is reduced into `numPartitions` partitions. + */ + def coalesce(numPartitions: Int): JavaPairRDD[K, V] = fromRDD(rdd.coalesce(numPartitions)) + + /** + * Return a new RDD that is reduced into `numPartitions` partitions. + */ + def coalesce(numPartitions: Int, shuffle: Boolean): JavaPairRDD[K, V] = + fromRDD(rdd.coalesce(numPartitions, shuffle)) + + /** + * Return a sampled subset of this RDD. + */ + def sample(withReplacement: Boolean, fraction: Double, seed: Int): JavaPairRDD[K, V] = + new JavaPairRDD[K, V](rdd.sample(withReplacement, fraction, seed)) + + /** + * Return the union of this RDD and another one. Any identical elements will appear multiple + * times (use `.distinct()` to eliminate them). + */ + def union(other: JavaPairRDD[K, V]): JavaPairRDD[K, V] = + new JavaPairRDD[K, V](rdd.union(other.rdd)) + + // first() has to be overridden here so that the generated method has the signature + // 'public scala.Tuple2 first()'; if the trait's definition is used, + // then the method has the signature 'public java.lang.Object first()', + // causing NoSuchMethodErrors at runtime. + override def first(): (K, V) = rdd.first() + + // Pair RDD functions + + /** + * Generic function to combine the elements for each key using a custom set of aggregation + * functions. Turns a JavaPairRDD[(K, V)] into a result of type JavaPairRDD[(K, C)], for a + * "combined type" C * Note that V and C can be different -- for example, one might group an + * RDD of type (Int, Int) into an RDD of type (Int, List[Int]). Users provide three + * functions: + * + * - `createCombiner`, which turns a V into a C (e.g., creates a one-element list) + * - `mergeValue`, to merge a V into a C (e.g., adds it to the end of a list) + * - `mergeCombiners`, to combine two C's into a single one. + * + * In addition, users can control the partitioning of the output RDD, and whether to perform + * map-side aggregation (if a mapper can produce multiple items with the same key). + */ + def combineByKey[C](createCombiner: JFunction[V, C], + mergeValue: JFunction2[C, V, C], + mergeCombiners: JFunction2[C, C, C], + partitioner: Partitioner): JavaPairRDD[K, C] = { + implicit val cm: ClassTag[C] = + implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[C]] + fromRDD(rdd.combineByKey( + createCombiner, + mergeValue, + mergeCombiners, + partitioner + )) + } + + /** + * Simplified version of combineByKey that hash-partitions the output RDD. + */ + def combineByKey[C](createCombiner: JFunction[V, C], + mergeValue: JFunction2[C, V, C], + mergeCombiners: JFunction2[C, C, C], + numPartitions: Int): JavaPairRDD[K, C] = + combineByKey(createCombiner, mergeValue, mergeCombiners, new HashPartitioner(numPartitions)) + + /** + * Merge the values for each key using an associative reduce function. This will also perform + * the merging locally on each mapper before sending results to a reducer, similarly to a + * "combiner" in MapReduce. + */ + def reduceByKey(partitioner: Partitioner, func: JFunction2[V, V, V]): JavaPairRDD[K, V] = + fromRDD(rdd.reduceByKey(partitioner, func)) + + /** + * Merge the values for each key using an associative reduce function, but return the results + * immediately to the master as a Map. This will also perform the merging locally on each mapper + * before sending results to a reducer, similarly to a "combiner" in MapReduce. + */ + def reduceByKeyLocally(func: JFunction2[V, V, V]): java.util.Map[K, V] = + mapAsJavaMap(rdd.reduceByKeyLocally(func)) + + /** Count the number of elements for each key, and return the result to the master as a Map. */ + def countByKey(): java.util.Map[K, Long] = mapAsJavaMap(rdd.countByKey()) + + /** + * (Experimental) Approximate version of countByKey that can return a partial result if it does + * not finish within a timeout. + */ + def countByKeyApprox(timeout: Long): PartialResult[java.util.Map[K, BoundedDouble]] = + rdd.countByKeyApprox(timeout).map(mapAsJavaMap) + + /** + * (Experimental) Approximate version of countByKey that can return a partial result if it does + * not finish within a timeout. + */ + def countByKeyApprox(timeout: Long, confidence: Double = 0.95) + : PartialResult[java.util.Map[K, BoundedDouble]] = + rdd.countByKeyApprox(timeout, confidence).map(mapAsJavaMap) + + /** + * Merge the values for each key using an associative function and a neutral "zero value" which may + * be added to the result an arbitrary number of times, and must not change the result (e.g., Nil for + * list concatenation, 0 for addition, or 1 for multiplication.). + */ + def foldByKey(zeroValue: V, partitioner: Partitioner, func: JFunction2[V, V, V]): JavaPairRDD[K, V] = + fromRDD(rdd.foldByKey(zeroValue, partitioner)(func)) + + /** + * Merge the values for each key using an associative function and a neutral "zero value" which may + * be added to the result an arbitrary number of times, and must not change the result (e.g., Nil for + * list concatenation, 0 for addition, or 1 for multiplication.). + */ + def foldByKey(zeroValue: V, numPartitions: Int, func: JFunction2[V, V, V]): JavaPairRDD[K, V] = + fromRDD(rdd.foldByKey(zeroValue, numPartitions)(func)) + + /** + * Merge the values for each key using an associative function and a neutral "zero value" which may + * be added to the result an arbitrary number of times, and must not change the result (e.g., Nil for + * list concatenation, 0 for addition, or 1 for multiplication.). + */ + def foldByKey(zeroValue: V, func: JFunction2[V, V, V]): JavaPairRDD[K, V] = + fromRDD(rdd.foldByKey(zeroValue)(func)) + + /** + * Merge the values for each key using an associative reduce function. This will also perform + * the merging locally on each mapper before sending results to a reducer, similarly to a + * "combiner" in MapReduce. Output will be hash-partitioned with numPartitions partitions. + */ + def reduceByKey(func: JFunction2[V, V, V], numPartitions: Int): JavaPairRDD[K, V] = + fromRDD(rdd.reduceByKey(func, numPartitions)) + + /** + * Group the values for each key in the RDD into a single sequence. Allows controlling the + * partitioning of the resulting key-value pair RDD by passing a Partitioner. + */ + def groupByKey(partitioner: Partitioner): JavaPairRDD[K, JList[V]] = + fromRDD(groupByResultToJava(rdd.groupByKey(partitioner))) + + /** + * Group the values for each key in the RDD into a single sequence. Hash-partitions the + * resulting RDD with into `numPartitions` partitions. + */ + def groupByKey(numPartitions: Int): JavaPairRDD[K, JList[V]] = + fromRDD(groupByResultToJava(rdd.groupByKey(numPartitions))) + + /** + * Return an RDD with the elements from `this` that are not in `other`. + * + * Uses `this` partitioner/partition size, because even if `other` is huge, the resulting + * RDD will be <= us. + */ + def subtract(other: JavaPairRDD[K, V]): JavaPairRDD[K, V] = + fromRDD(rdd.subtract(other)) + + /** + * Return an RDD with the elements from `this` that are not in `other`. + */ + def subtract(other: JavaPairRDD[K, V], numPartitions: Int): JavaPairRDD[K, V] = + fromRDD(rdd.subtract(other, numPartitions)) + + /** + * Return an RDD with the elements from `this` that are not in `other`. + */ + def subtract(other: JavaPairRDD[K, V], p: Partitioner): JavaPairRDD[K, V] = + fromRDD(rdd.subtract(other, p)) + + /** + * Return a copy of the RDD partitioned using the specified partitioner. + */ + def partitionBy(partitioner: Partitioner): JavaPairRDD[K, V] = + fromRDD(rdd.partitionBy(partitioner)) + + /** + * Merge the values for each key using an associative reduce function. This will also perform + * the merging locally on each mapper before sending results to a reducer, similarly to a + * "combiner" in MapReduce. + */ + def join[W](other: JavaPairRDD[K, W], partitioner: Partitioner): JavaPairRDD[K, (V, W)] = + fromRDD(rdd.join(other, partitioner)) + + /** + * Perform a left outer join of `this` and `other`. For each element (k, v) in `this`, the + * resulting RDD will either contain all pairs (k, (v, Some(w))) for w in `other`, or the + * pair (k, (v, None)) if no elements in `other` have key k. Uses the given Partitioner to + * partition the output RDD. + */ + def leftOuterJoin[W](other: JavaPairRDD[K, W], partitioner: Partitioner) + : JavaPairRDD[K, (V, Optional[W])] = { + val joinResult = rdd.leftOuterJoin(other, partitioner) + fromRDD(joinResult.mapValues{case (v, w) => (v, JavaUtils.optionToOptional(w))}) + } + + /** + * Perform a right outer join of `this` and `other`. For each element (k, w) in `other`, the + * resulting RDD will either contain all pairs (k, (Some(v), w)) for v in `this`, or the + * pair (k, (None, w)) if no elements in `this` have key k. Uses the given Partitioner to + * partition the output RDD. + */ + def rightOuterJoin[W](other: JavaPairRDD[K, W], partitioner: Partitioner) + : JavaPairRDD[K, (Optional[V], W)] = { + val joinResult = rdd.rightOuterJoin(other, partitioner) + fromRDD(joinResult.mapValues{case (v, w) => (JavaUtils.optionToOptional(v), w)}) + } + + /** + * Simplified version of combineByKey that hash-partitions the resulting RDD using the existing + * partitioner/parallelism level. + */ + def combineByKey[C](createCombiner: JFunction[V, C], + mergeValue: JFunction2[C, V, C], + mergeCombiners: JFunction2[C, C, C]): JavaPairRDD[K, C] = { + implicit val cm: ClassTag[C] = + implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[C]] + fromRDD(combineByKey(createCombiner, mergeValue, mergeCombiners, defaultPartitioner(rdd))) + } + + /** + * Merge the values for each key using an associative reduce function. This will also perform + * the merging locally on each mapper before sending results to a reducer, similarly to a + * "combiner" in MapReduce. Output will be hash-partitioned with the existing partitioner/ + * parallelism level. + */ + def reduceByKey(func: JFunction2[V, V, V]): JavaPairRDD[K, V] = { + fromRDD(reduceByKey(defaultPartitioner(rdd), func)) + } + + /** + * Group the values for each key in the RDD into a single sequence. Hash-partitions the + * resulting RDD with the existing partitioner/parallelism level. + */ + def groupByKey(): JavaPairRDD[K, JList[V]] = + fromRDD(groupByResultToJava(rdd.groupByKey())) + + /** + * Return an RDD containing all pairs of elements with matching keys in `this` and `other`. Each + * pair of elements will be returned as a (k, (v1, v2)) tuple, where (k, v1) is in `this` and + * (k, v2) is in `other`. Performs a hash join across the cluster. + */ + def join[W](other: JavaPairRDD[K, W]): JavaPairRDD[K, (V, W)] = + fromRDD(rdd.join(other)) + + /** + * Return an RDD containing all pairs of elements with matching keys in `this` and `other`. Each + * pair of elements will be returned as a (k, (v1, v2)) tuple, where (k, v1) is in `this` and + * (k, v2) is in `other`. Performs a hash join across the cluster. + */ + def join[W](other: JavaPairRDD[K, W], numPartitions: Int): JavaPairRDD[K, (V, W)] = + fromRDD(rdd.join(other, numPartitions)) + + /** + * Perform a left outer join of `this` and `other`. For each element (k, v) in `this`, the + * resulting RDD will either contain all pairs (k, (v, Some(w))) for w in `other`, or the + * pair (k, (v, None)) if no elements in `other` have key k. Hash-partitions the output + * using the existing partitioner/parallelism level. + */ + def leftOuterJoin[W](other: JavaPairRDD[K, W]): JavaPairRDD[K, (V, Optional[W])] = { + val joinResult = rdd.leftOuterJoin(other) + fromRDD(joinResult.mapValues{case (v, w) => (v, JavaUtils.optionToOptional(w))}) + } + + /** + * Perform a left outer join of `this` and `other`. For each element (k, v) in `this`, the + * resulting RDD will either contain all pairs (k, (v, Some(w))) for w in `other`, or the + * pair (k, (v, None)) if no elements in `other` have key k. Hash-partitions the output + * into `numPartitions` partitions. + */ + def leftOuterJoin[W](other: JavaPairRDD[K, W], numPartitions: Int): JavaPairRDD[K, (V, Optional[W])] = { + val joinResult = rdd.leftOuterJoin(other, numPartitions) + fromRDD(joinResult.mapValues{case (v, w) => (v, JavaUtils.optionToOptional(w))}) + } + + /** + * Perform a right outer join of `this` and `other`. For each element (k, w) in `other`, the + * resulting RDD will either contain all pairs (k, (Some(v), w)) for v in `this`, or the + * pair (k, (None, w)) if no elements in `this` have key k. Hash-partitions the resulting + * RDD using the existing partitioner/parallelism level. + */ + def rightOuterJoin[W](other: JavaPairRDD[K, W]): JavaPairRDD[K, (Optional[V], W)] = { + val joinResult = rdd.rightOuterJoin(other) + fromRDD(joinResult.mapValues{case (v, w) => (JavaUtils.optionToOptional(v), w)}) + } + + /** + * Perform a right outer join of `this` and `other`. For each element (k, w) in `other`, the + * resulting RDD will either contain all pairs (k, (Some(v), w)) for v in `this`, or the + * pair (k, (None, w)) if no elements in `this` have key k. Hash-partitions the resulting + * RDD into the given number of partitions. + */ + def rightOuterJoin[W](other: JavaPairRDD[K, W], numPartitions: Int): JavaPairRDD[K, (Optional[V], W)] = { + val joinResult = rdd.rightOuterJoin(other, numPartitions) + fromRDD(joinResult.mapValues{case (v, w) => (JavaUtils.optionToOptional(v), w)}) + } + + /** + * Return the key-value pairs in this RDD to the master as a Map. + */ + def collectAsMap(): java.util.Map[K, V] = mapAsJavaMap(rdd.collectAsMap()) + + /** + * Pass each value in the key-value pair RDD through a map function without changing the keys; + * this also retains the original RDD's partitioning. + */ + def mapValues[U](f: JFunction[V, U]): JavaPairRDD[K, U] = { + implicit val cm: ClassTag[U] = + implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[U]] + fromRDD(rdd.mapValues(f)) + } + + /** + * Pass each value in the key-value pair RDD through a flatMap function without changing the + * keys; this also retains the original RDD's partitioning. + */ + def flatMapValues[U](f: JFunction[V, java.lang.Iterable[U]]): JavaPairRDD[K, U] = { + import scala.collection.JavaConverters._ + def fn = (x: V) => f.apply(x).asScala + implicit val cm: ClassTag[U] = + implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[U]] + fromRDD(rdd.flatMapValues(fn)) + } + + /** + * For each key k in `this` or `other`, return a resulting RDD that contains a tuple with the + * list of values for that key in `this` as well as `other`. + */ + def cogroup[W](other: JavaPairRDD[K, W], partitioner: Partitioner) + : JavaPairRDD[K, (JList[V], JList[W])] = + fromRDD(cogroupResultToJava(rdd.cogroup(other, partitioner))) + + /** + * For each key k in `this` or `other1` or `other2`, return a resulting RDD that contains a + * tuple with the list of values for that key in `this`, `other1` and `other2`. + */ + def cogroup[W1, W2](other1: JavaPairRDD[K, W1], other2: JavaPairRDD[K, W2], partitioner: Partitioner) + : JavaPairRDD[K, (JList[V], JList[W1], JList[W2])] = + fromRDD(cogroupResult2ToJava(rdd.cogroup(other1, other2, partitioner))) + + /** + * For each key k in `this` or `other`, return a resulting RDD that contains a tuple with the + * list of values for that key in `this` as well as `other`. + */ + def cogroup[W](other: JavaPairRDD[K, W]): JavaPairRDD[K, (JList[V], JList[W])] = + fromRDD(cogroupResultToJava(rdd.cogroup(other))) + + /** + * For each key k in `this` or `other1` or `other2`, return a resulting RDD that contains a + * tuple with the list of values for that key in `this`, `other1` and `other2`. + */ + def cogroup[W1, W2](other1: JavaPairRDD[K, W1], other2: JavaPairRDD[K, W2]) + : JavaPairRDD[K, (JList[V], JList[W1], JList[W2])] = + fromRDD(cogroupResult2ToJava(rdd.cogroup(other1, other2))) + + /** + * For each key k in `this` or `other`, return a resulting RDD that contains a tuple with the + * list of values for that key in `this` as well as `other`. + */ + def cogroup[W](other: JavaPairRDD[K, W], numPartitions: Int): JavaPairRDD[K, (JList[V], JList[W])] + = fromRDD(cogroupResultToJava(rdd.cogroup(other, numPartitions))) + + /** + * For each key k in `this` or `other1` or `other2`, return a resulting RDD that contains a + * tuple with the list of values for that key in `this`, `other1` and `other2`. + */ + def cogroup[W1, W2](other1: JavaPairRDD[K, W1], other2: JavaPairRDD[K, W2], numPartitions: Int) + : JavaPairRDD[K, (JList[V], JList[W1], JList[W2])] = + fromRDD(cogroupResult2ToJava(rdd.cogroup(other1, other2, numPartitions))) + + /** Alias for cogroup. */ + def groupWith[W](other: JavaPairRDD[K, W]): JavaPairRDD[K, (JList[V], JList[W])] = + fromRDD(cogroupResultToJava(rdd.groupWith(other))) + + /** Alias for cogroup. */ + def groupWith[W1, W2](other1: JavaPairRDD[K, W1], other2: JavaPairRDD[K, W2]) + : JavaPairRDD[K, (JList[V], JList[W1], JList[W2])] = + fromRDD(cogroupResult2ToJava(rdd.groupWith(other1, other2))) + + /** + * Return the list of values in the RDD for key `key`. This operation is done efficiently if the + * RDD has a known partitioner by only searching the partition that the key maps to. + */ + def lookup(key: K): JList[V] = seqAsJavaList(rdd.lookup(key)) + + /** Output the RDD to any Hadoop-supported file system. */ + def saveAsHadoopFile[F <: OutputFormat[_, _]]( + path: String, + keyClass: Class[_], + valueClass: Class[_], + outputFormatClass: Class[F], + conf: JobConf) { + rdd.saveAsHadoopFile(path, keyClass, valueClass, outputFormatClass, conf) + } + + /** Output the RDD to any Hadoop-supported file system. */ + def saveAsHadoopFile[F <: OutputFormat[_, _]]( + path: String, + keyClass: Class[_], + valueClass: Class[_], + outputFormatClass: Class[F]) { + rdd.saveAsHadoopFile(path, keyClass, valueClass, outputFormatClass) + } + + /** Output the RDD to any Hadoop-supported file system, compressing with the supplied codec. */ + def saveAsHadoopFile[F <: OutputFormat[_, _]]( + path: String, + keyClass: Class[_], + valueClass: Class[_], + outputFormatClass: Class[F], + codec: Class[_ <: CompressionCodec]) { + rdd.saveAsHadoopFile(path, keyClass, valueClass, outputFormatClass, codec) + } + + /** Output the RDD to any Hadoop-supported file system. */ + def saveAsNewAPIHadoopFile[F <: NewOutputFormat[_, _]]( + path: String, + keyClass: Class[_], + valueClass: Class[_], + outputFormatClass: Class[F], + conf: Configuration) { + rdd.saveAsNewAPIHadoopFile(path, keyClass, valueClass, outputFormatClass, conf) + } + + /** Output the RDD to any Hadoop-supported file system. */ + def saveAsNewAPIHadoopFile[F <: NewOutputFormat[_, _]]( + path: String, + keyClass: Class[_], + valueClass: Class[_], + outputFormatClass: Class[F]) { + rdd.saveAsNewAPIHadoopFile(path, keyClass, valueClass, outputFormatClass) + } + + /** + * Output the RDD to any Hadoop-supported storage system, using a Hadoop JobConf object for + * that storage system. The JobConf should set an OutputFormat and any output paths required + * (e.g. a table name to write to) in the same way as it would be configured for a Hadoop + * MapReduce job. + */ + def saveAsHadoopDataset(conf: JobConf) { + rdd.saveAsHadoopDataset(conf) + } + + /** + * Sort the RDD by key, so that each partition contains a sorted range of the elements in + * ascending order. Calling `collect` or `save` on the resulting RDD will return or output an + * ordered list of records (in the `save` case, they will be written to multiple `part-X` files + * in the filesystem, in order of the keys). + */ + def sortByKey(): JavaPairRDD[K, V] = sortByKey(true) + + /** + * Sort the RDD by key, so that each partition contains a sorted range of the elements. Calling + * `collect` or `save` on the resulting RDD will return or output an ordered list of records + * (in the `save` case, they will be written to multiple `part-X` files in the filesystem, in + * order of the keys). + */ + def sortByKey(ascending: Boolean): JavaPairRDD[K, V] = { + val comp = com.google.common.collect.Ordering.natural().asInstanceOf[Comparator[K]] + sortByKey(comp, ascending) + } + + /** + * Sort the RDD by key, so that each partition contains a sorted range of the elements. Calling + * `collect` or `save` on the resulting RDD will return or output an ordered list of records + * (in the `save` case, they will be written to multiple `part-X` files in the filesystem, in + * order of the keys). + */ + def sortByKey(comp: Comparator[K]): JavaPairRDD[K, V] = sortByKey(comp, true) + + /** + * Sort the RDD by key, so that each partition contains a sorted range of the elements. Calling + * `collect` or `save` on the resulting RDD will return or output an ordered list of records + * (in the `save` case, they will be written to multiple `part-X` files in the filesystem, in + * order of the keys). + */ + def sortByKey(comp: Comparator[K], ascending: Boolean): JavaPairRDD[K, V] = { + class KeyOrdering(val a: K) extends Ordered[K] { + override def compare(b: K) = comp.compare(a, b) + } + implicit def toOrdered(x: K): Ordered[K] = new KeyOrdering(x) + fromRDD(new OrderedRDDFunctions[K, V, (K, V)](rdd).sortByKey(ascending)) + } + + /** + * Return an RDD with the keys of each tuple. + */ + def keys(): JavaRDD[K] = JavaRDD.fromRDD[K](rdd.map(_._1)) + + /** + * Return an RDD with the values of each tuple. + */ + def values(): JavaRDD[V] = JavaRDD.fromRDD[V](rdd.map(_._2)) +} + +object JavaPairRDD { + def groupByResultToJava[K, T](rdd: RDD[(K, Seq[T])])(implicit kcm: ClassTag[K], + vcm: ClassTag[T]): RDD[(K, JList[T])] = + rddToPairRDDFunctions(rdd).mapValues(seqAsJavaList _) + + def cogroupResultToJava[W, K, V](rdd: RDD[(K, (Seq[V], Seq[W]))])(implicit kcm: ClassTag[K], + vcm: ClassTag[V]): RDD[(K, (JList[V], JList[W]))] = rddToPairRDDFunctions(rdd).mapValues((x: (Seq[V], + Seq[W])) => (seqAsJavaList(x._1), seqAsJavaList(x._2))) + + def cogroupResult2ToJava[W1, W2, K, V](rdd: RDD[(K, (Seq[V], Seq[W1], + Seq[W2]))])(implicit kcm: ClassTag[K]) : RDD[(K, (JList[V], JList[W1], + JList[W2]))] = rddToPairRDDFunctions(rdd).mapValues( + (x: (Seq[V], Seq[W1], Seq[W2])) => (seqAsJavaList(x._1), + seqAsJavaList(x._2), + seqAsJavaList(x._3))) + + def fromRDD[K: ClassTag, V: ClassTag](rdd: RDD[(K, V)]): JavaPairRDD[K, V] = + new JavaPairRDD[K, V](rdd) + + implicit def toRDD[K, V](rdd: JavaPairRDD[K, V]): RDD[(K, V)] = rdd.rdd +} diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala new file mode 100644 index 0000000000..9968bc8e5f --- /dev/null +++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala @@ -0,0 +1,116 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.java + +import scala.reflect.ClassTag + +import org.apache.spark._ +import org.apache.spark.rdd.RDD +import org.apache.spark.api.java.function.{Function => JFunction} +import org.apache.spark.storage.StorageLevel + +class JavaRDD[T](val rdd: RDD[T])(implicit val classTag: ClassTag[T]) extends +JavaRDDLike[T, JavaRDD[T]] { + + override def wrapRDD(rdd: RDD[T]): JavaRDD[T] = JavaRDD.fromRDD(rdd) + + // Common RDD functions + + /** Persist this RDD with the default storage level (`MEMORY_ONLY`). */ + def cache(): JavaRDD[T] = wrapRDD(rdd.cache()) + + /** + * Set this RDD's storage level to persist its values across operations after the first time + * it is computed. This can only be used to assign a new storage level if the RDD does not + * have a storage level set yet.. + */ + def persist(newLevel: StorageLevel): JavaRDD[T] = wrapRDD(rdd.persist(newLevel)) + + /** + * Mark the RDD as non-persistent, and remove all blocks for it from memory and disk. + */ + def unpersist(): JavaRDD[T] = wrapRDD(rdd.unpersist()) + + // Transformations (return a new RDD) + + /** + * Return a new RDD containing the distinct elements in this RDD. + */ + def distinct(): JavaRDD[T] = wrapRDD(rdd.distinct()) + + /** + * Return a new RDD containing the distinct elements in this RDD. + */ + def distinct(numPartitions: Int): JavaRDD[T] = wrapRDD(rdd.distinct(numPartitions)) + + /** + * Return a new RDD containing only the elements that satisfy a predicate. + */ + def filter(f: JFunction[T, java.lang.Boolean]): JavaRDD[T] = + wrapRDD(rdd.filter((x => f(x).booleanValue()))) + + /** + * Return a new RDD that is reduced into `numPartitions` partitions. + */ + def coalesce(numPartitions: Int): JavaRDD[T] = rdd.coalesce(numPartitions) + + /** + * Return a new RDD that is reduced into `numPartitions` partitions. + */ + def coalesce(numPartitions: Int, shuffle: Boolean): JavaRDD[T] = + rdd.coalesce(numPartitions, shuffle) + + /** + * Return a sampled subset of this RDD. + */ + def sample(withReplacement: Boolean, fraction: Double, seed: Int): JavaRDD[T] = + wrapRDD(rdd.sample(withReplacement, fraction, seed)) + + /** + * Return the union of this RDD and another one. Any identical elements will appear multiple + * times (use `.distinct()` to eliminate them). + */ + def union(other: JavaRDD[T]): JavaRDD[T] = wrapRDD(rdd.union(other.rdd)) + + /** + * Return an RDD with the elements from `this` that are not in `other`. + * + * Uses `this` partitioner/partition size, because even if `other` is huge, the resulting + * RDD will be <= us. + */ + def subtract(other: JavaRDD[T]): JavaRDD[T] = wrapRDD(rdd.subtract(other)) + + /** + * Return an RDD with the elements from `this` that are not in `other`. + */ + def subtract(other: JavaRDD[T], numPartitions: Int): JavaRDD[T] = + wrapRDD(rdd.subtract(other, numPartitions)) + + /** + * Return an RDD with the elements from `this` that are not in `other`. + */ + def subtract(other: JavaRDD[T], p: Partitioner): JavaRDD[T] = + wrapRDD(rdd.subtract(other, p)) +} + +object JavaRDD { + + implicit def fromRDD[T: ClassTag](rdd: RDD[T]): JavaRDD[T] = new JavaRDD[T](rdd) + + implicit def toRDD[T](rdd: JavaRDD[T]): RDD[T] = rdd.rdd +} diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala new file mode 100644 index 0000000000..feb2cab578 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala @@ -0,0 +1,429 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.java + +import java.util.{List => JList, Comparator} +import scala.Tuple2 +import scala.collection.JavaConversions._ +import scala.reflect.ClassTag + +import com.google.common.base.Optional +import org.apache.hadoop.io.compress.CompressionCodec + +import org.apache.spark.{SparkContext, Partition, TaskContext} +import org.apache.spark.rdd.RDD +import org.apache.spark.api.java.JavaPairRDD._ +import org.apache.spark.api.java.function.{Function2 => JFunction2, Function => JFunction, _} +import org.apache.spark.partial.{PartialResult, BoundedDouble} +import org.apache.spark.storage.StorageLevel + + +trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { + def wrapRDD(rdd: RDD[T]): This + + implicit val classTag: ClassTag[T] + + def rdd: RDD[T] + + /** Set of partitions in this RDD. */ + def splits: JList[Partition] = new java.util.ArrayList(rdd.partitions.toSeq) + + /** The [[org.apache.spark.SparkContext]] that this RDD was created on. */ + def context: SparkContext = rdd.context + + /** A unique ID for this RDD (within its SparkContext). */ + def id: Int = rdd.id + + /** Get the RDD's current storage level, or StorageLevel.NONE if none is set. */ + def getStorageLevel: StorageLevel = rdd.getStorageLevel + + /** + * Internal method to this RDD; will read from cache if applicable, or otherwise compute it. + * This should ''not'' be called by users directly, but is available for implementors of custom + * subclasses of RDD. + */ + def iterator(split: Partition, taskContext: TaskContext): java.util.Iterator[T] = + asJavaIterator(rdd.iterator(split, taskContext)) + + // Transformations (return a new RDD) + + /** + * Return a new RDD by applying a function to all elements of this RDD. + */ + def map[R](f: JFunction[T, R]): JavaRDD[R] = + new JavaRDD(rdd.map(f)(f.returnType()))(f.returnType()) + + /** + * Return a new RDD by applying a function to all elements of this RDD. + */ + def map[R](f: DoubleFunction[T]): JavaDoubleRDD = + new JavaDoubleRDD(rdd.map(x => f(x).doubleValue())) + + /** + * Return a new RDD by applying a function to all elements of this RDD. + */ + def map[K2, V2](f: PairFunction[T, K2, V2]): JavaPairRDD[K2, V2] = { + def cm = implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[Tuple2[K2, V2]]] + new JavaPairRDD(rdd.map(f)(cm))(f.keyType(), f.valueType()) + } + + /** + * Return a new RDD by first applying a function to all elements of this + * RDD, and then flattening the results. + */ + def flatMap[U](f: FlatMapFunction[T, U]): JavaRDD[U] = { + import scala.collection.JavaConverters._ + def fn = (x: T) => f.apply(x).asScala + JavaRDD.fromRDD(rdd.flatMap(fn)(f.elementType()))(f.elementType()) + } + + /** + * Return a new RDD by first applying a function to all elements of this + * RDD, and then flattening the results. + */ + def flatMap(f: DoubleFlatMapFunction[T]): JavaDoubleRDD = { + import scala.collection.JavaConverters._ + def fn = (x: T) => f.apply(x).asScala + new JavaDoubleRDD(rdd.flatMap(fn).map((x: java.lang.Double) => x.doubleValue())) + } + + /** + * Return a new RDD by first applying a function to all elements of this + * RDD, and then flattening the results. + */ + def flatMap[K2, V2](f: PairFlatMapFunction[T, K2, V2]): JavaPairRDD[K2, V2] = { + import scala.collection.JavaConverters._ + def fn = (x: T) => f.apply(x).asScala + def cm = implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[Tuple2[K2, V2]]] + JavaPairRDD.fromRDD(rdd.flatMap(fn)(cm))(f.keyType(), f.valueType()) + } + + /** + * Return a new RDD by applying a function to each partition of this RDD. + */ + def mapPartitions[U](f: FlatMapFunction[java.util.Iterator[T], U]): JavaRDD[U] = { + def fn = (x: Iterator[T]) => asScalaIterator(f.apply(asJavaIterator(x)).iterator()) + JavaRDD.fromRDD(rdd.mapPartitions(fn)(f.elementType()))(f.elementType()) + } + + /** + * Return a new RDD by applying a function to each partition of this RDD. + */ + def mapPartitions(f: DoubleFlatMapFunction[java.util.Iterator[T]]): JavaDoubleRDD = { + def fn = (x: Iterator[T]) => asScalaIterator(f.apply(asJavaIterator(x)).iterator()) + new JavaDoubleRDD(rdd.mapPartitions(fn).map((x: java.lang.Double) => x.doubleValue())) + } + + /** + * Return a new RDD by applying a function to each partition of this RDD. + */ + def mapPartitions[K2, V2](f: PairFlatMapFunction[java.util.Iterator[T], K2, V2]): + JavaPairRDD[K2, V2] = { + def fn = (x: Iterator[T]) => asScalaIterator(f.apply(asJavaIterator(x)).iterator()) + JavaPairRDD.fromRDD(rdd.mapPartitions(fn))(f.keyType(), f.valueType()) + } + + /** + * Return an RDD created by coalescing all elements within each partition into an array. + */ + def glom(): JavaRDD[JList[T]] = + new JavaRDD(rdd.glom().map(x => new java.util.ArrayList[T](x.toSeq))) + + /** + * Return the Cartesian product of this RDD and another one, that is, the RDD of all pairs of + * elements (a, b) where a is in `this` and b is in `other`. + */ + def cartesian[U](other: JavaRDDLike[U, _]): JavaPairRDD[T, U] = + JavaPairRDD.fromRDD(rdd.cartesian(other.rdd)(other.classTag))(classTag, + other.classTag) + + /** + * Return an RDD of grouped elements. Each group consists of a key and a sequence of elements + * mapping to that key. + */ + def groupBy[K](f: JFunction[T, K]): JavaPairRDD[K, JList[T]] = { + implicit val kcm: ClassTag[K] = + implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[K]] + implicit val vcm: ClassTag[JList[T]] = + implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[JList[T]]] + JavaPairRDD.fromRDD(groupByResultToJava(rdd.groupBy(f)(f.returnType)))(kcm, vcm) + } + + /** + * Return an RDD of grouped elements. Each group consists of a key and a sequence of elements + * mapping to that key. + */ + def groupBy[K](f: JFunction[T, K], numPartitions: Int): JavaPairRDD[K, JList[T]] = { + implicit val kcm: ClassTag[K] = + implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[K]] + implicit val vcm: ClassTag[JList[T]] = + implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[JList[T]]] + JavaPairRDD.fromRDD(groupByResultToJava(rdd.groupBy(f, numPartitions)(f.returnType)))(kcm, vcm) + } + + /** + * Return an RDD created by piping elements to a forked external process. + */ + def pipe(command: String): JavaRDD[String] = rdd.pipe(command) + + /** + * Return an RDD created by piping elements to a forked external process. + */ + def pipe(command: JList[String]): JavaRDD[String] = + rdd.pipe(asScalaBuffer(command)) + + /** + * Return an RDD created by piping elements to a forked external process. + */ + def pipe(command: JList[String], env: java.util.Map[String, String]): JavaRDD[String] = + rdd.pipe(asScalaBuffer(command), mapAsScalaMap(env)) + + /** + * Zips this RDD with another one, returning key-value pairs with the first element in each RDD, + * second element in each RDD, etc. Assumes that the two RDDs have the *same number of + * partitions* and the *same number of elements in each partition* (e.g. one was made through + * a map on the other). + */ + def zip[U](other: JavaRDDLike[U, _]): JavaPairRDD[T, U] = { + JavaPairRDD.fromRDD(rdd.zip(other.rdd)(other.classTag))(classTag, other.classTag) + } + + /** + * Zip this RDD's partitions with one (or more) RDD(s) and return a new RDD by + * applying a function to the zipped partitions. Assumes that all the RDDs have the + * *same number of partitions*, but does *not* require them to have the same number + * of elements in each partition. + */ + def zipPartitions[U, V]( + other: JavaRDDLike[U, _], + f: FlatMapFunction2[java.util.Iterator[T], java.util.Iterator[U], V]): JavaRDD[V] = { + def fn = (x: Iterator[T], y: Iterator[U]) => asScalaIterator( + f.apply(asJavaIterator(x), asJavaIterator(y)).iterator()) + JavaRDD.fromRDD( + rdd.zipPartitions(other.rdd)(fn)(other.classTag, f.elementType()))(f.elementType()) + } + + // Actions (launch a job to return a value to the user program) + + /** + * Applies a function f to all elements of this RDD. + */ + def foreach(f: VoidFunction[T]) { + val cleanF = rdd.context.clean(f) + rdd.foreach(cleanF) + } + + /** + * Return an array that contains all of the elements in this RDD. + */ + def collect(): JList[T] = { + import scala.collection.JavaConversions._ + val arr: java.util.Collection[T] = rdd.collect().toSeq + new java.util.ArrayList(arr) + } + + /** + * Reduces the elements of this RDD using the specified commutative and associative binary operator. + */ + def reduce(f: JFunction2[T, T, T]): T = rdd.reduce(f) + + /** + * Aggregate the elements of each partition, and then the results for all the partitions, using a + * given associative function and a neutral "zero value". The function op(t1, t2) is allowed to + * modify t1 and return it as its result value to avoid object allocation; however, it should not + * modify t2. + */ + def fold(zeroValue: T)(f: JFunction2[T, T, T]): T = + rdd.fold(zeroValue)(f) + + /** + * Aggregate the elements of each partition, and then the results for all the partitions, using + * given combine functions and a neutral "zero value". This function can return a different result + * type, U, than the type of this RDD, T. Thus, we need one operation for merging a T into an U + * and one operation for merging two U's, as in scala.TraversableOnce. Both of these functions are + * allowed to modify and return their first argument instead of creating a new U to avoid memory + * allocation. + */ + def aggregate[U](zeroValue: U)(seqOp: JFunction2[U, T, U], + combOp: JFunction2[U, U, U]): U = + rdd.aggregate(zeroValue)(seqOp, combOp)(seqOp.returnType) + + /** + * Return the number of elements in the RDD. + */ + def count(): Long = rdd.count() + + /** + * (Experimental) Approximate version of count() that returns a potentially incomplete result + * within a timeout, even if not all tasks have finished. + */ + def countApprox(timeout: Long, confidence: Double): PartialResult[BoundedDouble] = + rdd.countApprox(timeout, confidence) + + /** + * (Experimental) Approximate version of count() that returns a potentially incomplete result + * within a timeout, even if not all tasks have finished. + */ + def countApprox(timeout: Long): PartialResult[BoundedDouble] = + rdd.countApprox(timeout) + + /** + * Return the count of each unique value in this RDD as a map of (value, count) pairs. The final + * combine step happens locally on the master, equivalent to running a single reduce task. + */ + def countByValue(): java.util.Map[T, java.lang.Long] = + mapAsJavaMap(rdd.countByValue().map((x => (x._1, new java.lang.Long(x._2))))) + + /** + * (Experimental) Approximate version of countByValue(). + */ + def countByValueApprox( + timeout: Long, + confidence: Double + ): PartialResult[java.util.Map[T, BoundedDouble]] = + rdd.countByValueApprox(timeout, confidence).map(mapAsJavaMap) + + /** + * (Experimental) Approximate version of countByValue(). + */ + def countByValueApprox(timeout: Long): PartialResult[java.util.Map[T, BoundedDouble]] = + rdd.countByValueApprox(timeout).map(mapAsJavaMap) + + /** + * Take the first num elements of the RDD. This currently scans the partitions *one by one*, so + * it will be slow if a lot of partitions are required. In that case, use collect() to get the + * whole RDD instead. + */ + def take(num: Int): JList[T] = { + import scala.collection.JavaConversions._ + val arr: java.util.Collection[T] = rdd.take(num).toSeq + new java.util.ArrayList(arr) + } + + def takeSample(withReplacement: Boolean, num: Int, seed: Int): JList[T] = { + import scala.collection.JavaConversions._ + val arr: java.util.Collection[T] = rdd.takeSample(withReplacement, num, seed).toSeq + new java.util.ArrayList(arr) + } + + /** + * Return the first element in this RDD. + */ + def first(): T = rdd.first() + + /** + * Save this RDD as a text file, using string representations of elements. + */ + def saveAsTextFile(path: String) = rdd.saveAsTextFile(path) + + + /** + * Save this RDD as a compressed text file, using string representations of elements. + */ + def saveAsTextFile(path: String, codec: Class[_ <: CompressionCodec]) = + rdd.saveAsTextFile(path, codec) + + /** + * Save this RDD as a SequenceFile of serialized objects. + */ + def saveAsObjectFile(path: String) = rdd.saveAsObjectFile(path) + + /** + * Creates tuples of the elements in this RDD by applying `f`. + */ + def keyBy[K](f: JFunction[T, K]): JavaPairRDD[K, T] = { + implicit val kcm: ClassTag[K] = implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[K]] + JavaPairRDD.fromRDD(rdd.keyBy(f)) + } + + /** + * Mark this RDD for checkpointing. It will be saved to a file inside the checkpoint + * directory set with SparkContext.setCheckpointDir() and all references to its parent + * RDDs will be removed. This function must be called before any job has been + * executed on this RDD. It is strongly recommended that this RDD is persisted in + * memory, otherwise saving it on a file will require recomputation. + */ + def checkpoint() = rdd.checkpoint() + + /** + * Return whether this RDD has been checkpointed or not + */ + def isCheckpointed: Boolean = rdd.isCheckpointed + + /** + * Gets the name of the file to which this RDD was checkpointed + */ + def getCheckpointFile(): Optional[String] = { + JavaUtils.optionToOptional(rdd.getCheckpointFile) + } + + /** A description of this RDD and its recursive dependencies for debugging. */ + def toDebugString(): String = { + rdd.toDebugString + } + + /** + * Returns the top K elements from this RDD as defined by + * the specified Comparator[T]. + * @param num the number of top elements to return + * @param comp the comparator that defines the order + * @return an array of top elements + */ + def top(num: Int, comp: Comparator[T]): JList[T] = { + import scala.collection.JavaConversions._ + val topElems = rdd.top(num)(Ordering.comparatorToOrdering(comp)) + val arr: java.util.Collection[T] = topElems.toSeq + new java.util.ArrayList(arr) + } + + /** + * Returns the top K elements from this RDD using the + * natural ordering for T. + * @param num the number of top elements to return + * @return an array of top elements + */ + def top(num: Int): JList[T] = { + val comp = com.google.common.collect.Ordering.natural().asInstanceOf[Comparator[T]] + top(num, comp) + } + + /** + * Returns the first K elements from this RDD as defined by + * the specified Comparator[T] and maintains the order. + * @param num the number of top elements to return + * @param comp the comparator that defines the order + * @return an array of top elements + */ + def takeOrdered(num: Int, comp: Comparator[T]): JList[T] = { + import scala.collection.JavaConversions._ + val topElems = rdd.takeOrdered(num)(Ordering.comparatorToOrdering(comp)) + val arr: java.util.Collection[T] = topElems.toSeq + new java.util.ArrayList(arr) + } + + /** + * Returns the first K elements from this RDD using the + * natural ordering for T while maintain the order. + * @param num the number of top elements to return + * @return an array of top elements + */ + def takeOrdered(num: Int): JList[T] = { + val comp = com.google.common.collect.Ordering.natural().asInstanceOf[Comparator[T]] + takeOrdered(num, comp) + } +} diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala new file mode 100644 index 0000000000..d5b77357a2 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala @@ -0,0 +1,419 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.java + +import java.util.{Map => JMap} + +import scala.collection.JavaConversions +import scala.collection.JavaConversions._ +import scala.reflect.ClassTag + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.mapred.InputFormat +import org.apache.hadoop.mapred.JobConf +import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat} +import com.google.common.base.Optional + +import org.apache.spark.{Accumulable, AccumulableParam, Accumulator, AccumulatorParam, SparkContext} +import org.apache.spark.SparkContext.IntAccumulatorParam +import org.apache.spark.SparkContext.DoubleAccumulatorParam +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.rdd.RDD + +/** + * A Java-friendly version of [[org.apache.spark.SparkContext]] that returns [[org.apache.spark.api.java.JavaRDD]]s and + * works with Java collections instead of Scala ones. + */ +class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWorkaround { + + /** + * @param master Cluster URL to connect to (e.g. mesos://host:port, spark://host:port, local[4]). + * @param appName A name for your application, to display on the cluster web UI + */ + def this(master: String, appName: String) = this(new SparkContext(master, appName)) + + /** + * @param master Cluster URL to connect to (e.g. mesos://host:port, spark://host:port, local[4]). + * @param appName A name for your application, to display on the cluster web UI + * @param sparkHome The SPARK_HOME directory on the slave nodes + * @param jarFile JAR file to send to the cluster. This can be a path on the local file system + * or an HDFS, HTTP, HTTPS, or FTP URL. + */ + def this(master: String, appName: String, sparkHome: String, jarFile: String) = + this(new SparkContext(master, appName, sparkHome, Seq(jarFile))) + + /** + * @param master Cluster URL to connect to (e.g. mesos://host:port, spark://host:port, local[4]). + * @param appName A name for your application, to display on the cluster web UI + * @param sparkHome The SPARK_HOME directory on the slave nodes + * @param jars Collection of JARs to send to the cluster. These can be paths on the local file + * system or HDFS, HTTP, HTTPS, or FTP URLs. + */ + def this(master: String, appName: String, sparkHome: String, jars: Array[String]) = + this(new SparkContext(master, appName, sparkHome, jars.toSeq)) + + /** + * @param master Cluster URL to connect to (e.g. mesos://host:port, spark://host:port, local[4]). + * @param appName A name for your application, to display on the cluster web UI + * @param sparkHome The SPARK_HOME directory on the slave nodes + * @param jars Collection of JARs to send to the cluster. These can be paths on the local file + * system or HDFS, HTTP, HTTPS, or FTP URLs. + * @param environment Environment variables to set on worker nodes + */ + def this(master: String, appName: String, sparkHome: String, jars: Array[String], + environment: JMap[String, String]) = + this(new SparkContext(master, appName, sparkHome, jars.toSeq, environment)) + + private[spark] val env = sc.env + + /** Distribute a local Scala collection to form an RDD. */ + def parallelize[T](list: java.util.List[T], numSlices: Int): JavaRDD[T] = { + implicit val cm: ClassTag[T] = + implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[T]] + sc.parallelize(JavaConversions.asScalaBuffer(list), numSlices) + } + + /** Distribute a local Scala collection to form an RDD. */ + def parallelize[T](list: java.util.List[T]): JavaRDD[T] = + parallelize(list, sc.defaultParallelism) + + /** Distribute a local Scala collection to form an RDD. */ + def parallelizePairs[K, V](list: java.util.List[Tuple2[K, V]], numSlices: Int) + : JavaPairRDD[K, V] = { + implicit val kcm: ClassTag[K] = + implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[K]] + implicit val vcm: ClassTag[V] = + implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[V]] + JavaPairRDD.fromRDD(sc.parallelize(JavaConversions.asScalaBuffer(list), numSlices)) + } + + /** Distribute a local Scala collection to form an RDD. */ + def parallelizePairs[K, V](list: java.util.List[Tuple2[K, V]]): JavaPairRDD[K, V] = + parallelizePairs(list, sc.defaultParallelism) + + /** Distribute a local Scala collection to form an RDD. */ + def parallelizeDoubles(list: java.util.List[java.lang.Double], numSlices: Int): JavaDoubleRDD = + JavaDoubleRDD.fromRDD(sc.parallelize(JavaConversions.asScalaBuffer(list).map(_.doubleValue()), + numSlices)) + + /** Distribute a local Scala collection to form an RDD. */ + def parallelizeDoubles(list: java.util.List[java.lang.Double]): JavaDoubleRDD = + parallelizeDoubles(list, sc.defaultParallelism) + + /** + * Read a text file from HDFS, a local file system (available on all nodes), or any + * Hadoop-supported file system URI, and return it as an RDD of Strings. + */ + def textFile(path: String): JavaRDD[String] = sc.textFile(path) + + /** + * Read a text file from HDFS, a local file system (available on all nodes), or any + * Hadoop-supported file system URI, and return it as an RDD of Strings. + */ + def textFile(path: String, minSplits: Int): JavaRDD[String] = sc.textFile(path, minSplits) + + /**Get an RDD for a Hadoop SequenceFile with given key and value types. */ + def sequenceFile[K, V](path: String, + keyClass: Class[K], + valueClass: Class[V], + minSplits: Int + ): JavaPairRDD[K, V] = { + implicit val kcm: ClassTag[K] = ClassTag(keyClass) + implicit val vcm: ClassTag[V] = ClassTag(valueClass) + new JavaPairRDD(sc.sequenceFile(path, keyClass, valueClass, minSplits)) + } + + /**Get an RDD for a Hadoop SequenceFile. */ + def sequenceFile[K, V](path: String, keyClass: Class[K], valueClass: Class[V]): + JavaPairRDD[K, V] = { + implicit val kcm: ClassTag[K] = ClassTag(keyClass) + implicit val vcm: ClassTag[V] = ClassTag(valueClass) + new JavaPairRDD(sc.sequenceFile(path, keyClass, valueClass)) + } + + /** + * Load an RDD saved as a SequenceFile containing serialized objects, with NullWritable keys and + * BytesWritable values that contain a serialized partition. This is still an experimental storage + * format and may not be supported exactly as is in future Spark releases. It will also be pretty + * slow if you use the default serializer (Java serialization), though the nice thing about it is + * that there's very little effort required to save arbitrary objects. + */ + def objectFile[T](path: String, minSplits: Int): JavaRDD[T] = { + implicit val cm: ClassTag[T] = + implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[T]] + sc.objectFile(path, minSplits)(cm) + } + + /** + * Load an RDD saved as a SequenceFile containing serialized objects, with NullWritable keys and + * BytesWritable values that contain a serialized partition. This is still an experimental storage + * format and may not be supported exactly as is in future Spark releases. It will also be pretty + * slow if you use the default serializer (Java serialization), though the nice thing about it is + * that there's very little effort required to save arbitrary objects. + */ + def objectFile[T](path: String): JavaRDD[T] = { + implicit val cm: ClassTag[T] = + implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[T]] + sc.objectFile(path)(cm) + } + + /** + * Get an RDD for a Hadoop-readable dataset from a Hadooop JobConf giving its InputFormat and any + * other necessary info (e.g. file name for a filesystem-based dataset, table name for HyperTable, + * etc). + */ + def hadoopRDD[K, V, F <: InputFormat[K, V]]( + conf: JobConf, + inputFormatClass: Class[F], + keyClass: Class[K], + valueClass: Class[V], + minSplits: Int + ): JavaPairRDD[K, V] = { + implicit val kcm: ClassTag[K] = ClassTag(keyClass) + implicit val vcm: ClassTag[V] = ClassTag(valueClass) + new JavaPairRDD(sc.hadoopRDD(conf, inputFormatClass, keyClass, valueClass, minSplits)) + } + + /** + * Get an RDD for a Hadoop-readable dataset from a Hadooop JobConf giving its InputFormat and any + * other necessary info (e.g. file name for a filesystem-based dataset, table name for HyperTable, + * etc). + */ + def hadoopRDD[K, V, F <: InputFormat[K, V]]( + conf: JobConf, + inputFormatClass: Class[F], + keyClass: Class[K], + valueClass: Class[V] + ): JavaPairRDD[K, V] = { + implicit val kcm: ClassTag[K] = ClassTag(keyClass) + implicit val vcm: ClassTag[V] = ClassTag(valueClass) + new JavaPairRDD(sc.hadoopRDD(conf, inputFormatClass, keyClass, valueClass)) + } + + /** Get an RDD for a Hadoop file with an arbitrary InputFormat */ + def hadoopFile[K, V, F <: InputFormat[K, V]]( + path: String, + inputFormatClass: Class[F], + keyClass: Class[K], + valueClass: Class[V], + minSplits: Int + ): JavaPairRDD[K, V] = { + implicit val kcm: ClassTag[K] = ClassTag(keyClass) + implicit val vcm: ClassTag[V] = ClassTag(valueClass) + new JavaPairRDD(sc.hadoopFile(path, inputFormatClass, keyClass, valueClass, minSplits)) + } + + /** Get an RDD for a Hadoop file with an arbitrary InputFormat */ + def hadoopFile[K, V, F <: InputFormat[K, V]]( + path: String, + inputFormatClass: Class[F], + keyClass: Class[K], + valueClass: Class[V] + ): JavaPairRDD[K, V] = { + implicit val kcm: ClassTag[K] = ClassTag(keyClass) + implicit val vcm: ClassTag[V] = ClassTag(valueClass) + new JavaPairRDD(sc.hadoopFile(path, + inputFormatClass, keyClass, valueClass)) + } + + /** + * Get an RDD for a given Hadoop file with an arbitrary new API InputFormat + * and extra configuration options to pass to the input format. + */ + def newAPIHadoopFile[K, V, F <: NewInputFormat[K, V]]( + path: String, + fClass: Class[F], + kClass: Class[K], + vClass: Class[V], + conf: Configuration): JavaPairRDD[K, V] = { + implicit val kcm: ClassTag[K] = ClassTag(kClass) + implicit val vcm: ClassTag[V] = ClassTag(vClass) + new JavaPairRDD(sc.newAPIHadoopFile(path, fClass, kClass, vClass, conf)) + } + + /** + * Get an RDD for a given Hadoop file with an arbitrary new API InputFormat + * and extra configuration options to pass to the input format. + */ + def newAPIHadoopRDD[K, V, F <: NewInputFormat[K, V]]( + conf: Configuration, + fClass: Class[F], + kClass: Class[K], + vClass: Class[V]): JavaPairRDD[K, V] = { + implicit val kcm: ClassTag[K] = ClassTag(kClass) + implicit val vcm: ClassTag[V] = ClassTag(vClass) + new JavaPairRDD(sc.newAPIHadoopRDD(conf, fClass, kClass, vClass)) + } + + /** Build the union of two or more RDDs. */ + override def union[T](first: JavaRDD[T], rest: java.util.List[JavaRDD[T]]): JavaRDD[T] = { + val rdds: Seq[RDD[T]] = (Seq(first) ++ asScalaBuffer(rest)).map(_.rdd) + implicit val cm: ClassTag[T] = first.classTag + sc.union(rdds)(cm) + } + + /** Build the union of two or more RDDs. */ + override def union[K, V](first: JavaPairRDD[K, V], rest: java.util.List[JavaPairRDD[K, V]]) + : JavaPairRDD[K, V] = { + val rdds: Seq[RDD[(K, V)]] = (Seq(first) ++ asScalaBuffer(rest)).map(_.rdd) + implicit val cm: ClassTag[(K, V)] = first.classTag + implicit val kcm: ClassTag[K] = first.kClassTag + implicit val vcm: ClassTag[V] = first.vClassTag + new JavaPairRDD(sc.union(rdds)(cm))(kcm, vcm) + } + + /** Build the union of two or more RDDs. */ + override def union(first: JavaDoubleRDD, rest: java.util.List[JavaDoubleRDD]): JavaDoubleRDD = { + val rdds: Seq[RDD[Double]] = (Seq(first) ++ asScalaBuffer(rest)).map(_.srdd) + new JavaDoubleRDD(sc.union(rdds)) + } + + /** + * Create an [[org.apache.spark.Accumulator]] integer variable, which tasks can "add" values + * to using the `add` method. Only the master can access the accumulator's `value`. + */ + def intAccumulator(initialValue: Int): Accumulator[java.lang.Integer] = + sc.accumulator(initialValue)(IntAccumulatorParam).asInstanceOf[Accumulator[java.lang.Integer]] + + /** + * Create an [[org.apache.spark.Accumulator]] double variable, which tasks can "add" values + * to using the `add` method. Only the master can access the accumulator's `value`. + */ + def doubleAccumulator(initialValue: Double): Accumulator[java.lang.Double] = + sc.accumulator(initialValue)(DoubleAccumulatorParam).asInstanceOf[Accumulator[java.lang.Double]] + + /** + * Create an [[org.apache.spark.Accumulator]] integer variable, which tasks can "add" values + * to using the `add` method. Only the master can access the accumulator's `value`. + */ + def accumulator(initialValue: Int): Accumulator[java.lang.Integer] = intAccumulator(initialValue) + + /** + * Create an [[org.apache.spark.Accumulator]] double variable, which tasks can "add" values + * to using the `add` method. Only the master can access the accumulator's `value`. + */ + def accumulator(initialValue: Double): Accumulator[java.lang.Double] = + doubleAccumulator(initialValue) + + /** + * Create an [[org.apache.spark.Accumulator]] variable of a given type, which tasks can "add" values + * to using the `add` method. Only the master can access the accumulator's `value`. + */ + def accumulator[T](initialValue: T, accumulatorParam: AccumulatorParam[T]): Accumulator[T] = + sc.accumulator(initialValue)(accumulatorParam) + + /** + * Create an [[org.apache.spark.Accumulable]] shared variable of the given type, to which tasks can + * "add" values with `add`. Only the master can access the accumuable's `value`. + */ + def accumulable[T, R](initialValue: T, param: AccumulableParam[T, R]): Accumulable[T, R] = + sc.accumulable(initialValue)(param) + + /** + * Broadcast a read-only variable to the cluster, returning a [[org.apache.spark.Broadcast]] object for + * reading it in distributed functions. The variable will be sent to each cluster only once. + */ + def broadcast[T](value: T): Broadcast[T] = sc.broadcast(value) + + /** Shut down the SparkContext. */ + def stop() { + sc.stop() + } + + /** + * Get Spark's home location from either a value set through the constructor, + * or the spark.home Java property, or the SPARK_HOME environment variable + * (in that order of preference). If neither of these is set, return None. + */ + def getSparkHome(): Optional[String] = JavaUtils.optionToOptional(sc.getSparkHome()) + + /** + * Add a file to be downloaded with this Spark job on every node. + * The `path` passed can be either a local file, a file in HDFS (or other Hadoop-supported + * filesystems), or an HTTP, HTTPS or FTP URI. To access the file in Spark jobs, + * use `SparkFiles.get(path)` to find its download location. + */ + def addFile(path: String) { + sc.addFile(path) + } + + /** + * Adds a JAR dependency for all tasks to be executed on this SparkContext in the future. + * The `path` passed can be either a local file, a file in HDFS (or other Hadoop-supported + * filesystems), or an HTTP, HTTPS or FTP URI. + */ + def addJar(path: String) { + sc.addJar(path) + } + + /** + * Clear the job's list of JARs added by `addJar` so that they do not get downloaded to + * any new nodes. + */ + def clearJars() { + sc.clearJars() + } + + /** + * Clear the job's list of files added by `addFile` so that they do not get downloaded to + * any new nodes. + */ + def clearFiles() { + sc.clearFiles() + } + + /** + * Returns the Hadoop configuration used for the Hadoop code (e.g. file systems) we reuse. + */ + def hadoopConfiguration(): Configuration = { + sc.hadoopConfiguration + } + + /** + * Set the directory under which RDDs are going to be checkpointed. The directory must + * be a HDFS path if running on a cluster. If the directory does not exist, it will + * be created. If the directory exists and useExisting is set to true, then the + * exisiting directory will be used. Otherwise an exception will be thrown to + * prevent accidental overriding of checkpoint files in the existing directory. + */ + def setCheckpointDir(dir: String, useExisting: Boolean) { + sc.setCheckpointDir(dir, useExisting) + } + + /** + * Set the directory under which RDDs are going to be checkpointed. The directory must + * be a HDFS path if running on a cluster. If the directory does not exist, it will + * be created. If the directory exists, an exception will be thrown to prevent accidental + * overriding of checkpoint files. + */ + def setCheckpointDir(dir: String) { + sc.setCheckpointDir(dir) + } + + protected def checkpointFile[T](path: String): JavaRDD[T] = { + implicit val cm: ClassTag[T] = + implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[T]] + new JavaRDD(sc.checkpointFile(path)) + } +} + +object JavaSparkContext { + implicit def fromSparkContext(sc: SparkContext): JavaSparkContext = new JavaSparkContext(sc) + + implicit def toSparkContext(jsc: JavaSparkContext): SparkContext = jsc.sc +} diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContextVarargsWorkaround.java b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContextVarargsWorkaround.java new file mode 100644 index 0000000000..c9cbce5624 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContextVarargsWorkaround.java @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.java; + +import java.util.Arrays; +import java.util.ArrayList; +import java.util.List; + +// See +// http://scala-programming-language.1934581.n4.nabble.com/Workaround-for-implementing-java-varargs-in-2-7-2-final-tp1944767p1944772.html +abstract class JavaSparkContextVarargsWorkaround { + public <T> JavaRDD<T> union(JavaRDD<T>... rdds) { + if (rdds.length == 0) { + throw new IllegalArgumentException("Union called on empty list"); + } + ArrayList<JavaRDD<T>> rest = new ArrayList<JavaRDD<T>>(rdds.length - 1); + for (int i = 1; i < rdds.length; i++) { + rest.add(rdds[i]); + } + return union(rdds[0], rest); + } + + public JavaDoubleRDD union(JavaDoubleRDD... rdds) { + if (rdds.length == 0) { + throw new IllegalArgumentException("Union called on empty list"); + } + ArrayList<JavaDoubleRDD> rest = new ArrayList<JavaDoubleRDD>(rdds.length - 1); + for (int i = 1; i < rdds.length; i++) { + rest.add(rdds[i]); + } + return union(rdds[0], rest); + } + + public <K, V> JavaPairRDD<K, V> union(JavaPairRDD<K, V>... rdds) { + if (rdds.length == 0) { + throw new IllegalArgumentException("Union called on empty list"); + } + ArrayList<JavaPairRDD<K, V>> rest = new ArrayList<JavaPairRDD<K, V>>(rdds.length - 1); + for (int i = 1; i < rdds.length; i++) { + rest.add(rdds[i]); + } + return union(rdds[0], rest); + } + + // These methods take separate "first" and "rest" elements to avoid having the same type erasure + abstract public <T> JavaRDD<T> union(JavaRDD<T> first, List<JavaRDD<T>> rest); + abstract public JavaDoubleRDD union(JavaDoubleRDD first, List<JavaDoubleRDD> rest); + abstract public <K, V> JavaPairRDD<K, V> union(JavaPairRDD<K, V> first, List<JavaPairRDD<K, V>> rest); +} diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaUtils.scala b/core/src/main/scala/org/apache/spark/api/java/JavaUtils.scala new file mode 100644 index 0000000000..ecbf18849a --- /dev/null +++ b/core/src/main/scala/org/apache/spark/api/java/JavaUtils.scala @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.java + +import com.google.common.base.Optional + +object JavaUtils { + def optionToOptional[T](option: Option[T]): Optional[T] = + option match { + case Some(value) => Optional.of(value) + case None => Optional.absent() + } +} diff --git a/core/src/main/scala/org/apache/spark/api/java/StorageLevels.java b/core/src/main/scala/org/apache/spark/api/java/StorageLevels.java new file mode 100644 index 0000000000..0744269773 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/api/java/StorageLevels.java @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.java; + +import org.apache.spark.storage.StorageLevel; + +/** + * Expose some commonly useful storage level constants. + */ +public class StorageLevels { + public static final StorageLevel NONE = new StorageLevel(false, false, false, 1); + public static final StorageLevel DISK_ONLY = new StorageLevel(true, false, false, 1); + public static final StorageLevel DISK_ONLY_2 = new StorageLevel(true, false, false, 2); + public static final StorageLevel MEMORY_ONLY = new StorageLevel(false, true, true, 1); + public static final StorageLevel MEMORY_ONLY_2 = new StorageLevel(false, true, true, 2); + public static final StorageLevel MEMORY_ONLY_SER = new StorageLevel(false, true, false, 1); + public static final StorageLevel MEMORY_ONLY_SER_2 = new StorageLevel(false, true, false, 2); + public static final StorageLevel MEMORY_AND_DISK = new StorageLevel(true, true, true, 1); + public static final StorageLevel MEMORY_AND_DISK_2 = new StorageLevel(true, true, true, 2); + public static final StorageLevel MEMORY_AND_DISK_SER = new StorageLevel(true, true, false, 1); + public static final StorageLevel MEMORY_AND_DISK_SER_2 = new StorageLevel(true, true, false, 2); + + /** + * Create a new StorageLevel object. + * @param useDisk saved to disk, if true + * @param useMemory saved to memory, if true + * @param deserialized saved as deserialized objects, if true + * @param replication replication factor + */ + public static StorageLevel create(boolean useDisk, boolean useMemory, boolean deserialized, int replication) { + return StorageLevel.apply(useDisk, useMemory, deserialized, replication); + } +} diff --git a/core/src/main/scala/org/apache/spark/api/java/function/DoubleFlatMapFunction.java b/core/src/main/scala/org/apache/spark/api/java/function/DoubleFlatMapFunction.java new file mode 100644 index 0000000000..4830067f7a --- /dev/null +++ b/core/src/main/scala/org/apache/spark/api/java/function/DoubleFlatMapFunction.java @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.java.function; + + +import scala.runtime.AbstractFunction1; + +import java.io.Serializable; + +/** + * A function that returns zero or more records of type Double from each input record. + */ +// DoubleFlatMapFunction does not extend FlatMapFunction because flatMap is +// overloaded for both FlatMapFunction and DoubleFlatMapFunction. +public abstract class DoubleFlatMapFunction<T> extends AbstractFunction1<T, Iterable<Double>> + implements Serializable { + + public abstract Iterable<Double> call(T t); + + @Override + public final Iterable<Double> apply(T t) { return call(t); } +} diff --git a/core/src/main/scala/org/apache/spark/api/java/function/DoubleFunction.java b/core/src/main/scala/org/apache/spark/api/java/function/DoubleFunction.java new file mode 100644 index 0000000000..db34cd190a --- /dev/null +++ b/core/src/main/scala/org/apache/spark/api/java/function/DoubleFunction.java @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.java.function; + + +import scala.runtime.AbstractFunction1; + +import java.io.Serializable; + +/** + * A function that returns Doubles, and can be used to construct DoubleRDDs. + */ +// DoubleFunction does not extend Function because some UDF functions, like map, +// are overloaded for both Function and DoubleFunction. +public abstract class DoubleFunction<T> extends WrappedFunction1<T, Double> + implements Serializable { + + public abstract Double call(T t) throws Exception; +} diff --git a/core/src/main/scala/org/apache/spark/api/java/function/FlatMapFunction.scala b/core/src/main/scala/org/apache/spark/api/java/function/FlatMapFunction.scala new file mode 100644 index 0000000000..b7c0d78e33 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/api/java/function/FlatMapFunction.scala @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.java.function + +import scala.reflect.ClassTag + +/** + * A function that returns zero or more output records from each input record. + */ +abstract class FlatMapFunction[T, R] extends Function[T, java.lang.Iterable[R]] { + @throws(classOf[Exception]) + def call(x: T) : java.lang.Iterable[R] + + def elementType() : ClassTag[R] = ClassTag.Any.asInstanceOf[ClassTag[R]] +} diff --git a/core/src/main/scala/org/apache/spark/api/java/function/FlatMapFunction2.scala b/core/src/main/scala/org/apache/spark/api/java/function/FlatMapFunction2.scala new file mode 100644 index 0000000000..7a505df4be --- /dev/null +++ b/core/src/main/scala/org/apache/spark/api/java/function/FlatMapFunction2.scala @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.java.function + +import scala.reflect.ClassTag + +/** + * A function that takes two inputs and returns zero or more output records. + */ +abstract class FlatMapFunction2[A, B, C] extends Function2[A, B, java.lang.Iterable[C]] { + @throws(classOf[Exception]) + def call(a: A, b:B) : java.lang.Iterable[C] + + def elementType() : ClassTag[C] = ClassTag.Any.asInstanceOf[ClassTag[C]] +} diff --git a/core/src/main/scala/org/apache/spark/api/java/function/Function.java b/core/src/main/scala/org/apache/spark/api/java/function/Function.java new file mode 100644 index 0000000000..f9dae6ed34 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/api/java/function/Function.java @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.java.function; + +import scala.reflect.ClassTag; +import scala.reflect.ClassTag$; +import scala.runtime.AbstractFunction1; + +import java.io.Serializable; + + +/** + * Base class for functions whose return types do not create special RDDs. PairFunction and + * DoubleFunction are handled separately, to allow PairRDDs and DoubleRDDs to be constructed + * when mapping RDDs of other types. + */ +public abstract class Function<T, R> extends WrappedFunction1<T, R> implements Serializable { + public abstract R call(T t) throws Exception; + + public ClassTag<R> returnType() { + return (ClassTag<R>) ClassTag$.MODULE$.apply(Object.class); + } +} + diff --git a/core/src/main/scala/org/apache/spark/api/java/function/Function2.java b/core/src/main/scala/org/apache/spark/api/java/function/Function2.java new file mode 100644 index 0000000000..1659bfc552 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/api/java/function/Function2.java @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.java.function; + +import scala.reflect.ClassTag; +import scala.reflect.ClassTag$; +import scala.runtime.AbstractFunction2; + +import java.io.Serializable; + +/** + * A two-argument function that takes arguments of type T1 and T2 and returns an R. + */ +public abstract class Function2<T1, T2, R> extends WrappedFunction2<T1, T2, R> + implements Serializable { + + public abstract R call(T1 t1, T2 t2) throws Exception; + + public ClassTag<R> returnType() { + return (ClassTag<R>) ClassTag$.MODULE$.apply(Object.class); + } +} + diff --git a/core/src/main/scala/org/apache/spark/api/java/function/PairFlatMapFunction.java b/core/src/main/scala/org/apache/spark/api/java/function/PairFlatMapFunction.java new file mode 100644 index 0000000000..5a5c9b6296 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/api/java/function/PairFlatMapFunction.java @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.java.function; + +import scala.Tuple2; +import scala.reflect.ClassTag; +import scala.reflect.ClassTag$; +import scala.runtime.AbstractFunction1; + +import java.io.Serializable; + +/** + * A function that returns zero or more key-value pair records from each input record. The + * key-value pairs are represented as scala.Tuple2 objects. + */ +// PairFlatMapFunction does not extend FlatMapFunction because flatMap is +// overloaded for both FlatMapFunction and PairFlatMapFunction. +public abstract class PairFlatMapFunction<T, K, V> + extends WrappedFunction1<T, Iterable<Tuple2<K, V>>> + implements Serializable { + + public abstract Iterable<Tuple2<K, V>> call(T t) throws Exception; + + public ClassTag<K> keyType() { + return (ClassTag<K>) ClassTag$.MODULE$.apply(Object.class); + } + + public ClassTag<V> valueType() { + return (ClassTag<V>) ClassTag$.MODULE$.apply(Object.class); + } +} diff --git a/core/src/main/scala/org/apache/spark/api/java/function/PairFunction.java b/core/src/main/scala/org/apache/spark/api/java/function/PairFunction.java new file mode 100644 index 0000000000..4c39f483e5 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/api/java/function/PairFunction.java @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.java.function; + +import scala.Tuple2; +import scala.reflect.ClassTag; +import scala.reflect.ClassTag$; +import scala.runtime.AbstractFunction1; + +import java.io.Serializable; + +/** + * A function that returns key-value pairs (Tuple2<K, V>), and can be used to construct PairRDDs. + */ +// PairFunction does not extend Function because some UDF functions, like map, +// are overloaded for both Function and PairFunction. +public abstract class PairFunction<T, K, V> + extends WrappedFunction1<T, Tuple2<K, V>> + implements Serializable { + + public abstract Tuple2<K, V> call(T t) throws Exception; + + public ClassTag<K> keyType() { + return (ClassTag<K>) ClassTag$.MODULE$.apply(Object.class); + } + + public ClassTag<V> valueType() { + return (ClassTag<V>) ClassTag$.MODULE$.apply(Object.class); + } +} diff --git a/core/src/main/scala/org/apache/spark/api/java/function/VoidFunction.scala b/core/src/main/scala/org/apache/spark/api/java/function/VoidFunction.scala new file mode 100644 index 0000000000..ea94313a4a --- /dev/null +++ b/core/src/main/scala/org/apache/spark/api/java/function/VoidFunction.scala @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.java.function + +/** + * A function with no return value. + */ +// This allows Java users to write void methods without having to return Unit. +abstract class VoidFunction[T] extends Serializable { + @throws(classOf[Exception]) + def call(t: T) : Unit +} + +// VoidFunction cannot extend AbstractFunction1 (because that would force users to explicitly +// return Unit), so it is implicitly converted to a Function1[T, Unit]: +object VoidFunction { + implicit def toFunction[T](f: VoidFunction[T]) : Function1[T, Unit] = ((x : T) => f.call(x)) +} diff --git a/core/src/main/scala/org/apache/spark/api/java/function/WrappedFunction1.scala b/core/src/main/scala/org/apache/spark/api/java/function/WrappedFunction1.scala new file mode 100644 index 0000000000..cfe694f65d --- /dev/null +++ b/core/src/main/scala/org/apache/spark/api/java/function/WrappedFunction1.scala @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.java.function + +import scala.runtime.AbstractFunction1 + +/** + * Subclass of Function1 for ease of calling from Java. The main thing it does is re-expose the + * apply() method as call() and declare that it can throw Exception (since AbstractFunction1.apply + * isn't marked to allow that). + */ +private[spark] abstract class WrappedFunction1[T, R] extends AbstractFunction1[T, R] { + @throws(classOf[Exception]) + def call(t: T): R + + final def apply(t: T): R = call(t) +} diff --git a/core/src/main/scala/org/apache/spark/api/java/function/WrappedFunction2.scala b/core/src/main/scala/org/apache/spark/api/java/function/WrappedFunction2.scala new file mode 100644 index 0000000000..eb9277c6fb --- /dev/null +++ b/core/src/main/scala/org/apache/spark/api/java/function/WrappedFunction2.scala @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.java.function + +import scala.runtime.AbstractFunction2 + +/** + * Subclass of Function2 for ease of calling from Java. The main thing it does is re-expose the + * apply() method as call() and declare that it can throw Exception (since AbstractFunction2.apply + * isn't marked to allow that). + */ +private[spark] abstract class WrappedFunction2[T1, T2, R] extends AbstractFunction2[T1, T2, R] { + @throws(classOf[Exception]) + def call(t1: T1, t2: T2): R + + final def apply(t1: T1, t2: T2): R = call(t1, t2) +} diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonPartitioner.scala b/core/src/main/scala/org/apache/spark/api/python/PythonPartitioner.scala new file mode 100644 index 0000000000..b090c6edf3 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/api/python/PythonPartitioner.scala @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.python + +import org.apache.spark.Partitioner +import java.util.Arrays +import org.apache.spark.util.Utils + +/** + * A [[org.apache.spark.Partitioner]] that performs handling of byte arrays, for use by the Python API. + * + * Stores the unique id() of the Python-side partitioning function so that it is incorporated into + * equality comparisons. Correctness requires that the id is a unique identifier for the + * lifetime of the program (i.e. that it is not re-used as the id of a different partitioning + * function). This can be ensured by using the Python id() function and maintaining a reference + * to the Python partitioning function so that its id() is not reused. + */ +private[spark] class PythonPartitioner( + override val numPartitions: Int, + val pyPartitionFunctionId: Long) + extends Partitioner { + + override def getPartition(key: Any): Int = key match { + case null => 0 + case key: Array[Byte] => Utils.nonNegativeMod(Arrays.hashCode(key), numPartitions) + case _ => Utils.nonNegativeMod(key.hashCode(), numPartitions) + } + + override def equals(other: Any): Boolean = other match { + case h: PythonPartitioner => + h.numPartitions == numPartitions && h.pyPartitionFunctionId == pyPartitionFunctionId + case _ => + false + } +} diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala new file mode 100644 index 0000000000..cb2db77f39 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -0,0 +1,346 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.python + +import java.io._ +import java.net._ +import java.util.{List => JList, ArrayList => JArrayList, Map => JMap, Collections} + +import scala.collection.JavaConversions._ +import scala.reflect.ClassTag + +import org.apache.spark.api.java.{JavaSparkContext, JavaPairRDD, JavaRDD} +import org.apache.spark.broadcast.Broadcast +import org.apache.spark._ +import org.apache.spark.rdd.RDD +import org.apache.spark.rdd.PipedRDD +import org.apache.spark.util.Utils + +private[spark] class PythonRDD[T: ClassTag]( + parent: RDD[T], + command: Seq[String], + envVars: JMap[String, String], + pythonIncludes: JList[String], + preservePartitoning: Boolean, + pythonExec: String, + broadcastVars: JList[Broadcast[Array[Byte]]], + accumulator: Accumulator[JList[Array[Byte]]]) + extends RDD[Array[Byte]](parent) { + + val bufferSize = System.getProperty("spark.buffer.size", "65536").toInt + + // Similar to Runtime.exec(), if we are given a single string, split it into words + // using a standard StringTokenizer (i.e. by spaces) + def this(parent: RDD[T], command: String, envVars: JMap[String, String], + pythonIncludes: JList[String], + preservePartitoning: Boolean, pythonExec: String, + broadcastVars: JList[Broadcast[Array[Byte]]], + accumulator: Accumulator[JList[Array[Byte]]]) = + this(parent, PipedRDD.tokenize(command), envVars, pythonIncludes, preservePartitoning, pythonExec, + broadcastVars, accumulator) + + override def getPartitions = parent.partitions + + override val partitioner = if (preservePartitoning) parent.partitioner else None + + + override def compute(split: Partition, context: TaskContext): Iterator[Array[Byte]] = { + val startTime = System.currentTimeMillis + val env = SparkEnv.get + val worker = env.createPythonWorker(pythonExec, envVars.toMap) + + // Start a thread to feed the process input from our parent's iterator + new Thread("stdin writer for " + pythonExec) { + override def run() { + try { + SparkEnv.set(env) + val stream = new BufferedOutputStream(worker.getOutputStream, bufferSize) + val dataOut = new DataOutputStream(stream) + val printOut = new PrintWriter(stream) + // Partition index + dataOut.writeInt(split.index) + // sparkFilesDir + PythonRDD.writeAsPickle(SparkFiles.getRootDirectory, dataOut) + // Broadcast variables + dataOut.writeInt(broadcastVars.length) + for (broadcast <- broadcastVars) { + dataOut.writeLong(broadcast.id) + dataOut.writeInt(broadcast.value.length) + dataOut.write(broadcast.value) + } + // Python includes (*.zip and *.egg files) + dataOut.writeInt(pythonIncludes.length) + for (f <- pythonIncludes) { + PythonRDD.writeAsPickle(f, dataOut) + } + dataOut.flush() + // Serialized user code + for (elem <- command) { + printOut.println(elem) + } + printOut.flush() + // Data values + for (elem <- parent.iterator(split, context)) { + PythonRDD.writeAsPickle(elem, dataOut) + } + dataOut.flush() + printOut.flush() + worker.shutdownOutput() + } catch { + case e: IOException => + // This can happen for legitimate reasons if the Python code stops returning data before we are done + // passing elements through, e.g., for take(). Just log a message to say it happened. + logInfo("stdin writer to Python finished early") + logDebug("stdin writer to Python finished early", e) + } + } + }.start() + + // Return an iterator that read lines from the process's stdout + val stream = new DataInputStream(new BufferedInputStream(worker.getInputStream, bufferSize)) + return new Iterator[Array[Byte]] { + def next(): Array[Byte] = { + val obj = _nextObj + if (hasNext) { + // FIXME: can deadlock if worker is waiting for us to + // respond to current message (currently irrelevant because + // output is shutdown before we read any input) + _nextObj = read() + } + obj + } + + private def read(): Array[Byte] = { + try { + stream.readInt() match { + case length if length > 0 => + val obj = new Array[Byte](length) + stream.readFully(obj) + obj + case -3 => + // Timing data from worker + val bootTime = stream.readLong() + val initTime = stream.readLong() + val finishTime = stream.readLong() + val boot = bootTime - startTime + val init = initTime - bootTime + val finish = finishTime - initTime + val total = finishTime - startTime + logInfo("Times: total = %s, boot = %s, init = %s, finish = %s".format(total, boot, init, finish)) + read + case -2 => + // Signals that an exception has been thrown in python + val exLength = stream.readInt() + val obj = new Array[Byte](exLength) + stream.readFully(obj) + throw new PythonException(new String(obj)) + case -1 => + // We've finished the data section of the output, but we can still + // read some accumulator updates; let's do that, breaking when we + // get a negative length record. + var len2 = stream.readInt() + while (len2 >= 0) { + val update = new Array[Byte](len2) + stream.readFully(update) + accumulator += Collections.singletonList(update) + len2 = stream.readInt() + } + new Array[Byte](0) + } + } catch { + case eof: EOFException => { + throw new SparkException("Python worker exited unexpectedly (crashed)", eof) + } + case e : Throwable => throw e + } + } + + var _nextObj = read() + + def hasNext = _nextObj.length != 0 + } + } + + val asJavaRDD : JavaRDD[Array[Byte]] = JavaRDD.fromRDD(this) +} + +/** Thrown for exceptions in user Python code. */ +private class PythonException(msg: String) extends Exception(msg) + +/** + * Form an RDD[(Array[Byte], Array[Byte])] from key-value pairs returned from Python. + * This is used by PySpark's shuffle operations. + */ +private class PairwiseRDD(prev: RDD[Array[Byte]]) extends + RDD[(Array[Byte], Array[Byte])](prev) { + override def getPartitions = prev.partitions + override def compute(split: Partition, context: TaskContext) = + prev.iterator(split, context).grouped(2).map { + case Seq(a, b) => (a, b) + case x => throw new SparkException("PairwiseRDD: unexpected value: " + x) + } + val asJavaPairRDD : JavaPairRDD[Array[Byte], Array[Byte]] = JavaPairRDD.fromRDD(this) +} + +private[spark] object PythonRDD { + + /** Strips the pickle PROTO and STOP opcodes from the start and end of a pickle */ + def stripPickle(arr: Array[Byte]) : Array[Byte] = { + arr.slice(2, arr.length - 1) + } + + /** + * Write strings, pickled Python objects, or pairs of pickled objects to a data output stream. + * The data format is a 32-bit integer representing the pickled object's length (in bytes), + * followed by the pickled data. + * + * Pickle module: + * + * http://docs.python.org/2/library/pickle.html + * + * The pickle protocol is documented in the source of the `pickle` and `pickletools` modules: + * + * http://hg.python.org/cpython/file/2.6/Lib/pickle.py + * http://hg.python.org/cpython/file/2.6/Lib/pickletools.py + * + * @param elem the object to write + * @param dOut a data output stream + */ + def writeAsPickle(elem: Any, dOut: DataOutputStream) { + if (elem.isInstanceOf[Array[Byte]]) { + val arr = elem.asInstanceOf[Array[Byte]] + dOut.writeInt(arr.length) + dOut.write(arr) + } else if (elem.isInstanceOf[scala.Tuple2[_, _]]) { + val t = elem.asInstanceOf[scala.Tuple2[Array[Byte], Array[Byte]]] + val length = t._1.length + t._2.length - 3 - 3 + 4 // stripPickle() removes 3 bytes + dOut.writeInt(length) + dOut.writeByte(Pickle.PROTO) + dOut.writeByte(Pickle.TWO) + dOut.write(PythonRDD.stripPickle(t._1)) + dOut.write(PythonRDD.stripPickle(t._2)) + dOut.writeByte(Pickle.TUPLE2) + dOut.writeByte(Pickle.STOP) + } else if (elem.isInstanceOf[String]) { + // For uniformity, strings are wrapped into Pickles. + val s = elem.asInstanceOf[String].getBytes("UTF-8") + val length = 2 + 1 + 4 + s.length + 1 + dOut.writeInt(length) + dOut.writeByte(Pickle.PROTO) + dOut.writeByte(Pickle.TWO) + dOut.write(Pickle.BINUNICODE) + dOut.writeInt(Integer.reverseBytes(s.length)) + dOut.write(s) + dOut.writeByte(Pickle.STOP) + } else { + throw new SparkException("Unexpected RDD type") + } + } + + def readRDDFromPickleFile(sc: JavaSparkContext, filename: String, parallelism: Int) : + JavaRDD[Array[Byte]] = { + val file = new DataInputStream(new FileInputStream(filename)) + val objs = new collection.mutable.ArrayBuffer[Array[Byte]] + try { + while (true) { + val length = file.readInt() + val obj = new Array[Byte](length) + file.readFully(obj) + objs.append(obj) + } + } catch { + case eof: EOFException => {} + case e : Throwable => throw e + } + JavaRDD.fromRDD(sc.sc.parallelize(objs, parallelism)) + } + + def writeIteratorToPickleFile[T](items: java.util.Iterator[T], filename: String) { + import scala.collection.JavaConverters._ + writeIteratorToPickleFile(items.asScala, filename) + } + + def writeIteratorToPickleFile[T](items: Iterator[T], filename: String) { + val file = new DataOutputStream(new FileOutputStream(filename)) + for (item <- items) { + writeAsPickle(item, file) + } + file.close() + } + + def takePartition[T](rdd: RDD[T], partition: Int): Iterator[T] = { + implicit val cm : ClassTag[T] = rdd.elementClassTag + rdd.context.runJob(rdd, ((x: Iterator[T]) => x.toArray), Seq(partition), true).head.iterator + } +} + +private object Pickle { + val PROTO: Byte = 0x80.toByte + val TWO: Byte = 0x02.toByte + val BINUNICODE: Byte = 'X' + val STOP: Byte = '.' + val TUPLE2: Byte = 0x86.toByte + val EMPTY_LIST: Byte = ']' + val MARK: Byte = '(' + val APPENDS: Byte = 'e' +} + +private class BytesToString extends org.apache.spark.api.java.function.Function[Array[Byte], String] { + override def call(arr: Array[Byte]) : String = new String(arr, "UTF-8") +} + +/** + * Internal class that acts as an `AccumulatorParam` for Python accumulators. Inside, it + * collects a list of pickled strings that we pass to Python through a socket. + */ +class PythonAccumulatorParam(@transient serverHost: String, serverPort: Int) + extends AccumulatorParam[JList[Array[Byte]]] { + + Utils.checkHost(serverHost, "Expected hostname") + + val bufferSize = System.getProperty("spark.buffer.size", "65536").toInt + + override def zero(value: JList[Array[Byte]]): JList[Array[Byte]] = new JArrayList + + override def addInPlace(val1: JList[Array[Byte]], val2: JList[Array[Byte]]) + : JList[Array[Byte]] = { + if (serverHost == null) { + // This happens on the worker node, where we just want to remember all the updates + val1.addAll(val2) + val1 + } else { + // This happens on the master, where we pass the updates to Python through a socket + val socket = new Socket(serverHost, serverPort) + val in = socket.getInputStream + val out = new DataOutputStream(new BufferedOutputStream(socket.getOutputStream, bufferSize)) + out.writeInt(val2.size) + for (array <- val2) { + out.writeInt(array.length) + out.write(array) + } + out.flush() + // Wait for a byte from the Python side as an acknowledgement + val byteRead = in.read() + if (byteRead == -1) { + throw new SparkException("EOF reached before Python server acknowledged") + } + socket.close() + null + } + } +} diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala new file mode 100644 index 0000000000..67d45723ba --- /dev/null +++ b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala @@ -0,0 +1,223 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.python + +import java.io.{OutputStreamWriter, File, DataInputStream, IOException} +import java.net.{ServerSocket, Socket, SocketException, InetAddress} + +import scala.collection.JavaConversions._ + +import org.apache.spark._ + +private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String, String]) + extends Logging { + + // Because forking processes from Java is expensive, we prefer to launch a single Python daemon + // (pyspark/daemon.py) and tell it to fork new workers for our tasks. This daemon currently + // only works on UNIX-based systems now because it uses signals for child management, so we can + // also fall back to launching workers (pyspark/worker.py) directly. + val useDaemon = !System.getProperty("os.name").startsWith("Windows") + + var daemon: Process = null + val daemonHost = InetAddress.getByAddress(Array(127, 0, 0, 1)) + var daemonPort: Int = 0 + + def create(): Socket = { + if (useDaemon) { + createThroughDaemon() + } else { + createSimpleWorker() + } + } + + /** + * Connect to a worker launched through pyspark/daemon.py, which forks python processes itself + * to avoid the high cost of forking from Java. This currently only works on UNIX-based systems. + */ + private def createThroughDaemon(): Socket = { + synchronized { + // Start the daemon if it hasn't been started + startDaemon() + + // Attempt to connect, restart and retry once if it fails + try { + new Socket(daemonHost, daemonPort) + } catch { + case exc: SocketException => { + logWarning("Python daemon unexpectedly quit, attempting to restart") + stopDaemon() + startDaemon() + new Socket(daemonHost, daemonPort) + } + case e => throw e + } + } + } + + /** + * Launch a worker by executing worker.py directly and telling it to connect to us. + */ + private def createSimpleWorker(): Socket = { + var serverSocket: ServerSocket = null + try { + serverSocket = new ServerSocket(0, 1, InetAddress.getByAddress(Array(127, 0, 0, 1))) + + // Create and start the worker + val sparkHome = new ProcessBuilder().environment().get("SPARK_HOME") + val pb = new ProcessBuilder(Seq(pythonExec, sparkHome + "/python/pyspark/worker.py")) + val workerEnv = pb.environment() + workerEnv.putAll(envVars) + val pythonPath = sparkHome + "/python/" + File.pathSeparator + workerEnv.get("PYTHONPATH") + workerEnv.put("PYTHONPATH", pythonPath) + val worker = pb.start() + + // Redirect the worker's stderr to ours + new Thread("stderr reader for " + pythonExec) { + setDaemon(true) + override def run() { + scala.util.control.Exception.ignoring(classOf[IOException]) { + // FIXME: We copy the stream on the level of bytes to avoid encoding problems. + val in = worker.getErrorStream + val buf = new Array[Byte](1024) + var len = in.read(buf) + while (len != -1) { + System.err.write(buf, 0, len) + len = in.read(buf) + } + } + } + }.start() + + // Redirect worker's stdout to our stderr + new Thread("stdout reader for " + pythonExec) { + setDaemon(true) + override def run() { + scala.util.control.Exception.ignoring(classOf[IOException]) { + // FIXME: We copy the stream on the level of bytes to avoid encoding problems. + val in = worker.getInputStream + val buf = new Array[Byte](1024) + var len = in.read(buf) + while (len != -1) { + System.err.write(buf, 0, len) + len = in.read(buf) + } + } + } + }.start() + + // Tell the worker our port + val out = new OutputStreamWriter(worker.getOutputStream) + out.write(serverSocket.getLocalPort + "\n") + out.flush() + + // Wait for it to connect to our socket + serverSocket.setSoTimeout(10000) + try { + return serverSocket.accept() + } catch { + case e: Exception => + throw new SparkException("Python worker did not connect back in time", e) + } + } finally { + if (serverSocket != null) { + serverSocket.close() + } + } + null + } + + def stop() { + stopDaemon() + } + + private def startDaemon() { + synchronized { + // Is it already running? + if (daemon != null) { + return + } + + try { + // Create and start the daemon + val sparkHome = new ProcessBuilder().environment().get("SPARK_HOME") + val pb = new ProcessBuilder(Seq(pythonExec, sparkHome + "/python/pyspark/daemon.py")) + val workerEnv = pb.environment() + workerEnv.putAll(envVars) + val pythonPath = sparkHome + "/python/" + File.pathSeparator + workerEnv.get("PYTHONPATH") + workerEnv.put("PYTHONPATH", pythonPath) + daemon = pb.start() + + // Redirect the stderr to ours + new Thread("stderr reader for " + pythonExec) { + setDaemon(true) + override def run() { + scala.util.control.Exception.ignoring(classOf[IOException]) { + // FIXME: We copy the stream on the level of bytes to avoid encoding problems. + val in = daemon.getErrorStream + val buf = new Array[Byte](1024) + var len = in.read(buf) + while (len != -1) { + System.err.write(buf, 0, len) + len = in.read(buf) + } + } + } + }.start() + + val in = new DataInputStream(daemon.getInputStream) + daemonPort = in.readInt() + + // Redirect further stdout output to our stderr + new Thread("stdout reader for " + pythonExec) { + setDaemon(true) + override def run() { + scala.util.control.Exception.ignoring(classOf[IOException]) { + // FIXME: We copy the stream on the level of bytes to avoid encoding problems. + val buf = new Array[Byte](1024) + var len = in.read(buf) + while (len != -1) { + System.err.write(buf, 0, len) + len = in.read(buf) + } + } + } + }.start() + } catch { + case e => { + stopDaemon() + throw e + } + } + + // Important: don't close daemon's stdin (daemon.getOutputStream) so it can correctly + // detect our disappearance. + } + } + + private def stopDaemon() { + synchronized { + // Request shutdown of existing daemon by sending SIGTERM + if (daemon != null) { + daemon.destroy() + } + + daemon = null + daemonPort = 0 + } + } +} diff --git a/core/src/main/scala/org/apache/spark/broadcast/BitTorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/BitTorrentBroadcast.scala new file mode 100644 index 0000000000..93e7815ab5 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/broadcast/BitTorrentBroadcast.scala @@ -0,0 +1,1058 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.broadcast + +import java.io._ +import java.net._ +import java.util.{BitSet, Comparator, Timer, TimerTask, UUID} +import java.util.concurrent.atomic.AtomicInteger + +import scala.collection.mutable.{ListBuffer, Map, Set} +import scala.math + +import org.apache.spark._ +import org.apache.spark.storage.StorageLevel +import org.apache.spark.util.Utils + +private[spark] class BitTorrentBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long) + extends Broadcast[T](id) + with Logging + with Serializable { + + def value = value_ + + def blockId: String = "broadcast_" + id + + MultiTracker.synchronized { + SparkEnv.get.blockManager.putSingle(blockId, value_, StorageLevel.MEMORY_AND_DISK, false) + } + + @transient var arrayOfBlocks: Array[BroadcastBlock] = null + @transient var hasBlocksBitVector: BitSet = null + @transient var numCopiesSent: Array[Int] = null + @transient var totalBytes = -1 + @transient var totalBlocks = -1 + @transient var hasBlocks = new AtomicInteger(0) + + // Used ONLY by driver to track how many unique blocks have been sent out + @transient var sentBlocks = new AtomicInteger(0) + + @transient var listenPortLock = new Object + @transient var guidePortLock = new Object + @transient var totalBlocksLock = new Object + + @transient var listOfSources = ListBuffer[SourceInfo]() + + @transient var serveMR: ServeMultipleRequests = null + + // Used only in driver + @transient var guideMR: GuideMultipleRequests = null + + // Used only in Workers + @transient var ttGuide: TalkToGuide = null + + @transient var hostAddress = Utils.localIpAddress + @transient var listenPort = -1 + @transient var guidePort = -1 + + @transient var stopBroadcast = false + + // Must call this after all the variables have been created/initialized + if (!isLocal) { + sendBroadcast() + } + + def sendBroadcast() { + logInfo("Local host address: " + hostAddress) + + // Create a variableInfo object and store it in valueInfos + var variableInfo = MultiTracker.blockifyObject(value_) + + // Prepare the value being broadcasted + arrayOfBlocks = variableInfo.arrayOfBlocks + totalBytes = variableInfo.totalBytes + totalBlocks = variableInfo.totalBlocks + hasBlocks.set(variableInfo.totalBlocks) + + // Guide has all the blocks + hasBlocksBitVector = new BitSet(totalBlocks) + hasBlocksBitVector.set(0, totalBlocks) + + // Guide still hasn't sent any block + numCopiesSent = new Array[Int](totalBlocks) + + guideMR = new GuideMultipleRequests + guideMR.setDaemon(true) + guideMR.start() + logInfo("GuideMultipleRequests started...") + + // Must always come AFTER guideMR is created + while (guidePort == -1) { + guidePortLock.synchronized { guidePortLock.wait() } + } + + serveMR = new ServeMultipleRequests + serveMR.setDaemon(true) + serveMR.start() + logInfo("ServeMultipleRequests started...") + + // Must always come AFTER serveMR is created + while (listenPort == -1) { + listenPortLock.synchronized { listenPortLock.wait() } + } + + // Must always come AFTER listenPort is created + val driverSource = + SourceInfo(hostAddress, listenPort, totalBlocks, totalBytes) + hasBlocksBitVector.synchronized { + driverSource.hasBlocksBitVector = hasBlocksBitVector + } + + // In the beginning, this is the only known source to Guide + listOfSources += driverSource + + // Register with the Tracker + MultiTracker.registerBroadcast(id, + SourceInfo(hostAddress, guidePort, totalBlocks, totalBytes)) + } + + private def readObject(in: ObjectInputStream) { + in.defaultReadObject() + MultiTracker.synchronized { + SparkEnv.get.blockManager.getSingle(blockId) match { + case Some(x) => + value_ = x.asInstanceOf[T] + + case None => + logInfo("Started reading broadcast variable " + id) + // Initializing everything because driver will only send null/0 values + // Only the 1st worker in a node can be here. Others will get from cache + initializeWorkerVariables() + + logInfo("Local host address: " + hostAddress) + + // Start local ServeMultipleRequests thread first + serveMR = new ServeMultipleRequests + serveMR.setDaemon(true) + serveMR.start() + logInfo("ServeMultipleRequests started...") + + val start = System.nanoTime + + val receptionSucceeded = receiveBroadcast(id) + if (receptionSucceeded) { + value_ = MultiTracker.unBlockifyObject[T](arrayOfBlocks, totalBytes, totalBlocks) + SparkEnv.get.blockManager.putSingle( + blockId, value_, StorageLevel.MEMORY_AND_DISK, false) + } else { + logError("Reading broadcast variable " + id + " failed") + } + + val time = (System.nanoTime - start) / 1e9 + logInfo("Reading broadcast variable " + id + " took " + time + " s") + } + } + } + + // Initialize variables in the worker node. Driver sends everything as 0/null + private def initializeWorkerVariables() { + arrayOfBlocks = null + hasBlocksBitVector = null + numCopiesSent = null + totalBytes = -1 + totalBlocks = -1 + hasBlocks = new AtomicInteger(0) + + listenPortLock = new Object + totalBlocksLock = new Object + + serveMR = null + ttGuide = null + + hostAddress = Utils.localIpAddress + listenPort = -1 + + listOfSources = ListBuffer[SourceInfo]() + + stopBroadcast = false + } + + private def getLocalSourceInfo: SourceInfo = { + // Wait till hostName and listenPort are OK + while (listenPort == -1) { + listenPortLock.synchronized { listenPortLock.wait() } + } + + // Wait till totalBlocks and totalBytes are OK + while (totalBlocks == -1) { + totalBlocksLock.synchronized { totalBlocksLock.wait() } + } + + var localSourceInfo = SourceInfo( + hostAddress, listenPort, totalBlocks, totalBytes) + + localSourceInfo.hasBlocks = hasBlocks.get + + hasBlocksBitVector.synchronized { + localSourceInfo.hasBlocksBitVector = hasBlocksBitVector + } + + return localSourceInfo + } + + // Add new SourceInfo to the listOfSources. Update if it exists already. + // Optimizing just by OR-ing the BitVectors was BAD for performance + private def addToListOfSources(newSourceInfo: SourceInfo) { + listOfSources.synchronized { + if (listOfSources.contains(newSourceInfo)) { + listOfSources = listOfSources - newSourceInfo + } + listOfSources += newSourceInfo + } + } + + private def addToListOfSources(newSourceInfos: ListBuffer[SourceInfo]) { + newSourceInfos.foreach { newSourceInfo => + addToListOfSources(newSourceInfo) + } + } + + class TalkToGuide(gInfo: SourceInfo) + extends Thread with Logging { + override def run() { + + // Keep exchaning information until all blocks have been received + while (hasBlocks.get < totalBlocks) { + talkOnce + Thread.sleep(MultiTracker.ranGen.nextInt( + MultiTracker.MaxKnockInterval - MultiTracker.MinKnockInterval) + + MultiTracker.MinKnockInterval) + } + + // Talk one more time to let the Guide know of reception completion + talkOnce + } + + // Connect to Guide and send this worker's information + private def talkOnce { + var clientSocketToGuide: Socket = null + var oosGuide: ObjectOutputStream = null + var oisGuide: ObjectInputStream = null + + clientSocketToGuide = new Socket(gInfo.hostAddress, gInfo.listenPort) + oosGuide = new ObjectOutputStream(clientSocketToGuide.getOutputStream) + oosGuide.flush() + oisGuide = new ObjectInputStream(clientSocketToGuide.getInputStream) + + // Send local information + oosGuide.writeObject(getLocalSourceInfo) + oosGuide.flush() + + // Receive source information from Guide + var suitableSources = + oisGuide.readObject.asInstanceOf[ListBuffer[SourceInfo]] + logDebug("Received suitableSources from Driver " + suitableSources) + + addToListOfSources(suitableSources) + + oisGuide.close() + oosGuide.close() + clientSocketToGuide.close() + } + } + + def receiveBroadcast(variableID: Long): Boolean = { + val gInfo = MultiTracker.getGuideInfo(variableID) + + if (gInfo.listenPort == SourceInfo.TxOverGoToDefault) { + return false + } + + // Wait until hostAddress and listenPort are created by the + // ServeMultipleRequests thread + while (listenPort == -1) { + listenPortLock.synchronized { listenPortLock.wait() } + } + + // Setup initial states of variables + totalBlocks = gInfo.totalBlocks + arrayOfBlocks = new Array[BroadcastBlock](totalBlocks) + hasBlocksBitVector = new BitSet(totalBlocks) + numCopiesSent = new Array[Int](totalBlocks) + totalBlocksLock.synchronized { totalBlocksLock.notifyAll() } + totalBytes = gInfo.totalBytes + + // Start ttGuide to periodically talk to the Guide + var ttGuide = new TalkToGuide(gInfo) + ttGuide.setDaemon(true) + ttGuide.start() + logInfo("TalkToGuide started...") + + // Start pController to run TalkToPeer threads + var pcController = new PeerChatterController + pcController.setDaemon(true) + pcController.start() + logInfo("PeerChatterController started...") + + // FIXME: Must fix this. This might never break if broadcast fails. + // We should be able to break and send false. Also need to kill threads + while (hasBlocks.get < totalBlocks) { + Thread.sleep(MultiTracker.MaxKnockInterval) + } + + return true + } + + class PeerChatterController + extends Thread with Logging { + private var peersNowTalking = ListBuffer[SourceInfo]() + // TODO: There is a possible bug with blocksInRequestBitVector when a + // certain bit is NOT unset upon failure resulting in an infinite loop. + private var blocksInRequestBitVector = new BitSet(totalBlocks) + + override def run() { + var threadPool = Utils.newDaemonFixedThreadPool(MultiTracker.MaxChatSlots) + + while (hasBlocks.get < totalBlocks) { + var numThreadsToCreate = 0 + listOfSources.synchronized { + numThreadsToCreate = math.min(listOfSources.size, MultiTracker.MaxChatSlots) - + threadPool.getActiveCount + } + + while (hasBlocks.get < totalBlocks && numThreadsToCreate > 0) { + var peerToTalkTo = pickPeerToTalkToRandom + + if (peerToTalkTo != null) + logDebug("Peer chosen: " + peerToTalkTo + " with " + peerToTalkTo.hasBlocksBitVector) + else + logDebug("No peer chosen...") + + if (peerToTalkTo != null) { + threadPool.execute(new TalkToPeer(peerToTalkTo)) + + // Add to peersNowTalking. Remove in the thread. We have to do this + // ASAP, otherwise pickPeerToTalkTo picks the same peer more than once + peersNowTalking.synchronized { peersNowTalking += peerToTalkTo } + } + + numThreadsToCreate = numThreadsToCreate - 1 + } + + // Sleep for a while before starting some more threads + Thread.sleep(MultiTracker.MinKnockInterval) + } + // Shutdown the thread pool + threadPool.shutdown() + } + + // Right now picking the one that has the most blocks this peer wants + // Also picking peer randomly if no one has anything interesting + private def pickPeerToTalkToRandom: SourceInfo = { + var curPeer: SourceInfo = null + var curMax = 0 + + logDebug("Picking peers to talk to...") + + // Find peers that are not connected right now + var peersNotInUse = ListBuffer[SourceInfo]() + listOfSources.synchronized { + peersNowTalking.synchronized { + peersNotInUse = listOfSources -- peersNowTalking + } + } + + // Select the peer that has the most blocks that this receiver does not + peersNotInUse.foreach { eachSource => + var tempHasBlocksBitVector: BitSet = null + hasBlocksBitVector.synchronized { + tempHasBlocksBitVector = hasBlocksBitVector.clone.asInstanceOf[BitSet] + } + tempHasBlocksBitVector.flip(0, tempHasBlocksBitVector.size) + tempHasBlocksBitVector.and(eachSource.hasBlocksBitVector) + + if (tempHasBlocksBitVector.cardinality > curMax) { + curPeer = eachSource + curMax = tempHasBlocksBitVector.cardinality + } + } + + // Always picking randomly + if (curPeer == null && peersNotInUse.size > 0) { + // Pick uniformly the i'th required peer + var i = MultiTracker.ranGen.nextInt(peersNotInUse.size) + + var peerIter = peersNotInUse.iterator + curPeer = peerIter.next + + while (i > 0) { + curPeer = peerIter.next + i = i - 1 + } + } + + return curPeer + } + + // Picking peer with the weight of rare blocks it has + private def pickPeerToTalkToRarestFirst: SourceInfo = { + // Find peers that are not connected right now + var peersNotInUse = ListBuffer[SourceInfo]() + listOfSources.synchronized { + peersNowTalking.synchronized { + peersNotInUse = listOfSources -- peersNowTalking + } + } + + // Count the number of copies of each block in the neighborhood + var numCopiesPerBlock = Array.tabulate [Int](totalBlocks)(_ => 0) + + listOfSources.synchronized { + listOfSources.foreach { eachSource => + for (i <- 0 until totalBlocks) { + numCopiesPerBlock(i) += + ( if (eachSource.hasBlocksBitVector.get(i)) 1 else 0 ) + } + } + } + + // A block is considered rare if there are at most 2 copies of that block + // This CONSTANT could be a function of the neighborhood size + var rareBlocksIndices = ListBuffer[Int]() + for (i <- 0 until totalBlocks) { + if (numCopiesPerBlock(i) > 0 && numCopiesPerBlock(i) <= 2) { + rareBlocksIndices += i + } + } + + // Find peers with rare blocks + var peersWithRareBlocks = ListBuffer[(SourceInfo, Int)]() + var totalRareBlocks = 0 + + peersNotInUse.foreach { eachPeer => + var hasRareBlocks = 0 + rareBlocksIndices.foreach { rareBlock => + if (eachPeer.hasBlocksBitVector.get(rareBlock)) { + hasRareBlocks += 1 + } + } + + if (hasRareBlocks > 0) { + peersWithRareBlocks += ((eachPeer, hasRareBlocks)) + } + totalRareBlocks += hasRareBlocks + } + + // Select a peer from peersWithRareBlocks based on weight calculated from + // unique rare blocks + var selectedPeerToTalkTo: SourceInfo = null + + if (peersWithRareBlocks.size > 0) { + // Sort the peers based on how many rare blocks they have + peersWithRareBlocks.sortBy(_._2) + + var randomNumber = MultiTracker.ranGen.nextDouble + var tempSum = 0.0 + + var i = 0 + do { + tempSum += (1.0 * peersWithRareBlocks(i)._2 / totalRareBlocks) + if (tempSum >= randomNumber) { + selectedPeerToTalkTo = peersWithRareBlocks(i)._1 + } + i += 1 + } while (i < peersWithRareBlocks.size && selectedPeerToTalkTo == null) + } + + if (selectedPeerToTalkTo == null) { + selectedPeerToTalkTo = pickPeerToTalkToRandom + } + + return selectedPeerToTalkTo + } + + class TalkToPeer(peerToTalkTo: SourceInfo) + extends Thread with Logging { + private var peerSocketToSource: Socket = null + private var oosSource: ObjectOutputStream = null + private var oisSource: ObjectInputStream = null + + override def run() { + // TODO: There is a possible bug here regarding blocksInRequestBitVector + var blockToAskFor = -1 + + // Setup the timeout mechanism + var timeOutTask = new TimerTask { + override def run() { + cleanUpConnections() + } + } + + var timeOutTimer = new Timer + timeOutTimer.schedule(timeOutTask, MultiTracker.MaxKnockInterval) + + logInfo("TalkToPeer started... => " + peerToTalkTo) + + try { + // Connect to the source + peerSocketToSource = + new Socket(peerToTalkTo.hostAddress, peerToTalkTo.listenPort) + oosSource = + new ObjectOutputStream(peerSocketToSource.getOutputStream) + oosSource.flush() + oisSource = + new ObjectInputStream(peerSocketToSource.getInputStream) + + // Receive latest SourceInfo from peerToTalkTo + var newPeerToTalkTo = oisSource.readObject.asInstanceOf[SourceInfo] + // Update listOfSources + addToListOfSources(newPeerToTalkTo) + + // Turn the timer OFF, if the sender responds before timeout + timeOutTimer.cancel() + + // Send the latest SourceInfo + oosSource.writeObject(getLocalSourceInfo) + oosSource.flush() + + var keepReceiving = true + + while (hasBlocks.get < totalBlocks && keepReceiving) { + blockToAskFor = + pickBlockRandom(newPeerToTalkTo.hasBlocksBitVector) + + // No block to request + if (blockToAskFor < 0) { + // Nothing to receive from newPeerToTalkTo + keepReceiving = false + } else { + // Let other threads know that blockToAskFor is being requested + blocksInRequestBitVector.synchronized { + blocksInRequestBitVector.set(blockToAskFor) + } + + // Start with sending the blockID + oosSource.writeObject(blockToAskFor) + oosSource.flush() + + // CHANGED: Driver might send some other block than the one + // requested to ensure fast spreading of all blocks. + val recvStartTime = System.currentTimeMillis + val bcBlock = oisSource.readObject.asInstanceOf[BroadcastBlock] + val receptionTime = (System.currentTimeMillis - recvStartTime) + + logDebug("Received block: " + bcBlock.blockID + " from " + peerToTalkTo + " in " + receptionTime + " millis.") + + if (!hasBlocksBitVector.get(bcBlock.blockID)) { + arrayOfBlocks(bcBlock.blockID) = bcBlock + + // Update the hasBlocksBitVector first + hasBlocksBitVector.synchronized { + hasBlocksBitVector.set(bcBlock.blockID) + hasBlocks.getAndIncrement + } + + // Some block(may NOT be blockToAskFor) has arrived. + // In any case, blockToAskFor is not in request any more + blocksInRequestBitVector.synchronized { + blocksInRequestBitVector.set(blockToAskFor, false) + } + + // Reset blockToAskFor to -1. Else it will be considered missing + blockToAskFor = -1 + } + + // Send the latest SourceInfo + oosSource.writeObject(getLocalSourceInfo) + oosSource.flush() + } + } + } catch { + // EOFException is expected to happen because sender can break + // connection due to timeout + case eofe: java.io.EOFException => { } + case e: Exception => { + logError("TalktoPeer had a " + e) + // FIXME: Remove 'newPeerToTalkTo' from listOfSources + // We probably should have the following in some form, but not + // really here. This exception can happen if the sender just breaks connection + // listOfSources.synchronized { + // logInfo("Exception in TalkToPeer. Removing source: " + peerToTalkTo) + // listOfSources = listOfSources - peerToTalkTo + // } + } + } finally { + // blockToAskFor != -1 => there was an exception + if (blockToAskFor != -1) { + blocksInRequestBitVector.synchronized { + blocksInRequestBitVector.set(blockToAskFor, false) + } + } + + cleanUpConnections() + } + } + + // Right now it picks a block uniformly that this peer does not have + private def pickBlockRandom(txHasBlocksBitVector: BitSet): Int = { + var needBlocksBitVector: BitSet = null + + // Blocks already present + hasBlocksBitVector.synchronized { + needBlocksBitVector = hasBlocksBitVector.clone.asInstanceOf[BitSet] + } + + // Include blocks already in transmission ONLY IF + // MultiTracker.EndGameFraction has NOT been achieved + if ((1.0 * hasBlocks.get / totalBlocks) < MultiTracker.EndGameFraction) { + blocksInRequestBitVector.synchronized { + needBlocksBitVector.or(blocksInRequestBitVector) + } + } + + // Find blocks that are neither here nor in transit + needBlocksBitVector.flip(0, needBlocksBitVector.size) + + // Blocks that should/can be requested + needBlocksBitVector.and(txHasBlocksBitVector) + + if (needBlocksBitVector.cardinality == 0) { + return -1 + } else { + // Pick uniformly the i'th required block + var i = MultiTracker.ranGen.nextInt(needBlocksBitVector.cardinality) + var pickedBlockIndex = needBlocksBitVector.nextSetBit(0) + + while (i > 0) { + pickedBlockIndex = + needBlocksBitVector.nextSetBit(pickedBlockIndex + 1) + i -= 1 + } + + return pickedBlockIndex + } + } + + // Pick the block that seems to be the rarest across sources + private def pickBlockRarestFirst(txHasBlocksBitVector: BitSet): Int = { + var needBlocksBitVector: BitSet = null + + // Blocks already present + hasBlocksBitVector.synchronized { + needBlocksBitVector = hasBlocksBitVector.clone.asInstanceOf[BitSet] + } + + // Include blocks already in transmission ONLY IF + // MultiTracker.EndGameFraction has NOT been achieved + if ((1.0 * hasBlocks.get / totalBlocks) < MultiTracker.EndGameFraction) { + blocksInRequestBitVector.synchronized { + needBlocksBitVector.or(blocksInRequestBitVector) + } + } + + // Find blocks that are neither here nor in transit + needBlocksBitVector.flip(0, needBlocksBitVector.size) + + // Blocks that should/can be requested + needBlocksBitVector.and(txHasBlocksBitVector) + + if (needBlocksBitVector.cardinality == 0) { + return -1 + } else { + // Count the number of copies for each block across all sources + var numCopiesPerBlock = Array.tabulate [Int](totalBlocks)(_ => 0) + + listOfSources.synchronized { + listOfSources.foreach { eachSource => + for (i <- 0 until totalBlocks) { + numCopiesPerBlock(i) += + ( if (eachSource.hasBlocksBitVector.get(i)) 1 else 0 ) + } + } + } + + // Find the minimum + var minVal = Integer.MAX_VALUE + for (i <- 0 until totalBlocks) { + if (numCopiesPerBlock(i) > 0 && numCopiesPerBlock(i) < minVal) { + minVal = numCopiesPerBlock(i) + } + } + + // Find the blocks with the least copies that this peer does not have + var minBlocksIndices = ListBuffer[Int]() + for (i <- 0 until totalBlocks) { + if (needBlocksBitVector.get(i) && numCopiesPerBlock(i) == minVal) { + minBlocksIndices += i + } + } + + // Now select a random index from minBlocksIndices + if (minBlocksIndices.size == 0) { + return -1 + } else { + // Pick uniformly the i'th index + var i = MultiTracker.ranGen.nextInt(minBlocksIndices.size) + return minBlocksIndices(i) + } + } + } + + private def cleanUpConnections() { + if (oisSource != null) { + oisSource.close() + } + if (oosSource != null) { + oosSource.close() + } + if (peerSocketToSource != null) { + peerSocketToSource.close() + } + + // Delete from peersNowTalking + peersNowTalking.synchronized { peersNowTalking -= peerToTalkTo } + } + } + } + + class GuideMultipleRequests + extends Thread with Logging { + // Keep track of sources that have completed reception + private var setOfCompletedSources = Set[SourceInfo]() + + override def run() { + var threadPool = Utils.newDaemonCachedThreadPool() + var serverSocket: ServerSocket = null + + serverSocket = new ServerSocket(0) + guidePort = serverSocket.getLocalPort + logInfo("GuideMultipleRequests => " + serverSocket + " " + guidePort) + + guidePortLock.synchronized { guidePortLock.notifyAll() } + + try { + while (!stopBroadcast) { + var clientSocket: Socket = null + try { + serverSocket.setSoTimeout(MultiTracker.ServerSocketTimeout) + clientSocket = serverSocket.accept() + } catch { + case e: Exception => { + // Stop broadcast if at least one worker has connected and + // everyone connected so far are done. Comparing with + // listOfSources.size - 1, because it includes the Guide itself + listOfSources.synchronized { + setOfCompletedSources.synchronized { + if (listOfSources.size > 1 && + setOfCompletedSources.size == listOfSources.size - 1) { + stopBroadcast = true + logInfo("GuideMultipleRequests Timeout. stopBroadcast == true.") + } + } + } + } + } + if (clientSocket != null) { + logDebug("Guide: Accepted new client connection:" + clientSocket) + try { + threadPool.execute(new GuideSingleRequest(clientSocket)) + } catch { + // In failure, close the socket here; else, thread will close it + case ioe: IOException => { + clientSocket.close() + } + } + } + } + + // Shutdown the thread pool + threadPool.shutdown() + + logInfo("Sending stopBroadcast notifications...") + sendStopBroadcastNotifications + + MultiTracker.unregisterBroadcast(id) + } finally { + if (serverSocket != null) { + logInfo("GuideMultipleRequests now stopping...") + serverSocket.close() + } + } + } + + private def sendStopBroadcastNotifications() { + listOfSources.synchronized { + listOfSources.foreach { sourceInfo => + + var guideSocketToSource: Socket = null + var gosSource: ObjectOutputStream = null + var gisSource: ObjectInputStream = null + + try { + // Connect to the source + guideSocketToSource = new Socket(sourceInfo.hostAddress, sourceInfo.listenPort) + gosSource = new ObjectOutputStream(guideSocketToSource.getOutputStream) + gosSource.flush() + gisSource = new ObjectInputStream(guideSocketToSource.getInputStream) + + // Throw away whatever comes in + gisSource.readObject.asInstanceOf[SourceInfo] + + // Send stopBroadcast signal. listenPort = SourceInfo.StopBroadcast + gosSource.writeObject(SourceInfo("", SourceInfo.StopBroadcast)) + gosSource.flush() + } catch { + case e: Exception => { + logError("sendStopBroadcastNotifications had a " + e) + } + } finally { + if (gisSource != null) { + gisSource.close() + } + if (gosSource != null) { + gosSource.close() + } + if (guideSocketToSource != null) { + guideSocketToSource.close() + } + } + } + } + } + + class GuideSingleRequest(val clientSocket: Socket) + extends Thread with Logging { + private val oos = new ObjectOutputStream(clientSocket.getOutputStream) + oos.flush() + private val ois = new ObjectInputStream(clientSocket.getInputStream) + + private var sourceInfo: SourceInfo = null + private var selectedSources: ListBuffer[SourceInfo] = null + + override def run() { + try { + logInfo("new GuideSingleRequest is running") + // Connecting worker is sending in its information + sourceInfo = ois.readObject.asInstanceOf[SourceInfo] + + // Select a suitable source and send it back to the worker + selectedSources = selectSuitableSources(sourceInfo) + logDebug("Sending selectedSources:" + selectedSources) + oos.writeObject(selectedSources) + oos.flush() + + // Add this source to the listOfSources + addToListOfSources(sourceInfo) + } catch { + case e: Exception => { + // Assuming exception caused by receiver failure: remove + if (listOfSources != null) { + listOfSources.synchronized { listOfSources -= sourceInfo } + } + } + } finally { + logInfo("GuideSingleRequest is closing streams and sockets") + ois.close() + oos.close() + clientSocket.close() + } + } + + // Randomly select some sources to send back + private def selectSuitableSources(skipSourceInfo: SourceInfo): ListBuffer[SourceInfo] = { + var selectedSources = ListBuffer[SourceInfo]() + + // If skipSourceInfo.hasBlocksBitVector has all bits set to 'true' + // then add skipSourceInfo to setOfCompletedSources. Return blank. + if (skipSourceInfo.hasBlocks == totalBlocks) { + setOfCompletedSources.synchronized { setOfCompletedSources += skipSourceInfo } + return selectedSources + } + + listOfSources.synchronized { + if (listOfSources.size <= MultiTracker.MaxPeersInGuideResponse) { + selectedSources = listOfSources.clone + } else { + var picksLeft = MultiTracker.MaxPeersInGuideResponse + var alreadyPicked = new BitSet(listOfSources.size) + + while (picksLeft > 0) { + var i = -1 + + do { + i = MultiTracker.ranGen.nextInt(listOfSources.size) + } while (alreadyPicked.get(i)) + + var peerIter = listOfSources.iterator + var curPeer = peerIter.next + + // Set the BitSet before i is decremented + alreadyPicked.set(i) + + while (i > 0) { + curPeer = peerIter.next + i = i - 1 + } + + selectedSources += curPeer + + picksLeft = picksLeft - 1 + } + } + } + + // Remove the receiving source (if present) + selectedSources = selectedSources - skipSourceInfo + + return selectedSources + } + } + } + + class ServeMultipleRequests + extends Thread with Logging { + // Server at most MultiTracker.MaxChatSlots peers + var threadPool = Utils.newDaemonFixedThreadPool(MultiTracker.MaxChatSlots) + + override def run() { + var serverSocket = new ServerSocket(0) + listenPort = serverSocket.getLocalPort + + logInfo("ServeMultipleRequests started with " + serverSocket) + + listenPortLock.synchronized { listenPortLock.notifyAll() } + + try { + while (!stopBroadcast) { + var clientSocket: Socket = null + try { + serverSocket.setSoTimeout(MultiTracker.ServerSocketTimeout) + clientSocket = serverSocket.accept() + } catch { + case e: Exception => { } + } + if (clientSocket != null) { + logDebug("Serve: Accepted new client connection:" + clientSocket) + try { + threadPool.execute(new ServeSingleRequest(clientSocket)) + } catch { + // In failure, close socket here; else, the thread will close it + case ioe: IOException => clientSocket.close() + } + } + } + } finally { + if (serverSocket != null) { + logInfo("ServeMultipleRequests now stopping...") + serverSocket.close() + } + } + // Shutdown the thread pool + threadPool.shutdown() + } + + class ServeSingleRequest(val clientSocket: Socket) + extends Thread with Logging { + private val oos = new ObjectOutputStream(clientSocket.getOutputStream) + oos.flush() + private val ois = new ObjectInputStream(clientSocket.getInputStream) + + logInfo("new ServeSingleRequest is running") + + override def run() { + try { + // Send latest local SourceInfo to the receiver + // In the case of receiver timeout and connection close, this will + // throw a java.net.SocketException: Broken pipe + oos.writeObject(getLocalSourceInfo) + oos.flush() + + // Receive latest SourceInfo from the receiver + var rxSourceInfo = ois.readObject.asInstanceOf[SourceInfo] + + if (rxSourceInfo.listenPort == SourceInfo.StopBroadcast) { + stopBroadcast = true + } else { + addToListOfSources(rxSourceInfo) + } + + val startTime = System.currentTimeMillis + var curTime = startTime + var keepSending = true + var numBlocksToSend = MultiTracker.MaxChatBlocks + + while (!stopBroadcast && keepSending && numBlocksToSend > 0) { + // Receive which block to send + var blockToSend = ois.readObject.asInstanceOf[Int] + + // If it is driver AND at least one copy of each block has not been + // sent out already, MODIFY blockToSend + if (MultiTracker.isDriver && sentBlocks.get < totalBlocks) { + blockToSend = sentBlocks.getAndIncrement + } + + // Send the block + sendBlock(blockToSend) + rxSourceInfo.hasBlocksBitVector.set(blockToSend) + + numBlocksToSend -= 1 + + // Receive latest SourceInfo from the receiver + rxSourceInfo = ois.readObject.asInstanceOf[SourceInfo] + logDebug("rxSourceInfo: " + rxSourceInfo + " with " + rxSourceInfo.hasBlocksBitVector) + addToListOfSources(rxSourceInfo) + + curTime = System.currentTimeMillis + // Revoke sending only if there is anyone waiting in the queue + if (curTime - startTime >= MultiTracker.MaxChatTime && + threadPool.getQueue.size > 0) { + keepSending = false + } + } + } catch { + case e: Exception => logError("ServeSingleRequest had a " + e) + } finally { + logInfo("ServeSingleRequest is closing streams and sockets") + ois.close() + oos.close() + clientSocket.close() + } + } + + private def sendBlock(blockToSend: Int) { + try { + oos.writeObject(arrayOfBlocks(blockToSend)) + oos.flush() + } catch { + case e: Exception => logError("sendBlock had a " + e) + } + logDebug("Sent block: " + blockToSend + " to " + clientSocket) + } + } + } +} + +private[spark] class BitTorrentBroadcastFactory +extends BroadcastFactory { + def initialize(isDriver: Boolean) { MultiTracker.initialize(isDriver) } + + def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) = + new BitTorrentBroadcast[T](value_, isLocal, id) + + def stop() { MultiTracker.stop() } +} diff --git a/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala new file mode 100644 index 0000000000..43c18294c5 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.broadcast + +import java.io._ +import java.util.concurrent.atomic.AtomicLong + +import org.apache.spark._ + +abstract class Broadcast[T](private[spark] val id: Long) extends Serializable { + def value: T + + // We cannot have an abstract readObject here due to some weird issues with + // readObject having to be 'private' in sub-classes. + + override def toString = "Broadcast(" + id + ")" +} + +private[spark] +class BroadcastManager(val _isDriver: Boolean) extends Logging with Serializable { + + private var initialized = false + private var broadcastFactory: BroadcastFactory = null + + initialize() + + // Called by SparkContext or Executor before using Broadcast + private def initialize() { + synchronized { + if (!initialized) { + val broadcastFactoryClass = System.getProperty( + "spark.broadcast.factory", "org.apache.spark.broadcast.HttpBroadcastFactory") + + broadcastFactory = + Class.forName(broadcastFactoryClass).newInstance.asInstanceOf[BroadcastFactory] + + // Initialize appropriate BroadcastFactory and BroadcastObject + broadcastFactory.initialize(isDriver) + + initialized = true + } + } + } + + def stop() { + broadcastFactory.stop() + } + + private val nextBroadcastId = new AtomicLong(0) + + def newBroadcast[T](value_ : T, isLocal: Boolean) = + broadcastFactory.newBroadcast[T](value_, isLocal, nextBroadcastId.getAndIncrement()) + + def isDriver = _isDriver +} diff --git a/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala b/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala new file mode 100644 index 0000000000..68bff75b90 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.broadcast + +/** + * An interface for all the broadcast implementations in Spark (to allow + * multiple broadcast implementations). SparkContext uses a user-specified + * BroadcastFactory implementation to instantiate a particular broadcast for the + * entire Spark job. + */ +private[spark] trait BroadcastFactory { + def initialize(isDriver: Boolean): Unit + def newBroadcast[T](value: T, isLocal: Boolean, id: Long): Broadcast[T] + def stop(): Unit +} diff --git a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala new file mode 100644 index 0000000000..9db26ae6de --- /dev/null +++ b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala @@ -0,0 +1,171 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.broadcast + +import java.io.{File, FileOutputStream, ObjectInputStream, OutputStream} +import java.net.URL + +import it.unimi.dsi.fastutil.io.FastBufferedInputStream +import it.unimi.dsi.fastutil.io.FastBufferedOutputStream + +import org.apache.spark.{HttpServer, Logging, SparkEnv} +import org.apache.spark.io.CompressionCodec +import org.apache.spark.storage.StorageLevel +import org.apache.spark.util.{Utils, MetadataCleaner, TimeStampedHashSet} + + +private[spark] class HttpBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long) + extends Broadcast[T](id) with Logging with Serializable { + + def value = value_ + + def blockId: String = "broadcast_" + id + + HttpBroadcast.synchronized { + SparkEnv.get.blockManager.putSingle(blockId, value_, StorageLevel.MEMORY_AND_DISK, false) + } + + if (!isLocal) { + HttpBroadcast.write(id, value_) + } + + // Called by JVM when deserializing an object + private def readObject(in: ObjectInputStream) { + in.defaultReadObject() + HttpBroadcast.synchronized { + SparkEnv.get.blockManager.getSingle(blockId) match { + case Some(x) => value_ = x.asInstanceOf[T] + case None => { + logInfo("Started reading broadcast variable " + id) + val start = System.nanoTime + value_ = HttpBroadcast.read[T](id) + SparkEnv.get.blockManager.putSingle(blockId, value_, StorageLevel.MEMORY_AND_DISK, false) + val time = (System.nanoTime - start) / 1e9 + logInfo("Reading broadcast variable " + id + " took " + time + " s") + } + } + } + } +} + +private[spark] class HttpBroadcastFactory extends BroadcastFactory { + def initialize(isDriver: Boolean) { HttpBroadcast.initialize(isDriver) } + + def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) = + new HttpBroadcast[T](value_, isLocal, id) + + def stop() { HttpBroadcast.stop() } +} + +private object HttpBroadcast extends Logging { + private var initialized = false + + private var broadcastDir: File = null + private var compress: Boolean = false + private var bufferSize: Int = 65536 + private var serverUri: String = null + private var server: HttpServer = null + + private val files = new TimeStampedHashSet[String] + private val cleaner = new MetadataCleaner("HttpBroadcast", cleanup) + + private lazy val compressionCodec = CompressionCodec.createCodec() + + def initialize(isDriver: Boolean) { + synchronized { + if (!initialized) { + bufferSize = System.getProperty("spark.buffer.size", "65536").toInt + compress = System.getProperty("spark.broadcast.compress", "true").toBoolean + if (isDriver) { + createServer() + } + serverUri = System.getProperty("spark.httpBroadcast.uri") + initialized = true + } + } + } + + def stop() { + synchronized { + if (server != null) { + server.stop() + server = null + } + initialized = false + cleaner.cancel() + } + } + + private def createServer() { + broadcastDir = Utils.createTempDir(Utils.getLocalDir) + server = new HttpServer(broadcastDir) + server.start() + serverUri = server.uri + System.setProperty("spark.httpBroadcast.uri", serverUri) + logInfo("Broadcast server started at " + serverUri) + } + + def write(id: Long, value: Any) { + val file = new File(broadcastDir, "broadcast-" + id) + val out: OutputStream = { + if (compress) { + compressionCodec.compressedOutputStream(new FileOutputStream(file)) + } else { + new FastBufferedOutputStream(new FileOutputStream(file), bufferSize) + } + } + val ser = SparkEnv.get.serializer.newInstance() + val serOut = ser.serializeStream(out) + serOut.writeObject(value) + serOut.close() + files += file.getAbsolutePath + } + + def read[T](id: Long): T = { + val url = serverUri + "/broadcast-" + id + val in = { + if (compress) { + compressionCodec.compressedInputStream(new URL(url).openStream()) + } else { + new FastBufferedInputStream(new URL(url).openStream(), bufferSize) + } + } + val ser = SparkEnv.get.serializer.newInstance() + val serIn = ser.deserializeStream(in) + val obj = serIn.readObject[T]() + serIn.close() + obj + } + + def cleanup(cleanupTime: Long) { + val iterator = files.internalMap.entrySet().iterator() + while(iterator.hasNext) { + val entry = iterator.next() + val (file, time) = (entry.getKey, entry.getValue) + if (time < cleanupTime) { + try { + iterator.remove() + new File(file.toString).delete() + logInfo("Deleted broadcast file '" + file + "'") + } catch { + case e: Exception => logWarning("Could not delete broadcast file '" + file + "'", e) + } + } + } + } +} diff --git a/core/src/main/scala/org/apache/spark/broadcast/MultiTracker.scala b/core/src/main/scala/org/apache/spark/broadcast/MultiTracker.scala new file mode 100644 index 0000000000..21ec94659e --- /dev/null +++ b/core/src/main/scala/org/apache/spark/broadcast/MultiTracker.scala @@ -0,0 +1,410 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.broadcast + +import java.io._ +import java.net._ +import java.util.Random + +import scala.collection.mutable.Map + +import org.apache.spark._ +import org.apache.spark.util.Utils + +private object MultiTracker +extends Logging { + + // Tracker Messages + val REGISTER_BROADCAST_TRACKER = 0 + val UNREGISTER_BROADCAST_TRACKER = 1 + val FIND_BROADCAST_TRACKER = 2 + + // Map to keep track of guides of ongoing broadcasts + var valueToGuideMap = Map[Long, SourceInfo]() + + // Random number generator + var ranGen = new Random + + private var initialized = false + private var _isDriver = false + + private var stopBroadcast = false + + private var trackMV: TrackMultipleValues = null + + def initialize(__isDriver: Boolean) { + synchronized { + if (!initialized) { + _isDriver = __isDriver + + if (isDriver) { + trackMV = new TrackMultipleValues + trackMV.setDaemon(true) + trackMV.start() + + // Set DriverHostAddress to the driver's IP address for the slaves to read + System.setProperty("spark.MultiTracker.DriverHostAddress", Utils.localIpAddress) + } + + initialized = true + } + } + } + + def stop() { + stopBroadcast = true + } + + // Load common parameters + private var DriverHostAddress_ = System.getProperty( + "spark.MultiTracker.DriverHostAddress", "") + private var DriverTrackerPort_ = System.getProperty( + "spark.broadcast.driverTrackerPort", "11111").toInt + private var BlockSize_ = System.getProperty( + "spark.broadcast.blockSize", "4096").toInt * 1024 + private var MaxRetryCount_ = System.getProperty( + "spark.broadcast.maxRetryCount", "2").toInt + + private var TrackerSocketTimeout_ = System.getProperty( + "spark.broadcast.trackerSocketTimeout", "50000").toInt + private var ServerSocketTimeout_ = System.getProperty( + "spark.broadcast.serverSocketTimeout", "10000").toInt + + private var MinKnockInterval_ = System.getProperty( + "spark.broadcast.minKnockInterval", "500").toInt + private var MaxKnockInterval_ = System.getProperty( + "spark.broadcast.maxKnockInterval", "999").toInt + + // Load TreeBroadcast config params + private var MaxDegree_ = System.getProperty( + "spark.broadcast.maxDegree", "2").toInt + + // Load BitTorrentBroadcast config params + private var MaxPeersInGuideResponse_ = System.getProperty( + "spark.broadcast.maxPeersInGuideResponse", "4").toInt + + private var MaxChatSlots_ = System.getProperty( + "spark.broadcast.maxChatSlots", "4").toInt + private var MaxChatTime_ = System.getProperty( + "spark.broadcast.maxChatTime", "500").toInt + private var MaxChatBlocks_ = System.getProperty( + "spark.broadcast.maxChatBlocks", "1024").toInt + + private var EndGameFraction_ = System.getProperty( + "spark.broadcast.endGameFraction", "0.95").toDouble + + def isDriver = _isDriver + + // Common config params + def DriverHostAddress = DriverHostAddress_ + def DriverTrackerPort = DriverTrackerPort_ + def BlockSize = BlockSize_ + def MaxRetryCount = MaxRetryCount_ + + def TrackerSocketTimeout = TrackerSocketTimeout_ + def ServerSocketTimeout = ServerSocketTimeout_ + + def MinKnockInterval = MinKnockInterval_ + def MaxKnockInterval = MaxKnockInterval_ + + // TreeBroadcast configs + def MaxDegree = MaxDegree_ + + // BitTorrentBroadcast configs + def MaxPeersInGuideResponse = MaxPeersInGuideResponse_ + + def MaxChatSlots = MaxChatSlots_ + def MaxChatTime = MaxChatTime_ + def MaxChatBlocks = MaxChatBlocks_ + + def EndGameFraction = EndGameFraction_ + + class TrackMultipleValues + extends Thread with Logging { + override def run() { + var threadPool = Utils.newDaemonCachedThreadPool() + var serverSocket: ServerSocket = null + + serverSocket = new ServerSocket(DriverTrackerPort) + logInfo("TrackMultipleValues started at " + serverSocket) + + try { + while (!stopBroadcast) { + var clientSocket: Socket = null + try { + serverSocket.setSoTimeout(TrackerSocketTimeout) + clientSocket = serverSocket.accept() + } catch { + case e: Exception => { + if (stopBroadcast) { + logInfo("Stopping TrackMultipleValues...") + } + } + } + + if (clientSocket != null) { + try { + threadPool.execute(new Thread { + override def run() { + val oos = new ObjectOutputStream(clientSocket.getOutputStream) + oos.flush() + val ois = new ObjectInputStream(clientSocket.getInputStream) + + try { + // First, read message type + val messageType = ois.readObject.asInstanceOf[Int] + + if (messageType == REGISTER_BROADCAST_TRACKER) { + // Receive Long + val id = ois.readObject.asInstanceOf[Long] + // Receive hostAddress and listenPort + val gInfo = ois.readObject.asInstanceOf[SourceInfo] + + // Add to the map + valueToGuideMap.synchronized { + valueToGuideMap += (id -> gInfo) + } + + logInfo ("New broadcast " + id + " registered with TrackMultipleValues. Ongoing ones: " + valueToGuideMap) + + // Send dummy ACK + oos.writeObject(-1) + oos.flush() + } else if (messageType == UNREGISTER_BROADCAST_TRACKER) { + // Receive Long + val id = ois.readObject.asInstanceOf[Long] + + // Remove from the map + valueToGuideMap.synchronized { + valueToGuideMap(id) = SourceInfo("", SourceInfo.TxOverGoToDefault) + } + + logInfo ("Broadcast " + id + " unregistered from TrackMultipleValues. Ongoing ones: " + valueToGuideMap) + + // Send dummy ACK + oos.writeObject(-1) + oos.flush() + } else if (messageType == FIND_BROADCAST_TRACKER) { + // Receive Long + val id = ois.readObject.asInstanceOf[Long] + + var gInfo = + if (valueToGuideMap.contains(id)) valueToGuideMap(id) + else SourceInfo("", SourceInfo.TxNotStartedRetry) + + logDebug("Got new request: " + clientSocket + " for " + id + " : " + gInfo.listenPort) + + // Send reply back + oos.writeObject(gInfo) + oos.flush() + } else { + throw new SparkException("Undefined messageType at TrackMultipleValues") + } + } catch { + case e: Exception => { + logError("TrackMultipleValues had a " + e) + } + } finally { + ois.close() + oos.close() + clientSocket.close() + } + } + }) + } catch { + // In failure, close socket here; else, client thread will close + case ioe: IOException => clientSocket.close() + } + } + } + } finally { + serverSocket.close() + } + // Shutdown the thread pool + threadPool.shutdown() + } + } + + def getGuideInfo(variableLong: Long): SourceInfo = { + var clientSocketToTracker: Socket = null + var oosTracker: ObjectOutputStream = null + var oisTracker: ObjectInputStream = null + + var gInfo: SourceInfo = SourceInfo("", SourceInfo.TxNotStartedRetry) + + var retriesLeft = MultiTracker.MaxRetryCount + do { + try { + // Connect to the tracker to find out GuideInfo + clientSocketToTracker = + new Socket(MultiTracker.DriverHostAddress, MultiTracker.DriverTrackerPort) + oosTracker = + new ObjectOutputStream(clientSocketToTracker.getOutputStream) + oosTracker.flush() + oisTracker = + new ObjectInputStream(clientSocketToTracker.getInputStream) + + // Send messageType/intention + oosTracker.writeObject(MultiTracker.FIND_BROADCAST_TRACKER) + oosTracker.flush() + + // Send Long and receive GuideInfo + oosTracker.writeObject(variableLong) + oosTracker.flush() + gInfo = oisTracker.readObject.asInstanceOf[SourceInfo] + } catch { + case e: Exception => logError("getGuideInfo had a " + e) + } finally { + if (oisTracker != null) { + oisTracker.close() + } + if (oosTracker != null) { + oosTracker.close() + } + if (clientSocketToTracker != null) { + clientSocketToTracker.close() + } + } + + Thread.sleep(MultiTracker.ranGen.nextInt( + MultiTracker.MaxKnockInterval - MultiTracker.MinKnockInterval) + + MultiTracker.MinKnockInterval) + + retriesLeft -= 1 + } while (retriesLeft > 0 && gInfo.listenPort == SourceInfo.TxNotStartedRetry) + + logDebug("Got this guidePort from Tracker: " + gInfo.listenPort) + return gInfo + } + + def registerBroadcast(id: Long, gInfo: SourceInfo) { + val socket = new Socket(MultiTracker.DriverHostAddress, DriverTrackerPort) + val oosST = new ObjectOutputStream(socket.getOutputStream) + oosST.flush() + val oisST = new ObjectInputStream(socket.getInputStream) + + // Send messageType/intention + oosST.writeObject(REGISTER_BROADCAST_TRACKER) + oosST.flush() + + // Send Long of this broadcast + oosST.writeObject(id) + oosST.flush() + + // Send this tracker's information + oosST.writeObject(gInfo) + oosST.flush() + + // Receive ACK and throw it away + oisST.readObject.asInstanceOf[Int] + + // Shut stuff down + oisST.close() + oosST.close() + socket.close() + } + + def unregisterBroadcast(id: Long) { + val socket = new Socket(MultiTracker.DriverHostAddress, DriverTrackerPort) + val oosST = new ObjectOutputStream(socket.getOutputStream) + oosST.flush() + val oisST = new ObjectInputStream(socket.getInputStream) + + // Send messageType/intention + oosST.writeObject(UNREGISTER_BROADCAST_TRACKER) + oosST.flush() + + // Send Long of this broadcast + oosST.writeObject(id) + oosST.flush() + + // Receive ACK and throw it away + oisST.readObject.asInstanceOf[Int] + + // Shut stuff down + oisST.close() + oosST.close() + socket.close() + } + + // Helper method to convert an object to Array[BroadcastBlock] + def blockifyObject[IN](obj: IN): VariableInfo = { + val baos = new ByteArrayOutputStream + val oos = new ObjectOutputStream(baos) + oos.writeObject(obj) + oos.close() + baos.close() + val byteArray = baos.toByteArray + val bais = new ByteArrayInputStream(byteArray) + + var blockNum = (byteArray.length / BlockSize) + if (byteArray.length % BlockSize != 0) + blockNum += 1 + + var retVal = new Array[BroadcastBlock](blockNum) + var blockID = 0 + + for (i <- 0 until (byteArray.length, BlockSize)) { + val thisBlockSize = math.min(BlockSize, byteArray.length - i) + var tempByteArray = new Array[Byte](thisBlockSize) + val hasRead = bais.read(tempByteArray, 0, thisBlockSize) + + retVal(blockID) = new BroadcastBlock(blockID, tempByteArray) + blockID += 1 + } + bais.close() + + var variableInfo = VariableInfo(retVal, blockNum, byteArray.length) + variableInfo.hasBlocks = blockNum + + return variableInfo + } + + // Helper method to convert Array[BroadcastBlock] to object + def unBlockifyObject[OUT](arrayOfBlocks: Array[BroadcastBlock], + totalBytes: Int, + totalBlocks: Int): OUT = { + + var retByteArray = new Array[Byte](totalBytes) + for (i <- 0 until totalBlocks) { + System.arraycopy(arrayOfBlocks(i).byteArray, 0, retByteArray, + i * BlockSize, arrayOfBlocks(i).byteArray.length) + } + byteArrayToObject(retByteArray) + } + + private def byteArrayToObject[OUT](bytes: Array[Byte]): OUT = { + val in = new ObjectInputStream (new ByteArrayInputStream (bytes)){ + override def resolveClass(desc: ObjectStreamClass) = + Class.forName(desc.getName, false, Thread.currentThread.getContextClassLoader) + } + val retVal = in.readObject.asInstanceOf[OUT] + in.close() + return retVal + } +} + +private[spark] case class BroadcastBlock(blockID: Int, byteArray: Array[Byte]) +extends Serializable + +private[spark] case class VariableInfo(@transient arrayOfBlocks : Array[BroadcastBlock], + totalBlocks: Int, + totalBytes: Int) +extends Serializable { + @transient var hasBlocks = 0 +} diff --git a/core/src/main/scala/org/apache/spark/broadcast/SourceInfo.scala b/core/src/main/scala/org/apache/spark/broadcast/SourceInfo.scala new file mode 100644 index 0000000000..baa1fd6da4 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/broadcast/SourceInfo.scala @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.broadcast + +import java.util.BitSet + +import org.apache.spark._ + +/** + * Used to keep and pass around information of peers involved in a broadcast + */ +private[spark] case class SourceInfo (hostAddress: String, + listenPort: Int, + totalBlocks: Int = SourceInfo.UnusedParam, + totalBytes: Int = SourceInfo.UnusedParam) +extends Comparable[SourceInfo] with Logging { + + var currentLeechers = 0 + var receptionFailed = false + + var hasBlocks = 0 + var hasBlocksBitVector: BitSet = new BitSet (totalBlocks) + + // Ascending sort based on leecher count + def compareTo (o: SourceInfo): Int = (currentLeechers - o.currentLeechers) +} + +/** + * Helper Object of SourceInfo for its constants + */ +private[spark] object SourceInfo { + // Broadcast has not started yet! Should never happen. + val TxNotStartedRetry = -1 + // Broadcast has already finished. Try default mechanism. + val TxOverGoToDefault = -3 + // Other constants + val StopBroadcast = -2 + val UnusedParam = 0 +} diff --git a/core/src/main/scala/org/apache/spark/broadcast/TreeBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TreeBroadcast.scala new file mode 100644 index 0000000000..80c97ca073 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/broadcast/TreeBroadcast.scala @@ -0,0 +1,603 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.broadcast + +import java.io._ +import java.net._ +import java.util.{Comparator, Random, UUID} + +import scala.collection.mutable.{ListBuffer, Map, Set} +import scala.math + +import org.apache.spark._ +import org.apache.spark.storage.StorageLevel +import org.apache.spark.util.Utils + +private[spark] class TreeBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long) +extends Broadcast[T](id) with Logging with Serializable { + + def value = value_ + + def blockId = "broadcast_" + id + + MultiTracker.synchronized { + SparkEnv.get.blockManager.putSingle(blockId, value_, StorageLevel.MEMORY_AND_DISK, false) + } + + @transient var arrayOfBlocks: Array[BroadcastBlock] = null + @transient var totalBytes = -1 + @transient var totalBlocks = -1 + @transient var hasBlocks = 0 + + @transient var listenPortLock = new Object + @transient var guidePortLock = new Object + @transient var totalBlocksLock = new Object + @transient var hasBlocksLock = new Object + + @transient var listOfSources = ListBuffer[SourceInfo]() + + @transient var serveMR: ServeMultipleRequests = null + @transient var guideMR: GuideMultipleRequests = null + + @transient var hostAddress = Utils.localIpAddress + @transient var listenPort = -1 + @transient var guidePort = -1 + + @transient var stopBroadcast = false + + // Must call this after all the variables have been created/initialized + if (!isLocal) { + sendBroadcast() + } + + def sendBroadcast() { + logInfo("Local host address: " + hostAddress) + + // Create a variableInfo object and store it in valueInfos + var variableInfo = MultiTracker.blockifyObject(value_) + + // Prepare the value being broadcasted + arrayOfBlocks = variableInfo.arrayOfBlocks + totalBytes = variableInfo.totalBytes + totalBlocks = variableInfo.totalBlocks + hasBlocks = variableInfo.totalBlocks + + guideMR = new GuideMultipleRequests + guideMR.setDaemon(true) + guideMR.start() + logInfo("GuideMultipleRequests started...") + + // Must always come AFTER guideMR is created + while (guidePort == -1) { + guidePortLock.synchronized { guidePortLock.wait() } + } + + serveMR = new ServeMultipleRequests + serveMR.setDaemon(true) + serveMR.start() + logInfo("ServeMultipleRequests started...") + + // Must always come AFTER serveMR is created + while (listenPort == -1) { + listenPortLock.synchronized { listenPortLock.wait() } + } + + // Must always come AFTER listenPort is created + val masterSource = + SourceInfo(hostAddress, listenPort, totalBlocks, totalBytes) + listOfSources += masterSource + + // Register with the Tracker + MultiTracker.registerBroadcast(id, + SourceInfo(hostAddress, guidePort, totalBlocks, totalBytes)) + } + + private def readObject(in: ObjectInputStream) { + in.defaultReadObject() + MultiTracker.synchronized { + SparkEnv.get.blockManager.getSingle(blockId) match { + case Some(x) => + value_ = x.asInstanceOf[T] + + case None => + logInfo("Started reading broadcast variable " + id) + // Initializing everything because Driver will only send null/0 values + // Only the 1st worker in a node can be here. Others will get from cache + initializeWorkerVariables() + + logInfo("Local host address: " + hostAddress) + + serveMR = new ServeMultipleRequests + serveMR.setDaemon(true) + serveMR.start() + logInfo("ServeMultipleRequests started...") + + val start = System.nanoTime + + val receptionSucceeded = receiveBroadcast(id) + if (receptionSucceeded) { + value_ = MultiTracker.unBlockifyObject[T](arrayOfBlocks, totalBytes, totalBlocks) + SparkEnv.get.blockManager.putSingle( + blockId, value_, StorageLevel.MEMORY_AND_DISK, false) + } else { + logError("Reading broadcast variable " + id + " failed") + } + + val time = (System.nanoTime - start) / 1e9 + logInfo("Reading broadcast variable " + id + " took " + time + " s") + } + } + } + + private def initializeWorkerVariables() { + arrayOfBlocks = null + totalBytes = -1 + totalBlocks = -1 + hasBlocks = 0 + + listenPortLock = new Object + totalBlocksLock = new Object + hasBlocksLock = new Object + + serveMR = null + + hostAddress = Utils.localIpAddress + listenPort = -1 + + stopBroadcast = false + } + + def receiveBroadcast(variableID: Long): Boolean = { + val gInfo = MultiTracker.getGuideInfo(variableID) + + if (gInfo.listenPort == SourceInfo.TxOverGoToDefault) { + return false + } + + // Wait until hostAddress and listenPort are created by the + // ServeMultipleRequests thread + while (listenPort == -1) { + listenPortLock.synchronized { listenPortLock.wait() } + } + + var clientSocketToDriver: Socket = null + var oosDriver: ObjectOutputStream = null + var oisDriver: ObjectInputStream = null + + // Connect and receive broadcast from the specified source, retrying the + // specified number of times in case of failures + var retriesLeft = MultiTracker.MaxRetryCount + do { + // Connect to Driver and send this worker's Information + clientSocketToDriver = new Socket(MultiTracker.DriverHostAddress, gInfo.listenPort) + oosDriver = new ObjectOutputStream(clientSocketToDriver.getOutputStream) + oosDriver.flush() + oisDriver = new ObjectInputStream(clientSocketToDriver.getInputStream) + + logDebug("Connected to Driver's guiding object") + + // Send local source information + oosDriver.writeObject(SourceInfo(hostAddress, listenPort)) + oosDriver.flush() + + // Receive source information from Driver + var sourceInfo = oisDriver.readObject.asInstanceOf[SourceInfo] + totalBlocks = sourceInfo.totalBlocks + arrayOfBlocks = new Array[BroadcastBlock](totalBlocks) + totalBlocksLock.synchronized { totalBlocksLock.notifyAll() } + totalBytes = sourceInfo.totalBytes + + logDebug("Received SourceInfo from Driver:" + sourceInfo + " My Port: " + listenPort) + + val start = System.nanoTime + val receptionSucceeded = receiveSingleTransmission(sourceInfo) + val time = (System.nanoTime - start) / 1e9 + + // Updating some statistics in sourceInfo. Driver will be using them later + if (!receptionSucceeded) { + sourceInfo.receptionFailed = true + } + + // Send back statistics to the Driver + oosDriver.writeObject(sourceInfo) + + if (oisDriver != null) { + oisDriver.close() + } + if (oosDriver != null) { + oosDriver.close() + } + if (clientSocketToDriver != null) { + clientSocketToDriver.close() + } + + retriesLeft -= 1 + } while (retriesLeft > 0 && hasBlocks < totalBlocks) + + return (hasBlocks == totalBlocks) + } + + /** + * Tries to receive broadcast from the source and returns Boolean status. + * This might be called multiple times to retry a defined number of times. + */ + private def receiveSingleTransmission(sourceInfo: SourceInfo): Boolean = { + var clientSocketToSource: Socket = null + var oosSource: ObjectOutputStream = null + var oisSource: ObjectInputStream = null + + var receptionSucceeded = false + try { + // Connect to the source to get the object itself + clientSocketToSource = new Socket(sourceInfo.hostAddress, sourceInfo.listenPort) + oosSource = new ObjectOutputStream(clientSocketToSource.getOutputStream) + oosSource.flush() + oisSource = new ObjectInputStream(clientSocketToSource.getInputStream) + + logDebug("Inside receiveSingleTransmission") + logDebug("totalBlocks: "+ totalBlocks + " " + "hasBlocks: " + hasBlocks) + + // Send the range + oosSource.writeObject((hasBlocks, totalBlocks)) + oosSource.flush() + + for (i <- hasBlocks until totalBlocks) { + val recvStartTime = System.currentTimeMillis + val bcBlock = oisSource.readObject.asInstanceOf[BroadcastBlock] + val receptionTime = (System.currentTimeMillis - recvStartTime) + + logDebug("Received block: " + bcBlock.blockID + " from " + sourceInfo + " in " + receptionTime + " millis.") + + arrayOfBlocks(hasBlocks) = bcBlock + hasBlocks += 1 + + // Set to true if at least one block is received + receptionSucceeded = true + hasBlocksLock.synchronized { hasBlocksLock.notifyAll() } + } + } catch { + case e: Exception => logError("receiveSingleTransmission had a " + e) + } finally { + if (oisSource != null) { + oisSource.close() + } + if (oosSource != null) { + oosSource.close() + } + if (clientSocketToSource != null) { + clientSocketToSource.close() + } + } + + return receptionSucceeded + } + + class GuideMultipleRequests + extends Thread with Logging { + // Keep track of sources that have completed reception + private var setOfCompletedSources = Set[SourceInfo]() + + override def run() { + var threadPool = Utils.newDaemonCachedThreadPool() + var serverSocket: ServerSocket = null + + serverSocket = new ServerSocket(0) + guidePort = serverSocket.getLocalPort + logInfo("GuideMultipleRequests => " + serverSocket + " " + guidePort) + + guidePortLock.synchronized { guidePortLock.notifyAll() } + + try { + while (!stopBroadcast) { + var clientSocket: Socket = null + try { + serverSocket.setSoTimeout(MultiTracker.ServerSocketTimeout) + clientSocket = serverSocket.accept + } catch { + case e: Exception => { + // Stop broadcast if at least one worker has connected and + // everyone connected so far are done. Comparing with + // listOfSources.size - 1, because it includes the Guide itself + listOfSources.synchronized { + setOfCompletedSources.synchronized { + if (listOfSources.size > 1 && + setOfCompletedSources.size == listOfSources.size - 1) { + stopBroadcast = true + logInfo("GuideMultipleRequests Timeout. stopBroadcast == true.") + } + } + } + } + } + if (clientSocket != null) { + logDebug("Guide: Accepted new client connection: " + clientSocket) + try { + threadPool.execute(new GuideSingleRequest(clientSocket)) + } catch { + // In failure, close() the socket here; else, the thread will close() it + case ioe: IOException => clientSocket.close() + } + } + } + + logInfo("Sending stopBroadcast notifications...") + sendStopBroadcastNotifications + + MultiTracker.unregisterBroadcast(id) + } finally { + if (serverSocket != null) { + logInfo("GuideMultipleRequests now stopping...") + serverSocket.close() + } + } + // Shutdown the thread pool + threadPool.shutdown() + } + + private def sendStopBroadcastNotifications() { + listOfSources.synchronized { + var listIter = listOfSources.iterator + while (listIter.hasNext) { + var sourceInfo = listIter.next + + var guideSocketToSource: Socket = null + var gosSource: ObjectOutputStream = null + var gisSource: ObjectInputStream = null + + try { + // Connect to the source + guideSocketToSource = new Socket(sourceInfo.hostAddress, sourceInfo.listenPort) + gosSource = new ObjectOutputStream(guideSocketToSource.getOutputStream) + gosSource.flush() + gisSource = new ObjectInputStream(guideSocketToSource.getInputStream) + + // Send stopBroadcast signal + gosSource.writeObject((SourceInfo.StopBroadcast, SourceInfo.StopBroadcast)) + gosSource.flush() + } catch { + case e: Exception => { + logError("sendStopBroadcastNotifications had a " + e) + } + } finally { + if (gisSource != null) { + gisSource.close() + } + if (gosSource != null) { + gosSource.close() + } + if (guideSocketToSource != null) { + guideSocketToSource.close() + } + } + } + } + } + + class GuideSingleRequest(val clientSocket: Socket) + extends Thread with Logging { + private val oos = new ObjectOutputStream(clientSocket.getOutputStream) + oos.flush() + private val ois = new ObjectInputStream(clientSocket.getInputStream) + + private var selectedSourceInfo: SourceInfo = null + private var thisWorkerInfo:SourceInfo = null + + override def run() { + try { + logInfo("new GuideSingleRequest is running") + // Connecting worker is sending in its hostAddress and listenPort it will + // be listening to. Other fields are invalid (SourceInfo.UnusedParam) + var sourceInfo = ois.readObject.asInstanceOf[SourceInfo] + + listOfSources.synchronized { + // Select a suitable source and send it back to the worker + selectedSourceInfo = selectSuitableSource(sourceInfo) + logDebug("Sending selectedSourceInfo: " + selectedSourceInfo) + oos.writeObject(selectedSourceInfo) + oos.flush() + + // Add this new (if it can finish) source to the list of sources + thisWorkerInfo = SourceInfo(sourceInfo.hostAddress, + sourceInfo.listenPort, totalBlocks, totalBytes) + logDebug("Adding possible new source to listOfSources: " + thisWorkerInfo) + listOfSources += thisWorkerInfo + } + + // Wait till the whole transfer is done. Then receive and update source + // statistics in listOfSources + sourceInfo = ois.readObject.asInstanceOf[SourceInfo] + + listOfSources.synchronized { + // This should work since SourceInfo is a case class + assert(listOfSources.contains(selectedSourceInfo)) + + // Remove first + // (Currently removing a source based on just one failure notification!) + listOfSources = listOfSources - selectedSourceInfo + + // Update sourceInfo and put it back in, IF reception succeeded + if (!sourceInfo.receptionFailed) { + // Add thisWorkerInfo to sources that have completed reception + setOfCompletedSources.synchronized { + setOfCompletedSources += thisWorkerInfo + } + + // Update leecher count and put it back in + selectedSourceInfo.currentLeechers -= 1 + listOfSources += selectedSourceInfo + } + } + } catch { + case e: Exception => { + // Remove failed worker from listOfSources and update leecherCount of + // corresponding source worker + listOfSources.synchronized { + if (selectedSourceInfo != null) { + // Remove first + listOfSources = listOfSources - selectedSourceInfo + // Update leecher count and put it back in + selectedSourceInfo.currentLeechers -= 1 + listOfSources += selectedSourceInfo + } + + // Remove thisWorkerInfo + if (listOfSources != null) { + listOfSources = listOfSources - thisWorkerInfo + } + } + } + } finally { + logInfo("GuideSingleRequest is closing streams and sockets") + ois.close() + oos.close() + clientSocket.close() + } + } + + // Assuming the caller to have a synchronized block on listOfSources + // Select one with the most leechers. This will level-wise fill the tree + private def selectSuitableSource(skipSourceInfo: SourceInfo): SourceInfo = { + var maxLeechers = -1 + var selectedSource: SourceInfo = null + + listOfSources.foreach { source => + if ((source.hostAddress != skipSourceInfo.hostAddress || + source.listenPort != skipSourceInfo.listenPort) && + source.currentLeechers < MultiTracker.MaxDegree && + source.currentLeechers > maxLeechers) { + selectedSource = source + maxLeechers = source.currentLeechers + } + } + + // Update leecher count + selectedSource.currentLeechers += 1 + return selectedSource + } + } + } + + class ServeMultipleRequests + extends Thread with Logging { + + var threadPool = Utils.newDaemonCachedThreadPool() + + override def run() { + var serverSocket = new ServerSocket(0) + listenPort = serverSocket.getLocalPort + + logInfo("ServeMultipleRequests started with " + serverSocket) + + listenPortLock.synchronized { listenPortLock.notifyAll() } + + try { + while (!stopBroadcast) { + var clientSocket: Socket = null + try { + serverSocket.setSoTimeout(MultiTracker.ServerSocketTimeout) + clientSocket = serverSocket.accept + } catch { + case e: Exception => { } + } + + if (clientSocket != null) { + logDebug("Serve: Accepted new client connection: " + clientSocket) + try { + threadPool.execute(new ServeSingleRequest(clientSocket)) + } catch { + // In failure, close socket here; else, the thread will close it + case ioe: IOException => clientSocket.close() + } + } + } + } finally { + if (serverSocket != null) { + logInfo("ServeMultipleRequests now stopping...") + serverSocket.close() + } + } + // Shutdown the thread pool + threadPool.shutdown() + } + + class ServeSingleRequest(val clientSocket: Socket) + extends Thread with Logging { + private val oos = new ObjectOutputStream(clientSocket.getOutputStream) + oos.flush() + private val ois = new ObjectInputStream(clientSocket.getInputStream) + + private var sendFrom = 0 + private var sendUntil = totalBlocks + + override def run() { + try { + logInfo("new ServeSingleRequest is running") + + // Receive range to send + var rangeToSend = ois.readObject.asInstanceOf[(Int, Int)] + sendFrom = rangeToSend._1 + sendUntil = rangeToSend._2 + + // If not a valid range, stop broadcast + if (sendFrom == SourceInfo.StopBroadcast && sendUntil == SourceInfo.StopBroadcast) { + stopBroadcast = true + } else { + sendObject + } + } catch { + case e: Exception => logError("ServeSingleRequest had a " + e) + } finally { + logInfo("ServeSingleRequest is closing streams and sockets") + ois.close() + oos.close() + clientSocket.close() + } + } + + private def sendObject() { + // Wait till receiving the SourceInfo from Driver + while (totalBlocks == -1) { + totalBlocksLock.synchronized { totalBlocksLock.wait() } + } + + for (i <- sendFrom until sendUntil) { + while (i == hasBlocks) { + hasBlocksLock.synchronized { hasBlocksLock.wait() } + } + try { + oos.writeObject(arrayOfBlocks(i)) + oos.flush() + } catch { + case e: Exception => logError("sendObject had a " + e) + } + logDebug("Sent block: " + i + " to " + clientSocket) + } + } + } + } +} + +private[spark] class TreeBroadcastFactory +extends BroadcastFactory { + def initialize(isDriver: Boolean) { MultiTracker.initialize(isDriver) } + + def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) = + new TreeBroadcast[T](value_, isLocal, id) + + def stop() { MultiTracker.stop() } +} diff --git a/core/src/main/scala/org/apache/spark/deploy/ApplicationDescription.scala b/core/src/main/scala/org/apache/spark/deploy/ApplicationDescription.scala new file mode 100644 index 0000000000..19d393a0db --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/ApplicationDescription.scala @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy + +private[spark] class ApplicationDescription( + val name: String, + val maxCores: Int, /* Integer.MAX_VALUE denotes an unlimited number of cores */ + val memoryPerSlave: Int, + val command: Command, + val sparkHome: String, + val appUiUrl: String) + extends Serializable { + + val user = System.getProperty("user.name", "<unknown>") + + override def toString: String = "ApplicationDescription(" + name + ")" +} diff --git a/core/src/main/scala/org/apache/spark/deploy/Command.scala b/core/src/main/scala/org/apache/spark/deploy/Command.scala new file mode 100644 index 0000000000..fa8af9a646 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/Command.scala @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy + +import scala.collection.Map + +private[spark] case class Command( + mainClass: String, + arguments: Seq[String], + environment: Map[String, String]) { +} diff --git a/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala b/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala new file mode 100644 index 0000000000..1cfff5e565 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala @@ -0,0 +1,134 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy + +import scala.collection.immutable.List + +import org.apache.spark.deploy.ExecutorState.ExecutorState +import org.apache.spark.deploy.master.{WorkerInfo, ApplicationInfo} +import org.apache.spark.deploy.worker.ExecutorRunner +import org.apache.spark.util.Utils + + +private[deploy] sealed trait DeployMessage extends Serializable + +private[deploy] object DeployMessages { + + // Worker to Master + + case class RegisterWorker( + id: String, + host: String, + port: Int, + cores: Int, + memory: Int, + webUiPort: Int, + publicAddress: String) + extends DeployMessage { + Utils.checkHost(host, "Required hostname") + assert (port > 0) + } + + case class ExecutorStateChanged( + appId: String, + execId: Int, + state: ExecutorState, + message: Option[String], + exitStatus: Option[Int]) + extends DeployMessage + + case class Heartbeat(workerId: String) extends DeployMessage + + // Master to Worker + + case class RegisteredWorker(masterWebUiUrl: String) extends DeployMessage + + case class RegisterWorkerFailed(message: String) extends DeployMessage + + case class KillExecutor(appId: String, execId: Int) extends DeployMessage + + case class LaunchExecutor( + appId: String, + execId: Int, + appDesc: ApplicationDescription, + cores: Int, + memory: Int, + sparkHome: String) + extends DeployMessage + + // Client to Master + + case class RegisterApplication(appDescription: ApplicationDescription) + extends DeployMessage + + // Master to Client + + case class RegisteredApplication(appId: String) extends DeployMessage + + // TODO(matei): replace hostPort with host + case class ExecutorAdded(id: Int, workerId: String, hostPort: String, cores: Int, memory: Int) { + Utils.checkHostPort(hostPort, "Required hostport") + } + + case class ExecutorUpdated(id: Int, state: ExecutorState, message: Option[String], + exitStatus: Option[Int]) + + case class ApplicationRemoved(message: String) + + // Internal message in Client + + case object StopClient + + // MasterWebUI To Master + + case object RequestMasterState + + // Master to MasterWebUI + + case class MasterStateResponse(host: String, port: Int, workers: Array[WorkerInfo], + activeApps: Array[ApplicationInfo], completedApps: Array[ApplicationInfo]) { + + Utils.checkHost(host, "Required hostname") + assert (port > 0) + + def uri = "spark://" + host + ":" + port + } + + // WorkerWebUI to Worker + + case object RequestWorkerState + + // Worker to WorkerWebUI + + case class WorkerStateResponse(host: String, port: Int, workerId: String, + executors: List[ExecutorRunner], finishedExecutors: List[ExecutorRunner], masterUrl: String, + cores: Int, memory: Int, coresUsed: Int, memoryUsed: Int, masterWebUiUrl: String) { + + Utils.checkHost(host, "Required hostname") + assert (port > 0) + } + + // Actor System to Master + + case object CheckForWorkerTimeOut + + case object RequestWebUIPort + + case class WebUIPortResponse(webUIBoundPort: Int) + +} diff --git a/core/src/main/scala/org/apache/spark/deploy/ExecutorState.scala b/core/src/main/scala/org/apache/spark/deploy/ExecutorState.scala new file mode 100644 index 0000000000..fcfea96ad6 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/ExecutorState.scala @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy + +private[spark] object ExecutorState + extends Enumeration("LAUNCHING", "LOADING", "RUNNING", "KILLED", "FAILED", "LOST") { + + val LAUNCHING, LOADING, RUNNING, KILLED, FAILED, LOST = Value + + type ExecutorState = Value + + def isFinished(state: ExecutorState): Boolean = Seq(KILLED, FAILED, LOST).contains(state) +} diff --git a/core/src/main/scala/org/apache/spark/deploy/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/deploy/JsonProtocol.scala new file mode 100644 index 0000000000..a6be8efef1 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/JsonProtocol.scala @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy + +import net.liftweb.json.JsonDSL._ + +import org.apache.spark.deploy.DeployMessages.{MasterStateResponse, WorkerStateResponse} +import org.apache.spark.deploy.master.{ApplicationInfo, WorkerInfo} +import org.apache.spark.deploy.worker.ExecutorRunner + + +private[spark] object JsonProtocol { + def writeWorkerInfo(obj: WorkerInfo) = { + ("id" -> obj.id) ~ + ("host" -> obj.host) ~ + ("port" -> obj.port) ~ + ("webuiaddress" -> obj.webUiAddress) ~ + ("cores" -> obj.cores) ~ + ("coresused" -> obj.coresUsed) ~ + ("memory" -> obj.memory) ~ + ("memoryused" -> obj.memoryUsed) ~ + ("state" -> obj.state.toString) + } + + def writeApplicationInfo(obj: ApplicationInfo) = { + ("starttime" -> obj.startTime) ~ + ("id" -> obj.id) ~ + ("name" -> obj.desc.name) ~ + ("cores" -> obj.desc.maxCores) ~ + ("user" -> obj.desc.user) ~ + ("memoryperslave" -> obj.desc.memoryPerSlave) ~ + ("submitdate" -> obj.submitDate.toString) + } + + def writeApplicationDescription(obj: ApplicationDescription) = { + ("name" -> obj.name) ~ + ("cores" -> obj.maxCores) ~ + ("memoryperslave" -> obj.memoryPerSlave) ~ + ("user" -> obj.user) + } + + def writeExecutorRunner(obj: ExecutorRunner) = { + ("id" -> obj.execId) ~ + ("memory" -> obj.memory) ~ + ("appid" -> obj.appId) ~ + ("appdesc" -> writeApplicationDescription(obj.appDesc)) + } + + def writeMasterState(obj: MasterStateResponse) = { + ("url" -> ("spark://" + obj.uri)) ~ + ("workers" -> obj.workers.toList.map(writeWorkerInfo)) ~ + ("cores" -> obj.workers.map(_.cores).sum) ~ + ("coresused" -> obj.workers.map(_.coresUsed).sum) ~ + ("memory" -> obj.workers.map(_.memory).sum) ~ + ("memoryused" -> obj.workers.map(_.memoryUsed).sum) ~ + ("activeapps" -> obj.activeApps.toList.map(writeApplicationInfo)) ~ + ("completedapps" -> obj.completedApps.toList.map(writeApplicationInfo)) + } + + def writeWorkerState(obj: WorkerStateResponse) = { + ("id" -> obj.workerId) ~ + ("masterurl" -> obj.masterUrl) ~ + ("masterwebuiurl" -> obj.masterWebUiUrl) ~ + ("cores" -> obj.cores) ~ + ("coresused" -> obj.coresUsed) ~ + ("memory" -> obj.memory) ~ + ("memoryused" -> obj.memoryUsed) ~ + ("executors" -> obj.executors.toList.map(writeExecutorRunner)) ~ + ("finishedexecutors" -> obj.finishedExecutors.toList.map(writeExecutorRunner)) + } +} diff --git a/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala b/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala new file mode 100644 index 0000000000..6a7d5a85ba --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy + +import akka.actor.{ActorRef, Props, Actor, ActorSystem, Terminated} + +import org.apache.spark.deploy.worker.Worker +import org.apache.spark.deploy.master.Master +import org.apache.spark.util.{Utils, AkkaUtils} +import org.apache.spark.{Logging} + +import scala.collection.mutable.ArrayBuffer + +/** + * Testing class that creates a Spark standalone process in-cluster (that is, running the + * spark.deploy.master.Master and spark.deploy.worker.Workers in the same JVMs). Executors launched + * by the Workers still run in separate JVMs. This can be used to test distributed operation and + * fault recovery without spinning up a lot of processes. + */ +private[spark] +class LocalSparkCluster(numWorkers: Int, coresPerWorker: Int, memoryPerWorker: Int) extends Logging { + + private val localHostname = Utils.localHostName() + private val masterActorSystems = ArrayBuffer[ActorSystem]() + private val workerActorSystems = ArrayBuffer[ActorSystem]() + + def start(): String = { + logInfo("Starting a local Spark cluster with " + numWorkers + " workers.") + + /* Start the Master */ + val (masterSystem, masterPort, _) = Master.startSystemAndActor(localHostname, 0, 0) + masterActorSystems += masterSystem + val masterUrl = "spark://" + localHostname + ":" + masterPort + + /* Start the Workers */ + for (workerNum <- 1 to numWorkers) { + val (workerSystem, _) = Worker.startSystemAndActor(localHostname, 0, 0, coresPerWorker, + memoryPerWorker, masterUrl, null, Some(workerNum)) + workerActorSystems += workerSystem + } + + return masterUrl + } + + def stop() { + logInfo("Shutting down local Spark cluster.") + // Stop the workers before the master so they don't get upset that it disconnected + // TODO: In Akka 2.1.x, ActorSystem.awaitTermination hangs when you have remote actors! + // This is unfortunate, but for now we just comment it out. + workerActorSystems.foreach(_.shutdown()) + //workerActorSystems.foreach(_.awaitTermination()) + masterActorSystems.foreach(_.shutdown()) + //masterActorSystems.foreach(_.awaitTermination()) + masterActorSystems.clear() + workerActorSystems.clear() + } +} diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala new file mode 100644 index 0000000000..0a5f4c368f --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.mapred.JobConf + + +/** + * Contains util methods to interact with Hadoop from spark. + */ +class SparkHadoopUtil { + + // Return an appropriate (subclass) of Configuration. Creating config can initializes some hadoop subsystems + def newConfiguration(): Configuration = new Configuration() + + // add any user credentials to the job conf which are necessary for running on a secure Hadoop cluster + def addCredentials(conf: JobConf) {} + + def isYarnMode(): Boolean = { false } + +} diff --git a/core/src/main/scala/org/apache/spark/deploy/WebUI.scala b/core/src/main/scala/org/apache/spark/deploy/WebUI.scala new file mode 100644 index 0000000000..ae258b58b9 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/WebUI.scala @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy + +import java.text.SimpleDateFormat +import java.util.Date + +/** + * Utilities used throughout the web UI. + */ +private[spark] object DeployWebUI { + val DATE_FORMAT = new SimpleDateFormat("yyyy/MM/dd HH:mm:ss") + + def formatDate(date: Date): String = DATE_FORMAT.format(date) + + def formatDate(timestamp: Long): String = DATE_FORMAT.format(new Date(timestamp)) + + def formatDuration(milliseconds: Long): String = { + val seconds = milliseconds.toDouble / 1000 + if (seconds < 60) { + return "%.0f s".format(seconds) + } + val minutes = seconds / 60 + if (minutes < 10) { + return "%.1f min".format(minutes) + } else if (minutes < 60) { + return "%.0f min".format(minutes) + } + val hours = minutes / 60 + return "%.1f h".format(hours) + } +} diff --git a/core/src/main/scala/org/apache/spark/deploy/client/Client.scala b/core/src/main/scala/org/apache/spark/deploy/client/Client.scala new file mode 100644 index 0000000000..14a90934f6 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/client/Client.scala @@ -0,0 +1,147 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.client + +import java.util.concurrent.TimeoutException + +import scala.concurrent.duration._ +import scala.concurrent.Await + +import akka.actor._ +import akka.actor.Terminated +import akka.pattern.AskTimeoutException +import akka.pattern.ask +import akka.remote.RemoteClientDisconnected +import akka.remote.RemoteClientLifeCycleEvent +import akka.remote.RemoteClientShutdown + +import org.apache.spark.Logging +import org.apache.spark.deploy.{ApplicationDescription, ExecutorState} +import org.apache.spark.deploy.DeployMessages._ +import org.apache.spark.deploy.master.Master + + +/** + * The main class used to talk to a Spark deploy cluster. Takes a master URL, an app description, + * and a listener for cluster events, and calls back the listener when various events occur. + */ +private[spark] class Client( + actorSystem: ActorSystem, + masterUrl: String, + appDescription: ApplicationDescription, + listener: ClientListener) + extends Logging { + + var actor: ActorRef = null + var appId: String = null + + class ClientActor extends Actor with Logging { + var master: ActorRef = null + var masterAddress: Address = null + var alreadyDisconnected = false // To avoid calling listener.disconnected() multiple times + + override def preStart() { + logInfo("Connecting to master " + masterUrl) + try { + master = context.actorFor(Master.toAkkaUrl(masterUrl)) + masterAddress = master.path.address + master ! RegisterApplication(appDescription) + context.system.eventStream.subscribe(self, classOf[RemoteClientLifeCycleEvent]) + context.watch(master) // Doesn't work with remote actors, but useful for testing + } catch { + case e: Exception => + logError("Failed to connect to master", e) + markDisconnected() + context.stop(self) + } + } + + override def receive = { + case RegisteredApplication(appId_) => + appId = appId_ + listener.connected(appId) + + case ApplicationRemoved(message) => + logError("Master removed our application: %s; stopping client".format(message)) + markDisconnected() + context.stop(self) + + case ExecutorAdded(id: Int, workerId: String, hostPort: String, cores: Int, memory: Int) => + val fullId = appId + "/" + id + logInfo("Executor added: %s on %s (%s) with %d cores".format(fullId, workerId, hostPort, cores)) + listener.executorAdded(fullId, workerId, hostPort, cores, memory) + + case ExecutorUpdated(id, state, message, exitStatus) => + val fullId = appId + "/" + id + val messageText = message.map(s => " (" + s + ")").getOrElse("") + logInfo("Executor updated: %s is now %s%s".format(fullId, state, messageText)) + if (ExecutorState.isFinished(state)) { + listener.executorRemoved(fullId, message.getOrElse(""), exitStatus) + } + + case Terminated(actor_) if actor_ == master => + logError("Connection to master failed; stopping client") + markDisconnected() + context.stop(self) + + case RemoteClientDisconnected(transport, address) if address == masterAddress => + logError("Connection to master failed; stopping client") + markDisconnected() + context.stop(self) + + case RemoteClientShutdown(transport, address) if address == masterAddress => + logError("Connection to master failed; stopping client") + markDisconnected() + context.stop(self) + + case StopClient => + markDisconnected() + sender ! true + context.stop(self) + } + + /** + * Notify the listener that we disconnected, if we hadn't already done so before. + */ + def markDisconnected() { + if (!alreadyDisconnected) { + listener.disconnected() + alreadyDisconnected = true + } + } + } + + def start() { + // Just launch an actor; it will call back into the listener. + actor = actorSystem.actorOf(Props(new ClientActor)) + } + + def stop() { + if (actor != null) { + try { + val timeout = Duration.create(System.getProperty("spark.akka.askTimeout", "10").toLong, "seconds") + val future = actor.ask(StopClient)(timeout) + Await.result(future, timeout) + } catch { + case e: TimeoutException => + logInfo("Stop request to Master timed out; it may already be shut down.") + } + actor = null + } + } +} diff --git a/core/src/main/scala/org/apache/spark/deploy/client/ClientListener.scala b/core/src/main/scala/org/apache/spark/deploy/client/ClientListener.scala new file mode 100644 index 0000000000..4605368c11 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/client/ClientListener.scala @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.client + +/** + * Callbacks invoked by deploy client when various events happen. There are currently four events: + * connecting to the cluster, disconnecting, being given an executor, and having an executor + * removed (either due to failure or due to revocation). + * + * Users of this API should *not* block inside the callback methods. + */ +private[spark] trait ClientListener { + def connected(appId: String): Unit + + def disconnected(): Unit + + def executorAdded(fullId: String, workerId: String, hostPort: String, cores: Int, memory: Int): Unit + + def executorRemoved(fullId: String, message: String, exitStatus: Option[Int]): Unit +} diff --git a/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala b/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala new file mode 100644 index 0000000000..d5e9a0e095 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.client + +import org.apache.spark.util.{Utils, AkkaUtils} +import org.apache.spark.{Logging} +import org.apache.spark.deploy.{Command, ApplicationDescription} + +private[spark] object TestClient { + + class TestListener extends ClientListener with Logging { + def connected(id: String) { + logInfo("Connected to master, got app ID " + id) + } + + def disconnected() { + logInfo("Disconnected from master") + System.exit(0) + } + + def executorAdded(id: String, workerId: String, hostPort: String, cores: Int, memory: Int) {} + + def executorRemoved(id: String, message: String, exitStatus: Option[Int]) {} + } + + def main(args: Array[String]) { + val url = args(0) + val (actorSystem, port) = AkkaUtils.createActorSystem("spark", Utils.localIpAddress, 0) + val desc = new ApplicationDescription( + "TestClient", 1, 512, Command("spark.deploy.client.TestExecutor", Seq(), Map()), "dummy-spark-home", "ignored") + val listener = new TestListener + val client = new Client(actorSystem, url, desc, listener) + client.start() + actorSystem.awaitTermination() + } +} diff --git a/core/src/main/scala/org/apache/spark/deploy/client/TestExecutor.scala b/core/src/main/scala/org/apache/spark/deploy/client/TestExecutor.scala new file mode 100644 index 0000000000..c5ac45c673 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/client/TestExecutor.scala @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.client + +private[spark] object TestExecutor { + def main(args: Array[String]) { + println("Hello world!") + while (true) { + Thread.sleep(1000) + } + } +} diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala new file mode 100644 index 0000000000..bd5327627a --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.master + +import org.apache.spark.deploy.ApplicationDescription +import java.util.Date +import akka.actor.ActorRef +import scala.collection.mutable + +private[spark] class ApplicationInfo( + val startTime: Long, + val id: String, + val desc: ApplicationDescription, + val submitDate: Date, + val driver: ActorRef, + val appUiUrl: String) +{ + var state = ApplicationState.WAITING + var executors = new mutable.HashMap[Int, ExecutorInfo] + var coresGranted = 0 + var endTime = -1L + val appSource = new ApplicationSource(this) + + private var nextExecutorId = 0 + + def newExecutorId(): Int = { + val id = nextExecutorId + nextExecutorId += 1 + id + } + + def addExecutor(worker: WorkerInfo, cores: Int): ExecutorInfo = { + val exec = new ExecutorInfo(newExecutorId(), this, worker, cores, desc.memoryPerSlave) + executors(exec.id) = exec + coresGranted += cores + exec + } + + def removeExecutor(exec: ExecutorInfo) { + if (executors.contains(exec.id)) { + executors -= exec.id + coresGranted -= exec.cores + } + } + + def coresLeft: Int = desc.maxCores - coresGranted + + private var _retryCount = 0 + + def retryCount = _retryCount + + def incrementRetryCount = { + _retryCount += 1 + _retryCount + } + + def markFinished(endState: ApplicationState.Value) { + state = endState + endTime = System.currentTimeMillis() + } + + def duration: Long = { + if (endTime != -1) { + endTime - startTime + } else { + System.currentTimeMillis() - startTime + } + } + +} diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ApplicationSource.scala b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationSource.scala new file mode 100644 index 0000000000..5a24042e14 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationSource.scala @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.master + +import com.codahale.metrics.{Gauge, MetricRegistry} + +import org.apache.spark.metrics.source.Source + +class ApplicationSource(val application: ApplicationInfo) extends Source { + val metricRegistry = new MetricRegistry() + val sourceName = "%s.%s.%s".format("application", application.desc.name, + System.currentTimeMillis()) + + metricRegistry.register(MetricRegistry.name("status"), new Gauge[String] { + override def getValue: String = application.state.toString + }) + + metricRegistry.register(MetricRegistry.name("runtime_ms"), new Gauge[Long] { + override def getValue: Long = application.duration + }) + + metricRegistry.register(MetricRegistry.name("cores", "number"), new Gauge[Int] { + override def getValue: Int = application.coresGranted + }) + +} diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ApplicationState.scala b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationState.scala new file mode 100644 index 0000000000..7e804223cf --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationState.scala @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.master + +private[spark] object ApplicationState + extends Enumeration("WAITING", "RUNNING", "FINISHED", "FAILED") { + + type ApplicationState = Value + + val WAITING, RUNNING, FINISHED, FAILED = Value + + val MAX_NUM_RETRY = 10 +} diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ExecutorInfo.scala b/core/src/main/scala/org/apache/spark/deploy/master/ExecutorInfo.scala new file mode 100644 index 0000000000..cf384a985e --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/master/ExecutorInfo.scala @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.master + +import org.apache.spark.deploy.ExecutorState + +private[spark] class ExecutorInfo( + val id: Int, + val application: ApplicationInfo, + val worker: WorkerInfo, + val cores: Int, + val memory: Int) { + + var state = ExecutorState.LAUNCHING + + def fullId: String = application.id + "/" + id +} diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala new file mode 100644 index 0000000000..2efd16bca0 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -0,0 +1,399 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.master + +import java.util.Date +import java.text.SimpleDateFormat + +import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} +import scala.concurrent.Await +import scala.concurrent.duration._ + +import akka.actor._ +import akka.actor.Terminated +import akka.pattern.ask +import akka.remote.{RemoteClientLifeCycleEvent, RemoteClientDisconnected, RemoteClientShutdown} + +import org.apache.spark.{Logging, SparkException} +import org.apache.spark.deploy.{ApplicationDescription, ExecutorState} +import org.apache.spark.deploy.DeployMessages._ +import org.apache.spark.deploy.master.ui.MasterWebUI +import org.apache.spark.metrics.MetricsSystem +import org.apache.spark.util.{Utils, AkkaUtils} +import akka.util.Timeout + + +private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Actor with Logging { + val DATE_FORMAT = new SimpleDateFormat("yyyyMMddHHmmss") // For application IDs + val WORKER_TIMEOUT = System.getProperty("spark.worker.timeout", "60").toLong * 1000 + val RETAINED_APPLICATIONS = System.getProperty("spark.deploy.retainedApplications", "200").toInt + val REAPER_ITERATIONS = System.getProperty("spark.dead.worker.persistence", "15").toInt + + var nextAppNumber = 0 + val workers = new HashSet[WorkerInfo] + val idToWorker = new HashMap[String, WorkerInfo] + val actorToWorker = new HashMap[ActorRef, WorkerInfo] + val addressToWorker = new HashMap[Address, WorkerInfo] + + val apps = new HashSet[ApplicationInfo] + val idToApp = new HashMap[String, ApplicationInfo] + val actorToApp = new HashMap[ActorRef, ApplicationInfo] + val addressToApp = new HashMap[Address, ApplicationInfo] + + val waitingApps = new ArrayBuffer[ApplicationInfo] + val completedApps = new ArrayBuffer[ApplicationInfo] + + var firstApp: Option[ApplicationInfo] = None + + Utils.checkHost(host, "Expected hostname") + + val masterMetricsSystem = MetricsSystem.createMetricsSystem("master") + val applicationMetricsSystem = MetricsSystem.createMetricsSystem("applications") + val masterSource = new MasterSource(this) + + val webUi = new MasterWebUI(this, webUiPort) + + val masterPublicAddress = { + val envVar = System.getenv("SPARK_PUBLIC_DNS") + if (envVar != null) envVar else host + } + + // As a temporary workaround before better ways of configuring memory, we allow users to set + // a flag that will perform round-robin scheduling across the nodes (spreading out each app + // among all the nodes) instead of trying to consolidate each app onto a small # of nodes. + val spreadOutApps = System.getProperty("spark.deploy.spreadOut", "true").toBoolean + + override def preStart() { + logInfo("Starting Spark master at spark://" + host + ":" + port) + // Listen for remote client disconnection events, since they don't go through Akka's watch() + context.system.eventStream.subscribe(self, classOf[RemoteClientLifeCycleEvent]) + webUi.start() + import context.dispatcher + context.system.scheduler.schedule(0 millis, WORKER_TIMEOUT millis, self, CheckForWorkerTimeOut) + + masterMetricsSystem.registerSource(masterSource) + masterMetricsSystem.start() + applicationMetricsSystem.start() + } + + override def postStop() { + webUi.stop() + masterMetricsSystem.stop() + applicationMetricsSystem.stop() + } + + override def receive = { + case RegisterWorker(id, host, workerPort, cores, memory, worker_webUiPort, publicAddress) => { + logInfo("Registering worker %s:%d with %d cores, %s RAM".format( + host, workerPort, cores, Utils.megabytesToString(memory))) + if (idToWorker.contains(id)) { + sender ! RegisterWorkerFailed("Duplicate worker ID") + } else { + addWorker(id, host, workerPort, cores, memory, worker_webUiPort, publicAddress) + context.watch(sender) // This doesn't work with remote actors but helps for testing + sender ! RegisteredWorker("http://" + masterPublicAddress + ":" + webUi.boundPort.get) + schedule() + } + } + + case RegisterApplication(description) => { + logInfo("Registering app " + description.name) + val app = addApplication(description, sender) + logInfo("Registered app " + description.name + " with ID " + app.id) + waitingApps += app + context.watch(sender) // This doesn't work with remote actors but helps for testing + sender ! RegisteredApplication(app.id) + schedule() + } + + case ExecutorStateChanged(appId, execId, state, message, exitStatus) => { + val execOption = idToApp.get(appId).flatMap(app => app.executors.get(execId)) + execOption match { + case Some(exec) => { + exec.state = state + exec.application.driver ! ExecutorUpdated(execId, state, message, exitStatus) + if (ExecutorState.isFinished(state)) { + val appInfo = idToApp(appId) + // Remove this executor from the worker and app + logInfo("Removing executor " + exec.fullId + " because it is " + state) + appInfo.removeExecutor(exec) + exec.worker.removeExecutor(exec) + + // Only retry certain number of times so we don't go into an infinite loop. + if (appInfo.incrementRetryCount < ApplicationState.MAX_NUM_RETRY) { + schedule() + } else { + logError("Application %s with ID %s failed %d times, removing it".format( + appInfo.desc.name, appInfo.id, appInfo.retryCount)) + removeApplication(appInfo, ApplicationState.FAILED) + } + } + } + case None => + logWarning("Got status update for unknown executor " + appId + "/" + execId) + } + } + + case Heartbeat(workerId) => { + idToWorker.get(workerId) match { + case Some(workerInfo) => + workerInfo.lastHeartbeat = System.currentTimeMillis() + case None => + logWarning("Got heartbeat from unregistered worker " + workerId) + } + } + + case Terminated(actor) => { + // The disconnected actor could've been either a worker or an app; remove whichever of + // those we have an entry for in the corresponding actor hashmap + actorToWorker.get(actor).foreach(removeWorker) + actorToApp.get(actor).foreach(finishApplication) + } + + case RemoteClientDisconnected(transport, address) => { + // The disconnected client could've been either a worker or an app; remove whichever it was + addressToWorker.get(address).foreach(removeWorker) + addressToApp.get(address).foreach(finishApplication) + } + + case RemoteClientShutdown(transport, address) => { + // The disconnected client could've been either a worker or an app; remove whichever it was + addressToWorker.get(address).foreach(removeWorker) + addressToApp.get(address).foreach(finishApplication) + } + + case RequestMasterState => { + sender ! MasterStateResponse(host, port, workers.toArray, apps.toArray, completedApps.toArray) + } + + case CheckForWorkerTimeOut => { + timeOutDeadWorkers() + } + + case RequestWebUIPort => { + sender ! WebUIPortResponse(webUi.boundPort.getOrElse(-1)) + } + } + + /** + * Can an app use the given worker? True if the worker has enough memory and we haven't already + * launched an executor for the app on it (right now the standalone backend doesn't like having + * two executors on the same worker). + */ + def canUse(app: ApplicationInfo, worker: WorkerInfo): Boolean = { + worker.memoryFree >= app.desc.memoryPerSlave && !worker.hasExecutor(app) + } + + /** + * Schedule the currently available resources among waiting apps. This method will be called + * every time a new app joins or resource availability changes. + */ + def schedule() { + // Right now this is a very simple FIFO scheduler. We keep trying to fit in the first app + // in the queue, then the second app, etc. + if (spreadOutApps) { + // Try to spread out each app among all the nodes, until it has all its cores + for (app <- waitingApps if app.coresLeft > 0) { + val usableWorkers = workers.toArray.filter(_.state == WorkerState.ALIVE) + .filter(canUse(app, _)).sortBy(_.coresFree).reverse + val numUsable = usableWorkers.length + val assigned = new Array[Int](numUsable) // Number of cores to give on each node + var toAssign = math.min(app.coresLeft, usableWorkers.map(_.coresFree).sum) + var pos = 0 + while (toAssign > 0) { + if (usableWorkers(pos).coresFree - assigned(pos) > 0) { + toAssign -= 1 + assigned(pos) += 1 + } + pos = (pos + 1) % numUsable + } + // Now that we've decided how many cores to give on each node, let's actually give them + for (pos <- 0 until numUsable) { + if (assigned(pos) > 0) { + val exec = app.addExecutor(usableWorkers(pos), assigned(pos)) + launchExecutor(usableWorkers(pos), exec, app.desc.sparkHome) + app.state = ApplicationState.RUNNING + } + } + } + } else { + // Pack each app into as few nodes as possible until we've assigned all its cores + for (worker <- workers if worker.coresFree > 0 && worker.state == WorkerState.ALIVE) { + for (app <- waitingApps if app.coresLeft > 0) { + if (canUse(app, worker)) { + val coresToUse = math.min(worker.coresFree, app.coresLeft) + if (coresToUse > 0) { + val exec = app.addExecutor(worker, coresToUse) + launchExecutor(worker, exec, app.desc.sparkHome) + app.state = ApplicationState.RUNNING + } + } + } + } + } + } + + def launchExecutor(worker: WorkerInfo, exec: ExecutorInfo, sparkHome: String) { + logInfo("Launching executor " + exec.fullId + " on worker " + worker.id) + worker.addExecutor(exec) + worker.actor ! LaunchExecutor( + exec.application.id, exec.id, exec.application.desc, exec.cores, exec.memory, sparkHome) + exec.application.driver ! ExecutorAdded( + exec.id, worker.id, worker.hostPort, exec.cores, exec.memory) + } + + def addWorker(id: String, host: String, port: Int, cores: Int, memory: Int, webUiPort: Int, + publicAddress: String): WorkerInfo = { + // There may be one or more refs to dead workers on this same node (w/ different ID's), + // remove them. + workers.filter { w => + (w.host == host && w.port == port) && (w.state == WorkerState.DEAD) + }.foreach { w => + workers -= w + } + val worker = new WorkerInfo(id, host, port, cores, memory, sender, webUiPort, publicAddress) + workers += worker + idToWorker(worker.id) = worker + actorToWorker(sender) = worker + addressToWorker(sender.path.address) = worker + worker + } + + def removeWorker(worker: WorkerInfo) { + logInfo("Removing worker " + worker.id + " on " + worker.host + ":" + worker.port) + worker.setState(WorkerState.DEAD) + idToWorker -= worker.id + actorToWorker -= worker.actor + addressToWorker -= worker.actor.path.address + for (exec <- worker.executors.values) { + logInfo("Telling app of lost executor: " + exec.id) + exec.application.driver ! ExecutorUpdated( + exec.id, ExecutorState.LOST, Some("worker lost"), None) + exec.application.removeExecutor(exec) + } + } + + def addApplication(desc: ApplicationDescription, driver: ActorRef): ApplicationInfo = { + val now = System.currentTimeMillis() + val date = new Date(now) + val app = new ApplicationInfo(now, newApplicationId(date), desc, date, driver, desc.appUiUrl) + applicationMetricsSystem.registerSource(app.appSource) + apps += app + idToApp(app.id) = app + actorToApp(driver) = app + addressToApp(driver.path.address) = app + if (firstApp == None) { + firstApp = Some(app) + } + val workersAlive = workers.filter(_.state == WorkerState.ALIVE).toArray + if (workersAlive.size > 0 && !workersAlive.exists(_.memoryFree >= desc.memoryPerSlave)) { + logWarning("Could not find any workers with enough memory for " + firstApp.get.id) + } + app + } + + def finishApplication(app: ApplicationInfo) { + removeApplication(app, ApplicationState.FINISHED) + } + + def removeApplication(app: ApplicationInfo, state: ApplicationState.Value) { + if (apps.contains(app)) { + logInfo("Removing app " + app.id) + apps -= app + idToApp -= app.id + actorToApp -= app.driver + addressToApp -= app.driver.path.address + if (completedApps.size >= RETAINED_APPLICATIONS) { + val toRemove = math.max(RETAINED_APPLICATIONS / 10, 1) + completedApps.take(toRemove).foreach( a => { + applicationMetricsSystem.removeSource(a.appSource) + }) + completedApps.trimStart(toRemove) + } + completedApps += app // Remember it in our history + waitingApps -= app + for (exec <- app.executors.values) { + exec.worker.removeExecutor(exec) + exec.worker.actor ! KillExecutor(exec.application.id, exec.id) + exec.state = ExecutorState.KILLED + } + app.markFinished(state) + if (state != ApplicationState.FINISHED) { + app.driver ! ApplicationRemoved(state.toString) + } + schedule() + } + } + + /** Generate a new app ID given a app's submission date */ + def newApplicationId(submitDate: Date): String = { + val appId = "app-%s-%04d".format(DATE_FORMAT.format(submitDate), nextAppNumber) + nextAppNumber += 1 + appId + } + + /** Check for, and remove, any timed-out workers */ + def timeOutDeadWorkers() { + // Copy the workers into an array so we don't modify the hashset while iterating through it + val currentTime = System.currentTimeMillis() + val toRemove = workers.filter(_.lastHeartbeat < currentTime - WORKER_TIMEOUT).toArray + for (worker <- toRemove) { + if (worker.state != WorkerState.DEAD) { + logWarning("Removing %s because we got no heartbeat in %d seconds".format( + worker.id, WORKER_TIMEOUT/1000)) + removeWorker(worker) + } else { + if (worker.lastHeartbeat < currentTime - ((REAPER_ITERATIONS + 1) * WORKER_TIMEOUT)) + workers -= worker // we've seen this DEAD worker in the UI, etc. for long enough; cull it + } + } + } +} + +private[spark] object Master { + private val systemName = "sparkMaster" + private val actorName = "Master" + private val sparkUrlRegex = "spark://([^:]+):([0-9]+)".r + + def main(argStrings: Array[String]) { + val args = new MasterArguments(argStrings) + val (actorSystem, _, _) = startSystemAndActor(args.host, args.port, args.webUiPort) + actorSystem.awaitTermination() + } + + /** Returns an `akka://...` URL for the Master actor given a sparkUrl `spark://host:ip`. */ + def toAkkaUrl(sparkUrl: String): String = { + sparkUrl match { + case sparkUrlRegex(host, port) => + "akka://%s@%s:%s/user/%s".format(systemName, host, port, actorName) + case _ => + throw new SparkException("Invalid master URL: " + sparkUrl) + } + } + + def startSystemAndActor(host: String, port: Int, webUiPort: Int): (ActorSystem, Int, Int) = { + val (actorSystem, boundPort) = AkkaUtils.createActorSystem(systemName, host, port) + val actor = actorSystem.actorOf(Props(new Master(host, boundPort, webUiPort)), name = actorName) + val timeoutDuration = Duration.create( + System.getProperty("spark.akka.askTimeout", "10").toLong, "seconds") + implicit val timeout = Timeout(timeoutDuration) + val respFuture = actor ? RequestWebUIPort // ask pattern + val resp = Await.result(respFuture, timeoutDuration).asInstanceOf[WebUIPortResponse] + (actorSystem, boundPort, resp.webUIBoundPort) + } +} diff --git a/core/src/main/scala/org/apache/spark/deploy/master/MasterArguments.scala b/core/src/main/scala/org/apache/spark/deploy/master/MasterArguments.scala new file mode 100644 index 0000000000..9d89b455fb --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/master/MasterArguments.scala @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.master + +import org.apache.spark.util.{Utils, IntParam} + +/** + * Command-line parser for the master. + */ +private[spark] class MasterArguments(args: Array[String]) { + var host = Utils.localHostName() + var port = 7077 + var webUiPort = 8080 + + // Check for settings in environment variables + if (System.getenv("SPARK_MASTER_HOST") != null) { + host = System.getenv("SPARK_MASTER_HOST") + } + if (System.getenv("SPARK_MASTER_PORT") != null) { + port = System.getenv("SPARK_MASTER_PORT").toInt + } + if (System.getenv("SPARK_MASTER_WEBUI_PORT") != null) { + webUiPort = System.getenv("SPARK_MASTER_WEBUI_PORT").toInt + } + if (System.getProperty("master.ui.port") != null) { + webUiPort = System.getProperty("master.ui.port").toInt + } + + parse(args.toList) + + def parse(args: List[String]): Unit = args match { + case ("--ip" | "-i") :: value :: tail => + Utils.checkHost(value, "ip no longer supported, please use hostname " + value) + host = value + parse(tail) + + case ("--host" | "-h") :: value :: tail => + Utils.checkHost(value, "Please use hostname " + value) + host = value + parse(tail) + + case ("--port" | "-p") :: IntParam(value) :: tail => + port = value + parse(tail) + + case "--webui-port" :: IntParam(value) :: tail => + webUiPort = value + parse(tail) + + case ("--help" | "-h") :: tail => + printUsageAndExit(0) + + case Nil => {} + + case _ => + printUsageAndExit(1) + } + + /** + * Print usage and exit JVM with the given exit code. + */ + def printUsageAndExit(exitCode: Int) { + System.err.println( + "Usage: Master [options]\n" + + "\n" + + "Options:\n" + + " -i HOST, --ip HOST Hostname to listen on (deprecated, please use --host or -h) \n" + + " -h HOST, --host HOST Hostname to listen on\n" + + " -p PORT, --port PORT Port to listen on (default: 7077)\n" + + " --webui-port PORT Port for web UI (default: 8080)") + System.exit(exitCode) + } +} diff --git a/core/src/main/scala/org/apache/spark/deploy/master/MasterSource.scala b/core/src/main/scala/org/apache/spark/deploy/master/MasterSource.scala new file mode 100644 index 0000000000..23d1cb77da --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/master/MasterSource.scala @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.master + +import com.codahale.metrics.{Gauge, MetricRegistry} + +import org.apache.spark.metrics.source.Source + +private[spark] class MasterSource(val master: Master) extends Source { + val metricRegistry = new MetricRegistry() + val sourceName = "master" + + // Gauge for worker numbers in cluster + metricRegistry.register(MetricRegistry.name("workers","number"), new Gauge[Int] { + override def getValue: Int = master.workers.size + }) + + // Gauge for application numbers in cluster + metricRegistry.register(MetricRegistry.name("apps", "number"), new Gauge[Int] { + override def getValue: Int = master.apps.size + }) + + // Gauge for waiting application numbers in cluster + metricRegistry.register(MetricRegistry.name("waitingApps", "number"), new Gauge[Int] { + override def getValue: Int = master.waitingApps.size + }) +} diff --git a/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala b/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala new file mode 100644 index 0000000000..6219f11f2a --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.master + +import akka.actor.ActorRef +import scala.collection.mutable +import org.apache.spark.util.Utils + +private[spark] class WorkerInfo( + val id: String, + val host: String, + val port: Int, + val cores: Int, + val memory: Int, + val actor: ActorRef, + val webUiPort: Int, + val publicAddress: String) { + + Utils.checkHost(host, "Expected hostname") + assert (port > 0) + + var executors = new mutable.HashMap[String, ExecutorInfo] // fullId => info + var state: WorkerState.Value = WorkerState.ALIVE + var coresUsed = 0 + var memoryUsed = 0 + + var lastHeartbeat = System.currentTimeMillis() + + def coresFree: Int = cores - coresUsed + def memoryFree: Int = memory - memoryUsed + + def hostPort: String = { + assert (port > 0) + host + ":" + port + } + + def addExecutor(exec: ExecutorInfo) { + executors(exec.fullId) = exec + coresUsed += exec.cores + memoryUsed += exec.memory + } + + def removeExecutor(exec: ExecutorInfo) { + if (executors.contains(exec.fullId)) { + executors -= exec.fullId + coresUsed -= exec.cores + memoryUsed -= exec.memory + } + } + + def hasExecutor(app: ApplicationInfo): Boolean = { + executors.values.exists(_.application == app) + } + + def webUiAddress : String = { + "http://" + this.publicAddress + ":" + this.webUiPort + } + + def setState(state: WorkerState.Value) = { + this.state = state + } +} diff --git a/core/src/main/scala/org/apache/spark/deploy/master/WorkerState.scala b/core/src/main/scala/org/apache/spark/deploy/master/WorkerState.scala new file mode 100644 index 0000000000..b5ee6dca79 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/master/WorkerState.scala @@ -0,0 +1,24 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.master + +private[spark] object WorkerState extends Enumeration("ALIVE", "DEAD", "DECOMMISSIONED") { + type WorkerState = Value + + val ALIVE, DEAD, DECOMMISSIONED = Value +} diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala new file mode 100644 index 0000000000..3b983c19eb --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala @@ -0,0 +1,119 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.master.ui + +import scala.xml.Node + +import akka.pattern.ask + +import scala.concurrent.Await +import scala.concurrent.duration._ + +import javax.servlet.http.HttpServletRequest + +import net.liftweb.json.JsonAST.JValue + +import org.apache.spark.deploy.DeployMessages.{MasterStateResponse, RequestMasterState} +import org.apache.spark.deploy.JsonProtocol +import org.apache.spark.deploy.master.ExecutorInfo +import org.apache.spark.ui.UIUtils +import org.apache.spark.util.Utils + +private[spark] class ApplicationPage(parent: MasterWebUI) { + val master = parent.masterActorRef + implicit val timeout = parent.timeout + + /** Executor details for a particular application */ + def renderJson(request: HttpServletRequest): JValue = { + val appId = request.getParameter("appId") + val stateFuture = (master ? RequestMasterState)(timeout).mapTo[MasterStateResponse] + val state = Await.result(stateFuture, 30 seconds) + val app = state.activeApps.find(_.id == appId).getOrElse({ + state.completedApps.find(_.id == appId).getOrElse(null) + }) + JsonProtocol.writeApplicationInfo(app) + } + + /** Executor details for a particular application */ + def render(request: HttpServletRequest): Seq[Node] = { + val appId = request.getParameter("appId") + val stateFuture = (master ? RequestMasterState)(timeout).mapTo[MasterStateResponse] + val state = Await.result(stateFuture, 30 seconds) + val app = state.activeApps.find(_.id == appId).getOrElse({ + state.completedApps.find(_.id == appId).getOrElse(null) + }) + + val executorHeaders = Seq("ExecutorID", "Worker", "Cores", "Memory", "State", "Logs") + val executors = app.executors.values.toSeq + val executorTable = UIUtils.listingTable(executorHeaders, executorRow, executors) + + val content = + <div class="row-fluid"> + <div class="span12"> + <ul class="unstyled"> + <li><strong>ID:</strong> {app.id}</li> + <li><strong>Name:</strong> {app.desc.name}</li> + <li><strong>User:</strong> {app.desc.user}</li> + <li><strong>Cores:</strong> + { + if (app.desc.maxCores == Integer.MAX_VALUE) { + "Unlimited (%s granted)".format(app.coresGranted) + } else { + "%s (%s granted, %s left)".format( + app.desc.maxCores, app.coresGranted, app.coresLeft) + } + } + </li> + <li> + <strong>Executor Memory:</strong> + {Utils.megabytesToString(app.desc.memoryPerSlave)} + </li> + <li><strong>Submit Date:</strong> {app.submitDate}</li> + <li><strong>State:</strong> {app.state}</li> + <li><strong><a href={app.appUiUrl}>Application Detail UI</a></strong></li> + </ul> + </div> + </div> + + <div class="row-fluid"> <!-- Executors --> + <div class="span12"> + <h4> Executor Summary </h4> + {executorTable} + </div> + </div>; + UIUtils.basicSparkPage(content, "Application: " + app.desc.name) + } + + def executorRow(executor: ExecutorInfo): Seq[Node] = { + <tr> + <td>{executor.id}</td> + <td> + <a href={executor.worker.webUiAddress}>{executor.worker.id}</a> + </td> + <td>{executor.cores}</td> + <td>{executor.memory}</td> + <td>{executor.state}</td> + <td> + <a href={"%s/logPage?appId=%s&executorId=%s&logType=stdout" + .format(executor.worker.webUiAddress, executor.application.id, executor.id)}>stdout</a> + <a href={"%s/logPage?appId=%s&executorId=%s&logType=stderr" + .format(executor.worker.webUiAddress, executor.application.id, executor.id)}>stderr</a> + </td> + </tr> + } +} diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/IndexPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/IndexPage.scala new file mode 100644 index 0000000000..65e7a14e7a --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/IndexPage.scala @@ -0,0 +1,141 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.master.ui + +import javax.servlet.http.HttpServletRequest + +import scala.xml.Node + +import scala.concurrent.Await +import akka.pattern.ask +import scala.concurrent.duration._ + +import net.liftweb.json.JsonAST.JValue + +import org.apache.spark.deploy.DeployWebUI +import org.apache.spark.deploy.DeployMessages.{MasterStateResponse, RequestMasterState} +import org.apache.spark.deploy.JsonProtocol +import org.apache.spark.deploy.master.{ApplicationInfo, WorkerInfo} +import org.apache.spark.ui.UIUtils +import org.apache.spark.util.Utils + +private[spark] class IndexPage(parent: MasterWebUI) { + val master = parent.masterActorRef + implicit val timeout = parent.timeout + + def renderJson(request: HttpServletRequest): JValue = { + val stateFuture = (master ? RequestMasterState)(timeout).mapTo[MasterStateResponse] + val state = Await.result(stateFuture, 30 seconds) + JsonProtocol.writeMasterState(state) + } + + /** Index view listing applications and executors */ + def render(request: HttpServletRequest): Seq[Node] = { + val stateFuture = (master ? RequestMasterState)(timeout).mapTo[MasterStateResponse] + val state = Await.result(stateFuture, 30 seconds) + + val workerHeaders = Seq("Id", "Address", "State", "Cores", "Memory") + val workers = state.workers.sortBy(_.id) + val workerTable = UIUtils.listingTable(workerHeaders, workerRow, workers) + + val appHeaders = Seq("ID", "Name", "Cores", "Memory per Node", "Submitted Time", "User", + "State", "Duration") + val activeApps = state.activeApps.sortBy(_.startTime).reverse + val activeAppsTable = UIUtils.listingTable(appHeaders, appRow, activeApps) + val completedApps = state.completedApps.sortBy(_.endTime).reverse + val completedAppsTable = UIUtils.listingTable(appHeaders, appRow, completedApps) + + val content = + <div class="row-fluid"> + <div class="span12"> + <ul class="unstyled"> + <li><strong>URL:</strong> {state.uri}</li> + <li><strong>Workers:</strong> {state.workers.size}</li> + <li><strong>Cores:</strong> {state.workers.map(_.cores).sum} Total, + {state.workers.map(_.coresUsed).sum} Used</li> + <li><strong>Memory:</strong> + {Utils.megabytesToString(state.workers.map(_.memory).sum)} Total, + {Utils.megabytesToString(state.workers.map(_.memoryUsed).sum)} Used</li> + <li><strong>Applications:</strong> + {state.activeApps.size} Running, + {state.completedApps.size} Completed </li> + </ul> + </div> + </div> + + <div class="row-fluid"> + <div class="span12"> + <h4> Workers </h4> + {workerTable} + </div> + </div> + + <div class="row-fluid"> + <div class="span12"> + <h4> Running Applications </h4> + + {activeAppsTable} + </div> + </div> + + <div class="row-fluid"> + <div class="span12"> + <h4> Completed Applications </h4> + {completedAppsTable} + </div> + </div>; + UIUtils.basicSparkPage(content, "Spark Master at " + state.uri) + } + + def workerRow(worker: WorkerInfo): Seq[Node] = { + <tr> + <td> + <a href={worker.webUiAddress}>{worker.id}</a> + </td> + <td>{worker.host}:{worker.port}</td> + <td>{worker.state}</td> + <td>{worker.cores} ({worker.coresUsed} Used)</td> + <td sorttable_customkey={"%s.%s".format(worker.memory, worker.memoryUsed)}> + {Utils.megabytesToString(worker.memory)} + ({Utils.megabytesToString(worker.memoryUsed)} Used) + </td> + </tr> + } + + + def appRow(app: ApplicationInfo): Seq[Node] = { + <tr> + <td> + <a href={"app?appId=" + app.id}>{app.id}</a> + </td> + <td> + <a href={app.appUiUrl}>{app.desc.name}</a> + </td> + <td> + {app.coresGranted} + </td> + <td sorttable_customkey={app.desc.memoryPerSlave.toString}> + {Utils.megabytesToString(app.desc.memoryPerSlave)} + </td> + <td>{DeployWebUI.formatDate(app.submitDate)}</td> + <td>{app.desc.user}</td> + <td>{app.state.toString}</td> + <td>{DeployWebUI.formatDuration(app.duration)}</td> + </tr> + } +} diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala new file mode 100644 index 0000000000..a211ce2b42 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.master.ui + +import scala.concurrent.duration._ + +import javax.servlet.http.HttpServletRequest + +import org.eclipse.jetty.server.{Handler, Server} + +import org.apache.spark.{Logging} +import org.apache.spark.deploy.master.Master +import org.apache.spark.ui.JettyUtils +import org.apache.spark.ui.JettyUtils._ +import org.apache.spark.util.Utils + +/** + * Web UI server for the standalone master. + */ +private[spark] +class MasterWebUI(val master: Master, requestedPort: Int) extends Logging { + implicit val timeout = Duration.create( + System.getProperty("spark.akka.askTimeout", "10").toLong, "seconds") + val host = Utils.localHostName() + val port = requestedPort + + val masterActorRef = master.self + + var server: Option[Server] = None + var boundPort: Option[Int] = None + + val applicationPage = new ApplicationPage(this) + val indexPage = new IndexPage(this) + + def start() { + try { + val (srv, bPort) = JettyUtils.startJettyServer("0.0.0.0", port, handlers) + server = Some(srv) + boundPort = Some(bPort) + logInfo("Started Master web UI at http://%s:%d".format(host, boundPort.get)) + } catch { + case e: Exception => + logError("Failed to create Master JettyUtils", e) + System.exit(1) + } + } + + val metricsHandlers = master.masterMetricsSystem.getServletHandlers ++ + master.applicationMetricsSystem.getServletHandlers + + val handlers = metricsHandlers ++ Array[(String, Handler)]( + ("/static", createStaticHandler(MasterWebUI.STATIC_RESOURCE_DIR)), + ("/app/json", (request: HttpServletRequest) => applicationPage.renderJson(request)), + ("/app", (request: HttpServletRequest) => applicationPage.render(request)), + ("/json", (request: HttpServletRequest) => indexPage.renderJson(request)), + ("*", (request: HttpServletRequest) => indexPage.render(request)) + ) + + def stop() { + server.foreach(_.stop()) + } +} + +private[spark] object MasterWebUI { + val STATIC_RESOURCE_DIR = "org/apache/spark/ui/static" +} diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala new file mode 100644 index 0000000000..e3dc30eefc --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala @@ -0,0 +1,200 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.worker + +import java.io._ +import java.lang.System.getenv + +import akka.actor.ActorRef + +import com.google.common.base.Charsets +import com.google.common.io.Files + +import org.apache.spark.{Logging} +import org.apache.spark.deploy.{ExecutorState, ApplicationDescription} +import org.apache.spark.deploy.DeployMessages.ExecutorStateChanged +import org.apache.spark.util.Utils + +/** + * Manages the execution of one executor process. + */ +private[spark] class ExecutorRunner( + val appId: String, + val execId: Int, + val appDesc: ApplicationDescription, + val cores: Int, + val memory: Int, + val worker: ActorRef, + val workerId: String, + val host: String, + val sparkHome: File, + val workDir: File) + extends Logging { + + val fullId = appId + "/" + execId + var workerThread: Thread = null + var process: Process = null + var shutdownHook: Thread = null + + private def getAppEnv(key: String): Option[String] = + appDesc.command.environment.get(key).orElse(Option(getenv(key))) + + def start() { + workerThread = new Thread("ExecutorRunner for " + fullId) { + override def run() { fetchAndRunExecutor() } + } + workerThread.start() + + // Shutdown hook that kills actors on shutdown. + shutdownHook = new Thread() { + override def run() { + if (process != null) { + logInfo("Shutdown hook killing child process.") + process.destroy() + process.waitFor() + } + } + } + Runtime.getRuntime.addShutdownHook(shutdownHook) + } + + /** Stop this executor runner, including killing the process it launched */ + def kill() { + if (workerThread != null) { + workerThread.interrupt() + workerThread = null + if (process != null) { + logInfo("Killing process!") + process.destroy() + process.waitFor() + } + worker ! ExecutorStateChanged(appId, execId, ExecutorState.KILLED, None, None) + Runtime.getRuntime.removeShutdownHook(shutdownHook) + } + } + + /** Replace variables such as {{EXECUTOR_ID}} and {{CORES}} in a command argument passed to us */ + def substituteVariables(argument: String): String = argument match { + case "{{EXECUTOR_ID}}" => execId.toString + case "{{HOSTNAME}}" => host + case "{{CORES}}" => cores.toString + case other => other + } + + def buildCommandSeq(): Seq[String] = { + val command = appDesc.command + val runner = getAppEnv("JAVA_HOME").map(_ + "/bin/java").getOrElse("java") + // SPARK-698: do not call the run.cmd script, as process.destroy() + // fails to kill a process tree on Windows + Seq(runner) ++ buildJavaOpts() ++ Seq(command.mainClass) ++ + command.arguments.map(substituteVariables) + } + + /** + * Attention: this must always be aligned with the environment variables in the run scripts and + * the way the JAVA_OPTS are assembled there. + */ + def buildJavaOpts(): Seq[String] = { + val libraryOpts = getAppEnv("SPARK_LIBRARY_PATH") + .map(p => List("-Djava.library.path=" + p)) + .getOrElse(Nil) + val workerLocalOpts = Option(getenv("SPARK_JAVA_OPTS")).map(Utils.splitCommandString).getOrElse(Nil) + val userOpts = getAppEnv("SPARK_JAVA_OPTS").map(Utils.splitCommandString).getOrElse(Nil) + val memoryOpts = Seq("-Xms" + memory + "M", "-Xmx" + memory + "M") + + // Figure out our classpath with the external compute-classpath script + val ext = if (System.getProperty("os.name").startsWith("Windows")) ".cmd" else ".sh" + val classPath = Utils.executeAndGetOutput( + Seq(sparkHome + "/bin/compute-classpath" + ext), + extraEnvironment=appDesc.command.environment) + + Seq("-cp", classPath) ++ libraryOpts ++ workerLocalOpts ++ userOpts ++ memoryOpts + } + + /** Spawn a thread that will redirect a given stream to a file */ + def redirectStream(in: InputStream, file: File) { + val out = new FileOutputStream(file, true) + new Thread("redirect output to " + file) { + override def run() { + try { + Utils.copyStream(in, out, true) + } catch { + case e: IOException => + logInfo("Redirection to " + file + " closed: " + e.getMessage) + } + } + }.start() + } + + /** + * Download and run the executor described in our ApplicationDescription + */ + def fetchAndRunExecutor() { + try { + // Create the executor's working directory + val executorDir = new File(workDir, appId + "/" + execId) + if (!executorDir.mkdirs()) { + throw new IOException("Failed to create directory " + executorDir) + } + + // Launch the process + val command = buildCommandSeq() + logInfo("Launch command: " + command.mkString("\"", "\" \"", "\"")) + val builder = new ProcessBuilder(command: _*).directory(executorDir) + val env = builder.environment() + for ((key, value) <- appDesc.command.environment) { + env.put(key, value) + } + // In case we are running this from within the Spark Shell, avoid creating a "scala" + // parent process for the executor command + env.put("SPARK_LAUNCH_WITH_SCALA", "0") + process = builder.start() + + val header = "Spark Executor Command: %s\n%s\n\n".format( + command.mkString("\"", "\" \"", "\""), "=" * 40) + + // Redirect its stdout and stderr to files + val stdout = new File(executorDir, "stdout") + redirectStream(process.getInputStream, stdout) + + val stderr = new File(executorDir, "stderr") + Files.write(header, stderr, Charsets.UTF_8) + redirectStream(process.getErrorStream, stderr) + + // Wait for it to exit; this is actually a bad thing if it happens, because we expect to run + // long-lived processes only. However, in the future, we might restart the executor a few + // times on the same machine. + val exitCode = process.waitFor() + val message = "Command exited with code " + exitCode + worker ! ExecutorStateChanged(appId, execId, ExecutorState.FAILED, Some(message), + Some(exitCode)) + } catch { + case interrupted: InterruptedException => + logInfo("Runner thread for executor " + fullId + " interrupted") + + case e: Exception => { + logError("Error running executor", e) + if (process != null) { + process.destroy() + } + val message = e.getClass + ": " + e.getMessage + worker ! ExecutorStateChanged(appId, execId, ExecutorState.FAILED, Some(message), None) + } + } + } +} diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala new file mode 100644 index 0000000000..a0a9d1040a --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala @@ -0,0 +1,215 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.worker + +import java.text.SimpleDateFormat +import java.util.Date +import java.io.File + +import scala.collection.mutable.HashMap +import scala.concurrent.duration._ + +import akka.actor.{ActorRef, Props, Actor, ActorSystem, Terminated} +import akka.remote.{RemoteClientLifeCycleEvent, RemoteClientShutdown, RemoteClientDisconnected} + +import org.apache.spark.{Logging} +import org.apache.spark.deploy.ExecutorState +import org.apache.spark.deploy.DeployMessages._ +import org.apache.spark.deploy.master.Master +import org.apache.spark.deploy.worker.ui.WorkerWebUI +import org.apache.spark.metrics.MetricsSystem +import org.apache.spark.util.{Utils, AkkaUtils} + + +private[spark] class Worker( + host: String, + port: Int, + webUiPort: Int, + cores: Int, + memory: Int, + masterUrl: String, + workDirPath: String = null) + extends Actor with Logging { + + Utils.checkHost(host, "Expected hostname") + assert (port > 0) + + val DATE_FORMAT = new SimpleDateFormat("yyyyMMddHHmmss") // For worker and executor IDs + + // Send a heartbeat every (heartbeat timeout) / 4 milliseconds + val HEARTBEAT_MILLIS = System.getProperty("spark.worker.timeout", "60").toLong * 1000 / 4 + + var master: ActorRef = null + var masterWebUiUrl : String = "" + val workerId = generateWorkerId() + var sparkHome: File = null + var workDir: File = null + val executors = new HashMap[String, ExecutorRunner] + val finishedExecutors = new HashMap[String, ExecutorRunner] + val publicAddress = { + val envVar = System.getenv("SPARK_PUBLIC_DNS") + if (envVar != null) envVar else host + } + var webUi: WorkerWebUI = null + + var coresUsed = 0 + var memoryUsed = 0 + + val metricsSystem = MetricsSystem.createMetricsSystem("worker") + val workerSource = new WorkerSource(this) + + def coresFree: Int = cores - coresUsed + def memoryFree: Int = memory - memoryUsed + + def createWorkDir() { + workDir = Option(workDirPath).map(new File(_)).getOrElse(new File(sparkHome, "work")) + try { + // This sporadically fails - not sure why ... !workDir.exists() && !workDir.mkdirs() + // So attempting to create and then check if directory was created or not. + workDir.mkdirs() + if ( !workDir.exists() || !workDir.isDirectory) { + logError("Failed to create work directory " + workDir) + System.exit(1) + } + assert (workDir.isDirectory) + } catch { + case e: Exception => + logError("Failed to create work directory " + workDir, e) + System.exit(1) + } + } + + override def preStart() { + logInfo("Starting Spark worker %s:%d with %d cores, %s RAM".format( + host, port, cores, Utils.megabytesToString(memory))) + sparkHome = new File(Option(System.getenv("SPARK_HOME")).getOrElse(".")) + logInfo("Spark home: " + sparkHome) + createWorkDir() + webUi = new WorkerWebUI(this, workDir, Some(webUiPort)) + + webUi.start() + connectToMaster() + + metricsSystem.registerSource(workerSource) + metricsSystem.start() + } + + def connectToMaster() { + logInfo("Connecting to master " + masterUrl) + master = context.actorFor(Master.toAkkaUrl(masterUrl)) + master ! RegisterWorker(workerId, host, port, cores, memory, webUi.boundPort.get, publicAddress) + context.system.eventStream.subscribe(self, classOf[RemoteClientLifeCycleEvent]) + context.watch(master) // Doesn't work with remote actors, but useful for testing + } + + import context.dispatcher + + override def receive = { + case RegisteredWorker(url) => + masterWebUiUrl = url + logInfo("Successfully registered with master") + context.system.scheduler.schedule(0 millis, HEARTBEAT_MILLIS millis) { + master ! Heartbeat(workerId) + } + + case RegisterWorkerFailed(message) => + logError("Worker registration failed: " + message) + System.exit(1) + + case LaunchExecutor(appId, execId, appDesc, cores_, memory_, execSparkHome_) => + logInfo("Asked to launch executor %s/%d for %s".format(appId, execId, appDesc.name)) + val manager = new ExecutorRunner( + appId, execId, appDesc, cores_, memory_, self, workerId, host, new File(execSparkHome_), workDir) + executors(appId + "/" + execId) = manager + manager.start() + coresUsed += cores_ + memoryUsed += memory_ + master ! ExecutorStateChanged(appId, execId, ExecutorState.RUNNING, None, None) + + case ExecutorStateChanged(appId, execId, state, message, exitStatus) => + master ! ExecutorStateChanged(appId, execId, state, message, exitStatus) + val fullId = appId + "/" + execId + if (ExecutorState.isFinished(state)) { + val executor = executors(fullId) + logInfo("Executor " + fullId + " finished with state " + state + + message.map(" message " + _).getOrElse("") + + exitStatus.map(" exitStatus " + _).getOrElse("")) + finishedExecutors(fullId) = executor + executors -= fullId + coresUsed -= executor.cores + memoryUsed -= executor.memory + } + + case KillExecutor(appId, execId) => + val fullId = appId + "/" + execId + executors.get(fullId) match { + case Some(executor) => + logInfo("Asked to kill executor " + fullId) + executor.kill() + case None => + logInfo("Asked to kill unknown executor " + fullId) + } + + case Terminated(_) | RemoteClientDisconnected(_, _) | RemoteClientShutdown(_, _) => + masterDisconnected() + + case RequestWorkerState => { + sender ! WorkerStateResponse(host, port, workerId, executors.values.toList, + finishedExecutors.values.toList, masterUrl, cores, memory, + coresUsed, memoryUsed, masterWebUiUrl) + } + } + + def masterDisconnected() { + // TODO: It would be nice to try to reconnect to the master, but just shut down for now. + // (Note that if reconnecting we would also need to assign IDs differently.) + logError("Connection to master failed! Shutting down.") + executors.values.foreach(_.kill()) + System.exit(1) + } + + def generateWorkerId(): String = { + "worker-%s-%s-%d".format(DATE_FORMAT.format(new Date), host, port) + } + + override def postStop() { + executors.values.foreach(_.kill()) + webUi.stop() + metricsSystem.stop() + } +} + +private[spark] object Worker { + def main(argStrings: Array[String]) { + val args = new WorkerArguments(argStrings) + val (actorSystem, _) = startSystemAndActor(args.host, args.port, args.webUiPort, args.cores, + args.memory, args.master, args.workDir) + actorSystem.awaitTermination() + } + + def startSystemAndActor(host: String, port: Int, webUiPort: Int, cores: Int, memory: Int, + masterUrl: String, workDir: String, workerNumber: Option[Int] = None): (ActorSystem, Int) = { + // The LocalSparkCluster runs multiple local sparkWorkerX actor systems + val systemName = "sparkWorker" + workerNumber.map(_.toString).getOrElse("") + val (actorSystem, boundPort) = AkkaUtils.createActorSystem(systemName, host, port) + val actor = actorSystem.actorOf(Props(new Worker(host, boundPort, webUiPort, cores, memory, + masterUrl, workDir)), name = "Worker") + (actorSystem, boundPort) + } + +} diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala new file mode 100644 index 0000000000..0ae89a864f --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala @@ -0,0 +1,151 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.worker + +import org.apache.spark.util.{Utils, IntParam, MemoryParam} +import java.lang.management.ManagementFactory + +/** + * Command-line parser for the master. + */ +private[spark] class WorkerArguments(args: Array[String]) { + var host = Utils.localHostName() + var port = 0 + var webUiPort = 8081 + var cores = inferDefaultCores() + var memory = inferDefaultMemory() + var master: String = null + var workDir: String = null + + // Check for settings in environment variables + if (System.getenv("SPARK_WORKER_PORT") != null) { + port = System.getenv("SPARK_WORKER_PORT").toInt + } + if (System.getenv("SPARK_WORKER_CORES") != null) { + cores = System.getenv("SPARK_WORKER_CORES").toInt + } + if (System.getenv("SPARK_WORKER_MEMORY") != null) { + memory = Utils.memoryStringToMb(System.getenv("SPARK_WORKER_MEMORY")) + } + if (System.getenv("SPARK_WORKER_WEBUI_PORT") != null) { + webUiPort = System.getenv("SPARK_WORKER_WEBUI_PORT").toInt + } + if (System.getenv("SPARK_WORKER_DIR") != null) { + workDir = System.getenv("SPARK_WORKER_DIR") + } + + parse(args.toList) + + def parse(args: List[String]): Unit = args match { + case ("--ip" | "-i") :: value :: tail => + Utils.checkHost(value, "ip no longer supported, please use hostname " + value) + host = value + parse(tail) + + case ("--host" | "-h") :: value :: tail => + Utils.checkHost(value, "Please use hostname " + value) + host = value + parse(tail) + + case ("--port" | "-p") :: IntParam(value) :: tail => + port = value + parse(tail) + + case ("--cores" | "-c") :: IntParam(value) :: tail => + cores = value + parse(tail) + + case ("--memory" | "-m") :: MemoryParam(value) :: tail => + memory = value + parse(tail) + + case ("--work-dir" | "-d") :: value :: tail => + workDir = value + parse(tail) + + case "--webui-port" :: IntParam(value) :: tail => + webUiPort = value + parse(tail) + + case ("--help" | "-h") :: tail => + printUsageAndExit(0) + + case value :: tail => + if (master != null) { // Two positional arguments were given + printUsageAndExit(1) + } + master = value + parse(tail) + + case Nil => + if (master == null) { // No positional argument was given + printUsageAndExit(1) + } + + case _ => + printUsageAndExit(1) + } + + /** + * Print usage and exit JVM with the given exit code. + */ + def printUsageAndExit(exitCode: Int) { + System.err.println( + "Usage: Worker [options] <master>\n" + + "\n" + + "Master must be a URL of the form spark://hostname:port\n" + + "\n" + + "Options:\n" + + " -c CORES, --cores CORES Number of cores to use\n" + + " -m MEM, --memory MEM Amount of memory to use (e.g. 1000M, 2G)\n" + + " -d DIR, --work-dir DIR Directory to run apps in (default: SPARK_HOME/work)\n" + + " -i HOST, --ip IP Hostname to listen on (deprecated, please use --host or -h)\n" + + " -h HOST, --host HOST Hostname to listen on\n" + + " -p PORT, --port PORT Port to listen on (default: random)\n" + + " --webui-port PORT Port for web UI (default: 8081)") + System.exit(exitCode) + } + + def inferDefaultCores(): Int = { + Runtime.getRuntime.availableProcessors() + } + + def inferDefaultMemory(): Int = { + val ibmVendor = System.getProperty("java.vendor").contains("IBM") + var totalMb = 0 + try { + val bean = ManagementFactory.getOperatingSystemMXBean() + if (ibmVendor) { + val beanClass = Class.forName("com.ibm.lang.management.OperatingSystemMXBean") + val method = beanClass.getDeclaredMethod("getTotalPhysicalMemory") + totalMb = (method.invoke(bean).asInstanceOf[Long] / 1024 / 1024).toInt + } else { + val beanClass = Class.forName("com.sun.management.OperatingSystemMXBean") + val method = beanClass.getDeclaredMethod("getTotalPhysicalMemorySize") + totalMb = (method.invoke(bean).asInstanceOf[Long] / 1024 / 1024).toInt + } + } catch { + case e: Exception => { + totalMb = 2*1024 + System.out.println("Failed to get total physical memory. Using " + totalMb + " MB") + } + } + // Leave out 1 GB for the operating system, but don't return a negative memory size + math.max(totalMb - 1024, 512) + } +} diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerSource.scala b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerSource.scala new file mode 100644 index 0000000000..df269fd047 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerSource.scala @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.worker + +import com.codahale.metrics.{Gauge, MetricRegistry} + +import org.apache.spark.metrics.source.Source + +private[spark] class WorkerSource(val worker: Worker) extends Source { + val sourceName = "worker" + val metricRegistry = new MetricRegistry() + + metricRegistry.register(MetricRegistry.name("executors", "number"), new Gauge[Int] { + override def getValue: Int = worker.executors.size + }) + + // Gauge for cores used of this worker + metricRegistry.register(MetricRegistry.name("coresUsed", "number"), new Gauge[Int] { + override def getValue: Int = worker.coresUsed + }) + + // Gauge for memory used of this worker + metricRegistry.register(MetricRegistry.name("memUsed", "MBytes"), new Gauge[Int] { + override def getValue: Int = worker.memoryUsed + }) + + // Gauge for cores free of this worker + metricRegistry.register(MetricRegistry.name("coresFree", "number"), new Gauge[Int] { + override def getValue: Int = worker.coresFree + }) + + // Gauge for memory free of this worker + metricRegistry.register(MetricRegistry.name("memFree", "MBytes"), new Gauge[Int] { + override def getValue: Int = worker.memoryFree + }) +} diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/IndexPage.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/IndexPage.scala new file mode 100644 index 0000000000..1a768d501f --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/IndexPage.scala @@ -0,0 +1,116 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.worker.ui + +import javax.servlet.http.HttpServletRequest + +import scala.xml.Node + +import scala.concurrent.duration._ +import scala.concurrent.Await + +import akka.pattern.ask + +import net.liftweb.json.JsonAST.JValue + +import org.apache.spark.deploy.JsonProtocol +import org.apache.spark.deploy.DeployMessages.{RequestWorkerState, WorkerStateResponse} +import org.apache.spark.deploy.worker.ExecutorRunner +import org.apache.spark.ui.UIUtils +import org.apache.spark.util.Utils + + +private[spark] class IndexPage(parent: WorkerWebUI) { + val workerActor = parent.worker.self + val worker = parent.worker + val timeout = parent.timeout + + def renderJson(request: HttpServletRequest): JValue = { + val stateFuture = (workerActor ? RequestWorkerState)(timeout).mapTo[WorkerStateResponse] + val workerState = Await.result(stateFuture, 30 seconds) + JsonProtocol.writeWorkerState(workerState) + } + + def render(request: HttpServletRequest): Seq[Node] = { + val stateFuture = (workerActor ? RequestWorkerState)(timeout).mapTo[WorkerStateResponse] + val workerState = Await.result(stateFuture, 30 seconds) + + val executorHeaders = Seq("ExecutorID", "Cores", "Memory", "Job Details", "Logs") + val runningExecutorTable = + UIUtils.listingTable(executorHeaders, executorRow, workerState.executors) + val finishedExecutorTable = + UIUtils.listingTable(executorHeaders, executorRow, workerState.finishedExecutors) + + val content = + <div class="row-fluid"> <!-- Worker Details --> + <div class="span12"> + <ul class="unstyled"> + <li><strong>ID:</strong> {workerState.workerId}</li> + <li><strong> + Master URL:</strong> {workerState.masterUrl} + </li> + <li><strong>Cores:</strong> {workerState.cores} ({workerState.coresUsed} Used)</li> + <li><strong>Memory:</strong> {Utils.megabytesToString(workerState.memory)} + ({Utils.megabytesToString(workerState.memoryUsed)} Used)</li> + </ul> + <p><a href={workerState.masterWebUiUrl}>Back to Master</a></p> + </div> + </div> + + <div class="row-fluid"> <!-- Running Executors --> + <div class="span12"> + <h4> Running Executors {workerState.executors.size} </h4> + {runningExecutorTable} + </div> + </div> + + <div class="row-fluid"> <!-- Finished Executors --> + <div class="span12"> + <h4> Finished Executors </h4> + {finishedExecutorTable} + </div> + </div>; + + UIUtils.basicSparkPage(content, "Spark Worker at %s:%s".format( + workerState.host, workerState.port)) + } + + def executorRow(executor: ExecutorRunner): Seq[Node] = { + <tr> + <td>{executor.execId}</td> + <td>{executor.cores}</td> + <td sorttable_customkey={executor.memory.toString}> + {Utils.megabytesToString(executor.memory)} + </td> + <td> + <ul class="unstyled"> + <li><strong>ID:</strong> {executor.appId}</li> + <li><strong>Name:</strong> {executor.appDesc.name}</li> + <li><strong>User:</strong> {executor.appDesc.user}</li> + </ul> + </td> + <td> + <a href={"logPage?appId=%s&executorId=%s&logType=stdout" + .format(executor.appId, executor.execId)}>stdout</a> + <a href={"logPage?appId=%s&executorId=%s&logType=stderr" + .format(executor.appId, executor.execId)}>stderr</a> + </td> + </tr> + } + +} diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala new file mode 100644 index 0000000000..07bc479c83 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala @@ -0,0 +1,194 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.worker.ui + +import akka.actor.ActorRef +import akka.util.Timeout + +import scala.concurrent.duration._ + +import java.io.{FileInputStream, File} + +import javax.servlet.http.HttpServletRequest + +import org.eclipse.jetty.server.{Handler, Server} + +import org.apache.spark.deploy.worker.Worker +import org.apache.spark.{Logging} +import org.apache.spark.ui.JettyUtils +import org.apache.spark.ui.JettyUtils._ +import org.apache.spark.ui.UIUtils +import org.apache.spark.util.Utils + +/** + * Web UI server for the standalone worker. + */ +private[spark] +class WorkerWebUI(val worker: Worker, val workDir: File, requestedPort: Option[Int] = None) + extends Logging { + implicit val timeout = Timeout( + Duration.create(System.getProperty("spark.akka.askTimeout", "10").toLong, "seconds")) + val host = Utils.localHostName() + val port = requestedPort.getOrElse( + System.getProperty("worker.ui.port", WorkerWebUI.DEFAULT_PORT).toInt) + + var server: Option[Server] = None + var boundPort: Option[Int] = None + + val indexPage = new IndexPage(this) + + val metricsHandlers = worker.metricsSystem.getServletHandlers + + val handlers = metricsHandlers ++ Array[(String, Handler)]( + ("/static", createStaticHandler(WorkerWebUI.STATIC_RESOURCE_DIR)), + ("/log", (request: HttpServletRequest) => log(request)), + ("/logPage", (request: HttpServletRequest) => logPage(request)), + ("/json", (request: HttpServletRequest) => indexPage.renderJson(request)), + ("*", (request: HttpServletRequest) => indexPage.render(request)) + ) + + def start() { + try { + val (srv, bPort) = JettyUtils.startJettyServer("0.0.0.0", port, handlers) + server = Some(srv) + boundPort = Some(bPort) + logInfo("Started Worker web UI at http://%s:%d".format(host, bPort)) + } catch { + case e: Exception => + logError("Failed to create Worker JettyUtils", e) + System.exit(1) + } + } + + def log(request: HttpServletRequest): String = { + val defaultBytes = 100 * 1024 + val appId = request.getParameter("appId") + val executorId = request.getParameter("executorId") + val logType = request.getParameter("logType") + val offset = Option(request.getParameter("offset")).map(_.toLong) + val byteLength = Option(request.getParameter("byteLength")).map(_.toInt).getOrElse(defaultBytes) + val path = "%s/%s/%s/%s".format(workDir.getPath, appId, executorId, logType) + + val (startByte, endByte) = getByteRange(path, offset, byteLength) + val file = new File(path) + val logLength = file.length + + val pre = "==== Bytes %s-%s of %s of %s/%s/%s ====\n" + .format(startByte, endByte, logLength, appId, executorId, logType) + pre + Utils.offsetBytes(path, startByte, endByte) + } + + def logPage(request: HttpServletRequest): Seq[scala.xml.Node] = { + val defaultBytes = 100 * 1024 + val appId = request.getParameter("appId") + val executorId = request.getParameter("executorId") + val logType = request.getParameter("logType") + val offset = Option(request.getParameter("offset")).map(_.toLong) + val byteLength = Option(request.getParameter("byteLength")).map(_.toInt).getOrElse(defaultBytes) + val path = "%s/%s/%s/%s".format(workDir.getPath, appId, executorId, logType) + + val (startByte, endByte) = getByteRange(path, offset, byteLength) + val file = new File(path) + val logLength = file.length + + val logText = <node>{Utils.offsetBytes(path, startByte, endByte)}</node> + + val linkToMaster = <p><a href={worker.masterWebUiUrl}>Back to Master</a></p> + + val range = <span>Bytes {startByte.toString} - {endByte.toString} of {logLength}</span> + + val backButton = + if (startByte > 0) { + <a href={"?appId=%s&executorId=%s&logType=%s&offset=%s&byteLength=%s" + .format(appId, executorId, logType, math.max(startByte-byteLength, 0), + byteLength)}> + <button type="button" class="btn btn-default"> + Previous {Utils.bytesToString(math.min(byteLength, startByte))} + </button> + </a> + } + else { + <button type="button" class="btn btn-default" disabled="disabled"> + Previous 0 B + </button> + } + + val nextButton = + if (endByte < logLength) { + <a href={"?appId=%s&executorId=%s&logType=%s&offset=%s&byteLength=%s". + format(appId, executorId, logType, endByte, byteLength)}> + <button type="button" class="btn btn-default"> + Next {Utils.bytesToString(math.min(byteLength, logLength-endByte))} + </button> + </a> + } + else { + <button type="button" class="btn btn-default" disabled="disabled"> + Next 0 B + </button> + } + + val content = + <html> + <body> + {linkToMaster} + <div> + <div style="float:left;width:40%">{backButton}</div> + <div style="float:left;">{range}</div> + <div style="float:right;">{nextButton}</div> + </div> + <br /> + <div style="height:500px;overflow:auto;padding:5px;"> + <pre>{logText}</pre> + </div> + </body> + </html> + UIUtils.basicSparkPage(content, logType + " log page for " + appId) + } + + /** Determine the byte range for a log or log page. */ + def getByteRange(path: String, offset: Option[Long], byteLength: Int) + : (Long, Long) = { + val defaultBytes = 100 * 1024 + val maxBytes = 1024 * 1024 + + val file = new File(path) + val logLength = file.length() + val getOffset = offset.getOrElse(logLength-defaultBytes) + + val startByte = + if (getOffset < 0) 0L + else if (getOffset > logLength) logLength + else getOffset + + val logPageLength = math.min(byteLength, maxBytes) + + val endByte = math.min(startByte+logPageLength, logLength) + + (startByte, endByte) + } + + def stop() { + server.foreach(_.stop()) + } +} + +private[spark] object WorkerWebUI { + val STATIC_RESOURCE_DIR = "org/apache/spark/ui/static" + val DEFAULT_PORT="8081" +} diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala new file mode 100644 index 0000000000..d365804994 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -0,0 +1,270 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.executor + +import java.io.{File} +import java.lang.management.ManagementFactory +import java.nio.ByteBuffer +import java.util.concurrent._ + +import scala.collection.JavaConversions._ +import scala.collection.mutable.HashMap + +import org.apache.spark.scheduler._ +import org.apache.spark._ +import org.apache.spark.util.Utils + + +/** + * The Mesos executor for Spark. + */ +private[spark] class Executor( + executorId: String, + slaveHostname: String, + properties: Seq[(String, String)]) + extends Logging +{ + // Application dependencies (added through SparkContext) that we've fetched so far on this node. + // Each map holds the master's timestamp for the version of that file or JAR we got. + private val currentFiles: HashMap[String, Long] = new HashMap[String, Long]() + private val currentJars: HashMap[String, Long] = new HashMap[String, Long]() + + private val EMPTY_BYTE_BUFFER = ByteBuffer.wrap(new Array[Byte](0)) + + initLogging() + + // No ip or host:port - just hostname + Utils.checkHost(slaveHostname, "Expected executed slave to be a hostname") + // must not have port specified. + assert (0 == Utils.parseHostPort(slaveHostname)._2) + + // Make sure the local hostname we report matches the cluster scheduler's name for this host + Utils.setCustomHostname(slaveHostname) + + // Set spark.* system properties from executor arg + for ((key, value) <- properties) { + System.setProperty(key, value) + } + + // If we are in yarn mode, systems can have different disk layouts so we must set it + // to what Yarn on this system said was available. This will be used later when SparkEnv + // created. + if (java.lang.Boolean.valueOf(System.getenv("SPARK_YARN_MODE"))) { + System.setProperty("spark.local.dir", getYarnLocalDirs()) + } + + // Create our ClassLoader and set it on this thread + private val urlClassLoader = createClassLoader() + private val replClassLoader = addReplClassLoaderIfNeeded(urlClassLoader) + Thread.currentThread.setContextClassLoader(replClassLoader) + + // Make any thread terminations due to uncaught exceptions kill the entire + // executor process to avoid surprising stalls. + Thread.setDefaultUncaughtExceptionHandler( + new Thread.UncaughtExceptionHandler { + override def uncaughtException(thread: Thread, exception: Throwable) { + try { + logError("Uncaught exception in thread " + thread, exception) + + // We may have been called from a shutdown hook. If so, we must not call System.exit(). + // (If we do, we will deadlock.) + if (!Utils.inShutdown()) { + if (exception.isInstanceOf[OutOfMemoryError]) { + System.exit(ExecutorExitCode.OOM) + } else { + System.exit(ExecutorExitCode.UNCAUGHT_EXCEPTION) + } + } + } catch { + case oom: OutOfMemoryError => Runtime.getRuntime.halt(ExecutorExitCode.OOM) + case t: Throwable => Runtime.getRuntime.halt(ExecutorExitCode.UNCAUGHT_EXCEPTION_TWICE) + } + } + } + ) + + val executorSource = new ExecutorSource(this) + + // Initialize Spark environment (using system properties read above) + val env = SparkEnv.createFromSystemProperties(executorId, slaveHostname, 0, false, false) + SparkEnv.set(env) + env.metricsSystem.registerSource(executorSource) + + private val akkaFrameSize = env.actorSystem.settings.config.getBytes("akka.remote.netty.message-frame-size") + + // Start worker thread pool + val threadPool = new ThreadPoolExecutor( + 1, 128, 600, TimeUnit.SECONDS, new SynchronousQueue[Runnable]) + + def launchTask(context: ExecutorBackend, taskId: Long, serializedTask: ByteBuffer) { + threadPool.execute(new TaskRunner(context, taskId, serializedTask)) + } + + /** Get the Yarn approved local directories. */ + private def getYarnLocalDirs(): String = { + // Hadoop 0.23 and 2.x have different Environment variable names for the + // local dirs, so lets check both. We assume one of the 2 is set. + // LOCAL_DIRS => 2.X, YARN_LOCAL_DIRS => 0.23.X + val localDirs = Option(System.getenv("YARN_LOCAL_DIRS")) + .getOrElse(Option(System.getenv("LOCAL_DIRS")) + .getOrElse("")) + + if (localDirs.isEmpty()) { + throw new Exception("Yarn Local dirs can't be empty") + } + return localDirs + } + + class TaskRunner(context: ExecutorBackend, taskId: Long, serializedTask: ByteBuffer) + extends Runnable { + + override def run() { + val startTime = System.currentTimeMillis() + SparkEnv.set(env) + Thread.currentThread.setContextClassLoader(replClassLoader) + val ser = SparkEnv.get.closureSerializer.newInstance() + logInfo("Running task ID " + taskId) + context.statusUpdate(taskId, TaskState.RUNNING, EMPTY_BYTE_BUFFER) + var attemptedTask: Option[Task[Any]] = None + var taskStart: Long = 0 + def getTotalGCTime = ManagementFactory.getGarbageCollectorMXBeans.map(g => g.getCollectionTime).sum + val startGCTime = getTotalGCTime + + try { + SparkEnv.set(env) + Accumulators.clear() + val (taskFiles, taskJars, taskBytes) = Task.deserializeWithDependencies(serializedTask) + updateDependencies(taskFiles, taskJars) + val task = ser.deserialize[Task[Any]](taskBytes, Thread.currentThread.getContextClassLoader) + attemptedTask = Some(task) + logInfo("Its epoch is " + task.epoch) + env.mapOutputTracker.updateEpoch(task.epoch) + taskStart = System.currentTimeMillis() + val value = task.run(taskId.toInt) + val taskFinish = System.currentTimeMillis() + for (m <- task.metrics) { + m.hostname = Utils.localHostName + m.executorDeserializeTime = (taskStart - startTime).toInt + m.executorRunTime = (taskFinish - taskStart).toInt + m.jvmGCTime = getTotalGCTime - startGCTime + } + //TODO I'd also like to track the time it takes to serialize the task results, but that is huge headache, b/c + // we need to serialize the task metrics first. If TaskMetrics had a custom serialized format, we could + // just change the relevants bytes in the byte buffer + val accumUpdates = Accumulators.values + val result = new TaskResult(value, accumUpdates, task.metrics.getOrElse(null)) + val serializedResult = ser.serialize(result) + logInfo("Serialized size of result for " + taskId + " is " + serializedResult.limit) + if (serializedResult.limit >= (akkaFrameSize - 1024)) { + context.statusUpdate(taskId, TaskState.FAILED, ser.serialize(TaskResultTooBigFailure())) + return + } + context.statusUpdate(taskId, TaskState.FINISHED, serializedResult) + logInfo("Finished task ID " + taskId) + } catch { + case ffe: FetchFailedException => { + val reason = ffe.toTaskEndReason + context.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason)) + } + + case t: Throwable => { + val serviceTime = (System.currentTimeMillis() - taskStart).toInt + val metrics = attemptedTask.flatMap(t => t.metrics) + for (m <- metrics) { + m.executorRunTime = serviceTime + m.jvmGCTime = getTotalGCTime - startGCTime + } + val reason = ExceptionFailure(t.getClass.getName, t.toString, t.getStackTrace, metrics) + context.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason)) + + // TODO: Should we exit the whole executor here? On the one hand, the failed task may + // have left some weird state around depending on when the exception was thrown, but on + // the other hand, maybe we could detect that when future tasks fail and exit then. + logError("Exception in task ID " + taskId, t) + //System.exit(1) + } + } + } + } + + /** + * Create a ClassLoader for use in tasks, adding any JARs specified by the user or any classes + * created by the interpreter to the search path + */ + private def createClassLoader(): ExecutorURLClassLoader = { + var loader = this.getClass.getClassLoader + + // For each of the jars in the jarSet, add them to the class loader. + // We assume each of the files has already been fetched. + val urls = currentJars.keySet.map { uri => + new File(uri.split("/").last).toURI.toURL + }.toArray + new ExecutorURLClassLoader(urls, loader) + } + + /** + * If the REPL is in use, add another ClassLoader that will read + * new classes defined by the REPL as the user types code + */ + private def addReplClassLoaderIfNeeded(parent: ClassLoader): ClassLoader = { + val classUri = System.getProperty("spark.repl.class.uri") + if (classUri != null) { + logInfo("Using REPL class URI: " + classUri) + try { + val klass = Class.forName("org.apache.spark.repl.ExecutorClassLoader") + .asInstanceOf[Class[_ <: ClassLoader]] + val constructor = klass.getConstructor(classOf[String], classOf[ClassLoader]) + return constructor.newInstance(classUri, parent) + } catch { + case _: ClassNotFoundException => + logError("Could not find org.apache.spark.repl.ExecutorClassLoader on classpath!") + System.exit(1) + null + } + } else { + return parent + } + } + + /** + * Download any missing dependencies if we receive a new set of files and JARs from the + * SparkContext. Also adds any new JARs we fetched to the class loader. + */ + private def updateDependencies(newFiles: HashMap[String, Long], newJars: HashMap[String, Long]) { + synchronized { + // Fetch missing dependencies + for ((name, timestamp) <- newFiles if currentFiles.getOrElse(name, -1L) < timestamp) { + logInfo("Fetching " + name + " with timestamp " + timestamp) + Utils.fetchFile(name, new File(SparkFiles.getRootDirectory)) + currentFiles(name) = timestamp + } + for ((name, timestamp) <- newJars if currentJars.getOrElse(name, -1L) < timestamp) { + logInfo("Fetching " + name + " with timestamp " + timestamp) + Utils.fetchFile(name, new File(SparkFiles.getRootDirectory)) + currentJars(name) = timestamp + // Add it to our class loader + val localName = name.split("/").last + val url = new File(SparkFiles.getRootDirectory, localName).toURI.toURL + if (!urlClassLoader.getURLs.contains(url)) { + logInfo("Adding " + url + " to class loader") + urlClassLoader.addURL(url) + } + } + } + } +} diff --git a/core/src/main/scala/org/apache/spark/executor/ExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/ExecutorBackend.scala new file mode 100644 index 0000000000..ad7dd34c76 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/executor/ExecutorBackend.scala @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.executor + +import java.nio.ByteBuffer +import org.apache.spark.TaskState.TaskState + +/** + * A pluggable interface used by the Executor to send updates to the cluster scheduler. + */ +private[spark] trait ExecutorBackend { + def statusUpdate(taskId: Long, state: TaskState, data: ByteBuffer) +} diff --git a/core/src/main/scala/org/apache/spark/executor/ExecutorExitCode.scala b/core/src/main/scala/org/apache/spark/executor/ExecutorExitCode.scala new file mode 100644 index 0000000000..e5c9bbbe28 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/executor/ExecutorExitCode.scala @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.executor + +/** + * These are exit codes that executors should use to provide the master with information about + * executor failures assuming that cluster management framework can capture the exit codes (but + * perhaps not log files). The exit code constants here are chosen to be unlikely to conflict + * with "natural" exit statuses that may be caused by the JVM or user code. In particular, + * exit codes 128+ arise on some Unix-likes as a result of signals, and it appears that the + * OpenJDK JVM may use exit code 1 in some of its own "last chance" code. + */ +private[spark] +object ExecutorExitCode { + /** The default uncaught exception handler was reached. */ + val UNCAUGHT_EXCEPTION = 50 + + /** The default uncaught exception handler was called and an exception was encountered while + logging the exception. */ + val UNCAUGHT_EXCEPTION_TWICE = 51 + + /** The default uncaught exception handler was reached, and the uncaught exception was an + OutOfMemoryError. */ + val OOM = 52 + + /** DiskStore failed to create a local temporary directory after many attempts. */ + val DISK_STORE_FAILED_TO_CREATE_DIR = 53 + + def explainExitCode(exitCode: Int): String = { + exitCode match { + case UNCAUGHT_EXCEPTION => "Uncaught exception" + case UNCAUGHT_EXCEPTION_TWICE => "Uncaught exception, and logging the exception failed" + case OOM => "OutOfMemoryError" + case DISK_STORE_FAILED_TO_CREATE_DIR => + "Failed to create local directory (bad spark.local.dir?)" + case _ => + "Unknown executor exit code (" + exitCode + ")" + ( + if (exitCode > 128) + " (died from signal " + (exitCode - 128) + "?)" + else + "" + ) + } + } +} diff --git a/core/src/main/scala/org/apache/spark/executor/ExecutorSource.scala b/core/src/main/scala/org/apache/spark/executor/ExecutorSource.scala new file mode 100644 index 0000000000..bf8fb4fd21 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/executor/ExecutorSource.scala @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.executor + +import com.codahale.metrics.{Gauge, MetricRegistry} + +import org.apache.hadoop.fs.FileSystem +import org.apache.hadoop.hdfs.DistributedFileSystem +import org.apache.hadoop.fs.LocalFileSystem + +import scala.collection.JavaConversions._ + +import org.apache.spark.metrics.source.Source + +class ExecutorSource(val executor: Executor) extends Source { + private def fileStats(scheme: String) : Option[FileSystem.Statistics] = + FileSystem.getAllStatistics().filter(s => s.getScheme.equals(scheme)).headOption + + private def registerFileSystemStat[T]( + scheme: String, name: String, f: FileSystem.Statistics => T, defaultValue: T) = { + metricRegistry.register(MetricRegistry.name("filesystem", scheme, name), new Gauge[T] { + override def getValue: T = fileStats(scheme).map(f).getOrElse(defaultValue) + }) + } + + val metricRegistry = new MetricRegistry() + val sourceName = "executor" + + // Gauge for executor thread pool's actively executing task counts + metricRegistry.register(MetricRegistry.name("threadpool", "activeTask", "count"), new Gauge[Int] { + override def getValue: Int = executor.threadPool.getActiveCount() + }) + + // Gauge for executor thread pool's approximate total number of tasks that have been completed + metricRegistry.register(MetricRegistry.name("threadpool", "completeTask", "count"), new Gauge[Long] { + override def getValue: Long = executor.threadPool.getCompletedTaskCount() + }) + + // Gauge for executor thread pool's current number of threads + metricRegistry.register(MetricRegistry.name("threadpool", "currentPool", "size"), new Gauge[Int] { + override def getValue: Int = executor.threadPool.getPoolSize() + }) + + // Gauge got executor thread pool's largest number of threads that have ever simultaneously been in th pool + metricRegistry.register(MetricRegistry.name("threadpool", "maxPool", "size"), new Gauge[Int] { + override def getValue: Int = executor.threadPool.getMaximumPoolSize() + }) + + // Gauge for file system stats of this executor + for (scheme <- Array("hdfs", "file")) { + registerFileSystemStat(scheme, "bytesRead", _.getBytesRead(), 0L) + registerFileSystemStat(scheme, "bytesWritten", _.getBytesWritten(), 0L) + registerFileSystemStat(scheme, "readOps", _.getReadOps(), 0) + registerFileSystemStat(scheme, "largeReadOps", _.getLargeReadOps(), 0) + registerFileSystemStat(scheme, "writeOps", _.getWriteOps(), 0) + } +} diff --git a/core/src/main/scala/org/apache/spark/executor/ExecutorURLClassLoader.scala b/core/src/main/scala/org/apache/spark/executor/ExecutorURLClassLoader.scala new file mode 100644 index 0000000000..f9bfe8ed2f --- /dev/null +++ b/core/src/main/scala/org/apache/spark/executor/ExecutorURLClassLoader.scala @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.executor + +import java.net.{URLClassLoader, URL} + +/** + * The addURL method in URLClassLoader is protected. We subclass it to make this accessible. + */ +private[spark] class ExecutorURLClassLoader(urls: Array[URL], parent: ClassLoader) + extends URLClassLoader(urls, parent) { + + override def addURL(url: URL) { + super.addURL(url) + } +} diff --git a/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala new file mode 100644 index 0000000000..da62091980 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.executor + +import java.nio.ByteBuffer +import org.apache.mesos.{Executor => MesosExecutor, MesosExecutorDriver, MesosNativeLibrary, ExecutorDriver} +import org.apache.mesos.Protos.{TaskState => MesosTaskState, TaskStatus => MesosTaskStatus, _} +import org.apache.spark.TaskState.TaskState +import com.google.protobuf.ByteString +import org.apache.spark.{Logging} +import org.apache.spark.TaskState +import org.apache.spark.util.Utils + +private[spark] class MesosExecutorBackend + extends MesosExecutor + with ExecutorBackend + with Logging { + + var executor: Executor = null + var driver: ExecutorDriver = null + + override def statusUpdate(taskId: Long, state: TaskState, data: ByteBuffer) { + val mesosTaskId = TaskID.newBuilder().setValue(taskId.toString).build() + driver.sendStatusUpdate(MesosTaskStatus.newBuilder() + .setTaskId(mesosTaskId) + .setState(TaskState.toMesos(state)) + .setData(ByteString.copyFrom(data)) + .build()) + } + + override def registered( + driver: ExecutorDriver, + executorInfo: ExecutorInfo, + frameworkInfo: FrameworkInfo, + slaveInfo: SlaveInfo) { + logInfo("Registered with Mesos as executor ID " + executorInfo.getExecutorId.getValue) + this.driver = driver + val properties = Utils.deserialize[Array[(String, String)]](executorInfo.getData.toByteArray) + executor = new Executor( + executorInfo.getExecutorId.getValue, + slaveInfo.getHostname, + properties) + } + + override def launchTask(d: ExecutorDriver, taskInfo: TaskInfo) { + val taskId = taskInfo.getTaskId.getValue.toLong + if (executor == null) { + logError("Received launchTask but executor was null") + } else { + executor.launchTask(this, taskId, taskInfo.getData.asReadOnlyByteBuffer) + } + } + + override def error(d: ExecutorDriver, message: String) { + logError("Error from Mesos: " + message) + } + + override def killTask(d: ExecutorDriver, t: TaskID) { + logWarning("Mesos asked us to kill task " + t.getValue + "; ignoring (not yet implemented)") + } + + override def reregistered(d: ExecutorDriver, p2: SlaveInfo) {} + + override def disconnected(d: ExecutorDriver) {} + + override def frameworkMessage(d: ExecutorDriver, data: Array[Byte]) {} + + override def shutdown(d: ExecutorDriver) {} +} + +/** + * Entry point for Mesos executor. + */ +private[spark] object MesosExecutorBackend { + def main(args: Array[String]) { + MesosNativeLibrary.load() + // Create a new Executor and start it running + val runner = new MesosExecutorBackend() + new MesosExecutorDriver(runner).run() + } +} diff --git a/core/src/main/scala/org/apache/spark/executor/StandaloneExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/StandaloneExecutorBackend.scala new file mode 100644 index 0000000000..7839023868 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/executor/StandaloneExecutorBackend.scala @@ -0,0 +1,107 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.executor + +import java.nio.ByteBuffer + +import akka.actor.{ActorRef, Actor, Props, Terminated} +import akka.remote.{RemoteClientLifeCycleEvent, RemoteClientShutdown, RemoteClientDisconnected} + +import org.apache.spark.{Logging, SparkEnv} +import org.apache.spark.TaskState.TaskState +import org.apache.spark.scheduler.cluster.StandaloneClusterMessages._ +import org.apache.spark.util.{Utils, AkkaUtils} + + +private[spark] class StandaloneExecutorBackend( + driverUrl: String, + executorId: String, + hostPort: String, + cores: Int) + extends Actor + with ExecutorBackend + with Logging { + + Utils.checkHostPort(hostPort, "Expected hostport") + + var executor: Executor = null + var driver: ActorRef = null + + override def preStart() { + logInfo("Connecting to driver: " + driverUrl) + driver = context.actorFor(driverUrl) + driver ! RegisterExecutor(executorId, hostPort, cores) + context.system.eventStream.subscribe(self, classOf[RemoteClientLifeCycleEvent]) + context.watch(driver) // Doesn't work with remote actors, but useful for testing + } + + override def receive = { + case RegisteredExecutor(sparkProperties) => + logInfo("Successfully registered with driver") + // Make this host instead of hostPort ? + executor = new Executor(executorId, Utils.parseHostPort(hostPort)._1, sparkProperties) + + case RegisterExecutorFailed(message) => + logError("Slave registration failed: " + message) + System.exit(1) + + case LaunchTask(taskDesc) => + logInfo("Got assigned task " + taskDesc.taskId) + if (executor == null) { + logError("Received launchTask but executor was null") + System.exit(1) + } else { + executor.launchTask(this, taskDesc.taskId, taskDesc.serializedTask) + } + + case Terminated(_) | RemoteClientDisconnected(_, _) | RemoteClientShutdown(_, _) => + logError("Driver terminated or disconnected! Shutting down.") + System.exit(1) + } + + override def statusUpdate(taskId: Long, state: TaskState, data: ByteBuffer) { + driver ! StatusUpdate(executorId, taskId, state, data) + } +} + +private[spark] object StandaloneExecutorBackend { + def run(driverUrl: String, executorId: String, hostname: String, cores: Int) { + // Debug code + Utils.checkHost(hostname) + + // Create a new ActorSystem to run the backend, because we can't create a SparkEnv / Executor + // before getting started with all our system properties, etc + val (actorSystem, boundPort) = AkkaUtils.createActorSystem("sparkExecutor", hostname, 0) + // set it + val sparkHostPort = hostname + ":" + boundPort + System.setProperty("spark.hostPort", sparkHostPort) + val actor = actorSystem.actorOf( + Props(new StandaloneExecutorBackend(driverUrl, executorId, sparkHostPort, cores)), + name = "Executor") + actorSystem.awaitTermination() + } + + def main(args: Array[String]) { + if (args.length < 4) { + //the reason we allow the last frameworkId argument is to make it easy to kill rogue executors + System.err.println("Usage: StandaloneExecutorBackend <driverUrl> <executorId> <hostname> <cores> [<appid>]") + System.exit(1) + } + run(args(0), args(1), args(2), args(3).toInt) + } +} diff --git a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala new file mode 100644 index 0000000000..f311141148 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.executor + +class TaskMetrics extends Serializable { + /** + * Host's name the task runs on + */ + var hostname: String = _ + + /** + * Time taken on the executor to deserialize this task + */ + var executorDeserializeTime: Int = _ + + /** + * Time the executor spends actually running the task (including fetching shuffle data) + */ + var executorRunTime: Int = _ + + /** + * The number of bytes this task transmitted back to the driver as the TaskResult + */ + var resultSize: Long = _ + + /** + * Amount of time the JVM spent in garbage collection while executing this task + */ + var jvmGCTime: Long = _ + + /** + * If this task reads from shuffle output, metrics on getting shuffle data will be collected here + */ + var shuffleReadMetrics: Option[ShuffleReadMetrics] = None + + /** + * If this task writes to shuffle output, metrics on the written shuffle data will be collected here + */ + var shuffleWriteMetrics: Option[ShuffleWriteMetrics] = None +} + +object TaskMetrics { + private[spark] def empty(): TaskMetrics = new TaskMetrics +} + + +class ShuffleReadMetrics extends Serializable { + /** + * Time when shuffle finishs + */ + var shuffleFinishTime: Long = _ + + /** + * Total number of blocks fetched in a shuffle (remote or local) + */ + var totalBlocksFetched: Int = _ + + /** + * Number of remote blocks fetched in a shuffle + */ + var remoteBlocksFetched: Int = _ + + /** + * Local blocks fetched in a shuffle + */ + var localBlocksFetched: Int = _ + + /** + * Total time that is spent blocked waiting for shuffle to fetch data + */ + var fetchWaitTime: Long = _ + + /** + * The total amount of time for all the shuffle fetches. This adds up time from overlapping + * shuffles, so can be longer than task time + */ + var remoteFetchTime: Long = _ + + /** + * Total number of remote bytes read from a shuffle + */ + var remoteBytesRead: Long = _ +} + +class ShuffleWriteMetrics extends Serializable { + /** + * Number of bytes written for a shuffle + */ + var shuffleBytesWritten: Long = _ +} diff --git a/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala b/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala new file mode 100644 index 0000000000..570a979b56 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.io + +import java.io.{InputStream, OutputStream} + +import com.ning.compress.lzf.{LZFInputStream, LZFOutputStream} + +import org.xerial.snappy.{SnappyInputStream, SnappyOutputStream} + + +/** + * CompressionCodec allows the customization of choosing different compression implementations + * to be used in block storage. + */ +trait CompressionCodec { + + def compressedOutputStream(s: OutputStream): OutputStream + + def compressedInputStream(s: InputStream): InputStream +} + + +private[spark] object CompressionCodec { + + def createCodec(): CompressionCodec = { + createCodec(System.getProperty( + "spark.io.compression.codec", classOf[LZFCompressionCodec].getName)) + } + + def createCodec(codecName: String): CompressionCodec = { + Class.forName(codecName, true, Thread.currentThread.getContextClassLoader) + .newInstance().asInstanceOf[CompressionCodec] + } +} + + +/** + * LZF implementation of [[org.apache.spark.io.CompressionCodec]]. + */ +class LZFCompressionCodec extends CompressionCodec { + + override def compressedOutputStream(s: OutputStream): OutputStream = { + new LZFOutputStream(s).setFinishBlockOnFlush(true) + } + + override def compressedInputStream(s: InputStream): InputStream = new LZFInputStream(s) +} + + +/** + * Snappy implementation of [[org.apache.spark.io.CompressionCodec]]. + * Block size can be configured by spark.io.compression.snappy.block.size. + */ +class SnappyCompressionCodec extends CompressionCodec { + + override def compressedOutputStream(s: OutputStream): OutputStream = { + val blockSize = System.getProperty("spark.io.compression.snappy.block.size", "32768").toInt + new SnappyOutputStream(s, blockSize) + } + + override def compressedInputStream(s: InputStream): InputStream = new SnappyInputStream(s) +} diff --git a/core/src/main/scala/org/apache/spark/metrics/MetricsConfig.scala b/core/src/main/scala/org/apache/spark/metrics/MetricsConfig.scala new file mode 100644 index 0000000000..0f9c4e00b1 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/metrics/MetricsConfig.scala @@ -0,0 +1,100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.metrics + +import java.util.Properties +import java.io.{File, FileInputStream, InputStream, IOException} + +import scala.collection.mutable +import scala.util.matching.Regex + +import org.apache.spark.Logging + +private[spark] class MetricsConfig(val configFile: Option[String]) extends Logging { + initLogging() + + val DEFAULT_PREFIX = "*" + val INSTANCE_REGEX = "^(\\*|[a-zA-Z]+)\\.(.+)".r + val METRICS_CONF = "metrics.properties" + + val properties = new Properties() + var propertyCategories: mutable.HashMap[String, Properties] = null + + private def setDefaultProperties(prop: Properties) { + prop.setProperty("*.sink.servlet.class", "org.apache.spark.metrics.sink.MetricsServlet") + prop.setProperty("*.sink.servlet.uri", "/metrics/json") + prop.setProperty("*.sink.servlet.sample", "false") + prop.setProperty("master.sink.servlet.uri", "/metrics/master/json") + prop.setProperty("applications.sink.servlet.uri", "/metrics/applications/json") + } + + def initialize() { + //Add default properties in case there's no properties file + setDefaultProperties(properties) + + // If spark.metrics.conf is not set, try to get file in class path + var is: InputStream = null + try { + is = configFile match { + case Some(f) => new FileInputStream(f) + case None => getClass.getClassLoader.getResourceAsStream(METRICS_CONF) + } + + if (is != null) { + properties.load(is) + } + } catch { + case e: Exception => logError("Error loading configure file", e) + } finally { + if (is != null) is.close() + } + + propertyCategories = subProperties(properties, INSTANCE_REGEX) + if (propertyCategories.contains(DEFAULT_PREFIX)) { + import scala.collection.JavaConversions._ + + val defaultProperty = propertyCategories(DEFAULT_PREFIX) + for { (inst, prop) <- propertyCategories + if (inst != DEFAULT_PREFIX) + (k, v) <- defaultProperty + if (prop.getProperty(k) == null) } { + prop.setProperty(k, v) + } + } + } + + def subProperties(prop: Properties, regex: Regex): mutable.HashMap[String, Properties] = { + val subProperties = new mutable.HashMap[String, Properties] + import scala.collection.JavaConversions._ + prop.foreach { kv => + if (regex.findPrefixOf(kv._1) != None) { + val regex(prefix, suffix) = kv._1 + subProperties.getOrElseUpdate(prefix, new Properties).setProperty(suffix, kv._2) + } + } + subProperties + } + + def getInstance(inst: String): Properties = { + propertyCategories.get(inst) match { + case Some(s) => s + case None => propertyCategories.getOrElse(DEFAULT_PREFIX, new Properties) + } + } +} + diff --git a/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala b/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala new file mode 100644 index 0000000000..bec0c83be8 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala @@ -0,0 +1,163 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.metrics + +import com.codahale.metrics.{Metric, MetricFilter, MetricRegistry} + +import java.util.Properties +import java.util.concurrent.TimeUnit + +import scala.collection.mutable + +import org.apache.spark.Logging +import org.apache.spark.metrics.sink.{MetricsServlet, Sink} +import org.apache.spark.metrics.source.Source + +/** + * Spark Metrics System, created by specific "instance", combined by source, + * sink, periodically poll source metrics data to sink destinations. + * + * "instance" specify "who" (the role) use metrics system. In spark there are several roles + * like master, worker, executor, client driver, these roles will create metrics system + * for monitoring. So instance represents these roles. Currently in Spark, several instances + * have already implemented: master, worker, executor, driver, applications. + * + * "source" specify "where" (source) to collect metrics data. In metrics system, there exists + * two kinds of source: + * 1. Spark internal source, like MasterSource, WorkerSource, etc, which will collect + * Spark component's internal state, these sources are related to instance and will be + * added after specific metrics system is created. + * 2. Common source, like JvmSource, which will collect low level state, is configured by + * configuration and loaded through reflection. + * + * "sink" specify "where" (destination) to output metrics data to. Several sinks can be + * coexisted and flush metrics to all these sinks. + * + * Metrics configuration format is like below: + * [instance].[sink|source].[name].[options] = xxxx + * + * [instance] can be "master", "worker", "executor", "driver", "applications" which means only + * the specified instance has this property. + * wild card "*" can be used to replace instance name, which means all the instances will have + * this property. + * + * [sink|source] means this property belongs to source or sink. This field can only be source or sink. + * + * [name] specify the name of sink or source, it is custom defined. + * + * [options] is the specific property of this source or sink. + */ +private[spark] class MetricsSystem private (val instance: String) extends Logging { + initLogging() + + val confFile = System.getProperty("spark.metrics.conf") + val metricsConfig = new MetricsConfig(Option(confFile)) + + val sinks = new mutable.ArrayBuffer[Sink] + val sources = new mutable.ArrayBuffer[Source] + val registry = new MetricRegistry() + + // Treat MetricsServlet as a special sink as it should be exposed to add handlers to web ui + private var metricsServlet: Option[MetricsServlet] = None + + /** Get any UI handlers used by this metrics system. */ + def getServletHandlers = metricsServlet.map(_.getHandlers).getOrElse(Array()) + + metricsConfig.initialize() + registerSources() + registerSinks() + + def start() { + sinks.foreach(_.start) + } + + def stop() { + sinks.foreach(_.stop) + } + + def registerSource(source: Source) { + sources += source + try { + registry.register(source.sourceName, source.metricRegistry) + } catch { + case e: IllegalArgumentException => logInfo("Metrics already registered", e) + } + } + + def removeSource(source: Source) { + sources -= source + registry.removeMatching(new MetricFilter { + def matches(name: String, metric: Metric): Boolean = name.startsWith(source.sourceName) + }) + } + + def registerSources() { + val instConfig = metricsConfig.getInstance(instance) + val sourceConfigs = metricsConfig.subProperties(instConfig, MetricsSystem.SOURCE_REGEX) + + // Register all the sources related to instance + sourceConfigs.foreach { kv => + val classPath = kv._2.getProperty("class") + try { + val source = Class.forName(classPath).newInstance() + registerSource(source.asInstanceOf[Source]) + } catch { + case e: Exception => logError("Source class " + classPath + " cannot be instantialized", e) + } + } + } + + def registerSinks() { + val instConfig = metricsConfig.getInstance(instance) + val sinkConfigs = metricsConfig.subProperties(instConfig, MetricsSystem.SINK_REGEX) + + sinkConfigs.foreach { kv => + val classPath = kv._2.getProperty("class") + try { + val sink = Class.forName(classPath) + .getConstructor(classOf[Properties], classOf[MetricRegistry]) + .newInstance(kv._2, registry) + if (kv._1 == "servlet") { + metricsServlet = Some(sink.asInstanceOf[MetricsServlet]) + } else { + sinks += sink.asInstanceOf[Sink] + } + } catch { + case e: Exception => logError("Sink class " + classPath + " cannot be instantialized", e) + } + } + } +} + +private[spark] object MetricsSystem { + val SINK_REGEX = "^sink\\.(.+)\\.(.+)".r + val SOURCE_REGEX = "^source\\.(.+)\\.(.+)".r + + val MINIMAL_POLL_UNIT = TimeUnit.SECONDS + val MINIMAL_POLL_PERIOD = 1 + + def checkMinimalPollingPeriod(pollUnit: TimeUnit, pollPeriod: Int) { + val period = MINIMAL_POLL_UNIT.convert(pollPeriod, pollUnit) + if (period < MINIMAL_POLL_PERIOD) { + throw new IllegalArgumentException("Polling period " + pollPeriod + " " + pollUnit + + " below than minimal polling period ") + } + } + + def createMetricsSystem(instance: String): MetricsSystem = new MetricsSystem(instance) +} diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/ConsoleSink.scala b/core/src/main/scala/org/apache/spark/metrics/sink/ConsoleSink.scala new file mode 100644 index 0000000000..bce257d6e6 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/metrics/sink/ConsoleSink.scala @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.metrics.sink + +import com.codahale.metrics.{ConsoleReporter, MetricRegistry} + +import java.util.Properties +import java.util.concurrent.TimeUnit + +import org.apache.spark.metrics.MetricsSystem + +class ConsoleSink(val property: Properties, val registry: MetricRegistry) extends Sink { + val CONSOLE_DEFAULT_PERIOD = 10 + val CONSOLE_DEFAULT_UNIT = "SECONDS" + + val CONSOLE_KEY_PERIOD = "period" + val CONSOLE_KEY_UNIT = "unit" + + val pollPeriod = Option(property.getProperty(CONSOLE_KEY_PERIOD)) match { + case Some(s) => s.toInt + case None => CONSOLE_DEFAULT_PERIOD + } + + val pollUnit = Option(property.getProperty(CONSOLE_KEY_UNIT)) match { + case Some(s) => TimeUnit.valueOf(s.toUpperCase()) + case None => TimeUnit.valueOf(CONSOLE_DEFAULT_UNIT) + } + + MetricsSystem.checkMinimalPollingPeriod(pollUnit, pollPeriod) + + val reporter: ConsoleReporter = ConsoleReporter.forRegistry(registry) + .convertDurationsTo(TimeUnit.MILLISECONDS) + .convertRatesTo(TimeUnit.SECONDS) + .build() + + override def start() { + reporter.start(pollPeriod, pollUnit) + } + + override def stop() { + reporter.stop() + } +} + diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/CsvSink.scala b/core/src/main/scala/org/apache/spark/metrics/sink/CsvSink.scala new file mode 100644 index 0000000000..3d1a06a395 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/metrics/sink/CsvSink.scala @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.metrics.sink + +import com.codahale.metrics.{CsvReporter, MetricRegistry} + +import java.io.File +import java.util.{Locale, Properties} +import java.util.concurrent.TimeUnit + +import org.apache.spark.metrics.MetricsSystem + +class CsvSink(val property: Properties, val registry: MetricRegistry) extends Sink { + val CSV_KEY_PERIOD = "period" + val CSV_KEY_UNIT = "unit" + val CSV_KEY_DIR = "directory" + + val CSV_DEFAULT_PERIOD = 10 + val CSV_DEFAULT_UNIT = "SECONDS" + val CSV_DEFAULT_DIR = "/tmp/" + + val pollPeriod = Option(property.getProperty(CSV_KEY_PERIOD)) match { + case Some(s) => s.toInt + case None => CSV_DEFAULT_PERIOD + } + + val pollUnit = Option(property.getProperty(CSV_KEY_UNIT)) match { + case Some(s) => TimeUnit.valueOf(s.toUpperCase()) + case None => TimeUnit.valueOf(CSV_DEFAULT_UNIT) + } + + MetricsSystem.checkMinimalPollingPeriod(pollUnit, pollPeriod) + + val pollDir = Option(property.getProperty(CSV_KEY_DIR)) match { + case Some(s) => s + case None => CSV_DEFAULT_DIR + } + + val reporter: CsvReporter = CsvReporter.forRegistry(registry) + .formatFor(Locale.US) + .convertDurationsTo(TimeUnit.MILLISECONDS) + .convertRatesTo(TimeUnit.SECONDS) + .build(new File(pollDir)) + + override def start() { + reporter.start(pollPeriod, pollUnit) + } + + override def stop() { + reporter.stop() + } +} + diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/JmxSink.scala b/core/src/main/scala/org/apache/spark/metrics/sink/JmxSink.scala new file mode 100644 index 0000000000..621d086d41 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/metrics/sink/JmxSink.scala @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.metrics.sink + +import com.codahale.metrics.{JmxReporter, MetricRegistry} + +import java.util.Properties + +class JmxSink(val property: Properties, val registry: MetricRegistry) extends Sink { + val reporter: JmxReporter = JmxReporter.forRegistry(registry).build() + + override def start() { + reporter.start() + } + + override def stop() { + reporter.stop() + } + +} diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/MetricsServlet.scala b/core/src/main/scala/org/apache/spark/metrics/sink/MetricsServlet.scala new file mode 100644 index 0000000000..4e90dd4323 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/metrics/sink/MetricsServlet.scala @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.metrics.sink + +import com.codahale.metrics.MetricRegistry +import com.codahale.metrics.json.MetricsModule + +import com.fasterxml.jackson.databind.ObjectMapper + +import java.util.Properties +import java.util.concurrent.TimeUnit +import javax.servlet.http.HttpServletRequest + +import org.eclipse.jetty.server.Handler + +import org.apache.spark.ui.JettyUtils + +class MetricsServlet(val property: Properties, val registry: MetricRegistry) extends Sink { + val SERVLET_KEY_URI = "uri" + val SERVLET_KEY_SAMPLE = "sample" + + val servletURI = property.getProperty(SERVLET_KEY_URI) + + val servletShowSample = property.getProperty(SERVLET_KEY_SAMPLE).toBoolean + + val mapper = new ObjectMapper().registerModule( + new MetricsModule(TimeUnit.SECONDS, TimeUnit.MILLISECONDS, servletShowSample)) + + def getHandlers = Array[(String, Handler)]( + (servletURI, JettyUtils.createHandler(request => getMetricsSnapshot(request), "text/json")) + ) + + def getMetricsSnapshot(request: HttpServletRequest): String = { + mapper.writeValueAsString(registry) + } + + override def start() { } + + override def stop() { } +} diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/Sink.scala b/core/src/main/scala/org/apache/spark/metrics/sink/Sink.scala new file mode 100644 index 0000000000..3a739aa563 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/metrics/sink/Sink.scala @@ -0,0 +1,23 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.metrics.sink + +trait Sink { + def start: Unit + def stop: Unit +} diff --git a/core/src/main/scala/org/apache/spark/metrics/source/JvmSource.scala b/core/src/main/scala/org/apache/spark/metrics/source/JvmSource.scala new file mode 100644 index 0000000000..75cb2b8973 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/metrics/source/JvmSource.scala @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.metrics.source + +import com.codahale.metrics.MetricRegistry +import com.codahale.metrics.jvm.{GarbageCollectorMetricSet, MemoryUsageGaugeSet} + +class JvmSource extends Source { + val sourceName = "jvm" + val metricRegistry = new MetricRegistry() + + val gcMetricSet = new GarbageCollectorMetricSet + val memGaugeSet = new MemoryUsageGaugeSet + + metricRegistry.registerAll(gcMetricSet) + metricRegistry.registerAll(memGaugeSet) +} diff --git a/core/src/main/scala/org/apache/spark/metrics/source/Source.scala b/core/src/main/scala/org/apache/spark/metrics/source/Source.scala new file mode 100644 index 0000000000..3fee55cc6d --- /dev/null +++ b/core/src/main/scala/org/apache/spark/metrics/source/Source.scala @@ -0,0 +1,25 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.metrics.source + +import com.codahale.metrics.MetricRegistry + +trait Source { + def sourceName: String + def metricRegistry: MetricRegistry +} diff --git a/core/src/main/scala/org/apache/spark/network/BufferMessage.scala b/core/src/main/scala/org/apache/spark/network/BufferMessage.scala new file mode 100644 index 0000000000..f736bb3713 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/network/BufferMessage.scala @@ -0,0 +1,111 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network + +import java.nio.ByteBuffer + +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.storage.BlockManager + + +private[spark] +class BufferMessage(id_ : Int, val buffers: ArrayBuffer[ByteBuffer], var ackId: Int) + extends Message(Message.BUFFER_MESSAGE, id_) { + + val initialSize = currentSize() + var gotChunkForSendingOnce = false + + def size = initialSize + + def currentSize() = { + if (buffers == null || buffers.isEmpty) { + 0 + } else { + buffers.map(_.remaining).reduceLeft(_ + _) + } + } + + def getChunkForSending(maxChunkSize: Int): Option[MessageChunk] = { + if (maxChunkSize <= 0) { + throw new Exception("Max chunk size is " + maxChunkSize) + } + + if (size == 0 && gotChunkForSendingOnce == false) { + val newChunk = new MessageChunk( + new MessageChunkHeader(typ, id, 0, 0, ackId, senderAddress), null) + gotChunkForSendingOnce = true + return Some(newChunk) + } + + while(!buffers.isEmpty) { + val buffer = buffers(0) + if (buffer.remaining == 0) { + BlockManager.dispose(buffer) + buffers -= buffer + } else { + val newBuffer = if (buffer.remaining <= maxChunkSize) { + buffer.duplicate() + } else { + buffer.slice().limit(maxChunkSize).asInstanceOf[ByteBuffer] + } + buffer.position(buffer.position + newBuffer.remaining) + val newChunk = new MessageChunk(new MessageChunkHeader( + typ, id, size, newBuffer.remaining, ackId, senderAddress), newBuffer) + gotChunkForSendingOnce = true + return Some(newChunk) + } + } + None + } + + def getChunkForReceiving(chunkSize: Int): Option[MessageChunk] = { + // STRONG ASSUMPTION: BufferMessage created when receiving data has ONLY ONE data buffer + if (buffers.size > 1) { + throw new Exception("Attempting to get chunk from message with multiple data buffers") + } + val buffer = buffers(0) + if (buffer.remaining > 0) { + if (buffer.remaining < chunkSize) { + throw new Exception("Not enough space in data buffer for receiving chunk") + } + val newBuffer = buffer.slice().limit(chunkSize).asInstanceOf[ByteBuffer] + buffer.position(buffer.position + newBuffer.remaining) + val newChunk = new MessageChunk(new MessageChunkHeader( + typ, id, size, newBuffer.remaining, ackId, senderAddress), newBuffer) + return Some(newChunk) + } + None + } + + def flip() { + buffers.foreach(_.flip) + } + + def hasAckId() = (ackId != 0) + + def isCompletelyReceived() = !buffers(0).hasRemaining + + override def toString = { + if (hasAckId) { + "BufferAckMessage(aid = " + ackId + ", id = " + id + ", size = " + size + ")" + } else { + "BufferMessage(id = " + id + ", size = " + size + ")" + } + } +} diff --git a/core/src/main/scala/org/apache/spark/network/Connection.scala b/core/src/main/scala/org/apache/spark/network/Connection.scala new file mode 100644 index 0000000000..95cb0206ac --- /dev/null +++ b/core/src/main/scala/org/apache/spark/network/Connection.scala @@ -0,0 +1,586 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network + +import org.apache.spark._ + +import scala.collection.mutable.{HashMap, Queue, ArrayBuffer} + +import java.io._ +import java.nio._ +import java.nio.channels._ +import java.nio.channels.spi._ +import java.net._ + + +private[spark] +abstract class Connection(val channel: SocketChannel, val selector: Selector, + val socketRemoteConnectionManagerId: ConnectionManagerId) + extends Logging { + + def this(channel_ : SocketChannel, selector_ : Selector) = { + this(channel_, selector_, + ConnectionManagerId.fromSocketAddress( + channel_.socket.getRemoteSocketAddress().asInstanceOf[InetSocketAddress])) + } + + channel.configureBlocking(false) + channel.socket.setTcpNoDelay(true) + channel.socket.setReuseAddress(true) + channel.socket.setKeepAlive(true) + /*channel.socket.setReceiveBufferSize(32768) */ + + @volatile private var closed = false + var onCloseCallback: Connection => Unit = null + var onExceptionCallback: (Connection, Exception) => Unit = null + var onKeyInterestChangeCallback: (Connection, Int) => Unit = null + + val remoteAddress = getRemoteAddress() + + def resetForceReregister(): Boolean + + // Read channels typically do not register for write and write does not for read + // Now, we do have write registering for read too (temporarily), but this is to detect + // channel close NOT to actually read/consume data on it ! + // How does this work if/when we move to SSL ? + + // What is the interest to register with selector for when we want this connection to be selected + def registerInterest() + + // What is the interest to register with selector for when we want this connection to + // be de-selected + // Traditionally, 0 - but in our case, for example, for close-detection on SendingConnection hack, + // it will be SelectionKey.OP_READ (until we fix it properly) + def unregisterInterest() + + // On receiving a read event, should we change the interest for this channel or not ? + // Will be true for ReceivingConnection, false for SendingConnection. + def changeInterestForRead(): Boolean + + // On receiving a write event, should we change the interest for this channel or not ? + // Will be false for ReceivingConnection, true for SendingConnection. + // Actually, for now, should not get triggered for ReceivingConnection + def changeInterestForWrite(): Boolean + + def getRemoteConnectionManagerId(): ConnectionManagerId = { + socketRemoteConnectionManagerId + } + + def key() = channel.keyFor(selector) + + def getRemoteAddress() = channel.socket.getRemoteSocketAddress().asInstanceOf[InetSocketAddress] + + // Returns whether we have to register for further reads or not. + def read(): Boolean = { + throw new UnsupportedOperationException( + "Cannot read on connection of type " + this.getClass.toString) + } + + // Returns whether we have to register for further writes or not. + def write(): Boolean = { + throw new UnsupportedOperationException( + "Cannot write on connection of type " + this.getClass.toString) + } + + def close() { + closed = true + val k = key() + if (k != null) { + k.cancel() + } + channel.close() + callOnCloseCallback() + } + + protected def isClosed: Boolean = closed + + def onClose(callback: Connection => Unit) { + onCloseCallback = callback + } + + def onException(callback: (Connection, Exception) => Unit) { + onExceptionCallback = callback + } + + def onKeyInterestChange(callback: (Connection, Int) => Unit) { + onKeyInterestChangeCallback = callback + } + + def callOnExceptionCallback(e: Exception) { + if (onExceptionCallback != null) { + onExceptionCallback(this, e) + } else { + logError("Error in connection to " + getRemoteConnectionManagerId() + + " and OnExceptionCallback not registered", e) + } + } + + def callOnCloseCallback() { + if (onCloseCallback != null) { + onCloseCallback(this) + } else { + logWarning("Connection to " + getRemoteConnectionManagerId() + + " closed and OnExceptionCallback not registered") + } + + } + + def changeConnectionKeyInterest(ops: Int) { + if (onKeyInterestChangeCallback != null) { + onKeyInterestChangeCallback(this, ops) + } else { + throw new Exception("OnKeyInterestChangeCallback not registered") + } + } + + def printRemainingBuffer(buffer: ByteBuffer) { + val bytes = new Array[Byte](buffer.remaining) + val curPosition = buffer.position + buffer.get(bytes) + bytes.foreach(x => print(x + " ")) + buffer.position(curPosition) + print(" (" + bytes.size + ")") + } + + def printBuffer(buffer: ByteBuffer, position: Int, length: Int) { + val bytes = new Array[Byte](length) + val curPosition = buffer.position + buffer.position(position) + buffer.get(bytes) + bytes.foreach(x => print(x + " ")) + print(" (" + position + ", " + length + ")") + buffer.position(curPosition) + } +} + + +private[spark] +class SendingConnection(val address: InetSocketAddress, selector_ : Selector, + remoteId_ : ConnectionManagerId) + extends Connection(SocketChannel.open, selector_, remoteId_) { + + private class Outbox(fair: Int = 0) { + val messages = new Queue[Message]() + val defaultChunkSize = 65536 //32768 //16384 + var nextMessageToBeUsed = 0 + + def addMessage(message: Message) { + messages.synchronized{ + /*messages += message*/ + messages.enqueue(message) + logDebug("Added [" + message + "] to outbox for sending to " + + "[" + getRemoteConnectionManagerId() + "]") + } + } + + def getChunk(): Option[MessageChunk] = { + fair match { + case 0 => getChunkFIFO() + case 1 => getChunkRR() + case _ => throw new Exception("Unexpected fairness policy in outbox") + } + } + + private def getChunkFIFO(): Option[MessageChunk] = { + /*logInfo("Using FIFO")*/ + messages.synchronized { + while (!messages.isEmpty) { + val message = messages(0) + val chunk = message.getChunkForSending(defaultChunkSize) + if (chunk.isDefined) { + messages += message // this is probably incorrect, it wont work as fifo + if (!message.started) { + logDebug("Starting to send [" + message + "]") + message.started = true + message.startTime = System.currentTimeMillis + } + return chunk + } else { + /*logInfo("Finished sending [" + message + "] to [" + getRemoteConnectionManagerId() + "]")*/ + message.finishTime = System.currentTimeMillis + logDebug("Finished sending [" + message + "] to [" + getRemoteConnectionManagerId() + + "] in " + message.timeTaken ) + } + } + } + None + } + + private def getChunkRR(): Option[MessageChunk] = { + messages.synchronized { + while (!messages.isEmpty) { + /*nextMessageToBeUsed = nextMessageToBeUsed % messages.size */ + /*val message = messages(nextMessageToBeUsed)*/ + val message = messages.dequeue + val chunk = message.getChunkForSending(defaultChunkSize) + if (chunk.isDefined) { + messages.enqueue(message) + nextMessageToBeUsed = nextMessageToBeUsed + 1 + if (!message.started) { + logDebug( + "Starting to send [" + message + "] to [" + getRemoteConnectionManagerId() + "]") + message.started = true + message.startTime = System.currentTimeMillis + } + logTrace( + "Sending chunk from [" + message+ "] to [" + getRemoteConnectionManagerId() + "]") + return chunk + } else { + message.finishTime = System.currentTimeMillis + logDebug("Finished sending [" + message + "] to [" + getRemoteConnectionManagerId() + + "] in " + message.timeTaken ) + } + } + } + None + } + } + + // outbox is used as a lock - ensure that it is always used as a leaf (since methods which + // lock it are invoked in context of other locks) + private val outbox = new Outbox(1) + /* + This is orthogonal to whether we have pending bytes to write or not - and satisfies a slightly + different purpose. This flag is to see if we need to force reregister for write even when we + do not have any pending bytes to write to socket. + This can happen due to a race between adding pending buffers, and checking for existing of + data as detailed in https://github.com/mesos/spark/pull/791 + */ + private var needForceReregister = false + val currentBuffers = new ArrayBuffer[ByteBuffer]() + + /*channel.socket.setSendBufferSize(256 * 1024)*/ + + override def getRemoteAddress() = address + + val DEFAULT_INTEREST = SelectionKey.OP_READ + + override def registerInterest() { + // Registering read too - does not really help in most cases, but for some + // it does - so let us keep it for now. + changeConnectionKeyInterest(SelectionKey.OP_WRITE | DEFAULT_INTEREST) + } + + override def unregisterInterest() { + changeConnectionKeyInterest(DEFAULT_INTEREST) + } + + def send(message: Message) { + outbox.synchronized { + outbox.addMessage(message) + needForceReregister = true + } + if (channel.isConnected) { + registerInterest() + } + } + + // return previous value after resetting it. + def resetForceReregister(): Boolean = { + outbox.synchronized { + val result = needForceReregister + needForceReregister = false + result + } + } + + // MUST be called within the selector loop + def connect() { + try{ + channel.register(selector, SelectionKey.OP_CONNECT) + channel.connect(address) + logInfo("Initiating connection to [" + address + "]") + } catch { + case e: Exception => { + logError("Error connecting to " + address, e) + callOnExceptionCallback(e) + } + } + } + + def finishConnect(force: Boolean): Boolean = { + try { + // Typically, this should finish immediately since it was triggered by a connect + // selection - though need not necessarily always complete successfully. + val connected = channel.finishConnect + if (!force && !connected) { + logInfo( + "finish connect failed [" + address + "], " + outbox.messages.size + " messages pending") + return false + } + + // Fallback to previous behavior - assume finishConnect completed + // This will happen only when finishConnect failed for some repeated number of times + // (10 or so) + // Is highly unlikely unless there was an unclean close of socket, etc + registerInterest() + logInfo("Connected to [" + address + "], " + outbox.messages.size + " messages pending") + return true + } catch { + case e: Exception => { + logWarning("Error finishing connection to " + address, e) + callOnExceptionCallback(e) + // ignore + return true + } + } + } + + override def write(): Boolean = { + try { + while (true) { + if (currentBuffers.size == 0) { + outbox.synchronized { + outbox.getChunk() match { + case Some(chunk) => { + val buffers = chunk.buffers + // If we have 'seen' pending messages, then reset flag - since we handle that as normal + // registering of event (below) + if (needForceReregister && buffers.exists(_.remaining() > 0)) resetForceReregister() + currentBuffers ++= buffers + } + case None => { + // changeConnectionKeyInterest(0) + /*key.interestOps(0)*/ + return false + } + } + } + } + + if (currentBuffers.size > 0) { + val buffer = currentBuffers(0) + val remainingBytes = buffer.remaining + val writtenBytes = channel.write(buffer) + if (buffer.remaining == 0) { + currentBuffers -= buffer + } + if (writtenBytes < remainingBytes) { + // re-register for write. + return true + } + } + } + } catch { + case e: Exception => { + logWarning("Error writing in connection to " + getRemoteConnectionManagerId(), e) + callOnExceptionCallback(e) + close() + return false + } + } + // should not happen - to keep scala compiler happy + return true + } + + // This is a hack to determine if remote socket was closed or not. + // SendingConnection DOES NOT expect to receive any data - if it does, it is an error + // For a bunch of cases, read will return -1 in case remote socket is closed : hence we + // register for reads to determine that. + override def read(): Boolean = { + // We don't expect the other side to send anything; so, we just read to detect an error or EOF. + try { + val length = channel.read(ByteBuffer.allocate(1)) + if (length == -1) { // EOF + close() + } else if (length > 0) { + logWarning( + "Unexpected data read from SendingConnection to " + getRemoteConnectionManagerId()) + } + } catch { + case e: Exception => + logError("Exception while reading SendingConnection to " + getRemoteConnectionManagerId(), e) + callOnExceptionCallback(e) + close() + } + + false + } + + override def changeInterestForRead(): Boolean = false + + override def changeInterestForWrite(): Boolean = ! isClosed +} + + +// Must be created within selector loop - else deadlock +private[spark] class ReceivingConnection(channel_ : SocketChannel, selector_ : Selector) + extends Connection(channel_, selector_) { + + class Inbox() { + val messages = new HashMap[Int, BufferMessage]() + + def getChunk(header: MessageChunkHeader): Option[MessageChunk] = { + + def createNewMessage: BufferMessage = { + val newMessage = Message.create(header).asInstanceOf[BufferMessage] + newMessage.started = true + newMessage.startTime = System.currentTimeMillis + logDebug( + "Starting to receive [" + newMessage + "] from [" + getRemoteConnectionManagerId() + "]") + messages += ((newMessage.id, newMessage)) + newMessage + } + + val message = messages.getOrElseUpdate(header.id, createNewMessage) + logTrace( + "Receiving chunk of [" + message + "] from [" + getRemoteConnectionManagerId() + "]") + message.getChunkForReceiving(header.chunkSize) + } + + def getMessageForChunk(chunk: MessageChunk): Option[BufferMessage] = { + messages.get(chunk.header.id) + } + + def removeMessage(message: Message) { + messages -= message.id + } + } + + @volatile private var inferredRemoteManagerId: ConnectionManagerId = null + + override def getRemoteConnectionManagerId(): ConnectionManagerId = { + val currId = inferredRemoteManagerId + if (currId != null) currId else super.getRemoteConnectionManagerId() + } + + // The reciever's remote address is the local socket on remote side : which is NOT + // the connection manager id of the receiver. + // We infer that from the messages we receive on the receiver socket. + private def processConnectionManagerId(header: MessageChunkHeader) { + val currId = inferredRemoteManagerId + if (header.address == null || currId != null) return + + val managerId = ConnectionManagerId.fromSocketAddress(header.address) + + if (managerId != null) { + inferredRemoteManagerId = managerId + } + } + + + val inbox = new Inbox() + val headerBuffer: ByteBuffer = ByteBuffer.allocate(MessageChunkHeader.HEADER_SIZE) + var onReceiveCallback: (Connection , Message) => Unit = null + var currentChunk: MessageChunk = null + + channel.register(selector, SelectionKey.OP_READ) + + override def read(): Boolean = { + try { + while (true) { + if (currentChunk == null) { + val headerBytesRead = channel.read(headerBuffer) + if (headerBytesRead == -1) { + close() + return false + } + if (headerBuffer.remaining > 0) { + // re-register for read event ... + return true + } + headerBuffer.flip + if (headerBuffer.remaining != MessageChunkHeader.HEADER_SIZE) { + throw new Exception( + "Unexpected number of bytes (" + headerBuffer.remaining + ") in the header") + } + val header = MessageChunkHeader.create(headerBuffer) + headerBuffer.clear() + + processConnectionManagerId(header) + + header.typ match { + case Message.BUFFER_MESSAGE => { + if (header.totalSize == 0) { + if (onReceiveCallback != null) { + onReceiveCallback(this, Message.create(header)) + } + currentChunk = null + // re-register for read event ... + return true + } else { + currentChunk = inbox.getChunk(header).orNull + } + } + case _ => throw new Exception("Message of unknown type received") + } + } + + if (currentChunk == null) throw new Exception("No message chunk to receive data") + + val bytesRead = channel.read(currentChunk.buffer) + if (bytesRead == 0) { + // re-register for read event ... + return true + } else if (bytesRead == -1) { + close() + return false + } + + /*logDebug("Read " + bytesRead + " bytes for the buffer")*/ + + if (currentChunk.buffer.remaining == 0) { + /*println("Filled buffer at " + System.currentTimeMillis)*/ + val bufferMessage = inbox.getMessageForChunk(currentChunk).get + if (bufferMessage.isCompletelyReceived) { + bufferMessage.flip + bufferMessage.finishTime = System.currentTimeMillis + logDebug("Finished receiving [" + bufferMessage + "] from " + + "[" + getRemoteConnectionManagerId() + "] in " + bufferMessage.timeTaken) + if (onReceiveCallback != null) { + onReceiveCallback(this, bufferMessage) + } + inbox.removeMessage(bufferMessage) + } + currentChunk = null + } + } + } catch { + case e: Exception => { + logWarning("Error reading from connection to " + getRemoteConnectionManagerId(), e) + callOnExceptionCallback(e) + close() + return false + } + } + // should not happen - to keep scala compiler happy + return true + } + + def onReceive(callback: (Connection, Message) => Unit) {onReceiveCallback = callback} + + // override def changeInterestForRead(): Boolean = ! isClosed + override def changeInterestForRead(): Boolean = true + + override def changeInterestForWrite(): Boolean = { + throw new IllegalStateException("Unexpected invocation right now") + } + + override def registerInterest() { + // Registering read too - does not really help in most cases, but for some + // it does - so let us keep it for now. + changeConnectionKeyInterest(SelectionKey.OP_READ) + } + + override def unregisterInterest() { + changeConnectionKeyInterest(0) + } + + // For read conn, always false. + override def resetForceReregister(): Boolean = false +} diff --git a/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala b/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala new file mode 100644 index 0000000000..c24fd48c04 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala @@ -0,0 +1,721 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network + +import org.apache.spark._ + +import java.nio._ +import java.nio.channels._ +import java.nio.channels.spi._ +import java.net._ +import java.util.concurrent.{LinkedBlockingDeque, TimeUnit, ThreadPoolExecutor} + +import scala.collection.mutable.HashSet +import scala.collection.mutable.HashMap +import scala.collection.mutable.SynchronizedMap +import scala.collection.mutable.SynchronizedQueue +import scala.collection.mutable.ArrayBuffer + +import scala.concurrent.{Await, Promise, ExecutionContext, Future} +import scala.concurrent.duration.Duration +import scala.concurrent.duration._ + +import org.apache.spark.util.Utils + +private[spark] class ConnectionManager(port: Int) extends Logging { + + class MessageStatus( + val message: Message, + val connectionManagerId: ConnectionManagerId, + completionHandler: MessageStatus => Unit) { + + var ackMessage: Option[Message] = None + var attempted = false + var acked = false + + def markDone() { completionHandler(this) } + } + + private val selector = SelectorProvider.provider.openSelector() + + private val handleMessageExecutor = new ThreadPoolExecutor( + System.getProperty("spark.core.connection.handler.threads.min","20").toInt, + System.getProperty("spark.core.connection.handler.threads.max","60").toInt, + System.getProperty("spark.core.connection.handler.threads.keepalive","60").toInt, TimeUnit.SECONDS, + new LinkedBlockingDeque[Runnable]()) + + private val handleReadWriteExecutor = new ThreadPoolExecutor( + System.getProperty("spark.core.connection.io.threads.min","4").toInt, + System.getProperty("spark.core.connection.io.threads.max","32").toInt, + System.getProperty("spark.core.connection.io.threads.keepalive","60").toInt, TimeUnit.SECONDS, + new LinkedBlockingDeque[Runnable]()) + + // Use a different, yet smaller, thread pool - infrequently used with very short lived tasks : which should be executed asap + private val handleConnectExecutor = new ThreadPoolExecutor( + System.getProperty("spark.core.connection.connect.threads.min","1").toInt, + System.getProperty("spark.core.connection.connect.threads.max","8").toInt, + System.getProperty("spark.core.connection.connect.threads.keepalive","60").toInt, TimeUnit.SECONDS, + new LinkedBlockingDeque[Runnable]()) + + private val serverChannel = ServerSocketChannel.open() + private val connectionsByKey = new HashMap[SelectionKey, Connection] with SynchronizedMap[SelectionKey, Connection] + private val connectionsById = new HashMap[ConnectionManagerId, SendingConnection] with SynchronizedMap[ConnectionManagerId, SendingConnection] + private val messageStatuses = new HashMap[Int, MessageStatus] + private val keyInterestChangeRequests = new SynchronizedQueue[(SelectionKey, Int)] + private val registerRequests = new SynchronizedQueue[SendingConnection] + + implicit val futureExecContext = ExecutionContext.fromExecutor(Utils.newDaemonCachedThreadPool()) + + private var onReceiveCallback: (BufferMessage, ConnectionManagerId) => Option[Message]= null + + serverChannel.configureBlocking(false) + serverChannel.socket.setReuseAddress(true) + serverChannel.socket.setReceiveBufferSize(256 * 1024) + + serverChannel.socket.bind(new InetSocketAddress(port)) + serverChannel.register(selector, SelectionKey.OP_ACCEPT) + + val id = new ConnectionManagerId(Utils.localHostName, serverChannel.socket.getLocalPort) + logInfo("Bound socket to port " + serverChannel.socket.getLocalPort() + " with id = " + id) + + private val selectorThread = new Thread("connection-manager-thread") { + override def run() = ConnectionManager.this.run() + } + selectorThread.setDaemon(true) + selectorThread.start() + + private val writeRunnableStarted: HashSet[SelectionKey] = new HashSet[SelectionKey]() + + private def triggerWrite(key: SelectionKey) { + val conn = connectionsByKey.getOrElse(key, null) + if (conn == null) return + + writeRunnableStarted.synchronized { + // So that we do not trigger more write events while processing this one. + // The write method will re-register when done. + if (conn.changeInterestForWrite()) conn.unregisterInterest() + if (writeRunnableStarted.contains(key)) { + // key.interestOps(key.interestOps() & ~ SelectionKey.OP_WRITE) + return + } + + writeRunnableStarted += key + } + handleReadWriteExecutor.execute(new Runnable { + override def run() { + var register: Boolean = false + try { + register = conn.write() + } finally { + writeRunnableStarted.synchronized { + writeRunnableStarted -= key + val needReregister = register || conn.resetForceReregister() + if (needReregister && conn.changeInterestForWrite()) { + conn.registerInterest() + } + } + } + } + } ) + } + + private val readRunnableStarted: HashSet[SelectionKey] = new HashSet[SelectionKey]() + + private def triggerRead(key: SelectionKey) { + val conn = connectionsByKey.getOrElse(key, null) + if (conn == null) return + + readRunnableStarted.synchronized { + // So that we do not trigger more read events while processing this one. + // The read method will re-register when done. + if (conn.changeInterestForRead())conn.unregisterInterest() + if (readRunnableStarted.contains(key)) { + return + } + + readRunnableStarted += key + } + handleReadWriteExecutor.execute(new Runnable { + override def run() { + var register: Boolean = false + try { + register = conn.read() + } finally { + readRunnableStarted.synchronized { + readRunnableStarted -= key + if (register && conn.changeInterestForRead()) { + conn.registerInterest() + } + } + } + } + } ) + } + + private def triggerConnect(key: SelectionKey) { + val conn = connectionsByKey.getOrElse(key, null).asInstanceOf[SendingConnection] + if (conn == null) return + + // prevent other events from being triggered + // Since we are still trying to connect, we do not need to do the additional steps in triggerWrite + conn.changeConnectionKeyInterest(0) + + handleConnectExecutor.execute(new Runnable { + override def run() { + + var tries: Int = 10 + while (tries >= 0) { + if (conn.finishConnect(false)) return + // Sleep ? + Thread.sleep(1) + tries -= 1 + } + + // fallback to previous behavior : we should not really come here since this method was + // triggered since channel became connectable : but at times, the first finishConnect need not + // succeed : hence the loop to retry a few 'times'. + conn.finishConnect(true) + } + } ) + } + + // MUST be called within selector loop - else deadlock. + private def triggerForceCloseByException(key: SelectionKey, e: Exception) { + try { + key.interestOps(0) + } catch { + // ignore exceptions + case e: Exception => logDebug("Ignoring exception", e) + } + + val conn = connectionsByKey.getOrElse(key, null) + if (conn == null) return + + // Pushing to connect threadpool + handleConnectExecutor.execute(new Runnable { + override def run() { + try { + conn.callOnExceptionCallback(e) + } catch { + // ignore exceptions + case e: Exception => logDebug("Ignoring exception", e) + } + try { + conn.close() + } catch { + // ignore exceptions + case e: Exception => logDebug("Ignoring exception", e) + } + } + }) + } + + + def run() { + try { + while(!selectorThread.isInterrupted) { + while (! registerRequests.isEmpty) { + val conn: SendingConnection = registerRequests.dequeue + addListeners(conn) + conn.connect() + addConnection(conn) + } + + while(!keyInterestChangeRequests.isEmpty) { + val (key, ops) = keyInterestChangeRequests.dequeue + + try { + if (key.isValid) { + val connection = connectionsByKey.getOrElse(key, null) + if (connection != null) { + val lastOps = key.interestOps() + key.interestOps(ops) + + // hot loop - prevent materialization of string if trace not enabled. + if (isTraceEnabled()) { + def intToOpStr(op: Int): String = { + val opStrs = ArrayBuffer[String]() + if ((op & SelectionKey.OP_READ) != 0) opStrs += "READ" + if ((op & SelectionKey.OP_WRITE) != 0) opStrs += "WRITE" + if ((op & SelectionKey.OP_CONNECT) != 0) opStrs += "CONNECT" + if ((op & SelectionKey.OP_ACCEPT) != 0) opStrs += "ACCEPT" + if (opStrs.size > 0) opStrs.reduceLeft(_ + " | " + _) else " " + } + + logTrace("Changed key for connection to [" + connection.getRemoteConnectionManagerId() + + "] changed from [" + intToOpStr(lastOps) + "] to [" + intToOpStr(ops) + "]") + } + } + } else { + logInfo("Key not valid ? " + key) + throw new CancelledKeyException() + } + } catch { + case e: CancelledKeyException => { + logInfo("key already cancelled ? " + key, e) + triggerForceCloseByException(key, e) + } + case e: Exception => { + logError("Exception processing key " + key, e) + triggerForceCloseByException(key, e) + } + } + } + + val selectedKeysCount = + try { + selector.select() + } catch { + // Explicitly only dealing with CancelledKeyException here since other exceptions should be dealt with differently. + case e: CancelledKeyException => { + // Some keys within the selectors list are invalid/closed. clear them. + val allKeys = selector.keys().iterator() + + while (allKeys.hasNext()) { + val key = allKeys.next() + try { + if (! key.isValid) { + logInfo("Key not valid ? " + key) + throw new CancelledKeyException() + } + } catch { + case e: CancelledKeyException => { + logInfo("key already cancelled ? " + key, e) + triggerForceCloseByException(key, e) + } + case e: Exception => { + logError("Exception processing key " + key, e) + triggerForceCloseByException(key, e) + } + } + } + } + 0 + } + + if (selectedKeysCount == 0) { + logDebug("Selector selected " + selectedKeysCount + " of " + selector.keys.size + " keys") + } + if (selectorThread.isInterrupted) { + logInfo("Selector thread was interrupted!") + return + } + + if (0 != selectedKeysCount) { + val selectedKeys = selector.selectedKeys().iterator() + while (selectedKeys.hasNext()) { + val key = selectedKeys.next + selectedKeys.remove() + try { + if (key.isValid) { + if (key.isAcceptable) { + acceptConnection(key) + } else + if (key.isConnectable) { + triggerConnect(key) + } else + if (key.isReadable) { + triggerRead(key) + } else + if (key.isWritable) { + triggerWrite(key) + } + } else { + logInfo("Key not valid ? " + key) + throw new CancelledKeyException() + } + } catch { + // weird, but we saw this happening - even though key.isValid was true, key.isAcceptable would throw CancelledKeyException. + case e: CancelledKeyException => { + logInfo("key already cancelled ? " + key, e) + triggerForceCloseByException(key, e) + } + case e: Exception => { + logError("Exception processing key " + key, e) + triggerForceCloseByException(key, e) + } + } + } + } + } + } catch { + case e: Exception => logError("Error in select loop", e) + } + } + + def acceptConnection(key: SelectionKey) { + val serverChannel = key.channel.asInstanceOf[ServerSocketChannel] + + var newChannel = serverChannel.accept() + + // accept them all in a tight loop. non blocking accept with no processing, should be fine + while (newChannel != null) { + try { + val newConnection = new ReceivingConnection(newChannel, selector) + newConnection.onReceive(receiveMessage) + addListeners(newConnection) + addConnection(newConnection) + logInfo("Accepted connection from [" + newConnection.remoteAddress.getAddress + "]") + } catch { + // might happen in case of issues with registering with selector + case e: Exception => logError("Error in accept loop", e) + } + + newChannel = serverChannel.accept() + } + } + + private def addListeners(connection: Connection) { + connection.onKeyInterestChange(changeConnectionKeyInterest) + connection.onException(handleConnectionError) + connection.onClose(removeConnection) + } + + def addConnection(connection: Connection) { + connectionsByKey += ((connection.key, connection)) + } + + def removeConnection(connection: Connection) { + connectionsByKey -= connection.key + + try { + if (connection.isInstanceOf[SendingConnection]) { + val sendingConnection = connection.asInstanceOf[SendingConnection] + val sendingConnectionManagerId = sendingConnection.getRemoteConnectionManagerId() + logInfo("Removing SendingConnection to " + sendingConnectionManagerId) + + connectionsById -= sendingConnectionManagerId + + messageStatuses.synchronized { + messageStatuses + .values.filter(_.connectionManagerId == sendingConnectionManagerId).foreach(status => { + logInfo("Notifying " + status) + status.synchronized { + status.attempted = true + status.acked = false + status.markDone() + } + }) + + messageStatuses.retain((i, status) => { + status.connectionManagerId != sendingConnectionManagerId + }) + } + } else if (connection.isInstanceOf[ReceivingConnection]) { + val receivingConnection = connection.asInstanceOf[ReceivingConnection] + val remoteConnectionManagerId = receivingConnection.getRemoteConnectionManagerId() + logInfo("Removing ReceivingConnection to " + remoteConnectionManagerId) + + val sendingConnectionOpt = connectionsById.get(remoteConnectionManagerId) + if (! sendingConnectionOpt.isDefined) { + logError("Corresponding SendingConnectionManagerId not found") + return + } + + val sendingConnection = sendingConnectionOpt.get + connectionsById -= remoteConnectionManagerId + sendingConnection.close() + + val sendingConnectionManagerId = sendingConnection.getRemoteConnectionManagerId() + + assert (sendingConnectionManagerId == remoteConnectionManagerId) + + messageStatuses.synchronized { + for (s <- messageStatuses.values if s.connectionManagerId == sendingConnectionManagerId) { + logInfo("Notifying " + s) + s.synchronized { + s.attempted = true + s.acked = false + s.markDone() + } + } + + messageStatuses.retain((i, status) => { + status.connectionManagerId != sendingConnectionManagerId + }) + } + } + } finally { + // So that the selection keys can be removed. + wakeupSelector() + } + } + + def handleConnectionError(connection: Connection, e: Exception) { + logInfo("Handling connection error on connection to " + connection.getRemoteConnectionManagerId()) + removeConnection(connection) + } + + def changeConnectionKeyInterest(connection: Connection, ops: Int) { + keyInterestChangeRequests += ((connection.key, ops)) + // so that registerations happen ! + wakeupSelector() + } + + def receiveMessage(connection: Connection, message: Message) { + val connectionManagerId = ConnectionManagerId.fromSocketAddress(message.senderAddress) + logDebug("Received [" + message + "] from [" + connectionManagerId + "]") + val runnable = new Runnable() { + val creationTime = System.currentTimeMillis + def run() { + logDebug("Handler thread delay is " + (System.currentTimeMillis - creationTime) + " ms") + handleMessage(connectionManagerId, message) + logDebug("Handling delay is " + (System.currentTimeMillis - creationTime) + " ms") + } + } + handleMessageExecutor.execute(runnable) + /*handleMessage(connection, message)*/ + } + + private def handleMessage(connectionManagerId: ConnectionManagerId, message: Message) { + logDebug("Handling [" + message + "] from [" + connectionManagerId + "]") + message match { + case bufferMessage: BufferMessage => { + if (bufferMessage.hasAckId) { + val sentMessageStatus = messageStatuses.synchronized { + messageStatuses.get(bufferMessage.ackId) match { + case Some(status) => { + messageStatuses -= bufferMessage.ackId + status + } + case None => { + throw new Exception("Could not find reference for received ack message " + message.id) + null + } + } + } + sentMessageStatus.synchronized { + sentMessageStatus.ackMessage = Some(message) + sentMessageStatus.attempted = true + sentMessageStatus.acked = true + sentMessageStatus.markDone() + } + } else { + val ackMessage = if (onReceiveCallback != null) { + logDebug("Calling back") + onReceiveCallback(bufferMessage, connectionManagerId) + } else { + logDebug("Not calling back as callback is null") + None + } + + if (ackMessage.isDefined) { + if (!ackMessage.get.isInstanceOf[BufferMessage]) { + logDebug("Response to " + bufferMessage + " is not a buffer message, it is of type " + ackMessage.get.getClass()) + } else if (!ackMessage.get.asInstanceOf[BufferMessage].hasAckId) { + logDebug("Response to " + bufferMessage + " does not have ack id set") + ackMessage.get.asInstanceOf[BufferMessage].ackId = bufferMessage.id + } + } + + sendMessage(connectionManagerId, ackMessage.getOrElse { + Message.createBufferMessage(bufferMessage.id) + }) + } + } + case _ => throw new Exception("Unknown type message received") + } + } + + private def sendMessage(connectionManagerId: ConnectionManagerId, message: Message) { + def startNewConnection(): SendingConnection = { + val inetSocketAddress = new InetSocketAddress(connectionManagerId.host, connectionManagerId.port) + val newConnection = new SendingConnection(inetSocketAddress, selector, connectionManagerId) + registerRequests.enqueue(newConnection) + + newConnection + } + // I removed the lookupKey stuff as part of merge ... should I re-add it ? We did not find it useful in our test-env ... + // If we do re-add it, we should consistently use it everywhere I guess ? + val connection = connectionsById.getOrElseUpdate(connectionManagerId, startNewConnection()) + message.senderAddress = id.toSocketAddress() + logDebug("Sending [" + message + "] to [" + connectionManagerId + "]") + connection.send(message) + + wakeupSelector() + } + + private def wakeupSelector() { + selector.wakeup() + } + + def sendMessageReliably(connectionManagerId: ConnectionManagerId, message: Message) + : Future[Option[Message]] = { + val promise = Promise[Option[Message]] + val status = new MessageStatus(message, connectionManagerId, s => promise.success(s.ackMessage)) + messageStatuses.synchronized { + messageStatuses += ((message.id, status)) + } + sendMessage(connectionManagerId, message) + promise.future + } + + def sendMessageReliablySync(connectionManagerId: ConnectionManagerId, message: Message): Option[Message] = { + Await.result(sendMessageReliably(connectionManagerId, message), Duration.Inf) + } + + def onReceiveMessage(callback: (Message, ConnectionManagerId) => Option[Message]) { + onReceiveCallback = callback + } + + def stop() { + selectorThread.interrupt() + selectorThread.join() + selector.close() + val connections = connectionsByKey.values + connections.foreach(_.close()) + if (connectionsByKey.size != 0) { + logWarning("All connections not cleaned up") + } + handleMessageExecutor.shutdown() + handleReadWriteExecutor.shutdown() + handleConnectExecutor.shutdown() + logInfo("ConnectionManager stopped") + } +} + + +private[spark] object ConnectionManager { + + def main(args: Array[String]) { + val manager = new ConnectionManager(9999) + manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { + println("Received [" + msg + "] from [" + id + "]") + None + }) + + /*testSequentialSending(manager)*/ + /*System.gc()*/ + + /*testParallelSending(manager)*/ + /*System.gc()*/ + + /*testParallelDecreasingSending(manager)*/ + /*System.gc()*/ + + testContinuousSending(manager) + System.gc() + } + + def testSequentialSending(manager: ConnectionManager) { + println("--------------------------") + println("Sequential Sending") + println("--------------------------") + val size = 10 * 1024 * 1024 + val count = 10 + + val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte)) + buffer.flip + + (0 until count).map(i => { + val bufferMessage = Message.createBufferMessage(buffer.duplicate) + manager.sendMessageReliablySync(manager.id, bufferMessage) + }) + println("--------------------------") + println() + } + + def testParallelSending(manager: ConnectionManager) { + println("--------------------------") + println("Parallel Sending") + println("--------------------------") + val size = 10 * 1024 * 1024 + val count = 10 + + val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte)) + buffer.flip + + val startTime = System.currentTimeMillis + (0 until count).map(i => { + val bufferMessage = Message.createBufferMessage(buffer.duplicate) + manager.sendMessageReliably(manager.id, bufferMessage) + }).foreach(f => { + val g = Await.result(f, 1 second) + if (!g.isDefined) println("Failed") + }) + val finishTime = System.currentTimeMillis + + val mb = size * count / 1024.0 / 1024.0 + val ms = finishTime - startTime + val tput = mb * 1000.0 / ms + println("--------------------------") + println("Started at " + startTime + ", finished at " + finishTime) + println("Sent " + count + " messages of size " + size + " in " + ms + " ms (" + tput + " MB/s)") + println("--------------------------") + println() + } + + def testParallelDecreasingSending(manager: ConnectionManager) { + println("--------------------------") + println("Parallel Decreasing Sending") + println("--------------------------") + val size = 10 * 1024 * 1024 + val count = 10 + val buffers = Array.tabulate(count)(i => ByteBuffer.allocate(size * (i + 1)).put(Array.tabulate[Byte](size * (i + 1))(x => x.toByte))) + buffers.foreach(_.flip) + val mb = buffers.map(_.remaining).reduceLeft(_ + _) / 1024.0 / 1024.0 + + val startTime = System.currentTimeMillis + (0 until count).map(i => { + val bufferMessage = Message.createBufferMessage(buffers(count - 1 - i).duplicate) + manager.sendMessageReliably(manager.id, bufferMessage) + }).foreach(f => { + val g = Await.result(f, 1 second) + if (!g.isDefined) println("Failed") + }) + val finishTime = System.currentTimeMillis + + val ms = finishTime - startTime + val tput = mb * 1000.0 / ms + println("--------------------------") + /*println("Started at " + startTime + ", finished at " + finishTime) */ + println("Sent " + mb + " MB in " + ms + " ms (" + tput + " MB/s)") + println("--------------------------") + println() + } + + def testContinuousSending(manager: ConnectionManager) { + println("--------------------------") + println("Continuous Sending") + println("--------------------------") + val size = 10 * 1024 * 1024 + val count = 10 + + val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte)) + buffer.flip + + val startTime = System.currentTimeMillis + while(true) { + (0 until count).map(i => { + val bufferMessage = Message.createBufferMessage(buffer.duplicate) + manager.sendMessageReliably(manager.id, bufferMessage) + }).foreach(f => { + val g = Await.result(f, 1 second) + if (!g.isDefined) println("Failed") + }) + val finishTime = System.currentTimeMillis + Thread.sleep(1000) + val mb = size * count / 1024.0 / 1024.0 + val ms = finishTime - startTime + val tput = mb * 1000.0 / ms + println("Sent " + mb + " MB in " + ms + " ms (" + tput + " MB/s)") + println("--------------------------") + println() + } + } +} diff --git a/core/src/main/scala/org/apache/spark/network/ConnectionManagerId.scala b/core/src/main/scala/org/apache/spark/network/ConnectionManagerId.scala new file mode 100644 index 0000000000..50dd9bc2d1 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/network/ConnectionManagerId.scala @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network + +import java.net.InetSocketAddress + +import org.apache.spark.util.Utils + + +private[spark] case class ConnectionManagerId(host: String, port: Int) { + // DEBUG code + Utils.checkHost(host) + assert (port > 0) + + def toSocketAddress() = new InetSocketAddress(host, port) +} + + +private[spark] object ConnectionManagerId { + def fromSocketAddress(socketAddress: InetSocketAddress): ConnectionManagerId = { + new ConnectionManagerId(socketAddress.getHostName(), socketAddress.getPort()) + } +} diff --git a/core/src/main/scala/org/apache/spark/network/ConnectionManagerTest.scala b/core/src/main/scala/org/apache/spark/network/ConnectionManagerTest.scala new file mode 100644 index 0000000000..4f5742d29b --- /dev/null +++ b/core/src/main/scala/org/apache/spark/network/ConnectionManagerTest.scala @@ -0,0 +1,102 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network + +import org.apache.spark._ +import org.apache.spark.SparkContext._ + +import scala.io.Source + +import java.nio.ByteBuffer +import java.net.InetAddress + +import scala.concurrent.Await +import scala.concurrent.duration._ + +private[spark] object ConnectionManagerTest extends Logging{ + def main(args: Array[String]) { + //<mesos cluster> - the master URL + //<slaves file> - a list slaves to run connectionTest on + //[num of tasks] - the number of parallel tasks to be initiated default is number of slave hosts + //[size of msg in MB (integer)] - the size of messages to be sent in each task, default is 10 + //[count] - how many times to run, default is 3 + //[await time in seconds] : await time (in seconds), default is 600 + if (args.length < 2) { + println("Usage: ConnectionManagerTest <mesos cluster> <slaves file> [num of tasks] [size of msg in MB (integer)] [count] [await time in seconds)] ") + System.exit(1) + } + + if (args(0).startsWith("local")) { + println("This runs only on a mesos cluster") + } + + val sc = new SparkContext(args(0), "ConnectionManagerTest") + val slavesFile = Source.fromFile(args(1)) + val slaves = slavesFile.mkString.split("\n") + slavesFile.close() + + /*println("Slaves")*/ + /*slaves.foreach(println)*/ + val tasknum = if (args.length > 2) args(2).toInt else slaves.length + val size = ( if (args.length > 3) (args(3).toInt) else 10 ) * 1024 * 1024 + val count = if (args.length > 4) args(4).toInt else 3 + val awaitTime = (if (args.length > 5) args(5).toInt else 600 ).second + println("Running "+count+" rounds of test: " + "parallel tasks = " + tasknum + ", msg size = " + size/1024/1024 + " MB, awaitTime = " + awaitTime) + val slaveConnManagerIds = sc.parallelize(0 until tasknum, tasknum).map( + i => SparkEnv.get.connectionManager.id).collect() + println("\nSlave ConnectionManagerIds") + slaveConnManagerIds.foreach(println) + println + + (0 until count).foreach(i => { + val resultStrs = sc.parallelize(0 until tasknum, tasknum).map(i => { + val connManager = SparkEnv.get.connectionManager + val thisConnManagerId = connManager.id + connManager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { + logInfo("Received [" + msg + "] from [" + id + "]") + None + }) + + val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte)) + buffer.flip + + val startTime = System.currentTimeMillis + val futures = slaveConnManagerIds.filter(_ != thisConnManagerId).map(slaveConnManagerId => { + val bufferMessage = Message.createBufferMessage(buffer.duplicate) + logInfo("Sending [" + bufferMessage + "] to [" + slaveConnManagerId + "]") + connManager.sendMessageReliably(slaveConnManagerId, bufferMessage) + }) + val results = futures.map(f => Await.result(f, awaitTime)) + val finishTime = System.currentTimeMillis + Thread.sleep(5000) + + val mb = size * results.size / 1024.0 / 1024.0 + val ms = finishTime - startTime + val resultStr = thisConnManagerId + " Sent " + mb + " MB in " + ms + " ms at " + (mb / ms * 1000.0) + " MB/s" + logInfo(resultStr) + resultStr + }).collect() + + println("---------------------") + println("Run " + i) + resultStrs.foreach(println) + println("---------------------") + }) + } +} + diff --git a/core/src/main/scala/org/apache/spark/network/Message.scala b/core/src/main/scala/org/apache/spark/network/Message.scala new file mode 100644 index 0000000000..f2ecc6d439 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/network/Message.scala @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network + +import java.nio.ByteBuffer +import java.net.InetSocketAddress + +import scala.collection.mutable.ArrayBuffer + + +private[spark] abstract class Message(val typ: Long, val id: Int) { + var senderAddress: InetSocketAddress = null + var started = false + var startTime = -1L + var finishTime = -1L + + def size: Int + + def getChunkForSending(maxChunkSize: Int): Option[MessageChunk] + + def getChunkForReceiving(chunkSize: Int): Option[MessageChunk] + + def timeTaken(): String = (finishTime - startTime).toString + " ms" + + override def toString = this.getClass.getSimpleName + "(id = " + id + ", size = " + size + ")" +} + + +private[spark] object Message { + val BUFFER_MESSAGE = 1111111111L + + var lastId = 1 + + def getNewId() = synchronized { + lastId += 1 + if (lastId == 0) { + lastId += 1 + } + lastId + } + + def createBufferMessage(dataBuffers: Seq[ByteBuffer], ackId: Int): BufferMessage = { + if (dataBuffers == null) { + return new BufferMessage(getNewId(), new ArrayBuffer[ByteBuffer], ackId) + } + if (dataBuffers.exists(_ == null)) { + throw new Exception("Attempting to create buffer message with null buffer") + } + return new BufferMessage(getNewId(), new ArrayBuffer[ByteBuffer] ++= dataBuffers, ackId) + } + + def createBufferMessage(dataBuffers: Seq[ByteBuffer]): BufferMessage = + createBufferMessage(dataBuffers, 0) + + def createBufferMessage(dataBuffer: ByteBuffer, ackId: Int): BufferMessage = { + if (dataBuffer == null) { + return createBufferMessage(Array(ByteBuffer.allocate(0)), ackId) + } else { + return createBufferMessage(Array(dataBuffer), ackId) + } + } + + def createBufferMessage(dataBuffer: ByteBuffer): BufferMessage = + createBufferMessage(dataBuffer, 0) + + def createBufferMessage(ackId: Int): BufferMessage = { + createBufferMessage(new Array[ByteBuffer](0), ackId) + } + + def create(header: MessageChunkHeader): Message = { + val newMessage: Message = header.typ match { + case BUFFER_MESSAGE => new BufferMessage(header.id, + ArrayBuffer(ByteBuffer.allocate(header.totalSize)), header.other) + } + newMessage.senderAddress = header.address + newMessage + } +} diff --git a/core/src/main/scala/org/apache/spark/network/MessageChunk.scala b/core/src/main/scala/org/apache/spark/network/MessageChunk.scala new file mode 100644 index 0000000000..e0fe57b80d --- /dev/null +++ b/core/src/main/scala/org/apache/spark/network/MessageChunk.scala @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network + +import java.nio.ByteBuffer + +import scala.collection.mutable.ArrayBuffer + + +private[network] +class MessageChunk(val header: MessageChunkHeader, val buffer: ByteBuffer) { + + val size = if (buffer == null) 0 else buffer.remaining + + lazy val buffers = { + val ab = new ArrayBuffer[ByteBuffer]() + ab += header.buffer + if (buffer != null) { + ab += buffer + } + ab + } + + override def toString = { + "" + this.getClass.getSimpleName + " (id = " + header.id + ", size = " + size + ")" + } +} diff --git a/core/src/main/scala/org/apache/spark/network/MessageChunkHeader.scala b/core/src/main/scala/org/apache/spark/network/MessageChunkHeader.scala new file mode 100644 index 0000000000..235fbc39b3 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/network/MessageChunkHeader.scala @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network + +import java.net.InetAddress +import java.net.InetSocketAddress +import java.nio.ByteBuffer + + +private[spark] class MessageChunkHeader( + val typ: Long, + val id: Int, + val totalSize: Int, + val chunkSize: Int, + val other: Int, + val address: InetSocketAddress) { + lazy val buffer = { + // No need to change this, at 'use' time, we do a reverse lookup of the hostname. + // Refer to network.Connection + val ip = address.getAddress.getAddress() + val port = address.getPort() + ByteBuffer. + allocate(MessageChunkHeader.HEADER_SIZE). + putLong(typ). + putInt(id). + putInt(totalSize). + putInt(chunkSize). + putInt(other). + putInt(ip.size). + put(ip). + putInt(port). + position(MessageChunkHeader.HEADER_SIZE). + flip.asInstanceOf[ByteBuffer] + } + + override def toString = "" + this.getClass.getSimpleName + ":" + id + " of type " + typ + + " and sizes " + totalSize + " / " + chunkSize + " bytes" +} + + +private[spark] object MessageChunkHeader { + val HEADER_SIZE = 40 + + def create(buffer: ByteBuffer): MessageChunkHeader = { + if (buffer.remaining != HEADER_SIZE) { + throw new IllegalArgumentException("Cannot convert buffer data to Message") + } + val typ = buffer.getLong() + val id = buffer.getInt() + val totalSize = buffer.getInt() + val chunkSize = buffer.getInt() + val other = buffer.getInt() + val ipSize = buffer.getInt() + val ipBytes = new Array[Byte](ipSize) + buffer.get(ipBytes) + val ip = InetAddress.getByAddress(ipBytes) + val port = buffer.getInt() + new MessageChunkHeader(typ, id, totalSize, chunkSize, other, new InetSocketAddress(ip, port)) + } +} diff --git a/core/src/main/scala/org/apache/spark/network/ReceiverTest.scala b/core/src/main/scala/org/apache/spark/network/ReceiverTest.scala new file mode 100644 index 0000000000..781715108b --- /dev/null +++ b/core/src/main/scala/org/apache/spark/network/ReceiverTest.scala @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network + +import java.nio.ByteBuffer +import java.net.InetAddress + +private[spark] object ReceiverTest { + + def main(args: Array[String]) { + val manager = new ConnectionManager(9999) + println("Started connection manager with id = " + manager.id) + + manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { + /*println("Received [" + msg + "] from [" + id + "] at " + System.currentTimeMillis)*/ + val buffer = ByteBuffer.wrap("response".getBytes()) + Some(Message.createBufferMessage(buffer, msg.id)) + }) + Thread.currentThread.join() + } +} + diff --git a/core/src/main/scala/org/apache/spark/network/SenderTest.scala b/core/src/main/scala/org/apache/spark/network/SenderTest.scala new file mode 100644 index 0000000000..777574980f --- /dev/null +++ b/core/src/main/scala/org/apache/spark/network/SenderTest.scala @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network + +import java.nio.ByteBuffer +import java.net.InetAddress + +private[spark] object SenderTest { + + def main(args: Array[String]) { + + if (args.length < 2) { + println("Usage: SenderTest <target host> <target port>") + System.exit(1) + } + + val targetHost = args(0) + val targetPort = args(1).toInt + val targetConnectionManagerId = new ConnectionManagerId(targetHost, targetPort) + + val manager = new ConnectionManager(0) + println("Started connection manager with id = " + manager.id) + + manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { + println("Received [" + msg + "] from [" + id + "]") + None + }) + + val size = 100 * 1024 * 1024 + val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte)) + buffer.flip + + val targetServer = args(0) + + val count = 100 + (0 until count).foreach(i => { + val dataMessage = Message.createBufferMessage(buffer.duplicate) + val startTime = System.currentTimeMillis + /*println("Started timer at " + startTime)*/ + val responseStr = manager.sendMessageReliablySync(targetConnectionManagerId, dataMessage) match { + case Some(response) => + val buffer = response.asInstanceOf[BufferMessage].buffers(0) + new String(buffer.array) + case None => "none" + } + val finishTime = System.currentTimeMillis + val mb = size / 1024.0 / 1024.0 + val ms = finishTime - startTime + /*val resultStr = "Sent " + mb + " MB " + targetServer + " in " + ms + " ms at " + (mb / ms * 1000.0) + " MB/s"*/ + val resultStr = "Sent " + mb + " MB " + targetServer + " in " + ms + " ms (" + (mb / ms * 1000.0).toInt + "MB/s) | Response = " + responseStr + println(resultStr) + }) + } +} + diff --git a/core/src/main/scala/org/apache/spark/network/netty/FileHeader.scala b/core/src/main/scala/org/apache/spark/network/netty/FileHeader.scala new file mode 100644 index 0000000000..3c29700920 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/network/netty/FileHeader.scala @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.netty + +import io.netty.buffer._ + +import org.apache.spark.Logging + +private[spark] class FileHeader ( + val fileLen: Int, + val blockId: String) extends Logging { + + lazy val buffer = { + val buf = Unpooled.buffer() + buf.capacity(FileHeader.HEADER_SIZE) + buf.writeInt(fileLen) + buf.writeInt(blockId.length) + blockId.foreach((x: Char) => buf.writeByte(x)) + //padding the rest of header + if (FileHeader.HEADER_SIZE - buf.readableBytes > 0 ) { + buf.writeZero(FileHeader.HEADER_SIZE - buf.readableBytes) + } else { + throw new Exception("too long header " + buf.readableBytes) + logInfo("too long header") + } + buf + } + +} + +private[spark] object FileHeader { + + val HEADER_SIZE = 40 + + def getFileLenOffset = 0 + def getFileLenSize = Integer.SIZE/8 + + def create(buf: ByteBuf): FileHeader = { + val length = buf.readInt + val idLength = buf.readInt + val idBuilder = new StringBuilder(idLength) + for (i <- 1 to idLength) { + idBuilder += buf.readByte().asInstanceOf[Char] + } + val blockId = idBuilder.toString() + new FileHeader(length, blockId) + } + + + def main (args:Array[String]){ + + val header = new FileHeader(25,"block_0"); + val buf = header.buffer; + val newheader = FileHeader.create(buf); + System.out.println("id="+newheader.blockId+",size="+newheader.fileLen) + + } +} + diff --git a/core/src/main/scala/org/apache/spark/network/netty/ShuffleCopier.scala b/core/src/main/scala/org/apache/spark/network/netty/ShuffleCopier.scala new file mode 100644 index 0000000000..9493ccffd9 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/network/netty/ShuffleCopier.scala @@ -0,0 +1,118 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.netty + +import java.util.concurrent.Executors + +import io.netty.buffer.ByteBuf +import io.netty.channel.ChannelHandlerContext +import io.netty.util.CharsetUtil + +import org.apache.spark.Logging +import org.apache.spark.network.ConnectionManagerId + +import scala.collection.JavaConverters._ + + +private[spark] class ShuffleCopier extends Logging { + + def getBlock(host: String, port: Int, blockId: String, + resultCollectCallback: (String, Long, ByteBuf) => Unit) { + + val handler = new ShuffleCopier.ShuffleClientHandler(resultCollectCallback) + val connectTimeout = System.getProperty("spark.shuffle.netty.connect.timeout", "60000").toInt + val fc = new FileClient(handler, connectTimeout) + + try { + fc.init() + fc.connect(host, port) + fc.sendRequest(blockId) + fc.waitForClose() + fc.close() + } catch { + // Handle any socket-related exceptions in FileClient + case e: Exception => { + logError("Shuffle copy of block " + blockId + " from " + host + ":" + port + " failed", e) + handler.handleError(blockId) + } + } + } + + def getBlock(cmId: ConnectionManagerId, blockId: String, + resultCollectCallback: (String, Long, ByteBuf) => Unit) { + getBlock(cmId.host, cmId.port, blockId, resultCollectCallback) + } + + def getBlocks(cmId: ConnectionManagerId, + blocks: Seq[(String, Long)], + resultCollectCallback: (String, Long, ByteBuf) => Unit) { + + for ((blockId, size) <- blocks) { + getBlock(cmId, blockId, resultCollectCallback) + } + } +} + + +private[spark] object ShuffleCopier extends Logging { + + private class ShuffleClientHandler(resultCollectCallBack: (String, Long, ByteBuf) => Unit) + extends FileClientHandler with Logging { + + override def handle(ctx: ChannelHandlerContext, in: ByteBuf, header: FileHeader) { + logDebug("Received Block: " + header.blockId + " (" + header.fileLen + "B)"); + resultCollectCallBack(header.blockId, header.fileLen.toLong, in.readBytes(header.fileLen)) + } + + override def handleError(blockId: String) { + if (!isComplete) { + resultCollectCallBack(blockId, -1, null) + } + } + } + + def echoResultCollectCallBack(blockId: String, size: Long, content: ByteBuf) { + if (size != -1) { + logInfo("File: " + blockId + " content is : \" " + content.toString(CharsetUtil.UTF_8) + "\"") + } + } + + def main(args: Array[String]) { + if (args.length < 3) { + System.err.println("Usage: ShuffleCopier <host> <port> <shuffle_block_id> <threads>") + System.exit(1) + } + val host = args(0) + val port = args(1).toInt + val file = args(2) + val threads = if (args.length > 3) args(3).toInt else 10 + + val copiers = Executors.newFixedThreadPool(80) + val tasks = (for (i <- Range(0, threads)) yield { + Executors.callable(new Runnable() { + def run() { + val copier = new ShuffleCopier() + copier.getBlock(host, port, file, echoResultCollectCallBack) + } + }) + }).asJava + copiers.invokeAll(tasks) + copiers.shutdown + System.exit(0) + } +} diff --git a/core/src/main/scala/org/apache/spark/network/netty/ShuffleSender.scala b/core/src/main/scala/org/apache/spark/network/netty/ShuffleSender.scala new file mode 100644 index 0000000000..537f225469 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/network/netty/ShuffleSender.scala @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.netty + +import java.io.File + +import org.apache.spark.Logging + + +private[spark] class ShuffleSender(portIn: Int, val pResolver: PathResolver) extends Logging { + + val server = new FileServer(pResolver, portIn) + server.start() + + def stop() { + server.stop() + } + + def port: Int = server.getPort() +} + + +/** + * An application for testing the shuffle sender as a standalone program. + */ +private[spark] object ShuffleSender { + + def main(args: Array[String]) { + if (args.length < 3) { + System.err.println( + "Usage: ShuffleSender <port> <subDirsPerLocalDir> <list of shuffle_block_directories>") + System.exit(1) + } + + val port = args(0).toInt + val subDirsPerLocalDir = args(1).toInt + val localDirs = args.drop(2).map(new File(_)) + + val pResovler = new PathResolver { + override def getAbsolutePath(blockId: String): String = { + if (!blockId.startsWith("shuffle_")) { + throw new Exception("Block " + blockId + " is not a shuffle block") + } + // Figure out which local directory it hashes to, and which subdirectory in that + val hash = math.abs(blockId.hashCode) + val dirId = hash % localDirs.length + val subDirId = (hash / localDirs.length) % subDirsPerLocalDir + val subDir = new File(localDirs(dirId), "%02x".format(subDirId)) + val file = new File(subDir, blockId) + return file.getAbsolutePath + } + } + val sender = new ShuffleSender(port, pResovler) + } +} diff --git a/core/src/main/scala/org/apache/spark/package.scala b/core/src/main/scala/org/apache/spark/package.scala new file mode 100644 index 0000000000..f132e2b735 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/package.scala @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Core Spark functionality. [[org.apache.spark.SparkContext]] serves as the main entry point to + * Spark, while [[org.apache.spark.rdd.RDD]] is the data type representing a distributed collection, + * and provides most parallel operations. + * + * In addition, [[org.apache.spark.rdd.PairRDDFunctions]] contains operations available only on RDDs + * of key-value pairs, such as `groupByKey` and `join`; [[org.apache.spark.rdd.DoubleRDDFunctions]] + * contains operations available only on RDDs of Doubles; and + * [[org.apache.spark.rdd.SequenceFileRDDFunctions]] contains operations available on RDDs that can + * be saved as SequenceFiles. These operations are automatically available on any RDD of the right + * type (e.g. RDD[(Int, Int)] through implicit conversions when you + * `import org.apache.spark.SparkContext._`. + */ +package object spark { + // For package docs only +} diff --git a/core/src/main/scala/org/apache/spark/partial/ApproximateActionListener.scala b/core/src/main/scala/org/apache/spark/partial/ApproximateActionListener.scala new file mode 100644 index 0000000000..d71069444a --- /dev/null +++ b/core/src/main/scala/org/apache/spark/partial/ApproximateActionListener.scala @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.partial + +import org.apache.spark._ +import org.apache.spark.scheduler.JobListener +import org.apache.spark.rdd.RDD + +/** + * A JobListener for an approximate single-result action, such as count() or non-parallel reduce(). + * This listener waits up to timeout milliseconds and will return a partial answer even if the + * complete answer is not available by then. + * + * This class assumes that the action is performed on an entire RDD[T] via a function that computes + * a result of type U for each partition, and that the action returns a partial or complete result + * of type R. Note that the type R must *include* any error bars on it (e.g. see BoundedInt). + */ +private[spark] class ApproximateActionListener[T, U, R]( + rdd: RDD[T], + func: (TaskContext, Iterator[T]) => U, + evaluator: ApproximateEvaluator[U, R], + timeout: Long) + extends JobListener { + + val startTime = System.currentTimeMillis() + val totalTasks = rdd.partitions.size + var finishedTasks = 0 + var failure: Option[Exception] = None // Set if the job has failed (permanently) + var resultObject: Option[PartialResult[R]] = None // Set if we've already returned a PartialResult + + override def taskSucceeded(index: Int, result: Any) { + synchronized { + evaluator.merge(index, result.asInstanceOf[U]) + finishedTasks += 1 + if (finishedTasks == totalTasks) { + // If we had already returned a PartialResult, set its final value + resultObject.foreach(r => r.setFinalValue(evaluator.currentResult())) + // Notify any waiting thread that may have called awaitResult + this.notifyAll() + } + } + } + + override def jobFailed(exception: Exception) { + synchronized { + failure = Some(exception) + this.notifyAll() + } + } + + /** + * Waits for up to timeout milliseconds since the listener was created and then returns a + * PartialResult with the result so far. This may be complete if the whole job is done. + */ + def awaitResult(): PartialResult[R] = synchronized { + val finishTime = startTime + timeout + while (true) { + val time = System.currentTimeMillis() + if (failure != None) { + throw failure.get + } else if (finishedTasks == totalTasks) { + return new PartialResult(evaluator.currentResult(), true) + } else if (time >= finishTime) { + resultObject = Some(new PartialResult(evaluator.currentResult(), false)) + return resultObject.get + } else { + this.wait(finishTime - time) + } + } + // Should never be reached, but required to keep the compiler happy + return null + } +} diff --git a/core/src/main/scala/org/apache/spark/partial/ApproximateEvaluator.scala b/core/src/main/scala/org/apache/spark/partial/ApproximateEvaluator.scala new file mode 100644 index 0000000000..9c2859c8b9 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/partial/ApproximateEvaluator.scala @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.partial + +/** + * An object that computes a function incrementally by merging in results of type U from multiple + * tasks. Allows partial evaluation at any point by calling currentResult(). + */ +private[spark] trait ApproximateEvaluator[U, R] { + def merge(outputId: Int, taskResult: U): Unit + def currentResult(): R +} diff --git a/core/src/main/scala/org/apache/spark/partial/BoundedDouble.scala b/core/src/main/scala/org/apache/spark/partial/BoundedDouble.scala new file mode 100644 index 0000000000..5f4450859c --- /dev/null +++ b/core/src/main/scala/org/apache/spark/partial/BoundedDouble.scala @@ -0,0 +1,25 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.partial + +/** + * A Double with error bars on it. + */ +class BoundedDouble(val mean: Double, val confidence: Double, val low: Double, val high: Double) { + override def toString(): String = "[%.3f, %.3f]".format(low, high) +} diff --git a/core/src/main/scala/org/apache/spark/partial/CountEvaluator.scala b/core/src/main/scala/org/apache/spark/partial/CountEvaluator.scala new file mode 100644 index 0000000000..3155dfe165 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/partial/CountEvaluator.scala @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.partial + +import cern.jet.stat.Probability + +/** + * An ApproximateEvaluator for counts. + * + * TODO: There's currently a lot of shared code between this and GroupedCountEvaluator. It might + * be best to make this a special case of GroupedCountEvaluator with one group. + */ +private[spark] class CountEvaluator(totalOutputs: Int, confidence: Double) + extends ApproximateEvaluator[Long, BoundedDouble] { + + var outputsMerged = 0 + var sum: Long = 0 + + override def merge(outputId: Int, taskResult: Long) { + outputsMerged += 1 + sum += taskResult + } + + override def currentResult(): BoundedDouble = { + if (outputsMerged == totalOutputs) { + new BoundedDouble(sum, 1.0, sum, sum) + } else if (outputsMerged == 0) { + new BoundedDouble(0, 0.0, Double.NegativeInfinity, Double.PositiveInfinity) + } else { + val p = outputsMerged.toDouble / totalOutputs + val mean = (sum + 1 - p) / p + val variance = (sum + 1) * (1 - p) / (p * p) + val stdev = math.sqrt(variance) + val confFactor = Probability.normalInverse(1 - (1 - confidence) / 2) + val low = mean - confFactor * stdev + val high = mean + confFactor * stdev + new BoundedDouble(mean, confidence, low, high) + } + } +} diff --git a/core/src/main/scala/org/apache/spark/partial/GroupedCountEvaluator.scala b/core/src/main/scala/org/apache/spark/partial/GroupedCountEvaluator.scala new file mode 100644 index 0000000000..e519e3a548 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/partial/GroupedCountEvaluator.scala @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.partial + +import java.util.{HashMap => JHashMap} +import java.util.{Map => JMap} + +import scala.collection.Map +import scala.collection.mutable.HashMap +import scala.collection.JavaConversions.mapAsScalaMap + +import cern.jet.stat.Probability + +import it.unimi.dsi.fastutil.objects.{Object2LongOpenHashMap => OLMap} + +/** + * An ApproximateEvaluator for counts by key. Returns a map of key to confidence interval. + */ +private[spark] class GroupedCountEvaluator[T](totalOutputs: Int, confidence: Double) + extends ApproximateEvaluator[OLMap[T], Map[T, BoundedDouble]] { + + var outputsMerged = 0 + var sums = new OLMap[T] // Sum of counts for each key + + override def merge(outputId: Int, taskResult: OLMap[T]) { + outputsMerged += 1 + val iter = taskResult.object2LongEntrySet.fastIterator() + while (iter.hasNext) { + val entry = iter.next() + sums.put(entry.getKey, sums.getLong(entry.getKey) + entry.getLongValue) + } + } + + override def currentResult(): Map[T, BoundedDouble] = { + if (outputsMerged == totalOutputs) { + val result = new JHashMap[T, BoundedDouble](sums.size) + val iter = sums.object2LongEntrySet.fastIterator() + while (iter.hasNext) { + val entry = iter.next() + val sum = entry.getLongValue() + result(entry.getKey) = new BoundedDouble(sum, 1.0, sum, sum) + } + result + } else if (outputsMerged == 0) { + new HashMap[T, BoundedDouble] + } else { + val p = outputsMerged.toDouble / totalOutputs + val confFactor = Probability.normalInverse(1 - (1 - confidence) / 2) + val result = new JHashMap[T, BoundedDouble](sums.size) + val iter = sums.object2LongEntrySet.fastIterator() + while (iter.hasNext) { + val entry = iter.next() + val sum = entry.getLongValue + val mean = (sum + 1 - p) / p + val variance = (sum + 1) * (1 - p) / (p * p) + val stdev = math.sqrt(variance) + val low = mean - confFactor * stdev + val high = mean + confFactor * stdev + result(entry.getKey) = new BoundedDouble(mean, confidence, low, high) + } + result + } + } +} diff --git a/core/src/main/scala/org/apache/spark/partial/GroupedMeanEvaluator.scala b/core/src/main/scala/org/apache/spark/partial/GroupedMeanEvaluator.scala new file mode 100644 index 0000000000..cf8a5680b6 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/partial/GroupedMeanEvaluator.scala @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.partial + +import java.util.{HashMap => JHashMap} +import java.util.{Map => JMap} + +import scala.collection.mutable.HashMap +import scala.collection.Map +import scala.collection.JavaConversions.mapAsScalaMap + +import org.apache.spark.util.StatCounter + +/** + * An ApproximateEvaluator for means by key. Returns a map of key to confidence interval. + */ +private[spark] class GroupedMeanEvaluator[T](totalOutputs: Int, confidence: Double) + extends ApproximateEvaluator[JHashMap[T, StatCounter], Map[T, BoundedDouble]] { + + var outputsMerged = 0 + var sums = new JHashMap[T, StatCounter] // Sum of counts for each key + + override def merge(outputId: Int, taskResult: JHashMap[T, StatCounter]) { + outputsMerged += 1 + val iter = taskResult.entrySet.iterator() + while (iter.hasNext) { + val entry = iter.next() + val old = sums.get(entry.getKey) + if (old != null) { + old.merge(entry.getValue) + } else { + sums.put(entry.getKey, entry.getValue) + } + } + } + + override def currentResult(): Map[T, BoundedDouble] = { + if (outputsMerged == totalOutputs) { + val result = new JHashMap[T, BoundedDouble](sums.size) + val iter = sums.entrySet.iterator() + while (iter.hasNext) { + val entry = iter.next() + val mean = entry.getValue.mean + result(entry.getKey) = new BoundedDouble(mean, 1.0, mean, mean) + } + result + } else if (outputsMerged == 0) { + new HashMap[T, BoundedDouble] + } else { + val p = outputsMerged.toDouble / totalOutputs + val studentTCacher = new StudentTCacher(confidence) + val result = new JHashMap[T, BoundedDouble](sums.size) + val iter = sums.entrySet.iterator() + while (iter.hasNext) { + val entry = iter.next() + val counter = entry.getValue + val mean = counter.mean + val stdev = math.sqrt(counter.sampleVariance / counter.count) + val confFactor = studentTCacher.get(counter.count) + val low = mean - confFactor * stdev + val high = mean + confFactor * stdev + result(entry.getKey) = new BoundedDouble(mean, confidence, low, high) + } + result + } + } +} diff --git a/core/src/main/scala/org/apache/spark/partial/GroupedSumEvaluator.scala b/core/src/main/scala/org/apache/spark/partial/GroupedSumEvaluator.scala new file mode 100644 index 0000000000..8225a5d933 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/partial/GroupedSumEvaluator.scala @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.partial + +import java.util.{HashMap => JHashMap} +import java.util.{Map => JMap} + +import scala.collection.mutable.HashMap +import scala.collection.Map +import scala.collection.JavaConversions.mapAsScalaMap + +import org.apache.spark.util.StatCounter + +/** + * An ApproximateEvaluator for sums by key. Returns a map of key to confidence interval. + */ +private[spark] class GroupedSumEvaluator[T](totalOutputs: Int, confidence: Double) + extends ApproximateEvaluator[JHashMap[T, StatCounter], Map[T, BoundedDouble]] { + + var outputsMerged = 0 + var sums = new JHashMap[T, StatCounter] // Sum of counts for each key + + override def merge(outputId: Int, taskResult: JHashMap[T, StatCounter]) { + outputsMerged += 1 + val iter = taskResult.entrySet.iterator() + while (iter.hasNext) { + val entry = iter.next() + val old = sums.get(entry.getKey) + if (old != null) { + old.merge(entry.getValue) + } else { + sums.put(entry.getKey, entry.getValue) + } + } + } + + override def currentResult(): Map[T, BoundedDouble] = { + if (outputsMerged == totalOutputs) { + val result = new JHashMap[T, BoundedDouble](sums.size) + val iter = sums.entrySet.iterator() + while (iter.hasNext) { + val entry = iter.next() + val sum = entry.getValue.sum + result(entry.getKey) = new BoundedDouble(sum, 1.0, sum, sum) + } + result + } else if (outputsMerged == 0) { + new HashMap[T, BoundedDouble] + } else { + val p = outputsMerged.toDouble / totalOutputs + val studentTCacher = new StudentTCacher(confidence) + val result = new JHashMap[T, BoundedDouble](sums.size) + val iter = sums.entrySet.iterator() + while (iter.hasNext) { + val entry = iter.next() + val counter = entry.getValue + val meanEstimate = counter.mean + val meanVar = counter.sampleVariance / counter.count + val countEstimate = (counter.count + 1 - p) / p + val countVar = (counter.count + 1) * (1 - p) / (p * p) + val sumEstimate = meanEstimate * countEstimate + val sumVar = (meanEstimate * meanEstimate * countVar) + + (countEstimate * countEstimate * meanVar) + + (meanVar * countVar) + val sumStdev = math.sqrt(sumVar) + val confFactor = studentTCacher.get(counter.count) + val low = sumEstimate - confFactor * sumStdev + val high = sumEstimate + confFactor * sumStdev + result(entry.getKey) = new BoundedDouble(sumEstimate, confidence, low, high) + } + result + } + } +} diff --git a/core/src/main/scala/org/apache/spark/partial/MeanEvaluator.scala b/core/src/main/scala/org/apache/spark/partial/MeanEvaluator.scala new file mode 100644 index 0000000000..d24959cba8 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/partial/MeanEvaluator.scala @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.partial + +import cern.jet.stat.Probability + +import org.apache.spark.util.StatCounter + +/** + * An ApproximateEvaluator for means. + */ +private[spark] class MeanEvaluator(totalOutputs: Int, confidence: Double) + extends ApproximateEvaluator[StatCounter, BoundedDouble] { + + var outputsMerged = 0 + var counter = new StatCounter + + override def merge(outputId: Int, taskResult: StatCounter) { + outputsMerged += 1 + counter.merge(taskResult) + } + + override def currentResult(): BoundedDouble = { + if (outputsMerged == totalOutputs) { + new BoundedDouble(counter.mean, 1.0, counter.mean, counter.mean) + } else if (outputsMerged == 0) { + new BoundedDouble(0, 0.0, Double.NegativeInfinity, Double.PositiveInfinity) + } else { + val mean = counter.mean + val stdev = math.sqrt(counter.sampleVariance / counter.count) + val confFactor = { + if (counter.count > 100) { + Probability.normalInverse(1 - (1 - confidence) / 2) + } else { + Probability.studentTInverse(1 - confidence, (counter.count - 1).toInt) + } + } + val low = mean - confFactor * stdev + val high = mean + confFactor * stdev + new BoundedDouble(mean, confidence, low, high) + } + } +} diff --git a/core/src/main/scala/org/apache/spark/partial/PartialResult.scala b/core/src/main/scala/org/apache/spark/partial/PartialResult.scala new file mode 100644 index 0000000000..5ce49b8100 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/partial/PartialResult.scala @@ -0,0 +1,137 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.partial + +class PartialResult[R](initialVal: R, isFinal: Boolean) { + private var finalValue: Option[R] = if (isFinal) Some(initialVal) else None + private var failure: Option[Exception] = None + private var completionHandler: Option[R => Unit] = None + private var failureHandler: Option[Exception => Unit] = None + + def initialValue: R = initialVal + + def isInitialValueFinal: Boolean = isFinal + + /** + * Blocking method to wait for and return the final value. + */ + def getFinalValue(): R = synchronized { + while (finalValue == None && failure == None) { + this.wait() + } + if (finalValue != None) { + return finalValue.get + } else { + throw failure.get + } + } + + /** + * Set a handler to be called when this PartialResult completes. Only one completion handler + * is supported per PartialResult. + */ + def onComplete(handler: R => Unit): PartialResult[R] = synchronized { + if (completionHandler != None) { + throw new UnsupportedOperationException("onComplete cannot be called twice") + } + completionHandler = Some(handler) + if (finalValue != None) { + // We already have a final value, so let's call the handler + handler(finalValue.get) + } + return this + } + + /** + * Set a handler to be called if this PartialResult's job fails. Only one failure handler + * is supported per PartialResult. + */ + def onFail(handler: Exception => Unit) { + synchronized { + if (failureHandler != None) { + throw new UnsupportedOperationException("onFail cannot be called twice") + } + failureHandler = Some(handler) + if (failure != None) { + // We already have a failure, so let's call the handler + handler(failure.get) + } + } + } + + /** + * Transform this PartialResult into a PartialResult of type T. + */ + def map[T](f: R => T) : PartialResult[T] = { + new PartialResult[T](f(initialVal), isFinal) { + override def getFinalValue() : T = synchronized { + f(PartialResult.this.getFinalValue()) + } + override def onComplete(handler: T => Unit): PartialResult[T] = synchronized { + PartialResult.this.onComplete(handler.compose(f)).map(f) + } + override def onFail(handler: Exception => Unit) { + synchronized { + PartialResult.this.onFail(handler) + } + } + override def toString : String = synchronized { + PartialResult.this.getFinalValueInternal() match { + case Some(value) => "(final: " + f(value) + ")" + case None => "(partial: " + initialValue + ")" + } + } + def getFinalValueInternal() = PartialResult.this.getFinalValueInternal().map(f) + } + } + + private[spark] def setFinalValue(value: R) { + synchronized { + if (finalValue != None) { + throw new UnsupportedOperationException("setFinalValue called twice on a PartialResult") + } + finalValue = Some(value) + // Call the completion handler if it was set + completionHandler.foreach(h => h(value)) + // Notify any threads that may be calling getFinalValue() + this.notifyAll() + } + } + + private def getFinalValueInternal() = finalValue + + private[spark] def setFailure(exception: Exception) { + synchronized { + if (failure != None) { + throw new UnsupportedOperationException("setFailure called twice on a PartialResult") + } + failure = Some(exception) + // Call the failure handler if it was set + failureHandler.foreach(h => h(exception)) + // Notify any threads that may be calling getFinalValue() + this.notifyAll() + } + } + + override def toString: String = synchronized { + finalValue match { + case Some(value) => "(final: " + value + ")" + case None => "(partial: " + initialValue + ")" + } + } +} diff --git a/core/src/main/scala/org/apache/spark/partial/StudentTCacher.scala b/core/src/main/scala/org/apache/spark/partial/StudentTCacher.scala new file mode 100644 index 0000000000..92915ee66d --- /dev/null +++ b/core/src/main/scala/org/apache/spark/partial/StudentTCacher.scala @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.partial + +import cern.jet.stat.Probability + +/** + * A utility class for caching Student's T distribution values for a given confidence level + * and various sample sizes. This is used by the MeanEvaluator to efficiently calculate + * confidence intervals for many keys. + */ +private[spark] class StudentTCacher(confidence: Double) { + val NORMAL_APPROX_SAMPLE_SIZE = 100 // For samples bigger than this, use Gaussian approximation + val normalApprox = Probability.normalInverse(1 - (1 - confidence) / 2) + val cache = Array.fill[Double](NORMAL_APPROX_SAMPLE_SIZE)(-1.0) + + def get(sampleSize: Long): Double = { + if (sampleSize >= NORMAL_APPROX_SAMPLE_SIZE) { + normalApprox + } else { + val size = sampleSize.toInt + if (cache(size) < 0) { + cache(size) = Probability.studentTInverse(1 - confidence, size - 1) + } + cache(size) + } + } +} diff --git a/core/src/main/scala/org/apache/spark/partial/SumEvaluator.scala b/core/src/main/scala/org/apache/spark/partial/SumEvaluator.scala new file mode 100644 index 0000000000..a74f800944 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/partial/SumEvaluator.scala @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.partial + +import cern.jet.stat.Probability + +import org.apache.spark.util.StatCounter + +/** + * An ApproximateEvaluator for sums. It estimates the mean and the cont and multiplies them + * together, then uses the formula for the variance of two independent random variables to get + * a variance for the result and compute a confidence interval. + */ +private[spark] class SumEvaluator(totalOutputs: Int, confidence: Double) + extends ApproximateEvaluator[StatCounter, BoundedDouble] { + + var outputsMerged = 0 + var counter = new StatCounter + + override def merge(outputId: Int, taskResult: StatCounter) { + outputsMerged += 1 + counter.merge(taskResult) + } + + override def currentResult(): BoundedDouble = { + if (outputsMerged == totalOutputs) { + new BoundedDouble(counter.sum, 1.0, counter.sum, counter.sum) + } else if (outputsMerged == 0) { + new BoundedDouble(0, 0.0, Double.NegativeInfinity, Double.PositiveInfinity) + } else { + val p = outputsMerged.toDouble / totalOutputs + val meanEstimate = counter.mean + val meanVar = counter.sampleVariance / counter.count + val countEstimate = (counter.count + 1 - p) / p + val countVar = (counter.count + 1) * (1 - p) / (p * p) + val sumEstimate = meanEstimate * countEstimate + val sumVar = (meanEstimate * meanEstimate * countVar) + + (countEstimate * countEstimate * meanVar) + + (meanVar * countVar) + val sumStdev = math.sqrt(sumVar) + val confFactor = { + if (counter.count > 100) { + Probability.normalInverse(1 - (1 - confidence) / 2) + } else { + Probability.studentTInverse(1 - confidence, (counter.count - 1).toInt) + } + } + val low = sumEstimate - confFactor * sumStdev + val high = sumEstimate + confFactor * sumStdev + new BoundedDouble(sumEstimate, confidence, low, high) + } + } +} diff --git a/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala b/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala new file mode 100644 index 0000000000..bca6956a18 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rdd + +import org.apache.spark.{SparkContext, SparkEnv, Partition, TaskContext} +import org.apache.spark.storage.BlockManager + +private[spark] class BlockRDDPartition(val blockId: String, idx: Int) extends Partition { + val index = idx +} + +private[spark] +class BlockRDD[T: ClassManifest](sc: SparkContext, @transient blockIds: Array[String]) + extends RDD[T](sc, Nil) { + + @transient lazy val locations_ = BlockManager.blockIdsToHosts(blockIds, SparkEnv.get) + + override def getPartitions: Array[Partition] = (0 until blockIds.size).map(i => { + new BlockRDDPartition(blockIds(i), i).asInstanceOf[Partition] + }).toArray + + override def compute(split: Partition, context: TaskContext): Iterator[T] = { + val blockManager = SparkEnv.get.blockManager + val blockId = split.asInstanceOf[BlockRDDPartition].blockId + blockManager.get(blockId) match { + case Some(block) => block.asInstanceOf[Iterator[T]] + case None => + throw new Exception("Could not compute split, block " + blockId + " not found") + } + } + + override def getPreferredLocations(split: Partition): Seq[String] = { + locations_(split.asInstanceOf[BlockRDDPartition].blockId) + } +} + diff --git a/core/src/main/scala/org/apache/spark/rdd/CartesianRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CartesianRDD.scala new file mode 100644 index 0000000000..4fb7f3aace --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rdd/CartesianRDD.scala @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rdd + +import java.io.{ObjectOutputStream, IOException} +import scala.reflect.ClassTag +import org.apache.spark._ + + +private[spark] +class CartesianPartition( + idx: Int, + @transient rdd1: RDD[_], + @transient rdd2: RDD[_], + s1Index: Int, + s2Index: Int + ) extends Partition { + var s1 = rdd1.partitions(s1Index) + var s2 = rdd2.partitions(s2Index) + override val index: Int = idx + + @throws(classOf[IOException]) + private def writeObject(oos: ObjectOutputStream) { + // Update the reference to parent split at the time of task serialization + s1 = rdd1.partitions(s1Index) + s2 = rdd2.partitions(s2Index) + oos.defaultWriteObject() + } +} + +private[spark] +class CartesianRDD[T: ClassTag, U:ClassTag]( + sc: SparkContext, + var rdd1 : RDD[T], + var rdd2 : RDD[U]) + extends RDD[Pair[T, U]](sc, Nil) + with Serializable { + + val numPartitionsInRdd2 = rdd2.partitions.size + + override def getPartitions: Array[Partition] = { + // create the cross product split + val array = new Array[Partition](rdd1.partitions.size * rdd2.partitions.size) + for (s1 <- rdd1.partitions; s2 <- rdd2.partitions) { + val idx = s1.index * numPartitionsInRdd2 + s2.index + array(idx) = new CartesianPartition(idx, rdd1, rdd2, s1.index, s2.index) + } + array + } + + override def getPreferredLocations(split: Partition): Seq[String] = { + val currSplit = split.asInstanceOf[CartesianPartition] + (rdd1.preferredLocations(currSplit.s1) ++ rdd2.preferredLocations(currSplit.s2)).distinct + } + + override def compute(split: Partition, context: TaskContext) = { + val currSplit = split.asInstanceOf[CartesianPartition] + for (x <- rdd1.iterator(currSplit.s1, context); + y <- rdd2.iterator(currSplit.s2, context)) yield (x, y) + } + + override def getDependencies: Seq[Dependency[_]] = List( + new NarrowDependency(rdd1) { + def getParents(id: Int): Seq[Int] = List(id / numPartitionsInRdd2) + }, + new NarrowDependency(rdd2) { + def getParents(id: Int): Seq[Int] = List(id % numPartitionsInRdd2) + } + ) + + override def clearDependencies() { + super.clearDependencies() + rdd1 = null + rdd2 = null + } +} diff --git a/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala new file mode 100644 index 0000000000..3f4d4ad46a --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala @@ -0,0 +1,156 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rdd + +import scala.reflect.ClassTag +import org.apache.spark._ +import org.apache.hadoop.mapred.{FileInputFormat, SequenceFileInputFormat, JobConf, Reporter} +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.io.{NullWritable, BytesWritable} +import org.apache.hadoop.util.ReflectionUtils +import org.apache.hadoop.fs.Path +import java.io.{File, IOException, EOFException} +import java.text.NumberFormat + +private[spark] class CheckpointRDDPartition(val index: Int) extends Partition {} + +/** + * This RDD represents a RDD checkpoint file (similar to HadoopRDD). + */ +private[spark] +class CheckpointRDD[T: ClassTag](sc: SparkContext, val checkpointPath: String) + extends RDD[T](sc, Nil) { + + @transient val fs = new Path(checkpointPath).getFileSystem(sc.hadoopConfiguration) + + override def getPartitions: Array[Partition] = { + val cpath = new Path(checkpointPath) + val numPartitions = + // listStatus can throw exception if path does not exist. + if (fs.exists(cpath)) { + val dirContents = fs.listStatus(cpath) + val partitionFiles = dirContents.map(_.getPath.toString).filter(_.contains("part-")).sorted + val numPart = partitionFiles.size + if (numPart > 0 && (! partitionFiles(0).endsWith(CheckpointRDD.splitIdToFile(0)) || + ! partitionFiles(numPart-1).endsWith(CheckpointRDD.splitIdToFile(numPart-1)))) { + throw new SparkException("Invalid checkpoint directory: " + checkpointPath) + } + numPart + } else 0 + + Array.tabulate(numPartitions)(i => new CheckpointRDDPartition(i)) + } + + checkpointData = Some(new RDDCheckpointData[T](this)) + checkpointData.get.cpFile = Some(checkpointPath) + + override def getPreferredLocations(split: Partition): Seq[String] = { + val status = fs.getFileStatus(new Path(checkpointPath, CheckpointRDD.splitIdToFile(split.index))) + val locations = fs.getFileBlockLocations(status, 0, status.getLen) + locations.headOption.toList.flatMap(_.getHosts).filter(_ != "localhost") + } + + override def compute(split: Partition, context: TaskContext): Iterator[T] = { + val file = new Path(checkpointPath, CheckpointRDD.splitIdToFile(split.index)) + CheckpointRDD.readFromFile(file, context) + } + + override def checkpoint() { + // Do nothing. CheckpointRDD should not be checkpointed. + } +} + +private[spark] object CheckpointRDD extends Logging { + + def splitIdToFile(splitId: Int): String = { + "part-%05d".format(splitId) + } + + def writeToFile[T](path: String, blockSize: Int = -1)(ctx: TaskContext, iterator: Iterator[T]) { + val env = SparkEnv.get + val outputDir = new Path(path) + val fs = outputDir.getFileSystem(env.hadoop.newConfiguration()) + + val finalOutputName = splitIdToFile(ctx.splitId) + val finalOutputPath = new Path(outputDir, finalOutputName) + val tempOutputPath = new Path(outputDir, "." + finalOutputName + "-attempt-" + ctx.attemptId) + + if (fs.exists(tempOutputPath)) { + throw new IOException("Checkpoint failed: temporary path " + + tempOutputPath + " already exists") + } + val bufferSize = System.getProperty("spark.buffer.size", "65536").toInt + + val fileOutputStream = if (blockSize < 0) { + fs.create(tempOutputPath, false, bufferSize) + } else { + // This is mainly for testing purpose + fs.create(tempOutputPath, false, bufferSize, fs.getDefaultReplication, blockSize) + } + val serializer = env.serializer.newInstance() + val serializeStream = serializer.serializeStream(fileOutputStream) + serializeStream.writeAll(iterator) + serializeStream.close() + + if (!fs.rename(tempOutputPath, finalOutputPath)) { + if (!fs.exists(finalOutputPath)) { + logInfo("Deleting tempOutputPath " + tempOutputPath) + fs.delete(tempOutputPath, false) + throw new IOException("Checkpoint failed: failed to save output of task: " + + ctx.attemptId + " and final output path does not exist") + } else { + // Some other copy of this task must've finished before us and renamed it + logInfo("Final output path " + finalOutputPath + " already exists; not overwriting it") + fs.delete(tempOutputPath, false) + } + } + } + + def readFromFile[T](path: Path, context: TaskContext): Iterator[T] = { + val env = SparkEnv.get + val fs = path.getFileSystem(env.hadoop.newConfiguration()) + val bufferSize = System.getProperty("spark.buffer.size", "65536").toInt + val fileInputStream = fs.open(path, bufferSize) + val serializer = env.serializer.newInstance() + val deserializeStream = serializer.deserializeStream(fileInputStream) + + // Register an on-task-completion callback to close the input stream. + context.addOnCompleteCallback(() => deserializeStream.close()) + + deserializeStream.asIterator.asInstanceOf[Iterator[T]] + } + + // Test whether CheckpointRDD generate expected number of partitions despite + // each split file having multiple blocks. This needs to be run on a + // cluster (mesos or standalone) using HDFS. + def main(args: Array[String]) { + import org.apache.spark._ + + val Array(cluster, hdfsPath) = args + val env = SparkEnv.get + val sc = new SparkContext(cluster, "CheckpointRDD Test") + val rdd = sc.makeRDD(1 to 10, 10).flatMap(x => 1 to 10000) + val path = new Path(hdfsPath, "temp") + val fs = path.getFileSystem(env.hadoop.newConfiguration()) + sc.runJob(rdd, CheckpointRDD.writeToFile(path.toString, 1024) _) + val cpRDD = new CheckpointRDD[Int](sc, path.toString) + assert(cpRDD.partitions.length == rdd.partitions.length, "Number of partitions is not the same") + assert(cpRDD.collect.toList == rdd.collect.toList, "Data of partitions not the same") + fs.delete(path, true) + } +} diff --git a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala new file mode 100644 index 0000000000..0187256a8e --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala @@ -0,0 +1,144 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rdd + +import java.io.{ObjectOutputStream, IOException} +import java.util.{HashMap => JHashMap} + +import scala.collection.JavaConversions +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.{Partition, Partitioner, SparkEnv, TaskContext} +import org.apache.spark.{Dependency, OneToOneDependency, ShuffleDependency} + + +private[spark] sealed trait CoGroupSplitDep extends Serializable + +private[spark] case class NarrowCoGroupSplitDep( + rdd: RDD[_], + splitIndex: Int, + var split: Partition + ) extends CoGroupSplitDep { + + @throws(classOf[IOException]) + private def writeObject(oos: ObjectOutputStream) { + // Update the reference to parent split at the time of task serialization + split = rdd.partitions(splitIndex) + oos.defaultWriteObject() + } +} + +private[spark] case class ShuffleCoGroupSplitDep(shuffleId: Int) extends CoGroupSplitDep + +private[spark] +class CoGroupPartition(idx: Int, val deps: Array[CoGroupSplitDep]) + extends Partition with Serializable { + override val index: Int = idx + override def hashCode(): Int = idx +} + + +/** + * A RDD that cogroups its parents. For each key k in parent RDDs, the resulting RDD contains a + * tuple with the list of values for that key. + * + * @param rdds parent RDDs. + * @param part partitioner used to partition the shuffle output. + */ +class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part: Partitioner) + extends RDD[(K, Seq[Seq[_]])](rdds.head.context, Nil) { + + private var serializerClass: String = null + + def setSerializer(cls: String): CoGroupedRDD[K] = { + serializerClass = cls + this + } + + override def getDependencies: Seq[Dependency[_]] = { + rdds.map { rdd: RDD[_ <: Product2[K, _]] => + if (rdd.partitioner == Some(part)) { + logDebug("Adding one-to-one dependency with " + rdd) + new OneToOneDependency(rdd) + } else { + logDebug("Adding shuffle dependency with " + rdd) + new ShuffleDependency[Any, Any](rdd, part, serializerClass) + } + } + } + + override def getPartitions: Array[Partition] = { + val array = new Array[Partition](part.numPartitions) + for (i <- 0 until array.size) { + // Each CoGroupPartition will have a dependency per contributing RDD + array(i) = new CoGroupPartition(i, rdds.zipWithIndex.map { case (rdd, j) => + // Assume each RDD contributed a single dependency, and get it + dependencies(j) match { + case s: ShuffleDependency[_, _] => + new ShuffleCoGroupSplitDep(s.shuffleId) + case _ => + new NarrowCoGroupSplitDep(rdd, i, rdd.partitions(i)) + } + }.toArray) + } + array + } + + override val partitioner = Some(part) + + override def compute(s: Partition, context: TaskContext): Iterator[(K, Seq[Seq[_]])] = { + val split = s.asInstanceOf[CoGroupPartition] + val numRdds = split.deps.size + // e.g. for `(k, a) cogroup (k, b)`, K -> Seq(ArrayBuffer as, ArrayBuffer bs) + val map = new JHashMap[K, Seq[ArrayBuffer[Any]]] + + def getSeq(k: K): Seq[ArrayBuffer[Any]] = { + val seq = map.get(k) + if (seq != null) { + seq + } else { + val seq = Array.fill(numRdds)(new ArrayBuffer[Any]) + map.put(k, seq) + seq + } + } + + val ser = SparkEnv.get.serializerManager.get(serializerClass) + for ((dep, depNum) <- split.deps.zipWithIndex) dep match { + case NarrowCoGroupSplitDep(rdd, _, itsSplit) => { + // Read them from the parent + rdd.iterator(itsSplit, context).asInstanceOf[Iterator[Product2[K, Any]]].foreach { kv => + getSeq(kv._1)(depNum) += kv._2 + } + } + case ShuffleCoGroupSplitDep(shuffleId) => { + // Read map outputs of shuffle + val fetcher = SparkEnv.get.shuffleFetcher + fetcher.fetch[Product2[K, Any]](shuffleId, split.index, context.taskMetrics, ser).foreach { + kv => getSeq(kv._1)(depNum) += kv._2 + } + } + } + JavaConversions.mapAsScalaMap(map).iterator + } + + override def clearDependencies() { + super.clearDependencies() + rdds = null + } +} diff --git a/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala new file mode 100644 index 0000000000..c5de6362a9 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala @@ -0,0 +1,342 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rdd + +import org.apache.spark._ +import java.io.{ObjectOutputStream, IOException} +import scala.collection.mutable +import scala.Some +import scala.collection.mutable.ArrayBuffer + +/** + * Class that captures a coalesced RDD by essentially keeping track of parent partitions + * @param index of this coalesced partition + * @param rdd which it belongs to + * @param parentsIndices list of indices in the parent that have been coalesced into this partition + * @param preferredLocation the preferred location for this partition + */ +case class CoalescedRDDPartition( + index: Int, + @transient rdd: RDD[_], + parentsIndices: Array[Int], + @transient preferredLocation: String = "" + ) extends Partition { + var parents: Seq[Partition] = parentsIndices.map(rdd.partitions(_)) + + @throws(classOf[IOException]) + private def writeObject(oos: ObjectOutputStream) { + // Update the reference to parent partition at the time of task serialization + parents = parentsIndices.map(rdd.partitions(_)) + oos.defaultWriteObject() + } + + /** + * Computes how many of the parents partitions have getPreferredLocation + * as one of their preferredLocations + * @return locality of this coalesced partition between 0 and 1 + */ + def localFraction: Double = { + val loc = parents.count(p => + rdd.context.getPreferredLocs(rdd, p.index).map(tl => tl.host).contains(preferredLocation)) + + if (parents.size == 0) 0.0 else (loc.toDouble / parents.size.toDouble) + } +} + +/** + * Represents a coalesced RDD that has fewer partitions than its parent RDD + * This class uses the PartitionCoalescer class to find a good partitioning of the parent RDD + * so that each new partition has roughly the same number of parent partitions and that + * the preferred location of each new partition overlaps with as many preferred locations of its + * parent partitions + * @param prev RDD to be coalesced + * @param maxPartitions number of desired partitions in the coalesced RDD + * @param balanceSlack used to trade-off balance and locality. 1.0 is all locality, 0 is all balance + */ +class CoalescedRDD[T: ClassManifest]( + @transient var prev: RDD[T], + maxPartitions: Int, + balanceSlack: Double = 0.10) + extends RDD[T](prev.context, Nil) { // Nil since we implement getDependencies + + override def getPartitions: Array[Partition] = { + val pc = new PartitionCoalescer(maxPartitions, prev, balanceSlack) + + pc.run().zipWithIndex.map { + case (pg, i) => + val ids = pg.arr.map(_.index).toArray + new CoalescedRDDPartition(i, prev, ids, pg.prefLoc) + } + } + + override def compute(partition: Partition, context: TaskContext): Iterator[T] = { + partition.asInstanceOf[CoalescedRDDPartition].parents.iterator.flatMap { parentPartition => + firstParent[T].iterator(parentPartition, context) + } + } + + override def getDependencies: Seq[Dependency[_]] = { + Seq(new NarrowDependency(prev) { + def getParents(id: Int): Seq[Int] = + partitions(id).asInstanceOf[CoalescedRDDPartition].parentsIndices + }) + } + + override def clearDependencies() { + super.clearDependencies() + prev = null + } + + /** + * Returns the preferred machine for the partition. If split is of type CoalescedRDDPartition, + * then the preferred machine will be one which most parent splits prefer too. + * @param partition + * @return the machine most preferred by split + */ + override def getPreferredLocations(partition: Partition): Seq[String] = { + List(partition.asInstanceOf[CoalescedRDDPartition].preferredLocation) + } +} + +/** + * Coalesce the partitions of a parent RDD (`prev`) into fewer partitions, so that each partition of + * this RDD computes one or more of the parent ones. It will produce exactly `maxPartitions` if the + * parent had more than maxPartitions, or fewer if the parent had fewer. + * + * This transformation is useful when an RDD with many partitions gets filtered into a smaller one, + * or to avoid having a large number of small tasks when processing a directory with many files. + * + * If there is no locality information (no preferredLocations) in the parent, then the coalescing + * is very simple: chunk parents that are close in the Array in chunks. + * If there is locality information, it proceeds to pack them with the following four goals: + * + * (1) Balance the groups so they roughly have the same number of parent partitions + * (2) Achieve locality per partition, i.e. find one machine which most parent partitions prefer + * (3) Be efficient, i.e. O(n) algorithm for n parent partitions (problem is likely NP-hard) + * (4) Balance preferred machines, i.e. avoid as much as possible picking the same preferred machine + * + * Furthermore, it is assumed that the parent RDD may have many partitions, e.g. 100 000. + * We assume the final number of desired partitions is small, e.g. less than 1000. + * + * The algorithm tries to assign unique preferred machines to each partition. If the number of + * desired partitions is greater than the number of preferred machines (can happen), it needs to + * start picking duplicate preferred machines. This is determined using coupon collector estimation + * (2n log(n)). The load balancing is done using power-of-two randomized bins-balls with one twist: + * it tries to also achieve locality. This is done by allowing a slack (balanceSlack) between two + * bins. If two bins are within the slack in terms of balance, the algorithm will assign partitions + * according to locality. (contact alig for questions) + * + */ + +private[spark] class PartitionCoalescer(maxPartitions: Int, prev: RDD[_], balanceSlack: Double) { + + def compare(o1: PartitionGroup, o2: PartitionGroup): Boolean = o1.size < o2.size + def compare(o1: Option[PartitionGroup], o2: Option[PartitionGroup]): Boolean = + if (o1 == None) false else if (o2 == None) true else compare(o1.get, o2.get) + + val rnd = new scala.util.Random(7919) // keep this class deterministic + + // each element of groupArr represents one coalesced partition + val groupArr = ArrayBuffer[PartitionGroup]() + + // hash used to check whether some machine is already in groupArr + val groupHash = mutable.Map[String, ArrayBuffer[PartitionGroup]]() + + // hash used for the first maxPartitions (to avoid duplicates) + val initialHash = mutable.Set[Partition]() + + // determines the tradeoff between load-balancing the partitions sizes and their locality + // e.g. balanceSlack=0.10 means that it allows up to 10% imbalance in favor of locality + val slack = (balanceSlack * prev.partitions.size).toInt + + var noLocality = true // if true if no preferredLocations exists for parent RDD + + // gets the *current* preferred locations from the DAGScheduler (as opposed to the static ones) + def currPrefLocs(part: Partition): Seq[String] = { + prev.context.getPreferredLocs(prev, part.index).map(tl => tl.host) + } + + // this class just keeps iterating and rotating infinitely over the partitions of the RDD + // next() returns the next preferred machine that a partition is replicated on + // the rotator first goes through the first replica copy of each partition, then second, third + // the iterators return type is a tuple: (replicaString, partition) + class LocationIterator(prev: RDD[_]) extends Iterator[(String, Partition)] { + + var it: Iterator[(String, Partition)] = resetIterator() + + override val isEmpty = !it.hasNext + + // initializes/resets to start iterating from the beginning + def resetIterator() = { + val iterators = (0 to 2).map( x => + prev.partitions.iterator.flatMap(p => { + if (currPrefLocs(p).size > x) Some((currPrefLocs(p)(x), p)) else None + } ) + ) + iterators.reduceLeft((x, y) => x ++ y) + } + + // hasNext() is false iff there are no preferredLocations for any of the partitions of the RDD + def hasNext(): Boolean = { !isEmpty } + + // return the next preferredLocation of some partition of the RDD + def next(): (String, Partition) = { + if (it.hasNext) + it.next() + else { + it = resetIterator() // ran out of preferred locations, reset and rotate to the beginning + it.next() + } + } + } + + /** + * Sorts and gets the least element of the list associated with key in groupHash + * The returned PartitionGroup is the least loaded of all groups that represent the machine "key" + * @param key string representing a partitioned group on preferred machine key + * @return Option of PartitionGroup that has least elements for key + */ + def getLeastGroupHash(key: String): Option[PartitionGroup] = { + groupHash.get(key).map(_.sortWith(compare).head) + } + + def addPartToPGroup(part: Partition, pgroup: PartitionGroup): Boolean = { + if (!initialHash.contains(part)) { + pgroup.arr += part // already assign this element + initialHash += part // needed to avoid assigning partitions to multiple buckets + true + } else { false } + } + + /** + * Initializes targetLen partition groups and assigns a preferredLocation + * This uses coupon collector to estimate how many preferredLocations it must rotate through + * until it has seen most of the preferred locations (2 * n log(n)) + * @param targetLen + */ + def setupGroups(targetLen: Int) { + val rotIt = new LocationIterator(prev) + + // deal with empty case, just create targetLen partition groups with no preferred location + if (!rotIt.hasNext()) { + (1 to targetLen).foreach(x => groupArr += PartitionGroup()) + return + } + + noLocality = false + + // number of iterations needed to be certain that we've seen most preferred locations + val expectedCoupons2 = 2 * (math.log(targetLen)*targetLen + targetLen + 0.5).toInt + var numCreated = 0 + var tries = 0 + + // rotate through until either targetLen unique/distinct preferred locations have been created + // OR we've rotated expectedCoupons2, in which case we have likely seen all preferred locations, + // i.e. likely targetLen >> number of preferred locations (more buckets than there are machines) + while (numCreated < targetLen && tries < expectedCoupons2) { + tries += 1 + val (nxt_replica, nxt_part) = rotIt.next() + if (!groupHash.contains(nxt_replica)) { + val pgroup = PartitionGroup(nxt_replica) + groupArr += pgroup + addPartToPGroup(nxt_part, pgroup) + groupHash += (nxt_replica -> (ArrayBuffer(pgroup))) // list in case we have multiple + numCreated += 1 + } + } + + while (numCreated < targetLen) { // if we don't have enough partition groups, create duplicates + var (nxt_replica, nxt_part) = rotIt.next() + val pgroup = PartitionGroup(nxt_replica) + groupArr += pgroup + groupHash.get(nxt_replica).get += pgroup + var tries = 0 + while (!addPartToPGroup(nxt_part, pgroup) && tries < targetLen) { // ensure at least one part + nxt_part = rotIt.next()._2 + tries += 1 + } + numCreated += 1 + } + + } + + /** + * Takes a parent RDD partition and decides which of the partition groups to put it in + * Takes locality into account, but also uses power of 2 choices to load balance + * It strikes a balance between the two use the balanceSlack variable + * @param p partition (ball to be thrown) + * @return partition group (bin to be put in) + */ + def pickBin(p: Partition): PartitionGroup = { + val pref = currPrefLocs(p).map(getLeastGroupHash(_)).sortWith(compare) // least loaded pref locs + val prefPart = if (pref == Nil) None else pref.head + + val r1 = rnd.nextInt(groupArr.size) + val r2 = rnd.nextInt(groupArr.size) + val minPowerOfTwo = if (groupArr(r1).size < groupArr(r2).size) groupArr(r1) else groupArr(r2) + if (prefPart== None) // if no preferred locations, just use basic power of two + return minPowerOfTwo + + val prefPartActual = prefPart.get + + if (minPowerOfTwo.size + slack <= prefPartActual.size) // more imbalance than the slack allows + return minPowerOfTwo // prefer balance over locality + else { + return prefPartActual // prefer locality over balance + } + } + + def throwBalls() { + if (noLocality) { // no preferredLocations in parent RDD, no randomization needed + if (maxPartitions > groupArr.size) { // just return prev.partitions + for ((p,i) <- prev.partitions.zipWithIndex) { + groupArr(i).arr += p + } + } else { // no locality available, then simply split partitions based on positions in array + for(i <- 0 until maxPartitions) { + val rangeStart = ((i.toLong * prev.partitions.length) / maxPartitions).toInt + val rangeEnd = (((i.toLong + 1) * prev.partitions.length) / maxPartitions).toInt + (rangeStart until rangeEnd).foreach{ j => groupArr(i).arr += prev.partitions(j) } + } + } + } else { + for (p <- prev.partitions if (!initialHash.contains(p))) { // throw every partition into group + pickBin(p).arr += p + } + } + } + + def getPartitions: Array[PartitionGroup] = groupArr.filter( pg => pg.size > 0).toArray + + /** + * Runs the packing algorithm and returns an array of PartitionGroups that if possible are + * load balanced and grouped by locality + * @return array of partition groups + */ + def run(): Array[PartitionGroup] = { + setupGroups(math.min(prev.partitions.length, maxPartitions)) // setup the groups (bins) + throwBalls() // assign partitions (balls) to each group (bins) + getPartitions + } +} + +private[spark] case class PartitionGroup(prefLoc: String = "") { + var arr = mutable.ArrayBuffer[Partition]() + + def size = arr.size +} diff --git a/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala new file mode 100644 index 0000000000..a4bec41752 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rdd + +import org.apache.spark.partial.BoundedDouble +import org.apache.spark.partial.MeanEvaluator +import org.apache.spark.partial.PartialResult +import org.apache.spark.partial.SumEvaluator +import org.apache.spark.util.StatCounter +import org.apache.spark.{TaskContext, Logging} + +/** + * Extra functions available on RDDs of Doubles through an implicit conversion. + * Import `org.apache.spark.SparkContext._` at the top of your program to use these functions. + */ +class DoubleRDDFunctions(self: RDD[Double]) extends Logging with Serializable { + /** Add up the elements in this RDD. */ + def sum(): Double = { + self.reduce(_ + _) + } + + /** + * Return a [[org.apache.spark.util.StatCounter]] object that captures the mean, variance and count + * of the RDD's elements in one operation. + */ + def stats(): StatCounter = { + self.mapPartitions(nums => Iterator(StatCounter(nums))).reduce((a, b) => a.merge(b)) + } + + /** Compute the mean of this RDD's elements. */ + def mean(): Double = stats().mean + + /** Compute the variance of this RDD's elements. */ + def variance(): Double = stats().variance + + /** Compute the standard deviation of this RDD's elements. */ + def stdev(): Double = stats().stdev + + /** + * Compute the sample standard deviation of this RDD's elements (which corrects for bias in + * estimating the standard deviation by dividing by N-1 instead of N). + */ + def sampleStdev(): Double = stats().sampleStdev + + /** + * Compute the sample variance of this RDD's elements (which corrects for bias in + * estimating the variance by dividing by N-1 instead of N). + */ + def sampleVariance(): Double = stats().sampleVariance + + /** (Experimental) Approximate operation to return the mean within a timeout. */ + def meanApprox(timeout: Long, confidence: Double = 0.95): PartialResult[BoundedDouble] = { + val processPartition = (ctx: TaskContext, ns: Iterator[Double]) => StatCounter(ns) + val evaluator = new MeanEvaluator(self.partitions.size, confidence) + self.context.runApproximateJob(self, processPartition, evaluator, timeout) + } + + /** (Experimental) Approximate operation to return the sum within a timeout. */ + def sumApprox(timeout: Long, confidence: Double = 0.95): PartialResult[BoundedDouble] = { + val processPartition = (ctx: TaskContext, ns: Iterator[Double]) => StatCounter(ns) + val evaluator = new SumEvaluator(self.partitions.size, confidence) + self.context.runApproximateJob(self, processPartition, evaluator, timeout) + } +} diff --git a/core/src/main/scala/org/apache/spark/rdd/EmptyRDD.scala b/core/src/main/scala/org/apache/spark/rdd/EmptyRDD.scala new file mode 100644 index 0000000000..c8900d1a93 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rdd/EmptyRDD.scala @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rdd + +import org.apache.spark.{SparkContext, SparkEnv, Partition, TaskContext} + + +/** + * An RDD that is empty, i.e. has no element in it. + */ +class EmptyRDD[T: ClassManifest](sc: SparkContext) extends RDD[T](sc, Nil) { + + override def getPartitions: Array[Partition] = Array.empty + + override def compute(split: Partition, context: TaskContext): Iterator[T] = { + throw new UnsupportedOperationException("empty RDD") + } +} diff --git a/core/src/main/scala/org/apache/spark/rdd/FilteredRDD.scala b/core/src/main/scala/org/apache/spark/rdd/FilteredRDD.scala new file mode 100644 index 0000000000..5312dc0b59 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rdd/FilteredRDD.scala @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rdd + +import org.apache.spark.{OneToOneDependency, Partition, TaskContext} + +private[spark] class FilteredRDD[T: ClassManifest]( + prev: RDD[T], + f: T => Boolean) + extends RDD[T](prev) { + + override def getPartitions: Array[Partition] = firstParent[T].partitions + + override val partitioner = prev.partitioner // Since filter cannot change a partition's keys + + override def compute(split: Partition, context: TaskContext) = + firstParent[T].iterator(split, context).filter(f) +} diff --git a/core/src/main/scala/org/apache/spark/rdd/FlatMappedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/FlatMappedRDD.scala new file mode 100644 index 0000000000..cbdf6d84c0 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rdd/FlatMappedRDD.scala @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rdd + +import org.apache.spark.{Partition, TaskContext} + + +private[spark] +class FlatMappedRDD[U: ClassManifest, T: ClassManifest]( + prev: RDD[T], + f: T => TraversableOnce[U]) + extends RDD[U](prev) { + + override def getPartitions: Array[Partition] = firstParent[T].partitions + + override def compute(split: Partition, context: TaskContext) = + firstParent[T].iterator(split, context).flatMap(f) +} diff --git a/core/src/main/scala/org/apache/spark/rdd/FlatMappedValuesRDD.scala b/core/src/main/scala/org/apache/spark/rdd/FlatMappedValuesRDD.scala new file mode 100644 index 0000000000..82000bac09 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rdd/FlatMappedValuesRDD.scala @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rdd + +import org.apache.spark.{TaskContext, Partition} + + +private[spark] +class FlatMappedValuesRDD[K, V, U](prev: RDD[_ <: Product2[K, V]], f: V => TraversableOnce[U]) + extends RDD[(K, U)](prev) { + + override def getPartitions = firstParent[Product2[K, V]].partitions + + override val partitioner = firstParent[Product2[K, V]].partitioner + + override def compute(split: Partition, context: TaskContext) = { + firstParent[Product2[K, V]].iterator(split, context).flatMap { case Product2(k, v) => + f(v).map(x => (k, x)) + } + } +} diff --git a/core/src/main/scala/org/apache/spark/rdd/GlommedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/GlommedRDD.scala new file mode 100644 index 0000000000..829545d7b0 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rdd/GlommedRDD.scala @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rdd + +import org.apache.spark.{Partition, TaskContext} + +private[spark] class GlommedRDD[T: ClassManifest](prev: RDD[T]) + extends RDD[Array[T]](prev) { + + override def getPartitions: Array[Partition] = firstParent[T].partitions + + override def compute(split: Partition, context: TaskContext) = + Array(firstParent[T].iterator(split, context).toArray).iterator +} diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala new file mode 100644 index 0000000000..2cb6734e41 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala @@ -0,0 +1,131 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rdd + +import java.io.EOFException + +import org.apache.hadoop.mapred.InputFormat +import org.apache.hadoop.mapred.InputSplit +import org.apache.hadoop.mapred.JobConf +import org.apache.hadoop.mapred.RecordReader +import org.apache.hadoop.mapred.Reporter +import org.apache.hadoop.util.ReflectionUtils + +import org.apache.spark.{Logging, Partition, SerializableWritable, SparkContext, SparkEnv, TaskContext} +import org.apache.spark.util.NextIterator +import org.apache.hadoop.conf.{Configuration, Configurable} + + +/** + * A Spark split class that wraps around a Hadoop InputSplit. + */ +private[spark] class HadoopPartition(rddId: Int, idx: Int, @transient s: InputSplit) + extends Partition { + + val inputSplit = new SerializableWritable[InputSplit](s) + + override def hashCode(): Int = (41 * (41 + rddId) + idx).toInt + + override val index: Int = idx +} + +/** + * An RDD that reads a Hadoop dataset as specified by a JobConf (e.g. files in HDFS, the local file + * system, or S3, tables in HBase, etc). + */ +class HadoopRDD[K, V]( + sc: SparkContext, + @transient conf: JobConf, + inputFormatClass: Class[_ <: InputFormat[K, V]], + keyClass: Class[K], + valueClass: Class[V], + minSplits: Int) + extends RDD[(K, V)](sc, Nil) with Logging { + + // A Hadoop JobConf can be about 10 KB, which is pretty big, so broadcast it + private val confBroadcast = sc.broadcast(new SerializableWritable(conf)) + + override def getPartitions: Array[Partition] = { + val env = SparkEnv.get + env.hadoop.addCredentials(conf) + val inputFormat = createInputFormat(conf) + if (inputFormat.isInstanceOf[Configurable]) { + inputFormat.asInstanceOf[Configurable].setConf(conf) + } + val inputSplits = inputFormat.getSplits(conf, minSplits) + val array = new Array[Partition](inputSplits.size) + for (i <- 0 until inputSplits.size) { + array(i) = new HadoopPartition(id, i, inputSplits(i)) + } + array + } + + def createInputFormat(conf: JobConf): InputFormat[K, V] = { + ReflectionUtils.newInstance(inputFormatClass.asInstanceOf[Class[_]], conf) + .asInstanceOf[InputFormat[K, V]] + } + + override def compute(theSplit: Partition, context: TaskContext) = new NextIterator[(K, V)] { + val split = theSplit.asInstanceOf[HadoopPartition] + logInfo("Input split: " + split.inputSplit) + var reader: RecordReader[K, V] = null + + val conf = confBroadcast.value.value + val fmt = createInputFormat(conf) + if (fmt.isInstanceOf[Configurable]) { + fmt.asInstanceOf[Configurable].setConf(conf) + } + reader = fmt.getRecordReader(split.inputSplit.value, conf, Reporter.NULL) + + // Register an on-task-completion callback to close the input stream. + context.addOnCompleteCallback{ () => closeIfNeeded() } + + val key: K = reader.createKey() + val value: V = reader.createValue() + + override def getNext() = { + try { + finished = !reader.next(key, value) + } catch { + case eof: EOFException => + finished = true + } + (key, value) + } + + override def close() { + try { + reader.close() + } catch { + case e: Exception => logWarning("Exception in RecordReader.close()", e) + } + } + } + + override def getPreferredLocations(split: Partition): Seq[String] = { + // TODO: Filtering out "localhost" in case of file:// URLs + val hadoopSplit = split.asInstanceOf[HadoopPartition] + hadoopSplit.inputSplit.value.getLocations.filter(_ != "localhost") + } + + override def checkpoint() { + // Do nothing. Hadoop RDD should not be checkpointed. + } + + def getConf: Configuration = confBroadcast.value.value +} diff --git a/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala b/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala new file mode 100644 index 0000000000..e72f86fb13 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala @@ -0,0 +1,121 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rdd + +import java.sql.{Connection, ResultSet} + +import scala.reflect.ClassTag +import org.apache.spark.{Logging, Partition, SparkContext, TaskContext} +import org.apache.spark.util.NextIterator + +private[spark] class JdbcPartition(idx: Int, val lower: Long, val upper: Long) extends Partition { + override def index = idx +} + +/** + * An RDD that executes an SQL query on a JDBC connection and reads results. + * For usage example, see test case JdbcRDDSuite. + * + * @param getConnection a function that returns an open Connection. + * The RDD takes care of closing the connection. + * @param sql the text of the query. + * The query must contain two ? placeholders for parameters used to partition the results. + * E.g. "select title, author from books where ? <= id and id <= ?" + * @param lowerBound the minimum value of the first placeholder + * @param upperBound the maximum value of the second placeholder + * The lower and upper bounds are inclusive. + * @param numPartitions the number of partitions. + * Given a lowerBound of 1, an upperBound of 20, and a numPartitions of 2, + * the query would be executed twice, once with (1, 10) and once with (11, 20) + * @param mapRow a function from a ResultSet to a single row of the desired result type(s). + * This should only call getInt, getString, etc; the RDD takes care of calling next. + * The default maps a ResultSet to an array of Object. + */ +class JdbcRDD[T: ClassTag]( + sc: SparkContext, + getConnection: () => Connection, + sql: String, + lowerBound: Long, + upperBound: Long, + numPartitions: Int, + mapRow: (ResultSet) => T = JdbcRDD.resultSetToObjectArray _) + extends RDD[T](sc, Nil) with Logging { + + override def getPartitions: Array[Partition] = { + // bounds are inclusive, hence the + 1 here and - 1 on end + val length = 1 + upperBound - lowerBound + (0 until numPartitions).map(i => { + val start = lowerBound + ((i * length) / numPartitions).toLong + val end = lowerBound + (((i + 1) * length) / numPartitions).toLong - 1 + new JdbcPartition(i, start, end) + }).toArray + } + + override def compute(thePart: Partition, context: TaskContext) = new NextIterator[T] { + context.addOnCompleteCallback{ () => closeIfNeeded() } + val part = thePart.asInstanceOf[JdbcPartition] + val conn = getConnection() + val stmt = conn.prepareStatement(sql, ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY) + + // setFetchSize(Integer.MIN_VALUE) is a mysql driver specific way to force streaming results, + // rather than pulling entire resultset into memory. + // see http://dev.mysql.com/doc/refman/5.0/en/connector-j-reference-implementation-notes.html + if (conn.getMetaData.getURL.matches("jdbc:mysql:.*")) { + stmt.setFetchSize(Integer.MIN_VALUE) + logInfo("statement fetch size set to: " + stmt.getFetchSize + " to force MySQL streaming ") + } + + stmt.setLong(1, part.lower) + stmt.setLong(2, part.upper) + val rs = stmt.executeQuery() + + override def getNext: T = { + if (rs.next()) { + mapRow(rs) + } else { + finished = true + null.asInstanceOf[T] + } + } + + override def close() { + try { + if (null != rs && ! rs.isClosed()) rs.close() + } catch { + case e: Exception => logWarning("Exception closing resultset", e) + } + try { + if (null != stmt && ! stmt.isClosed()) stmt.close() + } catch { + case e: Exception => logWarning("Exception closing statement", e) + } + try { + if (null != conn && ! stmt.isClosed()) conn.close() + logInfo("closed connection") + } catch { + case e: Exception => logWarning("Exception closing connection", e) + } + } + } +} + +object JdbcRDD { + def resultSetToObjectArray(rs: ResultSet) = { + Array.tabulate[Object](rs.getMetaData.getColumnCount)(i => rs.getObject(i + 1)) + } +} diff --git a/core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala b/core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala new file mode 100644 index 0000000000..203179c4ea --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rdd + +import org.apache.spark.{Partition, TaskContext} + + +private[spark] +class MapPartitionsRDD[U: ClassManifest, T: ClassManifest]( + prev: RDD[T], + f: Iterator[T] => Iterator[U], + preservesPartitioning: Boolean = false) + extends RDD[U](prev) { + + override val partitioner = + if (preservesPartitioning) firstParent[T].partitioner else None + + override def getPartitions: Array[Partition] = firstParent[T].partitions + + override def compute(split: Partition, context: TaskContext) = + f(firstParent[T].iterator(split, context)) +} diff --git a/core/src/main/scala/org/apache/spark/rdd/MapPartitionsWithIndexRDD.scala b/core/src/main/scala/org/apache/spark/rdd/MapPartitionsWithIndexRDD.scala new file mode 100644 index 0000000000..3ed8339010 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rdd/MapPartitionsWithIndexRDD.scala @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rdd + +import org.apache.spark.{Partition, TaskContext} + + +/** + * A variant of the MapPartitionsRDD that passes the partition index into the + * closure. This can be used to generate or collect partition specific + * information such as the number of tuples in a partition. + */ +private[spark] +class MapPartitionsWithIndexRDD[U: ClassManifest, T: ClassManifest]( + prev: RDD[T], + f: (Int, Iterator[T]) => Iterator[U], + preservesPartitioning: Boolean + ) extends RDD[U](prev) { + + override def getPartitions: Array[Partition] = firstParent[T].partitions + + override val partitioner = if (preservesPartitioning) prev.partitioner else None + + override def compute(split: Partition, context: TaskContext) = + f(split.index, firstParent[T].iterator(split, context)) +} diff --git a/core/src/main/scala/org/apache/spark/rdd/MappedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/MappedRDD.scala new file mode 100644 index 0000000000..e8be1c4816 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rdd/MappedRDD.scala @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rdd + +import org.apache.spark.{Partition, TaskContext} + +private[spark] +class MappedRDD[U: ClassManifest, T: ClassManifest](prev: RDD[T], f: T => U) + extends RDD[U](prev) { + + override def getPartitions: Array[Partition] = firstParent[T].partitions + + override def compute(split: Partition, context: TaskContext) = + firstParent[T].iterator(split, context).map(f) +} diff --git a/core/src/main/scala/org/apache/spark/rdd/MappedValuesRDD.scala b/core/src/main/scala/org/apache/spark/rdd/MappedValuesRDD.scala new file mode 100644 index 0000000000..d33c1af581 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rdd/MappedValuesRDD.scala @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rdd + + +import org.apache.spark.{TaskContext, Partition} + +private[spark] +class MappedValuesRDD[K, V, U](prev: RDD[_ <: Product2[K, V]], f: V => U) + extends RDD[(K, U)](prev) { + + override def getPartitions = firstParent[Product2[K, U]].partitions + + override val partitioner = firstParent[Product2[K, U]].partitioner + + override def compute(split: Partition, context: TaskContext): Iterator[(K, U)] = { + firstParent[Product2[K, V]].iterator(split, context).map { case Product2(k ,v) => (k, f(v)) } + } +} diff --git a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala new file mode 100644 index 0000000000..7b3a89f7e0 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rdd + +import java.text.SimpleDateFormat +import java.util.Date + +import org.apache.hadoop.conf.{Configurable, Configuration} +import org.apache.hadoop.io.Writable +import org.apache.hadoop.mapreduce._ + +import org.apache.spark.{Dependency, Logging, Partition, SerializableWritable, SparkContext, TaskContext} + + +private[spark] +class NewHadoopPartition(rddId: Int, val index: Int, @transient rawSplit: InputSplit with Writable) + extends Partition { + + val serializableHadoopSplit = new SerializableWritable(rawSplit) + + override def hashCode(): Int = (41 * (41 + rddId) + index) +} + +class NewHadoopRDD[K, V]( + sc : SparkContext, + inputFormatClass: Class[_ <: InputFormat[K, V]], + keyClass: Class[K], + valueClass: Class[V], + @transient conf: Configuration) + extends RDD[(K, V)](sc, Nil) + with SparkHadoopMapReduceUtil + with Logging { + + // A Hadoop Configuration can be about 10 KB, which is pretty big, so broadcast it + private val confBroadcast = sc.broadcast(new SerializableWritable(conf)) + // private val serializableConf = new SerializableWritable(conf) + + private val jobtrackerId: String = { + val formatter = new SimpleDateFormat("yyyyMMddHHmm") + formatter.format(new Date()) + } + + @transient private val jobId = new JobID(jobtrackerId, id) + + override def getPartitions: Array[Partition] = { + val inputFormat = inputFormatClass.newInstance + if (inputFormat.isInstanceOf[Configurable]) { + inputFormat.asInstanceOf[Configurable].setConf(conf) + } + val jobContext = newJobContext(conf, jobId) + val rawSplits = inputFormat.getSplits(jobContext).toArray + val result = new Array[Partition](rawSplits.size) + for (i <- 0 until rawSplits.size) { + result(i) = new NewHadoopPartition(id, i, rawSplits(i).asInstanceOf[InputSplit with Writable]) + } + result + } + + override def compute(theSplit: Partition, context: TaskContext) = new Iterator[(K, V)] { + val split = theSplit.asInstanceOf[NewHadoopPartition] + logInfo("Input split: " + split.serializableHadoopSplit) + val conf = confBroadcast.value.value + val attemptId = newTaskAttemptID(jobtrackerId, id, true, split.index, 0) + val hadoopAttemptContext = newTaskAttemptContext(conf, attemptId) + val format = inputFormatClass.newInstance + if (format.isInstanceOf[Configurable]) { + format.asInstanceOf[Configurable].setConf(conf) + } + val reader = format.createRecordReader( + split.serializableHadoopSplit.value, hadoopAttemptContext) + reader.initialize(split.serializableHadoopSplit.value, hadoopAttemptContext) + + // Register an on-task-completion callback to close the input stream. + context.addOnCompleteCallback(() => close()) + + var havePair = false + var finished = false + + override def hasNext: Boolean = { + if (!finished && !havePair) { + finished = !reader.nextKeyValue + havePair = !finished + } + !finished + } + + override def next: (K, V) = { + if (!hasNext) { + throw new java.util.NoSuchElementException("End of stream") + } + havePair = false + return (reader.getCurrentKey, reader.getCurrentValue) + } + + private def close() { + try { + reader.close() + } catch { + case e: Exception => logWarning("Exception in RecordReader.close()", e) + } + } + } + + override def getPreferredLocations(split: Partition): Seq[String] = { + val theSplit = split.asInstanceOf[NewHadoopPartition] + theSplit.serializableHadoopSplit.value.getLocations.filter(_ != "localhost") + } + + def getConf: Configuration = confBroadcast.value.value +} + diff --git a/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala new file mode 100644 index 0000000000..697be8b997 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rdd + +import org.apache.spark.{RangePartitioner, Logging} + +/** + * Extra functions available on RDDs of (key, value) pairs where the key is sortable through + * an implicit conversion. Import `org.apache.spark.SparkContext._` at the top of your program to + * use these functions. They will work with any key type that has a `scala.math.Ordered` + * implementation. + */ +class OrderedRDDFunctions[K <% Ordered[K]: ClassManifest, + V: ClassManifest, + P <: Product2[K, V] : ClassManifest]( + self: RDD[P]) + extends Logging with Serializable { + + /** + * Sort the RDD by key, so that each partition contains a sorted range of the elements. Calling + * `collect` or `save` on the resulting RDD will return or output an ordered list of records + * (in the `save` case, they will be written to multiple `part-X` files in the filesystem, in + * order of the keys). + */ + def sortByKey(ascending: Boolean = true, numPartitions: Int = self.partitions.size): RDD[P] = { + val part = new RangePartitioner(numPartitions, self, ascending) + val shuffled = new ShuffledRDD[K, V, P](self, part) + shuffled.mapPartitions(iter => { + val buf = iter.toArray + if (ascending) { + buf.sortWith((x, y) => x._1 < y._1).iterator + } else { + buf.sortWith((x, y) => x._1 > y._1).iterator + } + }, preservesPartitioning = true) + } +} diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala new file mode 100644 index 0000000000..aed585e6a1 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala @@ -0,0 +1,703 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rdd + +import java.nio.ByteBuffer +import java.util.Date +import java.text.SimpleDateFormat +import java.util.{HashMap => JHashMap} + +import scala.collection.{mutable, Map} +import scala.collection.mutable.ArrayBuffer +import scala.collection.JavaConversions._ +import scala.reflect.{ ClassTag, classTag} + +import org.apache.hadoop.mapred._ +import org.apache.hadoop.io.compress.CompressionCodec +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.Path +import org.apache.hadoop.io.SequenceFile.CompressionType +import org.apache.hadoop.mapred.FileOutputFormat +import org.apache.hadoop.mapred.OutputFormat +import org.apache.hadoop.mapreduce.{OutputFormat => NewOutputFormat} +import org.apache.hadoop.mapreduce.lib.output.{FileOutputFormat => NewFileOutputFormat} +import org.apache.hadoop.mapreduce.SparkHadoopMapReduceUtil +import org.apache.hadoop.mapreduce.{Job => NewAPIHadoopJob} +import org.apache.hadoop.mapreduce.{RecordWriter => NewRecordWriter} + +import org.apache.spark._ +import org.apache.spark.SparkContext._ +import org.apache.spark.partial.{BoundedDouble, PartialResult} +import org.apache.spark.Aggregator +import org.apache.spark.Partitioner +import org.apache.spark.Partitioner.defaultPartitioner + +/** + * Extra functions available on RDDs of (key, value) pairs through an implicit conversion. + * Import `org.apache.spark.SparkContext._` at the top of your program to use these functions. + */ +class PairRDDFunctions[K: ClassManifest, V: ClassManifest](self: RDD[(K, V)]) + extends Logging + with SparkHadoopMapReduceUtil + with Serializable { + + /** + * Generic function to combine the elements for each key using a custom set of aggregation + * functions. Turns an RDD[(K, V)] into a result of type RDD[(K, C)], for a "combined type" C + * Note that V and C can be different -- for example, one might group an RDD of type + * (Int, Int) into an RDD of type (Int, Seq[Int]). Users provide three functions: + * + * - `createCombiner`, which turns a V into a C (e.g., creates a one-element list) + * - `mergeValue`, to merge a V into a C (e.g., adds it to the end of a list) + * - `mergeCombiners`, to combine two C's into a single one. + * + * In addition, users can control the partitioning of the output RDD, and whether to perform + * map-side aggregation (if a mapper can produce multiple items with the same key). + */ + def combineByKey[C](createCombiner: V => C, + mergeValue: (C, V) => C, + mergeCombiners: (C, C) => C, + partitioner: Partitioner, + mapSideCombine: Boolean = true, + serializerClass: String = null): RDD[(K, C)] = { + if (getKeyClass().isArray) { + if (mapSideCombine) { + throw new SparkException("Cannot use map-side combining with array keys.") + } + if (partitioner.isInstanceOf[HashPartitioner]) { + throw new SparkException("Default partitioner cannot partition array keys.") + } + } + val aggregator = new Aggregator[K, V, C](createCombiner, mergeValue, mergeCombiners) + if (self.partitioner == Some(partitioner)) { + self.mapPartitions(aggregator.combineValuesByKey, preservesPartitioning = true) + } else if (mapSideCombine) { + val combined = self.mapPartitions(aggregator.combineValuesByKey, preservesPartitioning = true) + val partitioned = new ShuffledRDD[K, C, (K, C)](combined, partitioner) + .setSerializer(serializerClass) + partitioned.mapPartitions(aggregator.combineCombinersByKey, preservesPartitioning = true) + } else { + // Don't apply map-side combiner. + // A sanity check to make sure mergeCombiners is not defined. + assert(mergeCombiners == null) + val values = new ShuffledRDD[K, V, (K, V)](self, partitioner).setSerializer(serializerClass) + values.mapPartitions(aggregator.combineValuesByKey, preservesPartitioning = true) + } + } + + /** + * Simplified version of combineByKey that hash-partitions the output RDD. + */ + def combineByKey[C](createCombiner: V => C, + mergeValue: (C, V) => C, + mergeCombiners: (C, C) => C, + numPartitions: Int): RDD[(K, C)] = { + combineByKey(createCombiner, mergeValue, mergeCombiners, new HashPartitioner(numPartitions)) + } + + /** + * Merge the values for each key using an associative function and a neutral "zero value" which may + * be added to the result an arbitrary number of times, and must not change the result (e.g., Nil for + * list concatenation, 0 for addition, or 1 for multiplication.). + */ + def foldByKey(zeroValue: V, partitioner: Partitioner)(func: (V, V) => V): RDD[(K, V)] = { + // Serialize the zero value to a byte array so that we can get a new clone of it on each key + val zeroBuffer = SparkEnv.get.closureSerializer.newInstance().serialize(zeroValue) + val zeroArray = new Array[Byte](zeroBuffer.limit) + zeroBuffer.get(zeroArray) + + // When deserializing, use a lazy val to create just one instance of the serializer per task + lazy val cachedSerializer = SparkEnv.get.closureSerializer.newInstance() + def createZero() = cachedSerializer.deserialize[V](ByteBuffer.wrap(zeroArray)) + + combineByKey[V]((v: V) => func(createZero(), v), func, func, partitioner) + } + + /** + * Merge the values for each key using an associative function and a neutral "zero value" which may + * be added to the result an arbitrary number of times, and must not change the result (e.g., Nil for + * list concatenation, 0 for addition, or 1 for multiplication.). + */ + def foldByKey(zeroValue: V, numPartitions: Int)(func: (V, V) => V): RDD[(K, V)] = { + foldByKey(zeroValue, new HashPartitioner(numPartitions))(func) + } + + /** + * Merge the values for each key using an associative function and a neutral "zero value" which may + * be added to the result an arbitrary number of times, and must not change the result (e.g., Nil for + * list concatenation, 0 for addition, or 1 for multiplication.). + */ + def foldByKey(zeroValue: V)(func: (V, V) => V): RDD[(K, V)] = { + foldByKey(zeroValue, defaultPartitioner(self))(func) + } + + /** + * Merge the values for each key using an associative reduce function. This will also perform + * the merging locally on each mapper before sending results to a reducer, similarly to a + * "combiner" in MapReduce. + */ + def reduceByKey(partitioner: Partitioner, func: (V, V) => V): RDD[(K, V)] = { + combineByKey[V]((v: V) => v, func, func, partitioner) + } + + /** + * Merge the values for each key using an associative reduce function, but return the results + * immediately to the master as a Map. This will also perform the merging locally on each mapper + * before sending results to a reducer, similarly to a "combiner" in MapReduce. + */ + def reduceByKeyLocally(func: (V, V) => V): Map[K, V] = { + + if (getKeyClass().isArray) { + throw new SparkException("reduceByKeyLocally() does not support array keys") + } + + def reducePartition(iter: Iterator[(K, V)]): Iterator[JHashMap[K, V]] = { + val map = new JHashMap[K, V] + iter.foreach { case (k, v) => + val old = map.get(k) + map.put(k, if (old == null) v else func(old, v)) + } + Iterator(map) + } + + def mergeMaps(m1: JHashMap[K, V], m2: JHashMap[K, V]): JHashMap[K, V] = { + m2.foreach { case (k, v) => + val old = m1.get(k) + m1.put(k, if (old == null) v else func(old, v)) + } + m1 + } + + self.mapPartitions(reducePartition).reduce(mergeMaps) + } + + /** Alias for reduceByKeyLocally */ + def reduceByKeyToDriver(func: (V, V) => V): Map[K, V] = reduceByKeyLocally(func) + + /** Count the number of elements for each key, and return the result to the master as a Map. */ + def countByKey(): Map[K, Long] = self.map(_._1).countByValue() + + /** + * (Experimental) Approximate version of countByKey that can return a partial result if it does + * not finish within a timeout. + */ + def countByKeyApprox(timeout: Long, confidence: Double = 0.95) + : PartialResult[Map[K, BoundedDouble]] = { + self.map(_._1).countByValueApprox(timeout, confidence) + } + + /** + * Merge the values for each key using an associative reduce function. This will also perform + * the merging locally on each mapper before sending results to a reducer, similarly to a + * "combiner" in MapReduce. Output will be hash-partitioned with numPartitions partitions. + */ + def reduceByKey(func: (V, V) => V, numPartitions: Int): RDD[(K, V)] = { + reduceByKey(new HashPartitioner(numPartitions), func) + } + + /** + * Group the values for each key in the RDD into a single sequence. Allows controlling the + * partitioning of the resulting key-value pair RDD by passing a Partitioner. + */ + def groupByKey(partitioner: Partitioner): RDD[(K, Seq[V])] = { + // groupByKey shouldn't use map side combine because map side combine does not + // reduce the amount of data shuffled and requires all map side data be inserted + // into a hash table, leading to more objects in the old gen. + def createCombiner(v: V) = ArrayBuffer(v) + def mergeValue(buf: ArrayBuffer[V], v: V) = buf += v + val bufs = combineByKey[ArrayBuffer[V]]( + createCombiner _, mergeValue _, null, partitioner, mapSideCombine=false) + bufs.asInstanceOf[RDD[(K, Seq[V])]] + } + + /** + * Group the values for each key in the RDD into a single sequence. Hash-partitions the + * resulting RDD with into `numPartitions` partitions. + */ + def groupByKey(numPartitions: Int): RDD[(K, Seq[V])] = { + groupByKey(new HashPartitioner(numPartitions)) + } + + /** + * Return a copy of the RDD partitioned using the specified partitioner. + */ + def partitionBy(partitioner: Partitioner): RDD[(K, V)] = { + if (getKeyClass().isArray && partitioner.isInstanceOf[HashPartitioner]) { + throw new SparkException("Default partitioner cannot partition array keys.") + } + new ShuffledRDD[K, V, (K, V)](self, partitioner) + } + + /** + * Return an RDD containing all pairs of elements with matching keys in `this` and `other`. Each + * pair of elements will be returned as a (k, (v1, v2)) tuple, where (k, v1) is in `this` and + * (k, v2) is in `other`. Uses the given Partitioner to partition the output RDD. + */ + def join[W](other: RDD[(K, W)], partitioner: Partitioner): RDD[(K, (V, W))] = { + this.cogroup(other, partitioner).flatMapValues { case (vs, ws) => + for (v <- vs.iterator; w <- ws.iterator) yield (v, w) + } + } + + /** + * Perform a left outer join of `this` and `other`. For each element (k, v) in `this`, the + * resulting RDD will either contain all pairs (k, (v, Some(w))) for w in `other`, or the + * pair (k, (v, None)) if no elements in `other` have key k. Uses the given Partitioner to + * partition the output RDD. + */ + def leftOuterJoin[W](other: RDD[(K, W)], partitioner: Partitioner): RDD[(K, (V, Option[W]))] = { + this.cogroup(other, partitioner).flatMapValues { case (vs, ws) => + if (ws.isEmpty) { + vs.iterator.map(v => (v, None)) + } else { + for (v <- vs.iterator; w <- ws.iterator) yield (v, Some(w)) + } + } + } + + /** + * Perform a right outer join of `this` and `other`. For each element (k, w) in `other`, the + * resulting RDD will either contain all pairs (k, (Some(v), w)) for v in `this`, or the + * pair (k, (None, w)) if no elements in `this` have key k. Uses the given Partitioner to + * partition the output RDD. + */ + def rightOuterJoin[W](other: RDD[(K, W)], partitioner: Partitioner) + : RDD[(K, (Option[V], W))] = { + this.cogroup(other, partitioner).flatMapValues { case (vs, ws) => + if (vs.isEmpty) { + ws.iterator.map(w => (None, w)) + } else { + for (v <- vs.iterator; w <- ws.iterator) yield (Some(v), w) + } + } + } + + /** + * Simplified version of combineByKey that hash-partitions the resulting RDD using the + * existing partitioner/parallelism level. + */ + def combineByKey[C](createCombiner: V => C, mergeValue: (C, V) => C, mergeCombiners: (C, C) => C) + : RDD[(K, C)] = { + combineByKey(createCombiner, mergeValue, mergeCombiners, defaultPartitioner(self)) + } + + /** + * Merge the values for each key using an associative reduce function. This will also perform + * the merging locally on each mapper before sending results to a reducer, similarly to a + * "combiner" in MapReduce. Output will be hash-partitioned with the existing partitioner/ + * parallelism level. + */ + def reduceByKey(func: (V, V) => V): RDD[(K, V)] = { + reduceByKey(defaultPartitioner(self), func) + } + + /** + * Group the values for each key in the RDD into a single sequence. Hash-partitions the + * resulting RDD with the existing partitioner/parallelism level. + */ + def groupByKey(): RDD[(K, Seq[V])] = { + groupByKey(defaultPartitioner(self)) + } + + /** + * Return an RDD containing all pairs of elements with matching keys in `this` and `other`. Each + * pair of elements will be returned as a (k, (v1, v2)) tuple, where (k, v1) is in `this` and + * (k, v2) is in `other`. Performs a hash join across the cluster. + */ + def join[W](other: RDD[(K, W)]): RDD[(K, (V, W))] = { + join(other, defaultPartitioner(self, other)) + } + + /** + * Return an RDD containing all pairs of elements with matching keys in `this` and `other`. Each + * pair of elements will be returned as a (k, (v1, v2)) tuple, where (k, v1) is in `this` and + * (k, v2) is in `other`. Performs a hash join across the cluster. + */ + def join[W](other: RDD[(K, W)], numPartitions: Int): RDD[(K, (V, W))] = { + join(other, new HashPartitioner(numPartitions)) + } + + /** + * Perform a left outer join of `this` and `other`. For each element (k, v) in `this`, the + * resulting RDD will either contain all pairs (k, (v, Some(w))) for w in `other`, or the + * pair (k, (v, None)) if no elements in `other` have key k. Hash-partitions the output + * using the existing partitioner/parallelism level. + */ + def leftOuterJoin[W](other: RDD[(K, W)]): RDD[(K, (V, Option[W]))] = { + leftOuterJoin(other, defaultPartitioner(self, other)) + } + + /** + * Perform a left outer join of `this` and `other`. For each element (k, v) in `this`, the + * resulting RDD will either contain all pairs (k, (v, Some(w))) for w in `other`, or the + * pair (k, (v, None)) if no elements in `other` have key k. Hash-partitions the output + * into `numPartitions` partitions. + */ + def leftOuterJoin[W](other: RDD[(K, W)], numPartitions: Int): RDD[(K, (V, Option[W]))] = { + leftOuterJoin(other, new HashPartitioner(numPartitions)) + } + + /** + * Perform a right outer join of `this` and `other`. For each element (k, w) in `other`, the + * resulting RDD will either contain all pairs (k, (Some(v), w)) for v in `this`, or the + * pair (k, (None, w)) if no elements in `this` have key k. Hash-partitions the resulting + * RDD using the existing partitioner/parallelism level. + */ + def rightOuterJoin[W](other: RDD[(K, W)]): RDD[(K, (Option[V], W))] = { + rightOuterJoin(other, defaultPartitioner(self, other)) + } + + /** + * Perform a right outer join of `this` and `other`. For each element (k, w) in `other`, the + * resulting RDD will either contain all pairs (k, (Some(v), w)) for v in `this`, or the + * pair (k, (None, w)) if no elements in `this` have key k. Hash-partitions the resulting + * RDD into the given number of partitions. + */ + def rightOuterJoin[W](other: RDD[(K, W)], numPartitions: Int): RDD[(K, (Option[V], W))] = { + rightOuterJoin(other, new HashPartitioner(numPartitions)) + } + + /** + * Return the key-value pairs in this RDD to the master as a Map. + */ + def collectAsMap(): Map[K, V] = { + val data = self.toArray() + val map = new mutable.HashMap[K, V] + map.sizeHint(data.length) + data.foreach { case (k, v) => map.put(k, v) } + map + } + + /** + * Pass each value in the key-value pair RDD through a map function without changing the keys; + * this also retains the original RDD's partitioning. + */ + def mapValues[U](f: V => U): RDD[(K, U)] = { + val cleanF = self.context.clean(f) + new MappedValuesRDD(self, cleanF) + } + + /** + * Pass each value in the key-value pair RDD through a flatMap function without changing the + * keys; this also retains the original RDD's partitioning. + */ + def flatMapValues[U](f: V => TraversableOnce[U]): RDD[(K, U)] = { + val cleanF = self.context.clean(f) + new FlatMappedValuesRDD(self, cleanF) + } + + /** + * For each key k in `this` or `other`, return a resulting RDD that contains a tuple with the + * list of values for that key in `this` as well as `other`. + */ + def cogroup[W](other: RDD[(K, W)], partitioner: Partitioner): RDD[(K, (Seq[V], Seq[W]))] = { + if (partitioner.isInstanceOf[HashPartitioner] && getKeyClass().isArray) { + throw new SparkException("Default partitioner cannot partition array keys.") + } + val cg = new CoGroupedRDD[K](Seq(self, other), partitioner) + val prfs = new PairRDDFunctions[K, Seq[Seq[_]]](cg)(classTag[K], ClassTags.seqSeqClassTag) + prfs.mapValues { case Seq(vs, ws) => + (vs.asInstanceOf[Seq[V]], ws.asInstanceOf[Seq[W]]) + } + } + + /** + * For each key k in `this` or `other1` or `other2`, return a resulting RDD that contains a + * tuple with the list of values for that key in `this`, `other1` and `other2`. + */ + def cogroup[W1, W2](other1: RDD[(K, W1)], other2: RDD[(K, W2)], partitioner: Partitioner) + : RDD[(K, (Seq[V], Seq[W1], Seq[W2]))] = { + if (partitioner.isInstanceOf[HashPartitioner] && getKeyClass().isArray) { + throw new SparkException("Default partitioner cannot partition array keys.") + } + val cg = new CoGroupedRDD[K](Seq(self, other1, other2), partitioner) + val prfs = new PairRDDFunctions[K, Seq[Seq[_]]](cg)(classTag[K], ClassTags.seqSeqClassTag) + prfs.mapValues { case Seq(vs, w1s, w2s) => + (vs.asInstanceOf[Seq[V]], w1s.asInstanceOf[Seq[W1]], w2s.asInstanceOf[Seq[W2]]) + } + } + + /** + * For each key k in `this` or `other`, return a resulting RDD that contains a tuple with the + * list of values for that key in `this` as well as `other`. + */ + def cogroup[W](other: RDD[(K, W)]): RDD[(K, (Seq[V], Seq[W]))] = { + cogroup(other, defaultPartitioner(self, other)) + } + + /** + * For each key k in `this` or `other1` or `other2`, return a resulting RDD that contains a + * tuple with the list of values for that key in `this`, `other1` and `other2`. + */ + def cogroup[W1, W2](other1: RDD[(K, W1)], other2: RDD[(K, W2)]) + : RDD[(K, (Seq[V], Seq[W1], Seq[W2]))] = { + cogroup(other1, other2, defaultPartitioner(self, other1, other2)) + } + + /** + * For each key k in `this` or `other`, return a resulting RDD that contains a tuple with the + * list of values for that key in `this` as well as `other`. + */ + def cogroup[W](other: RDD[(K, W)], numPartitions: Int): RDD[(K, (Seq[V], Seq[W]))] = { + cogroup(other, new HashPartitioner(numPartitions)) + } + + /** + * For each key k in `this` or `other1` or `other2`, return a resulting RDD that contains a + * tuple with the list of values for that key in `this`, `other1` and `other2`. + */ + def cogroup[W1, W2](other1: RDD[(K, W1)], other2: RDD[(K, W2)], numPartitions: Int) + : RDD[(K, (Seq[V], Seq[W1], Seq[W2]))] = { + cogroup(other1, other2, new HashPartitioner(numPartitions)) + } + + /** Alias for cogroup. */ + def groupWith[W](other: RDD[(K, W)]): RDD[(K, (Seq[V], Seq[W]))] = { + cogroup(other, defaultPartitioner(self, other)) + } + + /** Alias for cogroup. */ + def groupWith[W1, W2](other1: RDD[(K, W1)], other2: RDD[(K, W2)]) + : RDD[(K, (Seq[V], Seq[W1], Seq[W2]))] = { + cogroup(other1, other2, defaultPartitioner(self, other1, other2)) + } + + /** + * Return an RDD with the pairs from `this` whose keys are not in `other`. + * + * Uses `this` partitioner/partition size, because even if `other` is huge, the resulting + * RDD will be <= us. + */ + def subtractByKey[W: ClassTag](other: RDD[(K, W)]): RDD[(K, V)] = + subtractByKey(other, self.partitioner.getOrElse(new HashPartitioner(self.partitions.size))) + + /** Return an RDD with the pairs from `this` whose keys are not in `other`. */ + def subtractByKey[W: ClassTag](other: RDD[(K, W)], numPartitions: Int): RDD[(K, V)] = + subtractByKey(other, new HashPartitioner(numPartitions)) + + /** Return an RDD with the pairs from `this` whose keys are not in `other`. */ + def subtractByKey[W: ClassTag](other: RDD[(K, W)], p: Partitioner): RDD[(K, V)] = + new SubtractedRDD[K, V, W](self, other, p) + + /** + * Return the list of values in the RDD for key `key`. This operation is done efficiently if the + * RDD has a known partitioner by only searching the partition that the key maps to. + */ + def lookup(key: K): Seq[V] = { + self.partitioner match { + case Some(p) => + val index = p.getPartition(key) + def process(it: Iterator[(K, V)]): Seq[V] = { + val buf = new ArrayBuffer[V] + for ((k, v) <- it if k == key) { + buf += v + } + buf + } + val res = self.context.runJob(self, process _, Array(index), false) + res(0) + case None => + self.filter(_._1 == key).map(_._2).collect() + } + } + + /** + * Output the RDD to any Hadoop-supported file system, using a Hadoop `OutputFormat` class + * supporting the key and value types K and V in this RDD. + */ + def saveAsHadoopFile[F <: OutputFormat[K, V]](path: String)(implicit fm: ClassTag[F]) { + saveAsHadoopFile(path, getKeyClass, getValueClass, fm.runtimeClass.asInstanceOf[Class[F]]) + } + + /** + * Output the RDD to any Hadoop-supported file system, using a Hadoop `OutputFormat` class + * supporting the key and value types K and V in this RDD. Compress the result with the + * supplied codec. + */ + def saveAsHadoopFile[F <: OutputFormat[K, V]]( + path: String, codec: Class[_ <: CompressionCodec]) (implicit fm: ClassTag[F]) { + saveAsHadoopFile(path, getKeyClass, getValueClass, fm.runtimeClass.asInstanceOf[Class[F]], codec) + } + + /** + * Output the RDD to any Hadoop-supported file system, using a new Hadoop API `OutputFormat` + * (mapreduce.OutputFormat) object supporting the key and value types K and V in this RDD. + */ + def saveAsNewAPIHadoopFile[F <: NewOutputFormat[K, V]](path: String)(implicit fm: ClassTag[F]) { + saveAsNewAPIHadoopFile(path, getKeyClass, getValueClass, fm.runtimeClass.asInstanceOf[Class[F]]) + } + + /** + * Output the RDD to any Hadoop-supported file system, using a new Hadoop API `OutputFormat` + * (mapreduce.OutputFormat) object supporting the key and value types K and V in this RDD. + */ + def saveAsNewAPIHadoopFile( + path: String, + keyClass: Class[_], + valueClass: Class[_], + outputFormatClass: Class[_ <: NewOutputFormat[_, _]], + conf: Configuration = self.context.hadoopConfiguration) { + val job = new NewAPIHadoopJob(conf) + job.setOutputKeyClass(keyClass) + job.setOutputValueClass(valueClass) + val wrappedConf = new SerializableWritable(job.getConfiguration) + NewFileOutputFormat.setOutputPath(job, new Path(path)) + val formatter = new SimpleDateFormat("yyyyMMddHHmm") + val jobtrackerID = formatter.format(new Date()) + val stageId = self.id + def writeShard(context: TaskContext, iter: Iterator[(K,V)]): Int = { + // Hadoop wants a 32-bit task attempt ID, so if ours is bigger than Int.MaxValue, roll it + // around by taking a mod. We expect that no task will be attempted 2 billion times. + val attemptNumber = (context.attemptId % Int.MaxValue).toInt + /* "reduce task" <split #> <attempt # = spark task #> */ + val attemptId = newTaskAttemptID(jobtrackerID, stageId, false, context.splitId, attemptNumber) + val hadoopContext = newTaskAttemptContext(wrappedConf.value, attemptId) + val format = outputFormatClass.newInstance + val committer = format.getOutputCommitter(hadoopContext) + committer.setupTask(hadoopContext) + val writer = format.getRecordWriter(hadoopContext).asInstanceOf[NewRecordWriter[K,V]] + while (iter.hasNext) { + val (k, v) = iter.next() + writer.write(k, v) + } + writer.close(hadoopContext) + committer.commitTask(hadoopContext) + return 1 + } + val jobFormat = outputFormatClass.newInstance + /* apparently we need a TaskAttemptID to construct an OutputCommitter; + * however we're only going to use this local OutputCommitter for + * setupJob/commitJob, so we just use a dummy "map" task. + */ + val jobAttemptId = newTaskAttemptID(jobtrackerID, stageId, true, 0, 0) + val jobTaskContext = newTaskAttemptContext(wrappedConf.value, jobAttemptId) + val jobCommitter = jobFormat.getOutputCommitter(jobTaskContext) + jobCommitter.setupJob(jobTaskContext) + val count = self.context.runJob(self, writeShard _).sum + jobCommitter.commitJob(jobTaskContext) + jobCommitter.cleanupJob(jobTaskContext) + } + + /** + * Output the RDD to any Hadoop-supported file system, using a Hadoop `OutputFormat` class + * supporting the key and value types K and V in this RDD. Compress with the supplied codec. + */ + def saveAsHadoopFile( + path: String, + keyClass: Class[_], + valueClass: Class[_], + outputFormatClass: Class[_ <: OutputFormat[_, _]], + codec: Class[_ <: CompressionCodec]) { + saveAsHadoopFile(path, keyClass, valueClass, outputFormatClass, + new JobConf(self.context.hadoopConfiguration), Some(codec)) + } + + /** + * Output the RDD to any Hadoop-supported file system, using a Hadoop `OutputFormat` class + * supporting the key and value types K and V in this RDD. + */ + def saveAsHadoopFile( + path: String, + keyClass: Class[_], + valueClass: Class[_], + outputFormatClass: Class[_ <: OutputFormat[_, _]], + conf: JobConf = new JobConf(self.context.hadoopConfiguration), + codec: Option[Class[_ <: CompressionCodec]] = None) { + conf.setOutputKeyClass(keyClass) + conf.setOutputValueClass(valueClass) + // conf.setOutputFormat(outputFormatClass) // Doesn't work in Scala 2.9 due to what may be a generics bug + conf.set("mapred.output.format.class", outputFormatClass.getName) + for (c <- codec) { + conf.setCompressMapOutput(true) + conf.set("mapred.output.compress", "true") + conf.setMapOutputCompressorClass(c) + conf.set("mapred.output.compression.codec", c.getCanonicalName) + conf.set("mapred.output.compression.type", CompressionType.BLOCK.toString) + } + conf.setOutputCommitter(classOf[FileOutputCommitter]) + FileOutputFormat.setOutputPath(conf, SparkHadoopWriter.createPathFromString(path, conf)) + saveAsHadoopDataset(conf) + } + + /** + * Output the RDD to any Hadoop-supported storage system, using a Hadoop JobConf object for + * that storage system. The JobConf should set an OutputFormat and any output paths required + * (e.g. a table name to write to) in the same way as it would be configured for a Hadoop + * MapReduce job. + */ + def saveAsHadoopDataset(conf: JobConf) { + val outputFormatClass = conf.getOutputFormat + val keyClass = conf.getOutputKeyClass + val valueClass = conf.getOutputValueClass + if (outputFormatClass == null) { + throw new SparkException("Output format class not set") + } + if (keyClass == null) { + throw new SparkException("Output key class not set") + } + if (valueClass == null) { + throw new SparkException("Output value class not set") + } + + logInfo("Saving as hadoop file of type (" + keyClass.getSimpleName+ ", " + valueClass.getSimpleName+ ")") + + val writer = new SparkHadoopWriter(conf) + writer.preSetup() + + def writeToFile(context: TaskContext, iter: Iterator[(K, V)]) { + // Hadoop wants a 32-bit task attempt ID, so if ours is bigger than Int.MaxValue, roll it + // around by taking a mod. We expect that no task will be attempted 2 billion times. + val attemptNumber = (context.attemptId % Int.MaxValue).toInt + + writer.setup(context.stageId, context.splitId, attemptNumber) + writer.open() + + var count = 0 + while(iter.hasNext) { + val record = iter.next() + count += 1 + writer.write(record._1.asInstanceOf[AnyRef], record._2.asInstanceOf[AnyRef]) + } + + writer.close() + writer.commit() + } + + self.context.runJob(self, writeToFile _) + writer.commitJob() + writer.cleanup() + } + + /** + * Return an RDD with the keys of each tuple. + */ + def keys: RDD[K] = self.map(_._1) + + /** + * Return an RDD with the values of each tuple. + */ + def values: RDD[V] = self.map(_._2) + + private[spark] def getKeyClass() = implicitly[ClassTag[K]].runtimeClass + + private[spark] def getValueClass() = implicitly[ClassTag[V]].runtimeClass +} + +private[spark] object ClassTags { + val seqSeqClassTag = classTag[Seq[Seq[_]]] +} diff --git a/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala new file mode 100644 index 0000000000..78fe0cdcdb --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala @@ -0,0 +1,155 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rdd + +import scala.collection.immutable.NumericRange +import scala.collection.mutable.ArrayBuffer +import scala.collection.Map +import scala.reflect.ClassTag + +import org.apache.spark._ +import java.io._ +import scala.Serializable +import org.apache.spark.serializer.JavaSerializer +import org.apache.spark.util.Utils + +private[spark] class ParallelCollectionPartition[T: ClassTag]( + var rddId: Long, + var slice: Int, + var values: Seq[T]) + extends Partition with Serializable { + + def iterator: Iterator[T] = values.iterator + + override def hashCode(): Int = (41 * (41 + rddId) + slice).toInt + + override def equals(other: Any): Boolean = other match { + case that: ParallelCollectionPartition[_] => (this.rddId == that.rddId && this.slice == that.slice) + case _ => false + } + + override def index: Int = slice + + @throws(classOf[IOException]) + private def writeObject(out: ObjectOutputStream): Unit = { + + val sfactory = SparkEnv.get.serializer + + // Treat java serializer with default action rather than going thru serialization, to avoid a + // separate serialization header. + + sfactory match { + case js: JavaSerializer => out.defaultWriteObject() + case _ => + out.writeLong(rddId) + out.writeInt(slice) + + val ser = sfactory.newInstance() + Utils.serializeViaNestedStream(out, ser)(_.writeObject(values)) + } + } + + @throws(classOf[IOException]) + private def readObject(in: ObjectInputStream): Unit = { + + val sfactory = SparkEnv.get.serializer + sfactory match { + case js: JavaSerializer => in.defaultReadObject() + case _ => + rddId = in.readLong() + slice = in.readInt() + + val ser = sfactory.newInstance() + Utils.deserializeViaNestedStream(in, ser)(ds => values = ds.readObject()) + } + } +} + +private[spark] class ParallelCollectionRDD[T: ClassTag]( + @transient sc: SparkContext, + @transient data: Seq[T], + numSlices: Int, + locationPrefs: Map[Int, Seq[String]]) + extends RDD[T](sc, Nil) { + // TODO: Right now, each split sends along its full data, even if later down the RDD chain it gets + // cached. It might be worthwhile to write the data to a file in the DFS and read it in the split + // instead. + // UPDATE: A parallel collection can be checkpointed to HDFS, which achieves this goal. + + override def getPartitions: Array[Partition] = { + val slices = ParallelCollectionRDD.slice(data, numSlices).toArray + slices.indices.map(i => new ParallelCollectionPartition(id, i, slices(i))).toArray + } + + override def compute(s: Partition, context: TaskContext) = + s.asInstanceOf[ParallelCollectionPartition[T]].iterator + + override def getPreferredLocations(s: Partition): Seq[String] = { + locationPrefs.getOrElse(s.index, Nil) + } +} + +private object ParallelCollectionRDD { + /** + * Slice a collection into numSlices sub-collections. One extra thing we do here is to treat Range + * collections specially, encoding the slices as other Ranges to minimize memory cost. This makes + * it efficient to run Spark over RDDs representing large sets of numbers. + */ + def slice[T: ClassTag](seq: Seq[T], numSlices: Int): Seq[Seq[T]] = { + if (numSlices < 1) { + throw new IllegalArgumentException("Positive number of slices required") + } + seq match { + case r: Range.Inclusive => { + val sign = if (r.step < 0) { + -1 + } else { + 1 + } + slice(new Range( + r.start, r.end + sign, r.step).asInstanceOf[Seq[T]], numSlices) + } + case r: Range => { + (0 until numSlices).map(i => { + val start = ((i * r.length.toLong) / numSlices).toInt + val end = (((i + 1) * r.length.toLong) / numSlices).toInt + new Range(r.start + start * r.step, r.start + end * r.step, r.step) + }).asInstanceOf[Seq[Seq[T]]] + } + case nr: NumericRange[_] => { + // For ranges of Long, Double, BigInteger, etc + val slices = new ArrayBuffer[Seq[T]](numSlices) + val sliceSize = (nr.size + numSlices - 1) / numSlices // Round up to catch everything + var r = nr + for (i <- 0 until numSlices) { + slices += r.take(sliceSize).asInstanceOf[Seq[T]] + r = r.drop(sliceSize) + } + slices + } + case _ => { + val array = seq.toArray // To prevent O(n^2) operations for List etc + (0 until numSlices).map(i => { + val start = ((i * array.length.toLong) / numSlices).toInt + val end = (((i + 1) * array.length.toLong) / numSlices).toInt + array.slice(start, end).toSeq + }) + } + } + } +} diff --git a/core/src/main/scala/org/apache/spark/rdd/PartitionPruningRDD.scala b/core/src/main/scala/org/apache/spark/rdd/PartitionPruningRDD.scala new file mode 100644 index 0000000000..bb9b309a70 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rdd/PartitionPruningRDD.scala @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rdd + +import scala.reflect.ClassTag + +import org.apache.spark.{NarrowDependency, SparkEnv, Partition, TaskContext} + + +class PartitionPruningRDDPartition(idx: Int, val parentSplit: Partition) extends Partition { + override val index = idx +} + + +/** + * Represents a dependency between the PartitionPruningRDD and its parent. In this + * case, the child RDD contains a subset of partitions of the parents'. + */ +class PruneDependency[T](rdd: RDD[T], @transient partitionFilterFunc: Int => Boolean) + extends NarrowDependency[T](rdd) { + + @transient + val partitions: Array[Partition] = rdd.partitions.zipWithIndex + .filter(s => partitionFilterFunc(s._2)) + .map { case(split, idx) => new PartitionPruningRDDPartition(idx, split) : Partition } + + override def getParents(partitionId: Int) = List(partitions(partitionId).index) +} + + +/** + * A RDD used to prune RDD partitions/partitions so we can avoid launching tasks on + * all partitions. An example use case: If we know the RDD is partitioned by range, + * and the execution DAG has a filter on the key, we can avoid launching tasks + * on partitions that don't have the range covering the key. + */ +class PartitionPruningRDD[T: ClassTag]( + @transient prev: RDD[T], + @transient partitionFilterFunc: Int => Boolean) + extends RDD[T](prev.context, List(new PruneDependency(prev, partitionFilterFunc))) { + + override def compute(split: Partition, context: TaskContext) = firstParent[T].iterator( + split.asInstanceOf[PartitionPruningRDDPartition].parentSplit, context) + + override protected def getPartitions: Array[Partition] = + getDependencies.head.asInstanceOf[PruneDependency[T]].partitions +} + + +object PartitionPruningRDD { + + /** + * Create a PartitionPruningRDD. This function can be used to create the PartitionPruningRDD + * when its type T is not known at compile time. + */ + def create[T](rdd: RDD[T], partitionFilterFunc: Int => Boolean) = { + new PartitionPruningRDD[T](rdd, partitionFilterFunc)(rdd.elementClassTag) + } +} diff --git a/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala new file mode 100644 index 0000000000..1dbbe39898 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rdd + +import java.io.PrintWriter +import java.util.StringTokenizer + +import scala.collection.Map +import scala.collection.JavaConversions._ +import scala.collection.mutable.ArrayBuffer +import scala.io.Source +import scala.reflect.ClassTag + +import org.apache.spark.{SparkEnv, Partition, TaskContext} +import org.apache.spark.broadcast.Broadcast + + +/** + * An RDD that pipes the contents of each parent partition through an external command + * (printing them one per line) and returns the output as a collection of strings. + */ +class PipedRDD[T: ClassTag]( + prev: RDD[T], + command: Seq[String], + envVars: Map[String, String], + printPipeContext: (String => Unit) => Unit, + printRDDElement: (T, String => Unit) => Unit) + extends RDD[String](prev) { + + // Similar to Runtime.exec(), if we are given a single string, split it into words + // using a standard StringTokenizer (i.e. by spaces) + def this( + prev: RDD[T], + command: String, + envVars: Map[String, String] = Map(), + printPipeContext: (String => Unit) => Unit = null, + printRDDElement: (T, String => Unit) => Unit = null) = + this(prev, PipedRDD.tokenize(command), envVars, printPipeContext, printRDDElement) + + + override def getPartitions: Array[Partition] = firstParent[T].partitions + + override def compute(split: Partition, context: TaskContext): Iterator[String] = { + val pb = new ProcessBuilder(command) + // Add the environmental variables to the process. + val currentEnvVars = pb.environment() + envVars.foreach { case (variable, value) => currentEnvVars.put(variable, value) } + + val proc = pb.start() + val env = SparkEnv.get + + // Start a thread to print the process's stderr to ours + new Thread("stderr reader for " + command) { + override def run() { + for (line <- Source.fromInputStream(proc.getErrorStream).getLines) { + System.err.println(line) + } + } + }.start() + + // Start a thread to feed the process input from our parent's iterator + new Thread("stdin writer for " + command) { + override def run() { + SparkEnv.set(env) + val out = new PrintWriter(proc.getOutputStream) + + // input the pipe context firstly + if (printPipeContext != null) { + printPipeContext(out.println(_)) + } + for (elem <- firstParent[T].iterator(split, context)) { + if (printRDDElement != null) { + printRDDElement(elem, out.println(_)) + } else { + out.println(elem) + } + } + out.close() + } + }.start() + + // Return an iterator that read lines from the process's stdout + val lines = Source.fromInputStream(proc.getInputStream).getLines + return new Iterator[String] { + def next() = lines.next() + def hasNext = { + if (lines.hasNext) { + true + } else { + val exitStatus = proc.waitFor() + if (exitStatus != 0) { + throw new Exception("Subprocess exited with status " + exitStatus) + } + false + } + } + } + } +} + +object PipedRDD { + // Split a string into words using a standard StringTokenizer + def tokenize(command: String): Seq[String] = { + val buf = new ArrayBuffer[String] + val tok = new StringTokenizer(command) + while(tok.hasMoreElements) { + buf += tok.nextToken() + } + buf + } +} diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala new file mode 100644 index 0000000000..70c967f4bf --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -0,0 +1,945 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rdd + +import java.util.Random + +import scala.collection.Map +import scala.collection.JavaConversions.mapAsScalaMap +import scala.collection.mutable.ArrayBuffer + +import scala.collection.mutable.HashMap +import scala.reflect.{classTag, ClassTag} + +import org.apache.hadoop.io.BytesWritable +import org.apache.hadoop.io.compress.CompressionCodec +import org.apache.hadoop.io.NullWritable +import org.apache.hadoop.io.Text +import org.apache.hadoop.mapred.TextOutputFormat + +import it.unimi.dsi.fastutil.objects.{Object2LongOpenHashMap => OLMap} + +import org.apache.spark.Partitioner._ +import org.apache.spark.api.java.JavaRDD +import org.apache.spark.partial.BoundedDouble +import org.apache.spark.partial.CountEvaluator +import org.apache.spark.partial.GroupedCountEvaluator +import org.apache.spark.partial.PartialResult +import org.apache.spark.storage.StorageLevel +import org.apache.spark.util.{Utils, BoundedPriorityQueue} + +import org.apache.spark.SparkContext._ +import org.apache.spark._ + +/** + * A Resilient Distributed Dataset (RDD), the basic abstraction in Spark. Represents an immutable, + * partitioned collection of elements that can be operated on in parallel. This class contains the + * basic operations available on all RDDs, such as `map`, `filter`, and `persist`. In addition, + * [[org.apache.spark.rdd.PairRDDFunctions]] contains operations available only on RDDs of key-value + * pairs, such as `groupByKey` and `join`; [[org.apache.spark.rdd.DoubleRDDFunctions]] contains + * operations available only on RDDs of Doubles; and [[org.apache.spark.rdd.SequenceFileRDDFunctions]] + * contains operations available on RDDs that can be saved as SequenceFiles. These operations are + * automatically available on any RDD of the right type (e.g. RDD[(Int, Int)] through implicit + * conversions when you `import org.apache.spark.SparkContext._`. + * + * Internally, each RDD is characterized by five main properties: + * + * - A list of partitions + * - A function for computing each split + * - A list of dependencies on other RDDs + * - Optionally, a Partitioner for key-value RDDs (e.g. to say that the RDD is hash-partitioned) + * - Optionally, a list of preferred locations to compute each split on (e.g. block locations for + * an HDFS file) + * + * All of the scheduling and execution in Spark is done based on these methods, allowing each RDD + * to implement its own way of computing itself. Indeed, users can implement custom RDDs (e.g. for + * reading data from a new storage system) by overriding these functions. Please refer to the + * [[http://www.cs.berkeley.edu/~matei/papers/2012/nsdi_spark.pdf Spark paper]] for more details + * on RDD internals. + */ +abstract class RDD[T: ClassTag]( + @transient private var sc: SparkContext, + @transient private var deps: Seq[Dependency[_]] + ) extends Serializable with Logging { + + /** Construct an RDD with just a one-to-one dependency on one parent */ + def this(@transient oneParent: RDD[_]) = + this(oneParent.context , List(new OneToOneDependency(oneParent))) + + // ======================================================================= + // Methods that should be implemented by subclasses of RDD + // ======================================================================= + + /** Implemented by subclasses to compute a given partition. */ + def compute(split: Partition, context: TaskContext): Iterator[T] + + /** + * Implemented by subclasses to return the set of partitions in this RDD. This method will only + * be called once, so it is safe to implement a time-consuming computation in it. + */ + protected def getPartitions: Array[Partition] + + /** + * Implemented by subclasses to return how this RDD depends on parent RDDs. This method will only + * be called once, so it is safe to implement a time-consuming computation in it. + */ + protected def getDependencies: Seq[Dependency[_]] = deps + + /** Optionally overridden by subclasses to specify placement preferences. */ + protected def getPreferredLocations(split: Partition): Seq[String] = Nil + + /** Optionally overridden by subclasses to specify how they are partitioned. */ + val partitioner: Option[Partitioner] = None + + // ======================================================================= + // Methods and fields available on all RDDs + // ======================================================================= + + /** The SparkContext that created this RDD. */ + def sparkContext: SparkContext = sc + + /** A unique ID for this RDD (within its SparkContext). */ + val id: Int = sc.newRddId() + + /** A friendly name for this RDD */ + var name: String = null + + /** Assign a name to this RDD */ + def setName(_name: String) = { + name = _name + this + } + + /** User-defined generator of this RDD*/ + var generator = Utils.getCallSiteInfo.firstUserClass + + /** Reset generator*/ + def setGenerator(_generator: String) = { + generator = _generator + } + + /** + * Set this RDD's storage level to persist its values across operations after the first time + * it is computed. This can only be used to assign a new storage level if the RDD does not + * have a storage level set yet.. + */ + def persist(newLevel: StorageLevel): RDD[T] = { + // TODO: Handle changes of StorageLevel + if (storageLevel != StorageLevel.NONE && newLevel != storageLevel) { + throw new UnsupportedOperationException( + "Cannot change storage level of an RDD after it was already assigned a level") + } + storageLevel = newLevel + // Register the RDD with the SparkContext + sc.persistentRdds(id) = this + this + } + + /** Persist this RDD with the default storage level (`MEMORY_ONLY`). */ + def persist(): RDD[T] = persist(StorageLevel.MEMORY_ONLY) + + /** Persist this RDD with the default storage level (`MEMORY_ONLY`). */ + def cache(): RDD[T] = persist() + + /** + * Mark the RDD as non-persistent, and remove all blocks for it from memory and disk. + * + * @param blocking Whether to block until all blocks are deleted. + * @return This RDD. + */ + def unpersist(blocking: Boolean = true): RDD[T] = { + logInfo("Removing RDD " + id + " from persistence list") + sc.env.blockManager.master.removeRdd(id, blocking) + sc.persistentRdds.remove(id) + storageLevel = StorageLevel.NONE + this + } + + /** Get the RDD's current storage level, or StorageLevel.NONE if none is set. */ + def getStorageLevel = storageLevel + + // Our dependencies and partitions will be gotten by calling subclass's methods below, and will + // be overwritten when we're checkpointed + private var dependencies_ : Seq[Dependency[_]] = null + @transient private var partitions_ : Array[Partition] = null + + /** An Option holding our checkpoint RDD, if we are checkpointed */ + private def checkpointRDD: Option[RDD[T]] = checkpointData.flatMap(_.checkpointRDD) + + /** + * Get the list of dependencies of this RDD, taking into account whether the + * RDD is checkpointed or not. + */ + final def dependencies: Seq[Dependency[_]] = { + checkpointRDD.map(r => List(new OneToOneDependency(r))).getOrElse { + if (dependencies_ == null) { + dependencies_ = getDependencies + } + dependencies_ + } + } + + /** + * Get the array of partitions of this RDD, taking into account whether the + * RDD is checkpointed or not. + */ + final def partitions: Array[Partition] = { + checkpointRDD.map(_.partitions).getOrElse { + if (partitions_ == null) { + partitions_ = getPartitions + } + partitions_ + } + } + + /** + * Get the preferred locations of a partition (as hostnames), taking into account whether the + * RDD is checkpointed. + */ + final def preferredLocations(split: Partition): Seq[String] = { + checkpointRDD.map(_.getPreferredLocations(split)).getOrElse { + getPreferredLocations(split) + } + } + + /** + * Internal method to this RDD; will read from cache if applicable, or otherwise compute it. + * This should ''not'' be called by users directly, but is available for implementors of custom + * subclasses of RDD. + */ + final def iterator(split: Partition, context: TaskContext): Iterator[T] = { + if (storageLevel != StorageLevel.NONE) { + SparkEnv.get.cacheManager.getOrCompute(this, split, context, storageLevel) + } else { + computeOrReadCheckpoint(split, context) + } + } + + /** + * Compute an RDD partition or read it from a checkpoint if the RDD is checkpointing. + */ + private[spark] def computeOrReadCheckpoint(split: Partition, context: TaskContext): Iterator[T] = { + if (isCheckpointed) { + firstParent[T].iterator(split, context) + } else { + compute(split, context) + } + } + + // Transformations (return a new RDD) + + /** + * Return a new RDD by applying a function to all elements of this RDD. + */ + def map[U: ClassTag](f: T => U): RDD[U] = new MappedRDD(this, sc.clean(f)) + + /** + * Return a new RDD by first applying a function to all elements of this + * RDD, and then flattening the results. + */ + def flatMap[U: ClassTag](f: T => TraversableOnce[U]): RDD[U] = + new FlatMappedRDD(this, sc.clean(f)) + + /** + * Return a new RDD containing only the elements that satisfy a predicate. + */ + def filter(f: T => Boolean): RDD[T] = new FilteredRDD(this, sc.clean(f)) + + /** + * Return a new RDD containing the distinct elements in this RDD. + */ + def distinct(numPartitions: Int): RDD[T] = + map(x => (x, null)).reduceByKey((x, y) => x, numPartitions).map(_._1) + + def distinct(): RDD[T] = distinct(partitions.size) + + /** + * Return a new RDD that is reduced into `numPartitions` partitions. + */ + def coalesce(numPartitions: Int, shuffle: Boolean = false): RDD[T] = { + if (shuffle) { + // include a shuffle step so that our upstream tasks are still distributed + new CoalescedRDD( + new ShuffledRDD[T, Null, (T, Null)](map(x => (x, null)), + new HashPartitioner(numPartitions)), + numPartitions).keys + } else { + new CoalescedRDD(this, numPartitions) + } + } + + /** + * Return a sampled subset of this RDD. + */ + def sample(withReplacement: Boolean, fraction: Double, seed: Int): RDD[T] = + new SampledRDD(this, withReplacement, fraction, seed) + + def takeSample(withReplacement: Boolean, num: Int, seed: Int): Array[T] = { + var fraction = 0.0 + var total = 0 + val multiplier = 3.0 + val initialCount = this.count() + var maxSelected = 0 + + if (num < 0) { + throw new IllegalArgumentException("Negative number of elements requested") + } + + if (initialCount > Integer.MAX_VALUE - 1) { + maxSelected = Integer.MAX_VALUE - 1 + } else { + maxSelected = initialCount.toInt + } + + if (num > initialCount && !withReplacement) { + total = maxSelected + fraction = multiplier * (maxSelected + 1) / initialCount + } else { + fraction = multiplier * (num + 1) / initialCount + total = num + } + + val rand = new Random(seed) + var samples = this.sample(withReplacement, fraction, rand.nextInt()).collect() + + // If the first sample didn't turn out large enough, keep trying to take samples; + // this shouldn't happen often because we use a big multiplier for thei initial size + while (samples.length < total) { + samples = this.sample(withReplacement, fraction, rand.nextInt()).collect() + } + + Utils.randomizeInPlace(samples, rand).take(total) + } + + /** + * Return the union of this RDD and another one. Any identical elements will appear multiple + * times (use `.distinct()` to eliminate them). + */ + def union(other: RDD[T]): RDD[T] = new UnionRDD(sc, Array(this, other)) + + /** + * Return the union of this RDD and another one. Any identical elements will appear multiple + * times (use `.distinct()` to eliminate them). + */ + def ++(other: RDD[T]): RDD[T] = this.union(other) + + /** + * Return an RDD created by coalescing all elements within each partition into an array. + */ + def glom(): RDD[Array[T]] = new GlommedRDD(this) + + /** + * Return the Cartesian product of this RDD and another one, that is, the RDD of all pairs of + * elements (a, b) where a is in `this` and b is in `other`. + */ + def cartesian[U: ClassTag](other: RDD[U]): RDD[(T, U)] = new CartesianRDD(sc, this, other) + + /** + * Return an RDD of grouped items. + */ + def groupBy[K: ClassTag](f: T => K): RDD[(K, Seq[T])] = + groupBy[K](f, defaultPartitioner(this)) + + /** + * Return an RDD of grouped elements. Each group consists of a key and a sequence of elements + * mapping to that key. + */ + def groupBy[K: ClassTag](f: T => K, numPartitions: Int): RDD[(K, Seq[T])] = + groupBy(f, new HashPartitioner(numPartitions)) + + /** + * Return an RDD of grouped items. + */ + def groupBy[K: ClassTag](f: T => K, p: Partitioner): RDD[(K, Seq[T])] = { + val cleanF = sc.clean(f) + this.map(t => (cleanF(t), t)).groupByKey(p) + } + + /** + * Return an RDD created by piping elements to a forked external process. + */ + def pipe(command: String): RDD[String] = new PipedRDD(this, command) + + /** + * Return an RDD created by piping elements to a forked external process. + */ + def pipe(command: String, env: Map[String, String]): RDD[String] = + new PipedRDD(this, command, env) + + + /** + * Return an RDD created by piping elements to a forked external process. + * The print behavior can be customized by providing two functions. + * + * @param command command to run in forked process. + * @param env environment variables to set. + * @param printPipeContext Before piping elements, this function is called as an oppotunity + * to pipe context data. Print line function (like out.println) will be + * passed as printPipeContext's parameter. + * @param printRDDElement Use this function to customize how to pipe elements. This function + * will be called with each RDD element as the 1st parameter, and the + * print line function (like out.println()) as the 2nd parameter. + * An example of pipe the RDD data of groupBy() in a streaming way, + * instead of constructing a huge String to concat all the elements: + * def printRDDElement(record:(String, Seq[String]), f:String=>Unit) = + * for (e <- record._2){f(e)} + * @return the result RDD + */ + def pipe( + command: Seq[String], + env: Map[String, String] = Map(), + printPipeContext: (String => Unit) => Unit = null, + printRDDElement: (T, String => Unit) => Unit = null): RDD[String] = + new PipedRDD(this, command, env, + if (printPipeContext ne null) sc.clean(printPipeContext) else null, + if (printRDDElement ne null) sc.clean(printRDDElement) else null) + + /** + * Return a new RDD by applying a function to each partition of this RDD. + */ + def mapPartitions[U: ClassTag](f: Iterator[T] => Iterator[U], + preservesPartitioning: Boolean = false): RDD[U] = + new MapPartitionsRDD(this, sc.clean(f), preservesPartitioning) + + /** + * Return a new RDD by applying a function to each partition of this RDD, while tracking the index + * of the original partition. + */ + def mapPartitionsWithIndex[U: ClassTag]( + f: (Int, Iterator[T]) => Iterator[U], + preservesPartitioning: Boolean = false): RDD[U] = + new MapPartitionsWithIndexRDD(this, sc.clean(f), preservesPartitioning) + + /** + * Return a new RDD by applying a function to each partition of this RDD, while tracking the index + * of the original partition. + */ + @deprecated("use mapPartitionsWithIndex", "0.7.0") + def mapPartitionsWithSplit[U: ClassTag]( + f: (Int, Iterator[T]) => Iterator[U], + preservesPartitioning: Boolean = false): RDD[U] = + new MapPartitionsWithIndexRDD(this, sc.clean(f), preservesPartitioning) + + /** + * Maps f over this RDD, where f takes an additional parameter of type A. This + * additional parameter is produced by constructA, which is called in each + * partition with the index of that partition. + */ + def mapWith[A: ClassTag, U: ClassTag](constructA: Int => A, preservesPartitioning: Boolean = false) + (f:(T, A) => U): RDD[U] = { + def iterF(index: Int, iter: Iterator[T]): Iterator[U] = { + val a = constructA(index) + iter.map(t => f(t, a)) + } + new MapPartitionsWithIndexRDD(this, sc.clean(iterF _), preservesPartitioning) + } + + /** + * FlatMaps f over this RDD, where f takes an additional parameter of type A. This + * additional parameter is produced by constructA, which is called in each + * partition with the index of that partition. + */ + def flatMapWith[A: ClassTag, U: ClassTag](constructA: Int => A, preservesPartitioning: Boolean = false) + (f:(T, A) => Seq[U]): RDD[U] = { + def iterF(index: Int, iter: Iterator[T]): Iterator[U] = { + val a = constructA(index) + iter.flatMap(t => f(t, a)) + } + new MapPartitionsWithIndexRDD(this, sc.clean(iterF _), preservesPartitioning) + } + + /** + * Applies f to each element of this RDD, where f takes an additional parameter of type A. + * This additional parameter is produced by constructA, which is called in each + * partition with the index of that partition. + */ + def foreachWith[A: ClassTag](constructA: Int => A) + (f:(T, A) => Unit) { + def iterF(index: Int, iter: Iterator[T]): Iterator[T] = { + val a = constructA(index) + iter.map(t => {f(t, a); t}) + } + (new MapPartitionsWithIndexRDD(this, sc.clean(iterF _), true)).foreach(_ => {}) + } + + /** + * Filters this RDD with p, where p takes an additional parameter of type A. This + * additional parameter is produced by constructA, which is called in each + * partition with the index of that partition. + */ + def filterWith[A: ClassTag](constructA: Int => A) + (p:(T, A) => Boolean): RDD[T] = { + def iterF(index: Int, iter: Iterator[T]): Iterator[T] = { + val a = constructA(index) + iter.filter(t => p(t, a)) + } + new MapPartitionsWithIndexRDD(this, sc.clean(iterF _), true) + } + + /** + * Zips this RDD with another one, returning key-value pairs with the first element in each RDD, + * second element in each RDD, etc. Assumes that the two RDDs have the *same number of + * partitions* and the *same number of elements in each partition* (e.g. one was made through + * a map on the other). + */ + def zip[U: ClassTag](other: RDD[U]): RDD[(T, U)] = new ZippedRDD(sc, this, other) + + /** + * Zip this RDD's partitions with one (or more) RDD(s) and return a new RDD by + * applying a function to the zipped partitions. Assumes that all the RDDs have the + * *same number of partitions*, but does *not* require them to have the same number + * of elements in each partition. + */ + def zipPartitions[B: ClassTag, V: ClassTag] + (rdd2: RDD[B]) + (f: (Iterator[T], Iterator[B]) => Iterator[V]): RDD[V] = + new ZippedPartitionsRDD2(sc, sc.clean(f), this, rdd2) + + def zipPartitions[B: ClassTag, C: ClassTag, V: ClassTag] + (rdd2: RDD[B], rdd3: RDD[C]) + (f: (Iterator[T], Iterator[B], Iterator[C]) => Iterator[V]): RDD[V] = + new ZippedPartitionsRDD3(sc, sc.clean(f), this, rdd2, rdd3) + + def zipPartitions[B: ClassTag, C: ClassTag, D: ClassTag, V: ClassTag] + (rdd2: RDD[B], rdd3: RDD[C], rdd4: RDD[D]) + (f: (Iterator[T], Iterator[B], Iterator[C], Iterator[D]) => Iterator[V]): RDD[V] = + new ZippedPartitionsRDD4(sc, sc.clean(f), this, rdd2, rdd3, rdd4) + + + // Actions (launch a job to return a value to the user program) + + /** + * Applies a function f to all elements of this RDD. + */ + def foreach(f: T => Unit) { + val cleanF = sc.clean(f) + sc.runJob(this, (iter: Iterator[T]) => iter.foreach(cleanF)) + } + + /** + * Applies a function f to each partition of this RDD. + */ + def foreachPartition(f: Iterator[T] => Unit) { + val cleanF = sc.clean(f) + sc.runJob(this, (iter: Iterator[T]) => cleanF(iter)) + } + + /** + * Return an array that contains all of the elements in this RDD. + */ + def collect(): Array[T] = { + val results = sc.runJob(this, (iter: Iterator[T]) => iter.toArray) + Array.concat(results: _*) + } + + /** + * Return an array that contains all of the elements in this RDD. + */ + def toArray(): Array[T] = collect() + + /** + * Return an RDD that contains all matching values by applying `f`. + */ + def collect[U: ClassTag](f: PartialFunction[T, U]): RDD[U] = { + filter(f.isDefinedAt).map(f) + } + + /** + * Return an RDD with the elements from `this` that are not in `other`. + * + * Uses `this` partitioner/partition size, because even if `other` is huge, the resulting + * RDD will be <= us. + */ + def subtract(other: RDD[T]): RDD[T] = + subtract(other, partitioner.getOrElse(new HashPartitioner(partitions.size))) + + /** + * Return an RDD with the elements from `this` that are not in `other`. + */ + def subtract(other: RDD[T], numPartitions: Int): RDD[T] = + subtract(other, new HashPartitioner(numPartitions)) + + /** + * Return an RDD with the elements from `this` that are not in `other`. + */ + def subtract(other: RDD[T], p: Partitioner): RDD[T] = { + if (partitioner == Some(p)) { + // Our partitioner knows how to handle T (which, since we have a partitioner, is + // really (K, V)) so make a new Partitioner that will de-tuple our fake tuples + val p2 = new Partitioner() { + override def numPartitions = p.numPartitions + override def getPartition(k: Any) = p.getPartition(k.asInstanceOf[(Any, _)]._1) + } + // Unfortunately, since we're making a new p2, we'll get ShuffleDependencies + // anyway, and when calling .keys, will not have a partitioner set, even though + // the SubtractedRDD will, thanks to p2's de-tupled partitioning, already be + // partitioned by the right/real keys (e.g. p). + this.map(x => (x, null)).subtractByKey(other.map((_, null)), p2).keys + } else { + this.map(x => (x, null)).subtractByKey(other.map((_, null)), p).keys + } + } + + /** + * Reduces the elements of this RDD using the specified commutative and associative binary operator. + */ + def reduce(f: (T, T) => T): T = { + val cleanF = sc.clean(f) + val reducePartition: Iterator[T] => Option[T] = iter => { + if (iter.hasNext) { + Some(iter.reduceLeft(cleanF)) + } else { + None + } + } + var jobResult: Option[T] = None + val mergeResult = (index: Int, taskResult: Option[T]) => { + if (taskResult != None) { + jobResult = jobResult match { + case Some(value) => Some(f(value, taskResult.get)) + case None => taskResult + } + } + } + sc.runJob(this, reducePartition, mergeResult) + // Get the final result out of our Option, or throw an exception if the RDD was empty + jobResult.getOrElse(throw new UnsupportedOperationException("empty collection")) + } + + /** + * Aggregate the elements of each partition, and then the results for all the partitions, using a + * given associative function and a neutral "zero value". The function op(t1, t2) is allowed to + * modify t1 and return it as its result value to avoid object allocation; however, it should not + * modify t2. + */ + def fold(zeroValue: T)(op: (T, T) => T): T = { + // Clone the zero value since we will also be serializing it as part of tasks + var jobResult = Utils.clone(zeroValue, sc.env.closureSerializer.newInstance()) + val cleanOp = sc.clean(op) + val foldPartition = (iter: Iterator[T]) => iter.fold(zeroValue)(cleanOp) + val mergeResult = (index: Int, taskResult: T) => jobResult = op(jobResult, taskResult) + sc.runJob(this, foldPartition, mergeResult) + jobResult + } + + /** + * Aggregate the elements of each partition, and then the results for all the partitions, using + * given combine functions and a neutral "zero value". This function can return a different result + * type, U, than the type of this RDD, T. Thus, we need one operation for merging a T into an U + * and one operation for merging two U's, as in scala.TraversableOnce. Both of these functions are + * allowed to modify and return their first argument instead of creating a new U to avoid memory + * allocation. + */ + def aggregate[U: ClassTag](zeroValue: U)(seqOp: (U, T) => U, combOp: (U, U) => U): U = { + // Clone the zero value since we will also be serializing it as part of tasks + var jobResult = Utils.clone(zeroValue, sc.env.closureSerializer.newInstance()) + val cleanSeqOp = sc.clean(seqOp) + val cleanCombOp = sc.clean(combOp) + val aggregatePartition = (it: Iterator[T]) => it.aggregate(zeroValue)(cleanSeqOp, cleanCombOp) + val mergeResult = (index: Int, taskResult: U) => jobResult = combOp(jobResult, taskResult) + sc.runJob(this, aggregatePartition, mergeResult) + jobResult + } + + /** + * Return the number of elements in the RDD. + */ + def count(): Long = { + sc.runJob(this, (iter: Iterator[T]) => { + var result = 0L + while (iter.hasNext) { + result += 1L + iter.next() + } + result + }).sum + } + + /** + * (Experimental) Approximate version of count() that returns a potentially incomplete result + * within a timeout, even if not all tasks have finished. + */ + def countApprox(timeout: Long, confidence: Double = 0.95): PartialResult[BoundedDouble] = { + val countElements: (TaskContext, Iterator[T]) => Long = { (ctx, iter) => + var result = 0L + while (iter.hasNext) { + result += 1L + iter.next() + } + result + } + val evaluator = new CountEvaluator(partitions.size, confidence) + sc.runApproximateJob(this, countElements, evaluator, timeout) + } + + /** + * Return the count of each unique value in this RDD as a map of (value, count) pairs. The final + * combine step happens locally on the master, equivalent to running a single reduce task. + */ + def countByValue(): Map[T, Long] = { + if (elementClassTag.runtimeClass.isArray) { + throw new SparkException("countByValue() does not support arrays") + } + // TODO: This should perhaps be distributed by default. + def countPartition(iter: Iterator[T]): Iterator[OLMap[T]] = { + val map = new OLMap[T] + while (iter.hasNext) { + val v = iter.next() + map.put(v, map.getLong(v) + 1L) + } + Iterator(map) + } + def mergeMaps(m1: OLMap[T], m2: OLMap[T]): OLMap[T] = { + val iter = m2.object2LongEntrySet.fastIterator() + while (iter.hasNext) { + val entry = iter.next() + m1.put(entry.getKey, m1.getLong(entry.getKey) + entry.getLongValue) + } + return m1 + } + val myResult = mapPartitions(countPartition).reduce(mergeMaps) + myResult.asInstanceOf[java.util.Map[T, Long]] // Will be wrapped as a Scala mutable Map + } + + /** + * (Experimental) Approximate version of countByValue(). + */ + def countByValueApprox( + timeout: Long, + confidence: Double = 0.95 + ): PartialResult[Map[T, BoundedDouble]] = { + if (elementClassTag.runtimeClass.isArray) { + throw new SparkException("countByValueApprox() does not support arrays") + } + val countPartition: (TaskContext, Iterator[T]) => OLMap[T] = { (ctx, iter) => + val map = new OLMap[T] + while (iter.hasNext) { + val v = iter.next() + map.put(v, map.getLong(v) + 1L) + } + map + } + val evaluator = new GroupedCountEvaluator[T](partitions.size, confidence) + sc.runApproximateJob(this, countPartition, evaluator, timeout) + } + + /** + * Take the first num elements of the RDD. This currently scans the partitions *one by one*, so + * it will be slow if a lot of partitions are required. In that case, use collect() to get the + * whole RDD instead. + */ + def take(num: Int): Array[T] = { + if (num == 0) { + return new Array[T](0) + } + val buf = new ArrayBuffer[T] + var p = 0 + while (buf.size < num && p < partitions.size) { + val left = num - buf.size + val res = sc.runJob(this, (it: Iterator[T]) => it.take(left).toArray, Array(p), true) + buf ++= res(0) + if (buf.size == num) + return buf.toArray + p += 1 + } + return buf.toArray + } + + /** + * Return the first element in this RDD. + */ + def first(): T = take(1) match { + case Array(t) => t + case _ => throw new UnsupportedOperationException("empty collection") + } + + /** + * Returns the top K elements from this RDD as defined by + * the specified implicit Ordering[T]. + * @param num the number of top elements to return + * @param ord the implicit ordering for T + * @return an array of top elements + */ + def top(num: Int)(implicit ord: Ordering[T]): Array[T] = { + mapPartitions { items => + val queue = new BoundedPriorityQueue[T](num) + queue ++= items + Iterator.single(queue) + }.reduce { (queue1, queue2) => + queue1 ++= queue2 + queue1 + }.toArray.sorted(ord.reverse) + } + + /** + * Returns the first K elements from this RDD as defined by + * the specified implicit Ordering[T] and maintains the + * ordering. + * @param num the number of top elements to return + * @param ord the implicit ordering for T + * @return an array of top elements + */ + def takeOrdered(num: Int)(implicit ord: Ordering[T]): Array[T] = top(num)(ord.reverse) + + /** + * Save this RDD as a text file, using string representations of elements. + */ + def saveAsTextFile(path: String) { + this.map(x => (NullWritable.get(), new Text(x.toString))) + .saveAsHadoopFile[TextOutputFormat[NullWritable, Text]](path) + } + + /** + * Save this RDD as a compressed text file, using string representations of elements. + */ + def saveAsTextFile(path: String, codec: Class[_ <: CompressionCodec]) { + this.map(x => (NullWritable.get(), new Text(x.toString))) + .saveAsHadoopFile[TextOutputFormat[NullWritable, Text]](path, codec) + } + + /** + * Save this RDD as a SequenceFile of serialized objects. + */ + def saveAsObjectFile(path: String) { + this.mapPartitions(iter => iter.grouped(10).map(_.toArray)) + .map(x => (NullWritable.get(), new BytesWritable(Utils.serialize(x)))) + .saveAsSequenceFile(path) + } + + /** + * Creates tuples of the elements in this RDD by applying `f`. + */ + def keyBy[K](f: T => K): RDD[(K, T)] = { + map(x => (f(x), x)) + } + + /** A private method for tests, to look at the contents of each partition */ + private[spark] def collectPartitions(): Array[Array[T]] = { + sc.runJob(this, (iter: Iterator[T]) => iter.toArray) + } + + /** + * Mark this RDD for checkpointing. It will be saved to a file inside the checkpoint + * directory set with SparkContext.setCheckpointDir() and all references to its parent + * RDDs will be removed. This function must be called before any job has been + * executed on this RDD. It is strongly recommended that this RDD is persisted in + * memory, otherwise saving it on a file will require recomputation. + */ + def checkpoint() { + if (context.checkpointDir.isEmpty) { + throw new Exception("Checkpoint directory has not been set in the SparkContext") + } else if (checkpointData.isEmpty) { + checkpointData = Some(new RDDCheckpointData(this)) + checkpointData.get.markForCheckpoint() + } + } + + /** + * Return whether this RDD has been checkpointed or not + */ + def isCheckpointed: Boolean = { + checkpointData.map(_.isCheckpointed).getOrElse(false) + } + + /** + * Gets the name of the file to which this RDD was checkpointed + */ + def getCheckpointFile: Option[String] = { + checkpointData.flatMap(_.getCheckpointFile) + } + + // ======================================================================= + // Other internal methods and fields + // ======================================================================= + + private var storageLevel: StorageLevel = StorageLevel.NONE + + /** Record user function generating this RDD. */ + private[spark] val origin = Utils.formatSparkCallSite + + private[spark] def elementClassTag: ClassTag[T] = classTag[T] + + private[spark] var checkpointData: Option[RDDCheckpointData[T]] = None + + /** Returns the first parent RDD */ + protected[spark] def firstParent[U: ClassTag] = { + dependencies.head.rdd.asInstanceOf[RDD[U]] + } + + /** The [[org.apache.spark.SparkContext]] that this RDD was created on. */ + def context = sc + + // Avoid handling doCheckpoint multiple times to prevent excessive recursion + private var doCheckpointCalled = false + + /** + * Performs the checkpointing of this RDD by saving this. It is called by the DAGScheduler + * after a job using this RDD has completed (therefore the RDD has been materialized and + * potentially stored in memory). doCheckpoint() is called recursively on the parent RDDs. + */ + private[spark] def doCheckpoint() { + if (!doCheckpointCalled) { + doCheckpointCalled = true + if (checkpointData.isDefined) { + checkpointData.get.doCheckpoint() + } else { + dependencies.foreach(_.rdd.doCheckpoint()) + } + } + } + + /** + * Changes the dependencies of this RDD from its original parents to a new RDD (`newRDD`) + * created from the checkpoint file, and forget its old dependencies and partitions. + */ + private[spark] def markCheckpointed(checkpointRDD: RDD[_]) { + clearDependencies() + partitions_ = null + deps = null // Forget the constructor argument for dependencies too + } + + /** + * Clears the dependencies of this RDD. This method must ensure that all references + * to the original parent RDDs is removed to enable the parent RDDs to be garbage + * collected. Subclasses of RDD may override this method for implementing their own cleaning + * logic. See [[org.apache.spark.rdd.UnionRDD]] for an example. + */ + protected def clearDependencies() { + dependencies_ = null + } + + /** A description of this RDD and its recursive dependencies for debugging. */ + def toDebugString: String = { + def debugString(rdd: RDD[_], prefix: String = ""): Seq[String] = { + Seq(prefix + rdd + " (" + rdd.partitions.size + " partitions)") ++ + rdd.dependencies.flatMap(d => debugString(d.rdd, prefix + " ")) + } + debugString(this).mkString("\n") + } + + override def toString: String = "%s%s[%d] at %s".format( + Option(name).map(_ + " ").getOrElse(""), + getClass.getSimpleName, + id, + origin) + + def toJavaRDD() : JavaRDD[T] = { + new JavaRDD(this)(elementClassTag) + } + +} diff --git a/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala b/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala new file mode 100644 index 0000000000..3b56e45aa9 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala @@ -0,0 +1,133 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rdd + +import scala.reflect.ClassTag + +import org.apache.hadoop.fs.Path +import org.apache.hadoop.conf.Configuration + +import org.apache.spark.{Partition, SparkException, Logging} +import org.apache.spark.scheduler.{ResultTask, ShuffleMapTask} + +/** + * Enumeration to manage state transitions of an RDD through checkpointing + * [ Initialized --> marked for checkpointing --> checkpointing in progress --> checkpointed ] + */ +private[spark] object CheckpointState extends Enumeration { + type CheckpointState = Value + val Initialized, MarkedForCheckpoint, CheckpointingInProgress, Checkpointed = Value +} + +/** + * This class contains all the information related to RDD checkpointing. Each instance of this class + * is associated with a RDD. It manages process of checkpointing of the associated RDD, as well as, + * manages the post-checkpoint state by providing the updated partitions, iterator and preferred locations + * of the checkpointed RDD. + */ +private[spark] class RDDCheckpointData[T: ClassTag](rdd: RDD[T]) + extends Logging with Serializable { + + import CheckpointState._ + + // The checkpoint state of the associated RDD. + var cpState = Initialized + + // The file to which the associated RDD has been checkpointed to + @transient var cpFile: Option[String] = None + + // The CheckpointRDD created from the checkpoint file, that is, the new parent the associated RDD. + var cpRDD: Option[RDD[T]] = None + + // Mark the RDD for checkpointing + def markForCheckpoint() { + RDDCheckpointData.synchronized { + if (cpState == Initialized) cpState = MarkedForCheckpoint + } + } + + // Is the RDD already checkpointed + def isCheckpointed: Boolean = { + RDDCheckpointData.synchronized { cpState == Checkpointed } + } + + // Get the file to which this RDD was checkpointed to as an Option + def getCheckpointFile: Option[String] = { + RDDCheckpointData.synchronized { cpFile } + } + + // Do the checkpointing of the RDD. Called after the first job using that RDD is over. + def doCheckpoint() { + // If it is marked for checkpointing AND checkpointing is not already in progress, + // then set it to be in progress, else return + RDDCheckpointData.synchronized { + if (cpState == MarkedForCheckpoint) { + cpState = CheckpointingInProgress + } else { + return + } + } + + // Create the output path for the checkpoint + val path = new Path(rdd.context.checkpointDir.get, "rdd-" + rdd.id) + val fs = path.getFileSystem(new Configuration()) + if (!fs.mkdirs(path)) { + throw new SparkException("Failed to create checkpoint path " + path) + } + + // Save to file, and reload it as an RDD + rdd.context.runJob(rdd, CheckpointRDD.writeToFile(path.toString) _) + val newRDD = new CheckpointRDD[T](rdd.context, path.toString) + + // Change the dependencies and partitions of the RDD + RDDCheckpointData.synchronized { + cpFile = Some(path.toString) + cpRDD = Some(newRDD) + rdd.markCheckpointed(newRDD) // Update the RDD's dependencies and partitions + cpState = Checkpointed + RDDCheckpointData.clearTaskCaches() + logInfo("Done checkpointing RDD " + rdd.id + ", new parent is RDD " + newRDD.id) + } + } + + // Get preferred location of a split after checkpointing + def getPreferredLocations(split: Partition): Seq[String] = { + RDDCheckpointData.synchronized { + cpRDD.get.preferredLocations(split) + } + } + + def getPartitions: Array[Partition] = { + RDDCheckpointData.synchronized { + cpRDD.get.partitions + } + } + + def checkpointRDD: Option[RDD[T]] = { + RDDCheckpointData.synchronized { + cpRDD + } + } +} + +private[spark] object RDDCheckpointData { + def clearTaskCaches() { + ShuffleMapTask.clearCache() + ResultTask.clearCache() + } +} diff --git a/core/src/main/scala/org/apache/spark/rdd/SampledRDD.scala b/core/src/main/scala/org/apache/spark/rdd/SampledRDD.scala new file mode 100644 index 0000000000..d433670cc2 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rdd/SampledRDD.scala @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rdd + +import scala.reflect.ClassTag +import java.util.Random + +import cern.jet.random.Poisson +import cern.jet.random.engine.DRand + +import org.apache.spark.{Partition, TaskContext} + +private[spark] +class SampledRDDPartition(val prev: Partition, val seed: Int) extends Partition with Serializable { + override val index: Int = prev.index +} + +class SampledRDD[T: ClassTag]( + prev: RDD[T], + withReplacement: Boolean, + frac: Double, + seed: Int) + extends RDD[T](prev) { + + override def getPartitions: Array[Partition] = { + val rg = new Random(seed) + firstParent[T].partitions.map(x => new SampledRDDPartition(x, rg.nextInt)) + } + + override def getPreferredLocations(split: Partition): Seq[String] = + firstParent[T].preferredLocations(split.asInstanceOf[SampledRDDPartition].prev) + + override def compute(splitIn: Partition, context: TaskContext): Iterator[T] = { + val split = splitIn.asInstanceOf[SampledRDDPartition] + if (withReplacement) { + // For large datasets, the expected number of occurrences of each element in a sample with + // replacement is Poisson(frac). We use that to get a count for each element. + val poisson = new Poisson(frac, new DRand(split.seed)) + firstParent[T].iterator(split.prev, context).flatMap { element => + val count = poisson.nextInt() + if (count == 0) { + Iterator.empty // Avoid object allocation when we return 0 items, which is quite often + } else { + Iterator.fill(count)(element) + } + } + } else { // Sampling without replacement + val rand = new Random(split.seed) + firstParent[T].iterator(split.prev, context).filter(x => (rand.nextDouble <= frac)) + } + } +} diff --git a/core/src/main/scala/org/apache/spark/rdd/SequenceFileRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/SequenceFileRDDFunctions.scala new file mode 100644 index 0000000000..2d1bd5b481 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rdd/SequenceFileRDDFunctions.scala @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.rdd + +import scala.reflect.{ ClassTag, classTag} + +import org.apache.hadoop.mapred.JobConf +import org.apache.hadoop.mapred.SequenceFileOutputFormat +import org.apache.hadoop.io.compress.CompressionCodec +import org.apache.hadoop.io.Writable + +import org.apache.spark.SparkContext._ +import org.apache.spark.Logging + +/** + * Extra functions available on RDDs of (key, value) pairs to create a Hadoop SequenceFile, + * through an implicit conversion. Note that this can't be part of PairRDDFunctions because + * we need more implicit parameters to convert our keys and values to Writable. + * + * Import `org.apache.spark.SparkContext._` at the top of their program to use these functions. + */ +class SequenceFileRDDFunctions[K <% Writable: ClassTag, V <% Writable : ClassTag]( + self: RDD[(K, V)]) + extends Logging + with Serializable { + + private def getWritableClass[T <% Writable: ClassTag](): Class[_ <: Writable] = { + val c = { + if (classOf[Writable].isAssignableFrom(classTag[T].runtimeClass)) { + classTag[T].runtimeClass + } else { + // We get the type of the Writable class by looking at the apply method which converts + // from T to Writable. Since we have two apply methods we filter out the one which + // is not of the form "java.lang.Object apply(java.lang.Object)" + implicitly[T => Writable].getClass.getDeclaredMethods().filter( + m => m.getReturnType().toString != "class java.lang.Object" && + m.getName() == "apply")(0).getReturnType + + } + // TODO: use something like WritableConverter to avoid reflection + } + c.asInstanceOf[Class[_ <: Writable]] + } + + /** + * Output the RDD as a Hadoop SequenceFile using the Writable types we infer from the RDD's key + * and value types. If the key or value are Writable, then we use their classes directly; + * otherwise we map primitive types such as Int and Double to IntWritable, DoubleWritable, etc, + * byte arrays to BytesWritable, and Strings to Text. The `path` can be on any Hadoop-supported + * file system. + */ + def saveAsSequenceFile(path: String, codec: Option[Class[_ <: CompressionCodec]] = None) { + def anyToWritable[U <% Writable](u: U): Writable = u + + val keyClass = getWritableClass[K] + val valueClass = getWritableClass[V] + val convertKey = !classOf[Writable].isAssignableFrom(self.getKeyClass) + val convertValue = !classOf[Writable].isAssignableFrom(self.getValueClass) + + logInfo("Saving as sequence file of type (" + keyClass.getSimpleName + "," + valueClass.getSimpleName + ")" ) + val format = classOf[SequenceFileOutputFormat[Writable, Writable]] + val jobConf = new JobConf(self.context.hadoopConfiguration) + if (!convertKey && !convertValue) { + self.saveAsHadoopFile(path, keyClass, valueClass, format, jobConf, codec) + } else if (!convertKey && convertValue) { + self.map(x => (x._1,anyToWritable(x._2))).saveAsHadoopFile( + path, keyClass, valueClass, format, jobConf, codec) + } else if (convertKey && !convertValue) { + self.map(x => (anyToWritable(x._1),x._2)).saveAsHadoopFile( + path, keyClass, valueClass, format, jobConf, codec) + } else if (convertKey && convertValue) { + self.map(x => (anyToWritable(x._1),anyToWritable(x._2))).saveAsHadoopFile( + path, keyClass, valueClass, format, jobConf, codec) + } + } +} diff --git a/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala new file mode 100644 index 0000000000..9537152335 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rdd + +import org.apache.spark.{Dependency, Partitioner, SparkEnv, ShuffleDependency, Partition, TaskContext} + + +private[spark] class ShuffledRDDPartition(val idx: Int) extends Partition { + override val index = idx + override def hashCode(): Int = idx +} + +/** + * The resulting RDD from a shuffle (e.g. repartitioning of data). + * @param prev the parent RDD. + * @param part the partitioner used to partition the RDD + * @tparam K the key class. + * @tparam V the value class. + */ +class ShuffledRDD[K, V, P <: Product2[K, V] : ClassManifest]( + @transient var prev: RDD[P], + part: Partitioner) + extends RDD[P](prev.context, Nil) { + + private var serializerClass: String = null + + def setSerializer(cls: String): ShuffledRDD[K, V, P] = { + serializerClass = cls + this + } + + override def getDependencies: Seq[Dependency[_]] = { + List(new ShuffleDependency(prev, part, serializerClass)) + } + + override val partitioner = Some(part) + + override def getPartitions: Array[Partition] = { + Array.tabulate[Partition](part.numPartitions)(i => new ShuffledRDDPartition(i)) + } + + override def compute(split: Partition, context: TaskContext): Iterator[P] = { + val shuffledId = dependencies.head.asInstanceOf[ShuffleDependency[K, V]].shuffleId + SparkEnv.get.shuffleFetcher.fetch[P](shuffledId, split.index, context.taskMetrics, + SparkEnv.get.serializerManager.get(serializerClass)) + } + + override def clearDependencies() { + super.clearDependencies() + prev = null + } +} diff --git a/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala new file mode 100644 index 0000000000..85c512f3de --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala @@ -0,0 +1,131 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rdd + +import java.util.{HashMap => JHashMap} + +import scala.collection.JavaConversions._ +import scala.collection.mutable.ArrayBuffer +import scala.reflect.ClassTag + +import org.apache.spark.Partitioner +import org.apache.spark.Dependency +import org.apache.spark.TaskContext +import org.apache.spark.Partition +import org.apache.spark.SparkEnv +import org.apache.spark.ShuffleDependency +import org.apache.spark.OneToOneDependency + + +/** + * An optimized version of cogroup for set difference/subtraction. + * + * It is possible to implement this operation with just `cogroup`, but + * that is less efficient because all of the entries from `rdd2`, for + * both matching and non-matching values in `rdd1`, are kept in the + * JHashMap until the end. + * + * With this implementation, only the entries from `rdd1` are kept in-memory, + * and the entries from `rdd2` are essentially streamed, as we only need to + * touch each once to decide if the value needs to be removed. + * + * This is particularly helpful when `rdd1` is much smaller than `rdd2`, as + * you can use `rdd1`'s partitioner/partition size and not worry about running + * out of memory because of the size of `rdd2`. + */ +private[spark] class SubtractedRDD[K: ClassTag, V: ClassTag, W: ClassTag]( + @transient var rdd1: RDD[_ <: Product2[K, V]], + @transient var rdd2: RDD[_ <: Product2[K, W]], + part: Partitioner) + extends RDD[(K, V)](rdd1.context, Nil) { + + private var serializerClass: String = null + + def setSerializer(cls: String): SubtractedRDD[K, V, W] = { + serializerClass = cls + this + } + + override def getDependencies: Seq[Dependency[_]] = { + Seq(rdd1, rdd2).map { rdd => + if (rdd.partitioner == Some(part)) { + logDebug("Adding one-to-one dependency with " + rdd) + new OneToOneDependency(rdd) + } else { + logDebug("Adding shuffle dependency with " + rdd) + new ShuffleDependency(rdd, part, serializerClass) + } + } + } + + override def getPartitions: Array[Partition] = { + val array = new Array[Partition](part.numPartitions) + for (i <- 0 until array.size) { + // Each CoGroupPartition will depend on rdd1 and rdd2 + array(i) = new CoGroupPartition(i, Seq(rdd1, rdd2).zipWithIndex.map { case (rdd, j) => + dependencies(j) match { + case s: ShuffleDependency[_, _] => + new ShuffleCoGroupSplitDep(s.shuffleId) + case _ => + new NarrowCoGroupSplitDep(rdd, i, rdd.partitions(i)) + } + }.toArray) + } + array + } + + override val partitioner = Some(part) + + override def compute(p: Partition, context: TaskContext): Iterator[(K, V)] = { + val partition = p.asInstanceOf[CoGroupPartition] + val serializer = SparkEnv.get.serializerManager.get(serializerClass) + val map = new JHashMap[K, ArrayBuffer[V]] + def getSeq(k: K): ArrayBuffer[V] = { + val seq = map.get(k) + if (seq != null) { + seq + } else { + val seq = new ArrayBuffer[V]() + map.put(k, seq) + seq + } + } + def integrate(dep: CoGroupSplitDep, op: Product2[K, V] => Unit) = dep match { + case NarrowCoGroupSplitDep(rdd, _, itsSplit) => { + rdd.iterator(itsSplit, context).asInstanceOf[Iterator[Product2[K, V]]].foreach(op) + } + case ShuffleCoGroupSplitDep(shuffleId) => { + val iter = SparkEnv.get.shuffleFetcher.fetch[Product2[K, V]](shuffleId, partition.index, + context.taskMetrics, serializer) + iter.foreach(op) + } + } + // the first dep is rdd1; add all values to the map + integrate(partition.deps(0), t => getSeq(t._1) += t._2) + // the second dep is rdd2; remove all of its keys + integrate(partition.deps(1), t => map.remove(t._1)) + map.iterator.map { t => t._2.iterator.map { (t._1, _) } }.flatten + } + + override def clearDependencies() { + super.clearDependencies() + rdd1 = null + rdd2 = null + } + +} diff --git a/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala b/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala new file mode 100644 index 0000000000..08a41ac558 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rdd + +import scala.collection.mutable.ArrayBuffer +import scala.reflect.ClassTag + +import org.apache.spark.{Dependency, RangeDependency, SparkContext, Partition, TaskContext} + +import java.io.{ObjectOutputStream, IOException} + +private[spark] class UnionPartition[T: ClassTag](idx: Int, rdd: RDD[T], splitIndex: Int) + extends Partition { + + var split: Partition = rdd.partitions(splitIndex) + + def iterator(context: TaskContext) = rdd.iterator(split, context) + + def preferredLocations() = rdd.preferredLocations(split) + + override val index: Int = idx + + @throws(classOf[IOException]) + private def writeObject(oos: ObjectOutputStream) { + // Update the reference to parent split at the time of task serialization + split = rdd.partitions(splitIndex) + oos.defaultWriteObject() + } +} + +class UnionRDD[T: ClassTag]( + sc: SparkContext, + @transient var rdds: Seq[RDD[T]]) + extends RDD[T](sc, Nil) { // Nil since we implement getDependencies + + override def getPartitions: Array[Partition] = { + val array = new Array[Partition](rdds.map(_.partitions.size).sum) + var pos = 0 + for (rdd <- rdds; split <- rdd.partitions) { + array(pos) = new UnionPartition(pos, rdd, split.index) + pos += 1 + } + array + } + + override def getDependencies: Seq[Dependency[_]] = { + val deps = new ArrayBuffer[Dependency[_]] + var pos = 0 + for (rdd <- rdds) { + deps += new RangeDependency(rdd, 0, pos, rdd.partitions.size) + pos += rdd.partitions.size + } + deps + } + + override def compute(s: Partition, context: TaskContext): Iterator[T] = + s.asInstanceOf[UnionPartition[T]].iterator(context) + + override def getPreferredLocations(s: Partition): Seq[String] = + s.asInstanceOf[UnionPartition[T]].preferredLocations() +} diff --git a/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala new file mode 100644 index 0000000000..e02c17bf45 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala @@ -0,0 +1,144 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rdd + +import org.apache.spark.{OneToOneDependency, SparkContext, Partition, TaskContext} +import java.io.{ObjectOutputStream, IOException} +import scala.reflect.ClassTag + +private[spark] class ZippedPartitionsPartition( + idx: Int, + @transient rdds: Seq[RDD[_]]) + extends Partition { + + override val index: Int = idx + var partitionValues = rdds.map(rdd => rdd.partitions(idx)) + def partitions = partitionValues + + @throws(classOf[IOException]) + private def writeObject(oos: ObjectOutputStream) { + // Update the reference to parent split at the time of task serialization + partitionValues = rdds.map(rdd => rdd.partitions(idx)) + oos.defaultWriteObject() + } +} + +abstract class ZippedPartitionsBaseRDD[V: ClassTag]( + sc: SparkContext, + var rdds: Seq[RDD[_]]) + extends RDD[V](sc, rdds.map(x => new OneToOneDependency(x))) { + + override def getPartitions: Array[Partition] = { + val sizes = rdds.map(x => x.partitions.size) + if (!sizes.forall(x => x == sizes(0))) { + throw new IllegalArgumentException("Can't zip RDDs with unequal numbers of partitions") + } + val array = new Array[Partition](sizes(0)) + for (i <- 0 until sizes(0)) { + array(i) = new ZippedPartitionsPartition(i, rdds) + } + array + } + + override def getPreferredLocations(s: Partition): Seq[String] = { + val parts = s.asInstanceOf[ZippedPartitionsPartition].partitions + val prefs = rdds.zip(parts).map { case (rdd, p) => rdd.preferredLocations(p) } + // Check whether there are any hosts that match all RDDs; otherwise return the union + val exactMatchLocations = prefs.reduce((x, y) => x.intersect(y)) + if (!exactMatchLocations.isEmpty) { + exactMatchLocations + } else { + prefs.flatten.distinct + } + } + + override def clearDependencies() { + super.clearDependencies() + rdds = null + } +} + +class ZippedPartitionsRDD2[A: ClassTag, B: ClassTag, V: ClassTag]( + sc: SparkContext, + f: (Iterator[A], Iterator[B]) => Iterator[V], + var rdd1: RDD[A], + var rdd2: RDD[B]) + extends ZippedPartitionsBaseRDD[V](sc, List(rdd1, rdd2)) { + + override def compute(s: Partition, context: TaskContext): Iterator[V] = { + val partitions = s.asInstanceOf[ZippedPartitionsPartition].partitions + f(rdd1.iterator(partitions(0), context), rdd2.iterator(partitions(1), context)) + } + + override def clearDependencies() { + super.clearDependencies() + rdd1 = null + rdd2 = null + } +} + +class ZippedPartitionsRDD3 + [A: ClassTag, B: ClassTag, C: ClassTag, V: ClassTag]( + sc: SparkContext, + f: (Iterator[A], Iterator[B], Iterator[C]) => Iterator[V], + var rdd1: RDD[A], + var rdd2: RDD[B], + var rdd3: RDD[C]) + extends ZippedPartitionsBaseRDD[V](sc, List(rdd1, rdd2, rdd3)) { + + override def compute(s: Partition, context: TaskContext): Iterator[V] = { + val partitions = s.asInstanceOf[ZippedPartitionsPartition].partitions + f(rdd1.iterator(partitions(0), context), + rdd2.iterator(partitions(1), context), + rdd3.iterator(partitions(2), context)) + } + + override def clearDependencies() { + super.clearDependencies() + rdd1 = null + rdd2 = null + rdd3 = null + } +} + +class ZippedPartitionsRDD4 + [A: ClassTag, B: ClassTag, C: ClassTag, D:ClassTag, V: ClassTag]( + sc: SparkContext, + f: (Iterator[A], Iterator[B], Iterator[C], Iterator[D]) => Iterator[V], + var rdd1: RDD[A], + var rdd2: RDD[B], + var rdd3: RDD[C], + var rdd4: RDD[D]) + extends ZippedPartitionsBaseRDD[V](sc, List(rdd1, rdd2, rdd3, rdd4)) { + + override def compute(s: Partition, context: TaskContext): Iterator[V] = { + val partitions = s.asInstanceOf[ZippedPartitionsPartition].partitions + f(rdd1.iterator(partitions(0), context), + rdd2.iterator(partitions(1), context), + rdd3.iterator(partitions(2), context), + rdd4.iterator(partitions(3), context)) + } + + override def clearDependencies() { + super.clearDependencies() + rdd1 = null + rdd2 = null + rdd3 = null + rdd4 = null + } +} diff --git a/core/src/main/scala/org/apache/spark/rdd/ZippedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ZippedRDD.scala new file mode 100644 index 0000000000..fb5b070c18 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rdd/ZippedRDD.scala @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rdd + +import org.apache.spark.{OneToOneDependency, SparkContext, Partition, TaskContext} + +import java.io.{ObjectOutputStream, IOException} + +import scala.reflect.ClassTag + +private[spark] class ZippedPartition[T: ClassTag, U: ClassTag]( + idx: Int, + @transient rdd1: RDD[T], + @transient rdd2: RDD[U] + ) extends Partition { + + var partition1 = rdd1.partitions(idx) + var partition2 = rdd2.partitions(idx) + override val index: Int = idx + + def partitions = (partition1, partition2) + + @throws(classOf[IOException]) + private def writeObject(oos: ObjectOutputStream) { + // Update the reference to parent partition at the time of task serialization + partition1 = rdd1.partitions(idx) + partition2 = rdd2.partitions(idx) + oos.defaultWriteObject() + } +} + +class ZippedRDD[T: ClassTag, U: ClassTag]( + sc: SparkContext, + var rdd1: RDD[T], + var rdd2: RDD[U]) + extends RDD[(T, U)](sc, List(new OneToOneDependency(rdd1), new OneToOneDependency(rdd2))) { + + override def getPartitions: Array[Partition] = { + if (rdd1.partitions.size != rdd2.partitions.size) { + throw new IllegalArgumentException("Can't zip RDDs with unequal numbers of partitions") + } + val array = new Array[Partition](rdd1.partitions.size) + for (i <- 0 until rdd1.partitions.size) { + array(i) = new ZippedPartition(i, rdd1, rdd2) + } + array + } + + override def compute(s: Partition, context: TaskContext): Iterator[(T, U)] = { + val (partition1, partition2) = s.asInstanceOf[ZippedPartition[T, U]].partitions + rdd1.iterator(partition1, context).zip(rdd2.iterator(partition2, context)) + } + + override def getPreferredLocations(s: Partition): Seq[String] = { + val (partition1, partition2) = s.asInstanceOf[ZippedPartition[T, U]].partitions + val pref1 = rdd1.preferredLocations(partition1) + val pref2 = rdd2.preferredLocations(partition2) + // Check whether there are any hosts that match both RDDs; otherwise return the union + val exactMatchLocations = pref1.intersect(pref2) + if (!exactMatchLocations.isEmpty) { + exactMatchLocations + } else { + (pref1 ++ pref2).distinct + } + } + + override def clearDependencies() { + super.clearDependencies() + rdd1 = null + rdd2 = null + } +} diff --git a/core/src/main/scala/org/apache/spark/scheduler/ActiveJob.scala b/core/src/main/scala/org/apache/spark/scheduler/ActiveJob.scala new file mode 100644 index 0000000000..0b04607d01 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/scheduler/ActiveJob.scala @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler + +import org.apache.spark.TaskContext + +import java.util.Properties + +/** + * Tracks information about an active job in the DAGScheduler. + */ +private[spark] class ActiveJob( + val jobId: Int, + val finalStage: Stage, + val func: (TaskContext, Iterator[_]) => _, + val partitions: Array[Int], + val callSite: String, + val listener: JobListener, + val properties: Properties) { + + val numPartitions = partitions.length + val finished = Array.fill[Boolean](numPartitions)(false) + var numFinished = 0 +} diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala new file mode 100644 index 0000000000..854dbfee09 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -0,0 +1,851 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler + +import java.io.NotSerializableException +import java.util.Properties +import java.util.concurrent.{LinkedBlockingQueue, TimeUnit} +import java.util.concurrent.atomic.AtomicInteger + +import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Map} +import scala.reflect.ClassTag + +import org.apache.spark._ +import org.apache.spark.rdd.RDD +import org.apache.spark.executor.TaskMetrics +import org.apache.spark.partial.{ApproximateActionListener, ApproximateEvaluator, PartialResult} +import org.apache.spark.scheduler.cluster.TaskInfo +import org.apache.spark.storage.{BlockManager, BlockManagerMaster} +import org.apache.spark.util.{MetadataCleaner, TimeStampedHashMap} + +/** + * The high-level scheduling layer that implements stage-oriented scheduling. It computes a DAG of + * stages for each job, keeps track of which RDDs and stage outputs are materialized, and finds a + * minimal schedule to run the job. It then submits stages as TaskSets to an underlying + * TaskScheduler implementation that runs them on the cluster. + * + * In addition to coming up with a DAG of stages, this class also determines the preferred + * locations to run each task on, based on the current cache status, and passes these to the + * low-level TaskScheduler. Furthermore, it handles failures due to shuffle output files being + * lost, in which case old stages may need to be resubmitted. Failures *within* a stage that are + * not caused by shuffie file loss are handled by the TaskScheduler, which will retry each task + * a small number of times before cancelling the whole stage. + * + * THREADING: This class runs all its logic in a single thread executing the run() method, to which + * events are submitted using a synchonized queue (eventQueue). The public API methods, such as + * runJob, taskEnded and executorLost, post events asynchronously to this queue. All other methods + * should be private. + */ +private[spark] +class DAGScheduler( + taskSched: TaskScheduler, + mapOutputTracker: MapOutputTracker, + blockManagerMaster: BlockManagerMaster, + env: SparkEnv) + extends TaskSchedulerListener with Logging { + + def this(taskSched: TaskScheduler) { + this(taskSched, SparkEnv.get.mapOutputTracker, SparkEnv.get.blockManager.master, SparkEnv.get) + } + taskSched.setListener(this) + + // Called by TaskScheduler to report task's starting. + override def taskStarted(task: Task[_], taskInfo: TaskInfo) { + eventQueue.put(BeginEvent(task, taskInfo)) + } + + // Called by TaskScheduler to report task completions or failures. + override def taskEnded( + task: Task[_], + reason: TaskEndReason, + result: Any, + accumUpdates: Map[Long, Any], + taskInfo: TaskInfo, + taskMetrics: TaskMetrics) { + eventQueue.put(CompletionEvent(task, reason, result, accumUpdates, taskInfo, taskMetrics)) + } + + // Called by TaskScheduler when an executor fails. + override def executorLost(execId: String) { + eventQueue.put(ExecutorLost(execId)) + } + + // Called by TaskScheduler when a host is added + override def executorGained(execId: String, host: String) { + eventQueue.put(ExecutorGained(execId, host)) + } + + // Called by TaskScheduler to cancel an entire TaskSet due to repeated failures. + override def taskSetFailed(taskSet: TaskSet, reason: String) { + eventQueue.put(TaskSetFailed(taskSet, reason)) + } + + // The time, in millis, to wait for fetch failure events to stop coming in after one is detected; + // this is a simplistic way to avoid resubmitting tasks in the non-fetchable map stage one by one + // as more failure events come in + val RESUBMIT_TIMEOUT = 50L + + // The time, in millis, to wake up between polls of the completion queue in order to potentially + // resubmit failed stages + val POLL_TIMEOUT = 10L + + private val eventQueue = new LinkedBlockingQueue[DAGSchedulerEvent] + + val nextJobId = new AtomicInteger(0) + + val nextStageId = new AtomicInteger(0) + + val stageIdToStage = new TimeStampedHashMap[Int, Stage] + + val shuffleToMapStage = new TimeStampedHashMap[Int, Stage] + + private[spark] val stageToInfos = new TimeStampedHashMap[Stage, StageInfo] + + private val listenerBus = new SparkListenerBus() + + // Contains the locations that each RDD's partitions are cached on + private val cacheLocs = new HashMap[Int, Array[Seq[TaskLocation]]] + + // For tracking failed nodes, we use the MapOutputTracker's epoch number, which is sent with + // every task. When we detect a node failing, we note the current epoch number and failed + // executor, increment it for new tasks, and use this to ignore stray ShuffleMapTask results. + // + // TODO: Garbage collect information about failure epochs when we know there are no more + // stray messages to detect. + val failedEpoch = new HashMap[String, Long] + + val idToActiveJob = new HashMap[Int, ActiveJob] + + val waiting = new HashSet[Stage] // Stages we need to run whose parents aren't done + val running = new HashSet[Stage] // Stages we are running right now + val failed = new HashSet[Stage] // Stages that must be resubmitted due to fetch failures + val pendingTasks = new TimeStampedHashMap[Stage, HashSet[Task[_]]] // Missing tasks from each stage + var lastFetchFailureTime: Long = 0 // Used to wait a bit to avoid repeated resubmits + + val activeJobs = new HashSet[ActiveJob] + val resultStageToJob = new HashMap[Stage, ActiveJob] + + val metadataCleaner = new MetadataCleaner("DAGScheduler", this.cleanup) + + // Start a thread to run the DAGScheduler event loop + def start() { + new Thread("DAGScheduler") { + setDaemon(true) + override def run() { + DAGScheduler.this.run() + } + }.start() + } + + def addSparkListener(listener: SparkListener) { + listenerBus.addListener(listener) + } + + private def getCacheLocs(rdd: RDD[_]): Array[Seq[TaskLocation]] = { + if (!cacheLocs.contains(rdd.id)) { + val blockIds = rdd.partitions.indices.map(index=> "rdd_%d_%d".format(rdd.id, index)).toArray + val locs = BlockManager.blockIdsToBlockManagers(blockIds, env, blockManagerMaster) + cacheLocs(rdd.id) = blockIds.map { id => + locs.getOrElse(id, Nil).map(bm => TaskLocation(bm.host, bm.executorId)) + } + } + cacheLocs(rdd.id) + } + + private def clearCacheLocs() { + cacheLocs.clear() + } + + /** + * Get or create a shuffle map stage for the given shuffle dependency's map side. + * The jobId value passed in will be used if the stage doesn't already exist with + * a lower jobId (jobId always increases across jobs.) + */ + private def getShuffleMapStage(shuffleDep: ShuffleDependency[_,_], jobId: Int): Stage = { + shuffleToMapStage.get(shuffleDep.shuffleId) match { + case Some(stage) => stage + case None => + val stage = newStage(shuffleDep.rdd, Some(shuffleDep), jobId) + shuffleToMapStage(shuffleDep.shuffleId) = stage + stage + } + } + + /** + * Create a Stage for the given RDD, either as a shuffle map stage (for a ShuffleDependency) or + * as a result stage for the final RDD used directly in an action. The stage will also be + * associated with the provided jobId. + */ + private def newStage( + rdd: RDD[_], + shuffleDep: Option[ShuffleDependency[_,_]], + jobId: Int, + callSite: Option[String] = None) + : Stage = + { + if (shuffleDep != None) { + // Kind of ugly: need to register RDDs with the cache and map output tracker here + // since we can't do it in the RDD constructor because # of partitions is unknown + logInfo("Registering RDD " + rdd.id + " (" + rdd.origin + ")") + mapOutputTracker.registerShuffle(shuffleDep.get.shuffleId, rdd.partitions.size) + } + val id = nextStageId.getAndIncrement() + val stage = new Stage(id, rdd, shuffleDep, getParentStages(rdd, jobId), jobId, callSite) + stageIdToStage(id) = stage + stageToInfos(stage) = StageInfo(stage) + stage + } + + /** + * Get or create the list of parent stages for a given RDD. The stages will be assigned the + * provided jobId if they haven't already been created with a lower jobId. + */ + private def getParentStages(rdd: RDD[_], jobId: Int): List[Stage] = { + val parents = new HashSet[Stage] + val visited = new HashSet[RDD[_]] + def visit(r: RDD[_]) { + if (!visited(r)) { + visited += r + // Kind of ugly: need to register RDDs with the cache here since + // we can't do it in its constructor because # of partitions is unknown + for (dep <- r.dependencies) { + dep match { + case shufDep: ShuffleDependency[_,_] => + parents += getShuffleMapStage(shufDep, jobId) + case _ => + visit(dep.rdd) + } + } + } + } + visit(rdd) + parents.toList + } + + private def getMissingParentStages(stage: Stage): List[Stage] = { + val missing = new HashSet[Stage] + val visited = new HashSet[RDD[_]] + def visit(rdd: RDD[_]) { + if (!visited(rdd)) { + visited += rdd + if (getCacheLocs(rdd).contains(Nil)) { + for (dep <- rdd.dependencies) { + dep match { + case shufDep: ShuffleDependency[_,_] => + val mapStage = getShuffleMapStage(shufDep, stage.jobId) + if (!mapStage.isAvailable) { + missing += mapStage + } + case narrowDep: NarrowDependency[_] => + visit(narrowDep.rdd) + } + } + } + } + } + visit(stage.rdd) + missing.toList + } + + /** + * Returns (and does not submit) a JobSubmitted event suitable to run a given job, and a + * JobWaiter whose getResult() method will return the result of the job when it is complete. + * + * The job is assumed to have at least one partition; zero partition jobs should be handled + * without a JobSubmitted event. + */ + private[scheduler] def prepareJob[T, U: ClassTag]( + finalRdd: RDD[T], + func: (TaskContext, Iterator[T]) => U, + partitions: Seq[Int], + callSite: String, + allowLocal: Boolean, + resultHandler: (Int, U) => Unit, + properties: Properties = null) + : (JobSubmitted, JobWaiter[U]) = + { + assert(partitions.size > 0) + val waiter = new JobWaiter(partitions.size, resultHandler) + val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _] + val toSubmit = JobSubmitted(finalRdd, func2, partitions.toArray, allowLocal, callSite, waiter, + properties) + (toSubmit, waiter) + } + + def runJob[T, U: ClassTag]( + finalRdd: RDD[T], + func: (TaskContext, Iterator[T]) => U, + partitions: Seq[Int], + callSite: String, + allowLocal: Boolean, + resultHandler: (Int, U) => Unit, + properties: Properties = null) + { + if (partitions.size == 0) { + return + } + + // Check to make sure we are not launching a task on a partition that does not exist. + val maxPartitions = finalRdd.partitions.length + partitions.find(p => p >= maxPartitions).foreach { p => + throw new IllegalArgumentException( + "Attempting to access a non-existent partition: " + p + ". " + + "Total number of partitions: " + maxPartitions) + } + + val (toSubmit: JobSubmitted, waiter: JobWaiter[_]) = prepareJob( + finalRdd, func, partitions, callSite, allowLocal, resultHandler, properties) + eventQueue.put(toSubmit) + waiter.awaitResult() match { + case JobSucceeded => {} + case JobFailed(exception: Exception, _) => + logInfo("Failed to run " + callSite) + throw exception + } + } + + def runApproximateJob[T, U, R]( + rdd: RDD[T], + func: (TaskContext, Iterator[T]) => U, + evaluator: ApproximateEvaluator[U, R], + callSite: String, + timeout: Long, + properties: Properties = null) + : PartialResult[R] = + { + val listener = new ApproximateActionListener(rdd, func, evaluator, timeout) + val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _] + val partitions = (0 until rdd.partitions.size).toArray + eventQueue.put(JobSubmitted(rdd, func2, partitions, allowLocal = false, callSite, listener, properties)) + listener.awaitResult() // Will throw an exception if the job fails + } + + /** + * Process one event retrieved from the event queue. + * Returns true if we should stop the event loop. + */ + private[scheduler] def processEvent(event: DAGSchedulerEvent): Boolean = { + event match { + case JobSubmitted(finalRDD, func, partitions, allowLocal, callSite, listener, properties) => + val jobId = nextJobId.getAndIncrement() + val finalStage = newStage(finalRDD, None, jobId, Some(callSite)) + val job = new ActiveJob(jobId, finalStage, func, partitions, callSite, listener, properties) + clearCacheLocs() + logInfo("Got job " + job.jobId + " (" + callSite + ") with " + partitions.length + + " output partitions (allowLocal=" + allowLocal + ")") + logInfo("Final stage: " + finalStage + " (" + finalStage.name + ")") + logInfo("Parents of final stage: " + finalStage.parents) + logInfo("Missing parents: " + getMissingParentStages(finalStage)) + if (allowLocal && finalStage.parents.size == 0 && partitions.length == 1) { + // Compute very short actions like first() or take() with no parent stages locally. + runLocally(job) + } else { + listenerBus.post(SparkListenerJobStart(job, properties)) + idToActiveJob(jobId) = job + activeJobs += job + resultStageToJob(finalStage) = job + submitStage(finalStage) + } + + case ExecutorGained(execId, host) => + handleExecutorGained(execId, host) + + case ExecutorLost(execId) => + handleExecutorLost(execId) + + case begin: BeginEvent => + listenerBus.post(SparkListenerTaskStart(begin.task, begin.taskInfo)) + + case completion: CompletionEvent => + listenerBus.post(SparkListenerTaskEnd( + completion.task, completion.reason, completion.taskInfo, completion.taskMetrics)) + handleTaskCompletion(completion) + + case TaskSetFailed(taskSet, reason) => + abortStage(stageIdToStage(taskSet.stageId), reason) + + case StopDAGScheduler => + // Cancel any active jobs + for (job <- activeJobs) { + val error = new SparkException("Job cancelled because SparkContext was shut down") + job.listener.jobFailed(error) + listenerBus.post(SparkListenerJobEnd(job, JobFailed(error, None))) + } + return true + } + false + } + + /** + * Resubmit any failed stages. Ordinarily called after a small amount of time has passed since + * the last fetch failure. + */ + private[scheduler] def resubmitFailedStages() { + logInfo("Resubmitting failed stages") + clearCacheLocs() + val failed2 = failed.toArray + failed.clear() + for (stage <- failed2.sortBy(_.jobId)) { + submitStage(stage) + } + } + + /** + * Check for waiting or failed stages which are now eligible for resubmission. + * Ordinarily run on every iteration of the event loop. + */ + private[scheduler] def submitWaitingStages() { + // TODO: We might want to run this less often, when we are sure that something has become + // runnable that wasn't before. + logTrace("Checking for newly runnable parent stages") + logTrace("running: " + running) + logTrace("waiting: " + waiting) + logTrace("failed: " + failed) + val waiting2 = waiting.toArray + waiting.clear() + for (stage <- waiting2.sortBy(_.jobId)) { + submitStage(stage) + } + } + + + /** + * The main event loop of the DAG scheduler, which waits for new-job / task-finished / failure + * events and responds by launching tasks. This runs in a dedicated thread and receives events + * via the eventQueue. + */ + private def run() { + SparkEnv.set(env) + + while (true) { + val event = eventQueue.poll(POLL_TIMEOUT, TimeUnit.MILLISECONDS) + if (event != null) { + logDebug("Got event of type " + event.getClass.getName) + } + this.synchronized { // needed in case other threads makes calls into methods of this class + if (event != null) { + if (processEvent(event)) { + return + } + } + + val time = System.currentTimeMillis() // TODO: use a pluggable clock for testability + // Periodically resubmit failed stages if some map output fetches have failed and we have + // waited at least RESUBMIT_TIMEOUT. We wait for this short time because when a node fails, + // tasks on many other nodes are bound to get a fetch failure, and they won't all get it at + // the same time, so we want to make sure we've identified all the reduce tasks that depend + // on the failed node. + if (failed.size > 0 && time > lastFetchFailureTime + RESUBMIT_TIMEOUT) { + resubmitFailedStages() + } else { + submitWaitingStages() + } + } + } + } + + /** + * Run a job on an RDD locally, assuming it has only a single partition and no dependencies. + * We run the operation in a separate thread just in case it takes a bunch of time, so that we + * don't block the DAGScheduler event loop or other concurrent jobs. + */ + protected def runLocally(job: ActiveJob) { + logInfo("Computing the requested partition locally") + new Thread("Local computation of job " + job.jobId) { + override def run() { + runLocallyWithinThread(job) + } + }.start() + } + + // Broken out for easier testing in DAGSchedulerSuite. + protected def runLocallyWithinThread(job: ActiveJob) { + try { + SparkEnv.set(env) + val rdd = job.finalStage.rdd + val split = rdd.partitions(job.partitions(0)) + val taskContext = new TaskContext(job.finalStage.id, job.partitions(0), 0) + try { + val result = job.func(taskContext, rdd.iterator(split, taskContext)) + job.listener.taskSucceeded(0, result) + } finally { + taskContext.executeOnCompleteCallbacks() + } + } catch { + case e: Exception => + job.listener.jobFailed(e) + } + } + + /** Submits stage, but first recursively submits any missing parents. */ + private def submitStage(stage: Stage) { + logDebug("submitStage(" + stage + ")") + if (!waiting(stage) && !running(stage) && !failed(stage)) { + val missing = getMissingParentStages(stage).sortBy(_.id) + logDebug("missing: " + missing) + if (missing == Nil) { + logInfo("Submitting " + stage + " (" + stage.rdd + "), which has no missing parents") + submitMissingTasks(stage) + running += stage + } else { + for (parent <- missing) { + submitStage(parent) + } + waiting += stage + } + } + } + + /** Called when stage's parents are available and we can now do its task. */ + private def submitMissingTasks(stage: Stage) { + logDebug("submitMissingTasks(" + stage + ")") + // Get our pending tasks and remember them in our pendingTasks entry + val myPending = pendingTasks.getOrElseUpdate(stage, new HashSet) + myPending.clear() + var tasks = ArrayBuffer[Task[_]]() + if (stage.isShuffleMap) { + for (p <- 0 until stage.numPartitions if stage.outputLocs(p) == Nil) { + val locs = getPreferredLocs(stage.rdd, p) + tasks += new ShuffleMapTask(stage.id, stage.rdd, stage.shuffleDep.get, p, locs) + } + } else { + // This is a final stage; figure out its job's missing partitions + val job = resultStageToJob(stage) + for (id <- 0 until job.numPartitions if !job.finished(id)) { + val partition = job.partitions(id) + val locs = getPreferredLocs(stage.rdd, partition) + tasks += new ResultTask(stage.id, stage.rdd, job.func, partition, locs, id) + } + } + // must be run listener before possible NotSerializableException + // should be "StageSubmitted" first and then "JobEnded" + val properties = idToActiveJob(stage.jobId).properties + listenerBus.post(SparkListenerStageSubmitted(stage, tasks.size, properties)) + + if (tasks.size > 0) { + // Preemptively serialize a task to make sure it can be serialized. We are catching this + // exception here because it would be fairly hard to catch the non-serializable exception + // down the road, where we have several different implementations for local scheduler and + // cluster schedulers. + try { + SparkEnv.get.closureSerializer.newInstance().serialize(tasks.head) + } catch { + case e: NotSerializableException => + abortStage(stage, e.toString) + running -= stage + return + } + + logInfo("Submitting " + tasks.size + " missing tasks from " + stage + " (" + stage.rdd + ")") + myPending ++= tasks + logDebug("New pending tasks: " + myPending) + taskSched.submitTasks( + new TaskSet(tasks.toArray, stage.id, stage.newAttemptId(), stage.jobId, properties)) + if (!stage.submissionTime.isDefined) { + stage.submissionTime = Some(System.currentTimeMillis()) + } + } else { + logDebug("Stage " + stage + " is actually done; %b %d %d".format( + stage.isAvailable, stage.numAvailableOutputs, stage.numPartitions)) + running -= stage + } + } + + /** + * Responds to a task finishing. This is called inside the event loop so it assumes that it can + * modify the scheduler's internal state. Use taskEnded() to post a task end event from outside. + */ + private def handleTaskCompletion(event: CompletionEvent) { + val task = event.task + val stage = stageIdToStage(task.stageId) + + def markStageAsFinished(stage: Stage) = { + val serviceTime = stage.submissionTime match { + case Some(t) => "%.03f".format((System.currentTimeMillis() - t) / 1000.0) + case _ => "Unkown" + } + logInfo("%s (%s) finished in %s s".format(stage, stage.name, serviceTime)) + stage.completionTime = Some(System.currentTimeMillis) + listenerBus.post(StageCompleted(stageToInfos(stage))) + running -= stage + } + event.reason match { + case Success => + logInfo("Completed " + task) + if (event.accumUpdates != null) { + Accumulators.add(event.accumUpdates) // TODO: do this only if task wasn't resubmitted + } + pendingTasks(stage) -= task + stageToInfos(stage).taskInfos += event.taskInfo -> event.taskMetrics + task match { + case rt: ResultTask[_, _] => + resultStageToJob.get(stage) match { + case Some(job) => + if (!job.finished(rt.outputId)) { + job.finished(rt.outputId) = true + job.numFinished += 1 + // If the whole job has finished, remove it + if (job.numFinished == job.numPartitions) { + idToActiveJob -= stage.jobId + activeJobs -= job + resultStageToJob -= stage + markStageAsFinished(stage) + listenerBus.post(SparkListenerJobEnd(job, JobSucceeded)) + } + job.listener.taskSucceeded(rt.outputId, event.result) + } + case None => + logInfo("Ignoring result from " + rt + " because its job has finished") + } + + case smt: ShuffleMapTask => + val status = event.result.asInstanceOf[MapStatus] + val execId = status.location.executorId + logDebug("ShuffleMapTask finished on " + execId) + if (failedEpoch.contains(execId) && smt.epoch <= failedEpoch(execId)) { + logInfo("Ignoring possibly bogus ShuffleMapTask completion from " + execId) + } else { + stage.addOutputLoc(smt.partition, status) + } + if (running.contains(stage) && pendingTasks(stage).isEmpty) { + markStageAsFinished(stage) + logInfo("looking for newly runnable stages") + logInfo("running: " + running) + logInfo("waiting: " + waiting) + logInfo("failed: " + failed) + if (stage.shuffleDep != None) { + // We supply true to increment the epoch number here in case this is a + // recomputation of the map outputs. In that case, some nodes may have cached + // locations with holes (from when we detected the error) and will need the + // epoch incremented to refetch them. + // TODO: Only increment the epoch number if this is not the first time + // we registered these map outputs. + mapOutputTracker.registerMapOutputs( + stage.shuffleDep.get.shuffleId, + stage.outputLocs.map(list => if (list.isEmpty) null else list.head).toArray, + changeEpoch = true) + } + clearCacheLocs() + if (stage.outputLocs.count(_ == Nil) != 0) { + // Some tasks had failed; let's resubmit this stage + // TODO: Lower-level scheduler should also deal with this + logInfo("Resubmitting " + stage + " (" + stage.name + + ") because some of its tasks had failed: " + + stage.outputLocs.zipWithIndex.filter(_._1 == Nil).map(_._2).mkString(", ")) + submitStage(stage) + } else { + val newlyRunnable = new ArrayBuffer[Stage] + for (stage <- waiting) { + logInfo("Missing parents for " + stage + ": " + getMissingParentStages(stage)) + } + for (stage <- waiting if getMissingParentStages(stage) == Nil) { + newlyRunnable += stage + } + waiting --= newlyRunnable + running ++= newlyRunnable + for (stage <- newlyRunnable.sortBy(_.id)) { + logInfo("Submitting " + stage + " (" + stage.rdd + "), which is now runnable") + submitMissingTasks(stage) + } + } + } + } + + case Resubmitted => + logInfo("Resubmitted " + task + ", so marking it as still running") + pendingTasks(stage) += task + + case FetchFailed(bmAddress, shuffleId, mapId, reduceId) => + // Mark the stage that the reducer was in as unrunnable + val failedStage = stageIdToStage(task.stageId) + running -= failedStage + failed += failedStage + // TODO: Cancel running tasks in the stage + logInfo("Marking " + failedStage + " (" + failedStage.name + + ") for resubmision due to a fetch failure") + // Mark the map whose fetch failed as broken in the map stage + val mapStage = shuffleToMapStage(shuffleId) + if (mapId != -1) { + mapStage.removeOutputLoc(mapId, bmAddress) + mapOutputTracker.unregisterMapOutput(shuffleId, mapId, bmAddress) + } + logInfo("The failed fetch was from " + mapStage + " (" + mapStage.name + + "); marking it for resubmission") + failed += mapStage + // Remember that a fetch failed now; this is used to resubmit the broken + // stages later, after a small wait (to give other tasks the chance to fail) + lastFetchFailureTime = System.currentTimeMillis() // TODO: Use pluggable clock + // TODO: mark the executor as failed only if there were lots of fetch failures on it + if (bmAddress != null) { + handleExecutorLost(bmAddress.executorId, Some(task.epoch)) + } + + case ExceptionFailure(className, description, stackTrace, metrics) => + // Do nothing here, left up to the TaskScheduler to decide how to handle user failures + + case other => + // Unrecognized failure - abort all jobs depending on this stage + abortStage(stageIdToStage(task.stageId), task + " failed: " + other) + } + } + + /** + * Responds to an executor being lost. This is called inside the event loop, so it assumes it can + * modify the scheduler's internal state. Use executorLost() to post a loss event from outside. + * + * Optionally the epoch during which the failure was caught can be passed to avoid allowing + * stray fetch failures from possibly retriggering the detection of a node as lost. + */ + private def handleExecutorLost(execId: String, maybeEpoch: Option[Long] = None) { + val currentEpoch = maybeEpoch.getOrElse(mapOutputTracker.getEpoch) + if (!failedEpoch.contains(execId) || failedEpoch(execId) < currentEpoch) { + failedEpoch(execId) = currentEpoch + logInfo("Executor lost: %s (epoch %d)".format(execId, currentEpoch)) + blockManagerMaster.removeExecutor(execId) + // TODO: This will be really slow if we keep accumulating shuffle map stages + for ((shuffleId, stage) <- shuffleToMapStage) { + stage.removeOutputsOnExecutor(execId) + val locs = stage.outputLocs.map(list => if (list.isEmpty) null else list.head).toArray + mapOutputTracker.registerMapOutputs(shuffleId, locs, changeEpoch = true) + } + if (shuffleToMapStage.isEmpty) { + mapOutputTracker.incrementEpoch() + } + clearCacheLocs() + } else { + logDebug("Additional executor lost message for " + execId + + "(epoch " + currentEpoch + ")") + } + } + + private def handleExecutorGained(execId: String, host: String) { + // remove from failedEpoch(execId) ? + if (failedEpoch.contains(execId)) { + logInfo("Host gained which was in lost list earlier: " + host) + failedEpoch -= execId + } + } + + /** + * Aborts all jobs depending on a particular Stage. This is called in response to a task set + * being cancelled by the TaskScheduler. Use taskSetFailed() to inject this event from outside. + */ + private def abortStage(failedStage: Stage, reason: String) { + val dependentStages = resultStageToJob.keys.filter(x => stageDependsOn(x, failedStage)).toSeq + failedStage.completionTime = Some(System.currentTimeMillis()) + for (resultStage <- dependentStages) { + val job = resultStageToJob(resultStage) + val error = new SparkException("Job failed: " + reason) + job.listener.jobFailed(error) + listenerBus.post(SparkListenerJobEnd(job, JobFailed(error, Some(failedStage)))) + idToActiveJob -= resultStage.jobId + activeJobs -= job + resultStageToJob -= resultStage + } + if (dependentStages.isEmpty) { + logInfo("Ignoring failure of " + failedStage + " because all jobs depending on it are done") + } + } + + /** + * Return true if one of stage's ancestors is target. + */ + private def stageDependsOn(stage: Stage, target: Stage): Boolean = { + if (stage == target) { + return true + } + val visitedRdds = new HashSet[RDD[_]] + val visitedStages = new HashSet[Stage] + def visit(rdd: RDD[_]) { + if (!visitedRdds(rdd)) { + visitedRdds += rdd + for (dep <- rdd.dependencies) { + dep match { + case shufDep: ShuffleDependency[_,_] => + val mapStage = getShuffleMapStage(shufDep, stage.jobId) + if (!mapStage.isAvailable) { + visitedStages += mapStage + visit(mapStage.rdd) + } // Otherwise there's no need to follow the dependency back + case narrowDep: NarrowDependency[_] => + visit(narrowDep.rdd) + } + } + } + } + visit(stage.rdd) + visitedRdds.contains(target.rdd) + } + + /** + * Synchronized method that might be called from other threads. + * @param rdd whose partitions are to be looked at + * @param partition to lookup locality information for + * @return list of machines that are preferred by the partition + */ + private[spark] + def getPreferredLocs(rdd: RDD[_], partition: Int): Seq[TaskLocation] = synchronized { + // If the partition is cached, return the cache locations + val cached = getCacheLocs(rdd)(partition) + if (!cached.isEmpty) { + return cached + } + // If the RDD has some placement preferences (as is the case for input RDDs), get those + val rddPrefs = rdd.preferredLocations(rdd.partitions(partition)).toList + if (!rddPrefs.isEmpty) { + return rddPrefs.map(host => TaskLocation(host)) + } + // If the RDD has narrow dependencies, pick the first partition of the first narrow dep + // that has any placement preferences. Ideally we would choose based on transfer sizes, + // but this will do for now. + rdd.dependencies.foreach(_ match { + case n: NarrowDependency[_] => + for (inPart <- n.getParents(partition)) { + val locs = getPreferredLocs(n.rdd, inPart) + if (locs != Nil) + return locs + } + case _ => + }) + Nil + } + + private def cleanup(cleanupTime: Long) { + var sizeBefore = stageIdToStage.size + stageIdToStage.clearOldValues(cleanupTime) + logInfo("stageIdToStage " + sizeBefore + " --> " + stageIdToStage.size) + + sizeBefore = shuffleToMapStage.size + shuffleToMapStage.clearOldValues(cleanupTime) + logInfo("shuffleToMapStage " + sizeBefore + " --> " + shuffleToMapStage.size) + + sizeBefore = pendingTasks.size + pendingTasks.clearOldValues(cleanupTime) + logInfo("pendingTasks " + sizeBefore + " --> " + pendingTasks.size) + + sizeBefore = stageToInfos.size + stageToInfos.clearOldValues(cleanupTime) + logInfo("stageToInfos " + sizeBefore + " --> " + stageToInfos.size) + } + + def stop() { + eventQueue.put(StopDAGScheduler) + metadataCleaner.cancel() + taskSched.stop() + } +} diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala new file mode 100644 index 0000000000..0d99670648 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler + +import java.util.Properties + +import org.apache.spark.scheduler.cluster.TaskInfo +import scala.collection.mutable.Map + +import org.apache.spark._ +import org.apache.spark.rdd.RDD +import org.apache.spark.executor.TaskMetrics + +/** + * Types of events that can be handled by the DAGScheduler. The DAGScheduler uses an event queue + * architecture where any thread can post an event (e.g. a task finishing or a new job being + * submitted) but there is a single "logic" thread that reads these events and takes decisions. + * This greatly simplifies synchronization. + */ +private[spark] sealed trait DAGSchedulerEvent + +private[spark] case class JobSubmitted( + finalRDD: RDD[_], + func: (TaskContext, Iterator[_]) => _, + partitions: Array[Int], + allowLocal: Boolean, + callSite: String, + listener: JobListener, + properties: Properties = null) + extends DAGSchedulerEvent + +private[spark] case class BeginEvent(task: Task[_], taskInfo: TaskInfo) extends DAGSchedulerEvent + +private[spark] case class CompletionEvent( + task: Task[_], + reason: TaskEndReason, + result: Any, + accumUpdates: Map[Long, Any], + taskInfo: TaskInfo, + taskMetrics: TaskMetrics) + extends DAGSchedulerEvent + +private[spark] case class ExecutorGained(execId: String, host: String) extends DAGSchedulerEvent + +private[spark] case class ExecutorLost(execId: String) extends DAGSchedulerEvent + +private[spark] case class TaskSetFailed(taskSet: TaskSet, reason: String) extends DAGSchedulerEvent + +private[spark] case object StopDAGScheduler extends DAGSchedulerEvent diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerSource.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerSource.scala new file mode 100644 index 0000000000..22e3723ac8 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerSource.scala @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler + +import com.codahale.metrics.{Gauge,MetricRegistry} + +import org.apache.spark.metrics.source.Source + +private[spark] class DAGSchedulerSource(val dagScheduler: DAGScheduler) extends Source { + val metricRegistry = new MetricRegistry() + val sourceName = "DAGScheduler" + + metricRegistry.register(MetricRegistry.name("stage", "failedStages", "number"), new Gauge[Int] { + override def getValue: Int = dagScheduler.failed.size + }) + + metricRegistry.register(MetricRegistry.name("stage", "runningStages", "number"), new Gauge[Int] { + override def getValue: Int = dagScheduler.running.size + }) + + metricRegistry.register(MetricRegistry.name("stage", "waitingStages", "number"), new Gauge[Int] { + override def getValue: Int = dagScheduler.waiting.size + }) + + metricRegistry.register(MetricRegistry.name("job", "allJobs", "number"), new Gauge[Int] { + override def getValue: Int = dagScheduler.nextJobId.get() + }) + + metricRegistry.register(MetricRegistry.name("job", "activeJobs", "number"), new Gauge[Int] { + override def getValue: Int = dagScheduler.activeJobs.size + }) +} diff --git a/core/src/main/scala/org/apache/spark/scheduler/InputFormatInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/InputFormatInfo.scala new file mode 100644 index 0000000000..370ccd183c --- /dev/null +++ b/core/src/main/scala/org/apache/spark/scheduler/InputFormatInfo.scala @@ -0,0 +1,178 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler + +import org.apache.spark.{Logging, SparkEnv} +import scala.collection.immutable.Set +import org.apache.hadoop.mapred.{FileInputFormat, JobConf} +import org.apache.hadoop.security.UserGroupInformation +import org.apache.hadoop.util.ReflectionUtils +import org.apache.hadoop.mapreduce.Job +import org.apache.hadoop.conf.Configuration +import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} +import scala.collection.JavaConversions._ + + +/** + * Parses and holds information about inputFormat (and files) specified as a parameter. + */ +class InputFormatInfo(val configuration: Configuration, val inputFormatClazz: Class[_], + val path: String) extends Logging { + + var mapreduceInputFormat: Boolean = false + var mapredInputFormat: Boolean = false + + validate() + + override def toString(): String = { + "InputFormatInfo " + super.toString + " .. inputFormatClazz " + inputFormatClazz + ", path : " + path + } + + override def hashCode(): Int = { + var hashCode = inputFormatClazz.hashCode + hashCode = hashCode * 31 + path.hashCode + hashCode + } + + // Since we are not doing canonicalization of path, this can be wrong : like relative vs absolute path + // .. which is fine, this is best case effort to remove duplicates - right ? + override def equals(other: Any): Boolean = other match { + case that: InputFormatInfo => { + // not checking config - that should be fine, right ? + this.inputFormatClazz == that.inputFormatClazz && + this.path == that.path + } + case _ => false + } + + private def validate() { + logDebug("validate InputFormatInfo : " + inputFormatClazz + ", path " + path) + + try { + if (classOf[org.apache.hadoop.mapreduce.InputFormat[_, _]].isAssignableFrom(inputFormatClazz)) { + logDebug("inputformat is from mapreduce package") + mapreduceInputFormat = true + } + else if (classOf[org.apache.hadoop.mapred.InputFormat[_, _]].isAssignableFrom(inputFormatClazz)) { + logDebug("inputformat is from mapred package") + mapredInputFormat = true + } + else { + throw new IllegalArgumentException("Specified inputformat " + inputFormatClazz + + " is NOT a supported input format ? does not implement either of the supported hadoop api's") + } + } + catch { + case e: ClassNotFoundException => { + throw new IllegalArgumentException("Specified inputformat " + inputFormatClazz + " cannot be found ?", e) + } + } + } + + + // This method does not expect failures, since validate has already passed ... + private def prefLocsFromMapreduceInputFormat(): Set[SplitInfo] = { + val env = SparkEnv.get + val conf = new JobConf(configuration) + env.hadoop.addCredentials(conf) + FileInputFormat.setInputPaths(conf, path) + + val instance: org.apache.hadoop.mapreduce.InputFormat[_, _] = + ReflectionUtils.newInstance(inputFormatClazz.asInstanceOf[Class[_]], conf).asInstanceOf[ + org.apache.hadoop.mapreduce.InputFormat[_, _]] + val job = new Job(conf) + + val retval = new ArrayBuffer[SplitInfo]() + val list = instance.getSplits(job) + for (split <- list) { + retval ++= SplitInfo.toSplitInfo(inputFormatClazz, path, split) + } + + return retval.toSet + } + + // This method does not expect failures, since validate has already passed ... + private def prefLocsFromMapredInputFormat(): Set[SplitInfo] = { + val env = SparkEnv.get + val jobConf = new JobConf(configuration) + env.hadoop.addCredentials(jobConf) + FileInputFormat.setInputPaths(jobConf, path) + + val instance: org.apache.hadoop.mapred.InputFormat[_, _] = + ReflectionUtils.newInstance(inputFormatClazz.asInstanceOf[Class[_]], jobConf).asInstanceOf[ + org.apache.hadoop.mapred.InputFormat[_, _]] + + val retval = new ArrayBuffer[SplitInfo]() + instance.getSplits(jobConf, jobConf.getNumMapTasks()).foreach( + elem => retval ++= SplitInfo.toSplitInfo(inputFormatClazz, path, elem) + ) + + return retval.toSet + } + + private def findPreferredLocations(): Set[SplitInfo] = { + logDebug("mapreduceInputFormat : " + mapreduceInputFormat + ", mapredInputFormat : " + mapredInputFormat + + ", inputFormatClazz : " + inputFormatClazz) + if (mapreduceInputFormat) { + return prefLocsFromMapreduceInputFormat() + } + else { + assert(mapredInputFormat) + return prefLocsFromMapredInputFormat() + } + } +} + + + + +object InputFormatInfo { + /** + Computes the preferred locations based on input(s) and returned a location to block map. + Typical use of this method for allocation would follow some algo like this + (which is what we currently do in YARN branch) : + a) For each host, count number of splits hosted on that host. + b) Decrement the currently allocated containers on that host. + c) Compute rack info for each host and update rack -> count map based on (b). + d) Allocate nodes based on (c) + e) On the allocation result, ensure that we dont allocate "too many" jobs on a single node + (even if data locality on that is very high) : this is to prevent fragility of job if a single + (or small set of) hosts go down. + + go to (a) until required nodes are allocated. + + If a node 'dies', follow same procedure. + + PS: I know the wording here is weird, hopefully it makes some sense ! + */ + def computePreferredLocations(formats: Seq[InputFormatInfo]): HashMap[String, HashSet[SplitInfo]] = { + + val nodeToSplit = new HashMap[String, HashSet[SplitInfo]] + for (inputSplit <- formats) { + val splits = inputSplit.findPreferredLocations() + + for (split <- splits){ + val location = split.hostLocation + val set = nodeToSplit.getOrElseUpdate(location, new HashSet[SplitInfo]) + set += split + } + } + + nodeToSplit + } +} diff --git a/core/src/main/scala/org/apache/spark/scheduler/JobListener.scala b/core/src/main/scala/org/apache/spark/scheduler/JobListener.scala new file mode 100644 index 0000000000..50c2b9acd6 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/scheduler/JobListener.scala @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler + +/** + * Interface used to listen for job completion or failure events after submitting a job to the + * DAGScheduler. The listener is notified each time a task succeeds, as well as if the whole + * job fails (and no further taskSucceeded events will happen). + */ +private[spark] trait JobListener { + def taskSucceeded(index: Int, result: Any) + def jobFailed(exception: Exception) +} diff --git a/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala b/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala new file mode 100644 index 0000000000..c8b78bf00a --- /dev/null +++ b/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala @@ -0,0 +1,293 @@ +/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.scheduler
+
+import java.io.PrintWriter
+import java.io.File
+import java.io.FileNotFoundException
+import java.text.SimpleDateFormat
+import java.util.{Date, Properties}
+import java.util.concurrent.LinkedBlockingQueue
+
+import scala.collection.mutable.{Map, HashMap, ListBuffer}
+import scala.io.Source
+
+import org.apache.spark._
+import org.apache.spark.rdd.RDD
+import org.apache.spark.executor.TaskMetrics
+import org.apache.spark.scheduler.cluster.TaskInfo
+
+// Used to record runtime information for each job, including RDD graph
+// tasks' start/stop shuffle information and information from outside
+
+class JobLogger(val logDirName: String) extends SparkListener with Logging {
+ private val logDir =
+ if (System.getenv("SPARK_LOG_DIR") != null)
+ System.getenv("SPARK_LOG_DIR")
+ else
+ "/tmp/spark"
+ private val jobIDToPrintWriter = new HashMap[Int, PrintWriter]
+ private val stageIDToJobID = new HashMap[Int, Int]
+ private val jobIDToStages = new HashMap[Int, ListBuffer[Stage]]
+ private val DATE_FORMAT = new SimpleDateFormat("yyyy/MM/dd HH:mm:ss")
+ private val eventQueue = new LinkedBlockingQueue[SparkListenerEvents]
+
+ createLogDir()
+ def this() = this(String.valueOf(System.currentTimeMillis()))
+
+ def getLogDir = logDir
+ def getJobIDtoPrintWriter = jobIDToPrintWriter
+ def getStageIDToJobID = stageIDToJobID
+ def getJobIDToStages = jobIDToStages
+ def getEventQueue = eventQueue
+
+ // Create a folder for log files, the folder's name is the creation time of the jobLogger
+ protected def createLogDir() {
+ val dir = new File(logDir + "/" + logDirName + "/")
+ if (dir.exists()) {
+ return
+ }
+ if (dir.mkdirs() == false) {
+ logError("create log directory error:" + logDir + "/" + logDirName + "/")
+ }
+ }
+
+ // Create a log file for one job, the file name is the jobID
+ protected def createLogWriter(jobID: Int) {
+ try{
+ val fileWriter = new PrintWriter(logDir + "/" + logDirName + "/" + jobID)
+ jobIDToPrintWriter += (jobID -> fileWriter)
+ } catch {
+ case e: FileNotFoundException => e.printStackTrace()
+ }
+ }
+
+ // Close log file, and clean the stage relationship in stageIDToJobID
+ protected def closeLogWriter(jobID: Int) =
+ jobIDToPrintWriter.get(jobID).foreach { fileWriter =>
+ fileWriter.close()
+ jobIDToStages.get(jobID).foreach(_.foreach{ stage =>
+ stageIDToJobID -= stage.id
+ })
+ jobIDToPrintWriter -= jobID
+ jobIDToStages -= jobID
+ }
+
+ // Write log information to log file, withTime parameter controls whether to recored
+ // time stamp for the information
+ protected def jobLogInfo(jobID: Int, info: String, withTime: Boolean = true) {
+ var writeInfo = info
+ if (withTime) {
+ val date = new Date(System.currentTimeMillis())
+ writeInfo = DATE_FORMAT.format(date) + ": " +info
+ }
+ jobIDToPrintWriter.get(jobID).foreach(_.println(writeInfo))
+ }
+
+ protected def stageLogInfo(stageID: Int, info: String, withTime: Boolean = true) =
+ stageIDToJobID.get(stageID).foreach(jobID => jobLogInfo(jobID, info, withTime))
+
+ protected def buildJobDep(jobID: Int, stage: Stage) {
+ if (stage.jobId == jobID) {
+ jobIDToStages.get(jobID) match {
+ case Some(stageList) => stageList += stage
+ case None => val stageList = new ListBuffer[Stage]
+ stageList += stage
+ jobIDToStages += (jobID -> stageList)
+ }
+ stageIDToJobID += (stage.id -> jobID)
+ stage.parents.foreach(buildJobDep(jobID, _))
+ }
+ }
+
+ protected def recordStageDep(jobID: Int) {
+ def getRddsInStage(rdd: RDD[_]): ListBuffer[RDD[_]] = {
+ var rddList = new ListBuffer[RDD[_]]
+ rddList += rdd
+ rdd.dependencies.foreach{ dep => dep match {
+ case shufDep: ShuffleDependency[_,_] =>
+ case _ => rddList ++= getRddsInStage(dep.rdd)
+ }
+ }
+ rddList
+ }
+ jobIDToStages.get(jobID).foreach {_.foreach { stage =>
+ var depRddDesc: String = ""
+ getRddsInStage(stage.rdd).foreach { rdd =>
+ depRddDesc += rdd.id + ","
+ }
+ var depStageDesc: String = ""
+ stage.parents.foreach { stage =>
+ depStageDesc += "(" + stage.id + "," + stage.shuffleDep.get.shuffleId + ")"
+ }
+ jobLogInfo(jobID, "STAGE_ID=" + stage.id + " RDD_DEP=(" +
+ depRddDesc.substring(0, depRddDesc.length - 1) + ")" +
+ " STAGE_DEP=" + depStageDesc, false)
+ }
+ }
+ }
+
+ // Generate indents and convert to String
+ protected def indentString(indent: Int) = {
+ val sb = new StringBuilder()
+ for (i <- 1 to indent) {
+ sb.append(" ")
+ }
+ sb.toString()
+ }
+
+ protected def getRddName(rdd: RDD[_]) = {
+ var rddName = rdd.getClass.getName
+ if (rdd.name != null) {
+ rddName = rdd.name
+ }
+ rddName
+ }
+
+ protected def recordRddInStageGraph(jobID: Int, rdd: RDD[_], indent: Int) {
+ val rddInfo = "RDD_ID=" + rdd.id + "(" + getRddName(rdd) + "," + rdd.generator + ")"
+ jobLogInfo(jobID, indentString(indent) + rddInfo, false)
+ rdd.dependencies.foreach{ dep => dep match {
+ case shufDep: ShuffleDependency[_,_] =>
+ val depInfo = "SHUFFLE_ID=" + shufDep.shuffleId
+ jobLogInfo(jobID, indentString(indent + 1) + depInfo, false)
+ case _ => recordRddInStageGraph(jobID, dep.rdd, indent + 1)
+ }
+ }
+ }
+
+ protected def recordStageDepGraph(jobID: Int, stage: Stage, indent: Int = 0) {
+ var stageInfo: String = ""
+ if (stage.isShuffleMap) {
+ stageInfo = "STAGE_ID=" + stage.id + " MAP_STAGE SHUFFLE_ID=" +
+ stage.shuffleDep.get.shuffleId
+ }else{
+ stageInfo = "STAGE_ID=" + stage.id + " RESULT_STAGE"
+ }
+ if (stage.jobId == jobID) {
+ jobLogInfo(jobID, indentString(indent) + stageInfo, false)
+ recordRddInStageGraph(jobID, stage.rdd, indent)
+ stage.parents.foreach(recordStageDepGraph(jobID, _, indent + 2))
+ } else
+ jobLogInfo(jobID, indentString(indent) + stageInfo + " JOB_ID=" + stage.jobId, false)
+ }
+
+ // Record task metrics into job log files
+ protected def recordTaskMetrics(stageID: Int, status: String,
+ taskInfo: TaskInfo, taskMetrics: TaskMetrics) {
+ val info = " TID=" + taskInfo.taskId + " STAGE_ID=" + stageID +
+ " START_TIME=" + taskInfo.launchTime + " FINISH_TIME=" + taskInfo.finishTime +
+ " EXECUTOR_ID=" + taskInfo.executorId + " HOST=" + taskMetrics.hostname
+ val executorRunTime = " EXECUTOR_RUN_TIME=" + taskMetrics.executorRunTime
+ val readMetrics =
+ taskMetrics.shuffleReadMetrics match {
+ case Some(metrics) =>
+ " SHUFFLE_FINISH_TIME=" + metrics.shuffleFinishTime +
+ " BLOCK_FETCHED_TOTAL=" + metrics.totalBlocksFetched +
+ " BLOCK_FETCHED_LOCAL=" + metrics.localBlocksFetched +
+ " BLOCK_FETCHED_REMOTE=" + metrics.remoteBlocksFetched +
+ " REMOTE_FETCH_WAIT_TIME=" + metrics.fetchWaitTime +
+ " REMOTE_FETCH_TIME=" + metrics.remoteFetchTime +
+ " REMOTE_BYTES_READ=" + metrics.remoteBytesRead
+ case None => ""
+ }
+ val writeMetrics =
+ taskMetrics.shuffleWriteMetrics match {
+ case Some(metrics) =>
+ " SHUFFLE_BYTES_WRITTEN=" + metrics.shuffleBytesWritten
+ case None => ""
+ }
+ stageLogInfo(stageID, status + info + executorRunTime + readMetrics + writeMetrics)
+ }
+
+ override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted) {
+ stageLogInfo(
+ stageSubmitted.stage.id,
+ "STAGE_ID=%d STATUS=SUBMITTED TASK_SIZE=%d".format(
+ stageSubmitted.stage.id, stageSubmitted.taskSize))
+ }
+
+ override def onStageCompleted(stageCompleted: StageCompleted) {
+ stageLogInfo(
+ stageCompleted.stageInfo.stage.id,
+ "STAGE_ID=%d STATUS=COMPLETED".format(stageCompleted.stageInfo.stage.id))
+
+ }
+
+ override def onTaskStart(taskStart: SparkListenerTaskStart) { }
+
+ override def onTaskEnd(taskEnd: SparkListenerTaskEnd) {
+ val task = taskEnd.task
+ val taskInfo = taskEnd.taskInfo
+ var taskStatus = ""
+ task match {
+ case resultTask: ResultTask[_, _] => taskStatus = "TASK_TYPE=RESULT_TASK"
+ case shuffleMapTask: ShuffleMapTask => taskStatus = "TASK_TYPE=SHUFFLE_MAP_TASK"
+ }
+ taskEnd.reason match {
+ case Success => taskStatus += " STATUS=SUCCESS"
+ recordTaskMetrics(task.stageId, taskStatus, taskInfo, taskEnd.taskMetrics)
+ case Resubmitted =>
+ taskStatus += " STATUS=RESUBMITTED TID=" + taskInfo.taskId +
+ " STAGE_ID=" + task.stageId
+ stageLogInfo(task.stageId, taskStatus)
+ case FetchFailed(bmAddress, shuffleId, mapId, reduceId) =>
+ taskStatus += " STATUS=FETCHFAILED TID=" + taskInfo.taskId + " STAGE_ID=" +
+ task.stageId + " SHUFFLE_ID=" + shuffleId + " MAP_ID=" +
+ mapId + " REDUCE_ID=" + reduceId
+ stageLogInfo(task.stageId, taskStatus)
+ case OtherFailure(message) =>
+ taskStatus += " STATUS=FAILURE TID=" + taskInfo.taskId +
+ " STAGE_ID=" + task.stageId + " INFO=" + message
+ stageLogInfo(task.stageId, taskStatus)
+ case _ =>
+ }
+ }
+
+ override def onJobEnd(jobEnd: SparkListenerJobEnd) {
+ val job = jobEnd.job
+ var info = "JOB_ID=" + job.jobId
+ jobEnd.jobResult match {
+ case JobSucceeded => info += " STATUS=SUCCESS"
+ case JobFailed(exception, _) =>
+ info += " STATUS=FAILED REASON="
+ exception.getMessage.split("\\s+").foreach(info += _ + "_")
+ case _ =>
+ }
+ jobLogInfo(job.jobId, info.substring(0, info.length - 1).toUpperCase)
+ closeLogWriter(job.jobId)
+ }
+
+ protected def recordJobProperties(jobID: Int, properties: Properties) {
+ if(properties != null) {
+ val description = properties.getProperty(SparkContext.SPARK_JOB_DESCRIPTION, "")
+ jobLogInfo(jobID, description, false)
+ }
+ }
+
+ override def onJobStart(jobStart: SparkListenerJobStart) {
+ val job = jobStart.job
+ val properties = jobStart.properties
+ createLogWriter(job.jobId)
+ recordJobProperties(job.jobId, properties)
+ buildJobDep(job.jobId, job.finalStage)
+ recordStageDep(job.jobId)
+ recordStageDepGraph(job.jobId, job.finalStage)
+ jobLogInfo(job.jobId, "JOB_ID=" + job.jobId + " STATUS=STARTED")
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/JobResult.scala b/core/src/main/scala/org/apache/spark/scheduler/JobResult.scala new file mode 100644 index 0000000000..c381348a8d --- /dev/null +++ b/core/src/main/scala/org/apache/spark/scheduler/JobResult.scala @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler + +/** + * A result of a job in the DAGScheduler. + */ +private[spark] sealed trait JobResult + +private[spark] case object JobSucceeded extends JobResult +private[spark] case class JobFailed(exception: Exception, failedStage: Option[Stage]) extends JobResult diff --git a/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala b/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala new file mode 100644 index 0000000000..200d881799 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler + +import scala.collection.mutable.ArrayBuffer + +/** + * An object that waits for a DAGScheduler job to complete. As tasks finish, it passes their + * results to the given handler function. + */ +private[spark] class JobWaiter[T](totalTasks: Int, resultHandler: (Int, T) => Unit) + extends JobListener { + + private var finishedTasks = 0 + + private var jobFinished = false // Is the job as a whole finished (succeeded or failed)? + private var jobResult: JobResult = null // If the job is finished, this will be its result + + override def taskSucceeded(index: Int, result: Any) { + synchronized { + if (jobFinished) { + throw new UnsupportedOperationException("taskSucceeded() called on a finished JobWaiter") + } + resultHandler(index, result.asInstanceOf[T]) + finishedTasks += 1 + if (finishedTasks == totalTasks) { + jobFinished = true + jobResult = JobSucceeded + this.notifyAll() + } + } + } + + override def jobFailed(exception: Exception) { + synchronized { + if (jobFinished) { + throw new UnsupportedOperationException("jobFailed() called on a finished JobWaiter") + } + jobFinished = true + jobResult = JobFailed(exception, None) + this.notifyAll() + } + } + + def awaitResult(): JobResult = synchronized { + while (!jobFinished) { + this.wait() + } + return jobResult + } +} diff --git a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala new file mode 100644 index 0000000000..1c61687f28 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler + +import org.apache.spark.storage.BlockManagerId +import java.io.{ObjectOutput, ObjectInput, Externalizable} + +/** + * Result returned by a ShuffleMapTask to a scheduler. Includes the block manager address that the + * task ran on as well as the sizes of outputs for each reducer, for passing on to the reduce tasks. + * The map output sizes are compressed using MapOutputTracker.compressSize. + */ +private[spark] class MapStatus(var location: BlockManagerId, var compressedSizes: Array[Byte]) + extends Externalizable { + + def this() = this(null, null) // For deserialization only + + def writeExternal(out: ObjectOutput) { + location.writeExternal(out) + out.writeInt(compressedSizes.length) + out.write(compressedSizes) + } + + def readExternal(in: ObjectInput) { + location = BlockManagerId(in) + compressedSizes = new Array[Byte](in.readInt()) + in.readFully(compressedSizes) + } +} diff --git a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala new file mode 100644 index 0000000000..2b007cbe82 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala @@ -0,0 +1,137 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler + +import java.io._ +import java.util.zip.{GZIPInputStream, GZIPOutputStream} + +import org.apache.spark._ +import org.apache.spark.rdd.RDD +import org.apache.spark.rdd.RDDCheckpointData +import org.apache.spark.util.{MetadataCleaner, TimeStampedHashMap} + +private[spark] object ResultTask { + + // A simple map between the stage id to the serialized byte array of a task. + // Served as a cache for task serialization because serialization can be + // expensive on the master node if it needs to launch thousands of tasks. + val serializedInfoCache = new TimeStampedHashMap[Int, Array[Byte]] + + val metadataCleaner = new MetadataCleaner("ResultTask", serializedInfoCache.clearOldValues) + + def serializeInfo(stageId: Int, rdd: RDD[_], func: (TaskContext, Iterator[_]) => _): Array[Byte] = { + synchronized { + val old = serializedInfoCache.get(stageId).orNull + if (old != null) { + return old + } else { + val out = new ByteArrayOutputStream + val ser = SparkEnv.get.closureSerializer.newInstance + val objOut = ser.serializeStream(new GZIPOutputStream(out)) + objOut.writeObject(rdd) + objOut.writeObject(func) + objOut.close() + val bytes = out.toByteArray + serializedInfoCache.put(stageId, bytes) + return bytes + } + } + } + + def deserializeInfo(stageId: Int, bytes: Array[Byte]): (RDD[_], (TaskContext, Iterator[_]) => _) = { + val loader = Thread.currentThread.getContextClassLoader + val in = new GZIPInputStream(new ByteArrayInputStream(bytes)) + val ser = SparkEnv.get.closureSerializer.newInstance + val objIn = ser.deserializeStream(in) + val rdd = objIn.readObject().asInstanceOf[RDD[_]] + val func = objIn.readObject().asInstanceOf[(TaskContext, Iterator[_]) => _] + return (rdd, func) + } + + def clearCache() { + synchronized { + serializedInfoCache.clear() + } + } +} + + +private[spark] class ResultTask[T, U]( + stageId: Int, + var rdd: RDD[T], + var func: (TaskContext, Iterator[T]) => U, + var partition: Int, + @transient locs: Seq[TaskLocation], + val outputId: Int) + extends Task[U](stageId) with Externalizable { + + def this() = this(0, null, null, 0, null, 0) + + var split = if (rdd == null) { + null + } else { + rdd.partitions(partition) + } + + @transient private val preferredLocs: Seq[TaskLocation] = { + if (locs == null) Nil else locs.toSet.toSeq + } + + override def run(attemptId: Long): U = { + val context = new TaskContext(stageId, partition, attemptId) + metrics = Some(context.taskMetrics) + try { + func(context, rdd.iterator(split, context)) + } finally { + context.executeOnCompleteCallbacks() + } + } + + override def preferredLocations: Seq[TaskLocation] = preferredLocs + + override def toString = "ResultTask(" + stageId + ", " + partition + ")" + + override def writeExternal(out: ObjectOutput) { + RDDCheckpointData.synchronized { + split = rdd.partitions(partition) + out.writeInt(stageId) + val bytes = ResultTask.serializeInfo( + stageId, rdd, func.asInstanceOf[(TaskContext, Iterator[_]) => _]) + out.writeInt(bytes.length) + out.write(bytes) + out.writeInt(partition) + out.writeInt(outputId) + out.writeLong(epoch) + out.writeObject(split) + } + } + + override def readExternal(in: ObjectInput) { + val stageId = in.readInt() + val numBytes = in.readInt() + val bytes = new Array[Byte](numBytes) + in.readFully(bytes) + val (rdd_, func_) = ResultTask.deserializeInfo(stageId, bytes) + rdd = rdd_.asInstanceOf[RDD[T]] + func = func_.asInstanceOf[(TaskContext, Iterator[T]) => U] + partition = in.readInt() + val outputId = in.readInt() + epoch = in.readLong() + split = in.readObject().asInstanceOf[Partition] + } +} diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala new file mode 100644 index 0000000000..764775fede --- /dev/null +++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala @@ -0,0 +1,191 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler + +import java.io._ +import java.util.zip.{GZIPInputStream, GZIPOutputStream} + +import scala.collection.mutable.HashMap + +import org.apache.spark._ +import org.apache.spark.executor.ShuffleWriteMetrics +import org.apache.spark.storage._ +import org.apache.spark.util.{TimeStampedHashMap, MetadataCleaner} +import org.apache.spark.rdd.RDD +import org.apache.spark.rdd.RDDCheckpointData + + +private[spark] object ShuffleMapTask { + + // A simple map between the stage id to the serialized byte array of a task. + // Served as a cache for task serialization because serialization can be + // expensive on the master node if it needs to launch thousands of tasks. + val serializedInfoCache = new TimeStampedHashMap[Int, Array[Byte]] + + val metadataCleaner = new MetadataCleaner("ShuffleMapTask", serializedInfoCache.clearOldValues) + + def serializeInfo(stageId: Int, rdd: RDD[_], dep: ShuffleDependency[_,_]): Array[Byte] = { + synchronized { + val old = serializedInfoCache.get(stageId).orNull + if (old != null) { + return old + } else { + val out = new ByteArrayOutputStream + val ser = SparkEnv.get.closureSerializer.newInstance() + val objOut = ser.serializeStream(new GZIPOutputStream(out)) + objOut.writeObject(rdd) + objOut.writeObject(dep) + objOut.close() + val bytes = out.toByteArray + serializedInfoCache.put(stageId, bytes) + return bytes + } + } + } + + def deserializeInfo(stageId: Int, bytes: Array[Byte]): (RDD[_], ShuffleDependency[_,_]) = { + synchronized { + val loader = Thread.currentThread.getContextClassLoader + val in = new GZIPInputStream(new ByteArrayInputStream(bytes)) + val ser = SparkEnv.get.closureSerializer.newInstance() + val objIn = ser.deserializeStream(in) + val rdd = objIn.readObject().asInstanceOf[RDD[_]] + val dep = objIn.readObject().asInstanceOf[ShuffleDependency[_,_]] + return (rdd, dep) + } + } + + // Since both the JarSet and FileSet have the same format this is used for both. + def deserializeFileSet(bytes: Array[Byte]) : HashMap[String, Long] = { + val in = new GZIPInputStream(new ByteArrayInputStream(bytes)) + val objIn = new ObjectInputStream(in) + val set = objIn.readObject().asInstanceOf[Array[(String, Long)]].toMap + return (HashMap(set.toSeq: _*)) + } + + def clearCache() { + synchronized { + serializedInfoCache.clear() + } + } +} + +private[spark] class ShuffleMapTask( + stageId: Int, + var rdd: RDD[_], + var dep: ShuffleDependency[_,_], + var partition: Int, + @transient private var locs: Seq[TaskLocation]) + extends Task[MapStatus](stageId) + with Externalizable + with Logging { + + protected def this() = this(0, null, null, 0, null) + + @transient private val preferredLocs: Seq[TaskLocation] = { + if (locs == null) Nil else locs.toSet.toSeq + } + + var split = if (rdd == null) null else rdd.partitions(partition) + + override def writeExternal(out: ObjectOutput) { + RDDCheckpointData.synchronized { + split = rdd.partitions(partition) + out.writeInt(stageId) + val bytes = ShuffleMapTask.serializeInfo(stageId, rdd, dep) + out.writeInt(bytes.length) + out.write(bytes) + out.writeInt(partition) + out.writeLong(epoch) + out.writeObject(split) + } + } + + override def readExternal(in: ObjectInput) { + val stageId = in.readInt() + val numBytes = in.readInt() + val bytes = new Array[Byte](numBytes) + in.readFully(bytes) + val (rdd_, dep_) = ShuffleMapTask.deserializeInfo(stageId, bytes) + rdd = rdd_ + dep = dep_ + partition = in.readInt() + epoch = in.readLong() + split = in.readObject().asInstanceOf[Partition] + } + + override def run(attemptId: Long): MapStatus = { + val numOutputSplits = dep.partitioner.numPartitions + + val taskContext = new TaskContext(stageId, partition, attemptId) + metrics = Some(taskContext.taskMetrics) + + val blockManager = SparkEnv.get.blockManager + var shuffle: ShuffleBlocks = null + var buckets: ShuffleWriterGroup = null + + try { + // Obtain all the block writers for shuffle blocks. + val ser = SparkEnv.get.serializerManager.get(dep.serializerClass) + shuffle = blockManager.shuffleBlockManager.forShuffle(dep.shuffleId, numOutputSplits, ser) + buckets = shuffle.acquireWriters(partition) + + // Write the map output to its associated buckets. + for (elem <- rdd.iterator(split, taskContext)) { + val pair = elem.asInstanceOf[Product2[Any, Any]] + val bucketId = dep.partitioner.getPartition(pair._1) + buckets.writers(bucketId).write(pair) + } + + // Commit the writes. Get the size of each bucket block (total block size). + var totalBytes = 0L + val compressedSizes: Array[Byte] = buckets.writers.map { writer: BlockObjectWriter => + writer.commit() + writer.close() + val size = writer.size() + totalBytes += size + MapOutputTracker.compressSize(size) + } + + // Update shuffle metrics. + val shuffleMetrics = new ShuffleWriteMetrics + shuffleMetrics.shuffleBytesWritten = totalBytes + metrics.get.shuffleWriteMetrics = Some(shuffleMetrics) + + return new MapStatus(blockManager.blockManagerId, compressedSizes) + } catch { case e: Exception => + // If there is an exception from running the task, revert the partial writes + // and throw the exception upstream to Spark. + if (buckets != null) { + buckets.writers.foreach(_.revertPartialWrites()) + } + throw e + } finally { + // Release the writers back to the shuffle block manager. + if (shuffle != null && buckets != null) { + shuffle.releaseWriters(buckets) + } + // Execute the callbacks on task completion. + taskContext.executeOnCompleteCallbacks() + } + } + + override def preferredLocations: Seq[TaskLocation] = preferredLocs + + override def toString = "ShuffleMapTask(%d, %d)".format(stageId, partition) +} diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala new file mode 100644 index 0000000000..c3cf4b8907 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala @@ -0,0 +1,204 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler + +import java.util.Properties +import org.apache.spark.scheduler.cluster.TaskInfo +import org.apache.spark.util.{Utils, Distribution} +import org.apache.spark.{Logging, SparkContext, TaskEndReason} +import org.apache.spark.executor.TaskMetrics + +sealed trait SparkListenerEvents + +case class SparkListenerStageSubmitted(stage: Stage, taskSize: Int, properties: Properties) + extends SparkListenerEvents + +case class StageCompleted(val stageInfo: StageInfo) extends SparkListenerEvents + +case class SparkListenerTaskStart(task: Task[_], taskInfo: TaskInfo) extends SparkListenerEvents + +case class SparkListenerTaskEnd(task: Task[_], reason: TaskEndReason, taskInfo: TaskInfo, + taskMetrics: TaskMetrics) extends SparkListenerEvents + +case class SparkListenerJobStart(job: ActiveJob, properties: Properties = null) + extends SparkListenerEvents + +case class SparkListenerJobEnd(job: ActiveJob, jobResult: JobResult) + extends SparkListenerEvents + +trait SparkListener { + /** + * Called when a stage is completed, with information on the completed stage + */ + def onStageCompleted(stageCompleted: StageCompleted) { } + + /** + * Called when a stage is submitted + */ + def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted) { } + + /** + * Called when a task starts + */ + def onTaskStart(taskEnd: SparkListenerTaskStart) { } + + /** + * Called when a task ends + */ + def onTaskEnd(taskEnd: SparkListenerTaskEnd) { } + + /** + * Called when a job starts + */ + def onJobStart(jobStart: SparkListenerJobStart) { } + + /** + * Called when a job ends + */ + def onJobEnd(jobEnd: SparkListenerJobEnd) { } + +} + +/** + * Simple SparkListener that logs a few summary statistics when each stage completes + */ +class StatsReportListener extends SparkListener with Logging { + override def onStageCompleted(stageCompleted: StageCompleted) { + import org.apache.spark.scheduler.StatsReportListener._ + implicit val sc = stageCompleted + this.logInfo("Finished stage: " + stageCompleted.stageInfo) + showMillisDistribution("task runtime:", (info, _) => Some(info.duration)) + + //shuffle write + showBytesDistribution("shuffle bytes written:",(_,metric) => metric.shuffleWriteMetrics.map{_.shuffleBytesWritten}) + + //fetch & io + showMillisDistribution("fetch wait time:",(_, metric) => metric.shuffleReadMetrics.map{_.fetchWaitTime}) + showBytesDistribution("remote bytes read:", (_, metric) => metric.shuffleReadMetrics.map{_.remoteBytesRead}) + showBytesDistribution("task result size:", (_, metric) => Some(metric.resultSize)) + + //runtime breakdown + + val runtimePcts = stageCompleted.stageInfo.taskInfos.map{ + case (info, metrics) => RuntimePercentage(info.duration, metrics) + } + showDistribution("executor (non-fetch) time pct: ", Distribution(runtimePcts.map{_.executorPct * 100}), "%2.0f %%") + showDistribution("fetch wait time pct: ", Distribution(runtimePcts.flatMap{_.fetchPct.map{_ * 100}}), "%2.0f %%") + showDistribution("other time pct: ", Distribution(runtimePcts.map{_.other * 100}), "%2.0f %%") + } + +} + +object StatsReportListener extends Logging { + + //for profiling, the extremes are more interesting + val percentiles = Array[Int](0,5,10,25,50,75,90,95,100) + val probabilities = percentiles.map{_ / 100.0} + val percentilesHeader = "\t" + percentiles.mkString("%\t") + "%" + + def extractDoubleDistribution(stage:StageCompleted, getMetric: (TaskInfo,TaskMetrics) => Option[Double]): Option[Distribution] = { + Distribution(stage.stageInfo.taskInfos.flatMap{ + case ((info,metric)) => getMetric(info, metric)}) + } + + //is there some way to setup the types that I can get rid of this completely? + def extractLongDistribution(stage:StageCompleted, getMetric: (TaskInfo,TaskMetrics) => Option[Long]): Option[Distribution] = { + extractDoubleDistribution(stage, (info, metric) => getMetric(info,metric).map{_.toDouble}) + } + + def showDistribution(heading: String, d: Distribution, formatNumber: Double => String) { + val stats = d.statCounter + logInfo(heading + stats) + val quantiles = d.getQuantiles(probabilities).map{formatNumber} + logInfo(percentilesHeader) + logInfo("\t" + quantiles.mkString("\t")) + } + + def showDistribution(heading: String, dOpt: Option[Distribution], formatNumber: Double => String) { + dOpt.foreach { d => showDistribution(heading, d, formatNumber)} + } + + def showDistribution(heading: String, dOpt: Option[Distribution], format:String) { + def f(d:Double) = format.format(d) + showDistribution(heading, dOpt, f _) + } + + def showDistribution(heading:String, format: String, getMetric: (TaskInfo,TaskMetrics) => Option[Double]) + (implicit stage: StageCompleted) { + showDistribution(heading, extractDoubleDistribution(stage, getMetric), format) + } + + def showBytesDistribution(heading:String, getMetric: (TaskInfo,TaskMetrics) => Option[Long]) + (implicit stage: StageCompleted) { + showBytesDistribution(heading, extractLongDistribution(stage, getMetric)) + } + + def showBytesDistribution(heading: String, dOpt: Option[Distribution]) { + dOpt.foreach{dist => showBytesDistribution(heading, dist)} + } + + def showBytesDistribution(heading: String, dist: Distribution) { + showDistribution(heading, dist, (d => Utils.bytesToString(d.toLong)): Double => String) + } + + def showMillisDistribution(heading: String, dOpt: Option[Distribution]) { + showDistribution(heading, dOpt, (d => StatsReportListener.millisToString(d.toLong)): Double => String) + } + + def showMillisDistribution(heading: String, getMetric: (TaskInfo, TaskMetrics) => Option[Long]) + (implicit stage: StageCompleted) { + showMillisDistribution(heading, extractLongDistribution(stage, getMetric)) + } + + + + val seconds = 1000L + val minutes = seconds * 60 + val hours = minutes * 60 + + /** + * reformat a time interval in milliseconds to a prettier format for output + */ + def millisToString(ms: Long) = { + val (size, units) = + if (ms > hours) { + (ms.toDouble / hours, "hours") + } else if (ms > minutes) { + (ms.toDouble / minutes, "min") + } else if (ms > seconds) { + (ms.toDouble / seconds, "s") + } else { + (ms.toDouble, "ms") + } + "%.1f %s".format(size, units) + } +} + + + +case class RuntimePercentage(executorPct: Double, fetchPct: Option[Double], other: Double) +object RuntimePercentage { + def apply(totalTime: Long, metrics: TaskMetrics): RuntimePercentage = { + val denom = totalTime.toDouble + val fetchTime = metrics.shuffleReadMetrics.map{_.fetchWaitTime} + val fetch = fetchTime.map{_ / denom} + val exec = (metrics.executorRunTime - fetchTime.getOrElse(0l)) / denom + val other = 1.0 - (exec + fetch.getOrElse(0d)) + RuntimePercentage(exec, fetch, other) + } +} diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala new file mode 100644 index 0000000000..a65e1ecd6d --- /dev/null +++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala @@ -0,0 +1,74 @@ +/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.scheduler
+
+import java.util.concurrent.LinkedBlockingQueue
+
+import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer}
+
+import org.apache.spark.Logging
+
+/** Asynchronously passes SparkListenerEvents to registered SparkListeners. */
+private[spark] class SparkListenerBus() extends Logging {
+ private val sparkListeners = new ArrayBuffer[SparkListener]() with SynchronizedBuffer[SparkListener]
+
+ /* Cap the capacity of the SparkListenerEvent queue so we get an explicit error (rather than
+ * an OOM exception) if it's perpetually being added to more quickly than it's being drained. */
+ private val EVENT_QUEUE_CAPACITY = 10000
+ private val eventQueue = new LinkedBlockingQueue[SparkListenerEvents](EVENT_QUEUE_CAPACITY)
+ private var queueFullErrorMessageLogged = false
+
+ new Thread("SparkListenerBus") {
+ setDaemon(true)
+ override def run() {
+ while (true) {
+ val event = eventQueue.take
+ event match {
+ case stageSubmitted: SparkListenerStageSubmitted =>
+ sparkListeners.foreach(_.onStageSubmitted(stageSubmitted))
+ case stageCompleted: StageCompleted =>
+ sparkListeners.foreach(_.onStageCompleted(stageCompleted))
+ case jobStart: SparkListenerJobStart =>
+ sparkListeners.foreach(_.onJobStart(jobStart))
+ case jobEnd: SparkListenerJobEnd =>
+ sparkListeners.foreach(_.onJobEnd(jobEnd))
+ case taskStart: SparkListenerTaskStart =>
+ sparkListeners.foreach(_.onTaskStart(taskStart))
+ case taskEnd: SparkListenerTaskEnd =>
+ sparkListeners.foreach(_.onTaskEnd(taskEnd))
+ case _ =>
+ }
+ }
+ }
+ }.start()
+
+ def addListener(listener: SparkListener) {
+ sparkListeners += listener
+ }
+
+ def post(event: SparkListenerEvents) {
+ val eventAdded = eventQueue.offer(event)
+ if (!eventAdded && !queueFullErrorMessageLogged) {
+ logError("Dropping SparkListenerEvent because no remaining room in event queue. " +
+ "This likely means one of the SparkListeners is too slow and cannot keep up with the " +
+ "rate at which tasks are being started by the scheduler.")
+ queueFullErrorMessageLogged = true
+ }
+ }
+}
+
diff --git a/core/src/main/scala/org/apache/spark/scheduler/SplitInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/SplitInfo.scala new file mode 100644 index 0000000000..5b40a3eb29 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/scheduler/SplitInfo.scala @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler + +import collection.mutable.ArrayBuffer + +// information about a specific split instance : handles both split instances. +// So that we do not need to worry about the differences. +class SplitInfo(val inputFormatClazz: Class[_], val hostLocation: String, val path: String, + val length: Long, val underlyingSplit: Any) { + override def toString(): String = { + "SplitInfo " + super.toString + " .. inputFormatClazz " + inputFormatClazz + + ", hostLocation : " + hostLocation + ", path : " + path + + ", length : " + length + ", underlyingSplit " + underlyingSplit + } + + override def hashCode(): Int = { + var hashCode = inputFormatClazz.hashCode + hashCode = hashCode * 31 + hostLocation.hashCode + hashCode = hashCode * 31 + path.hashCode + // ignore overflow ? It is hashcode anyway ! + hashCode = hashCode * 31 + (length & 0x7fffffff).toInt + hashCode + } + + // This is practically useless since most of the Split impl's dont seem to implement equals :-( + // So unless there is identity equality between underlyingSplits, it will always fail even if it + // is pointing to same block. + override def equals(other: Any): Boolean = other match { + case that: SplitInfo => { + this.hostLocation == that.hostLocation && + this.inputFormatClazz == that.inputFormatClazz && + this.path == that.path && + this.length == that.length && + // other split specific checks (like start for FileSplit) + this.underlyingSplit == that.underlyingSplit + } + case _ => false + } +} + +object SplitInfo { + + def toSplitInfo(inputFormatClazz: Class[_], path: String, + mapredSplit: org.apache.hadoop.mapred.InputSplit): Seq[SplitInfo] = { + val retval = new ArrayBuffer[SplitInfo]() + val length = mapredSplit.getLength + for (host <- mapredSplit.getLocations) { + retval += new SplitInfo(inputFormatClazz, host, path, length, mapredSplit) + } + retval + } + + def toSplitInfo(inputFormatClazz: Class[_], path: String, + mapreduceSplit: org.apache.hadoop.mapreduce.InputSplit): Seq[SplitInfo] = { + val retval = new ArrayBuffer[SplitInfo]() + val length = mapreduceSplit.getLength + for (host <- mapreduceSplit.getLocations) { + retval += new SplitInfo(inputFormatClazz, host, path, length, mapreduceSplit) + } + retval + } +} diff --git a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala new file mode 100644 index 0000000000..aa293dc6b3 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala @@ -0,0 +1,111 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler + +import org.apache.spark._ +import org.apache.spark.rdd.RDD +import org.apache.spark.storage.BlockManagerId + +/** + * A stage is a set of independent tasks all computing the same function that need to run as part + * of a Spark job, where all the tasks have the same shuffle dependencies. Each DAG of tasks run + * by the scheduler is split up into stages at the boundaries where shuffle occurs, and then the + * DAGScheduler runs these stages in topological order. + * + * Each Stage can either be a shuffle map stage, in which case its tasks' results are input for + * another stage, or a result stage, in which case its tasks directly compute the action that + * initiated a job (e.g. count(), save(), etc). For shuffle map stages, we also track the nodes + * that each output partition is on. + * + * Each Stage also has a jobId, identifying the job that first submitted the stage. When FIFO + * scheduling is used, this allows Stages from earlier jobs to be computed first or recovered + * faster on failure. + */ +private[spark] class Stage( + val id: Int, + val rdd: RDD[_], + val shuffleDep: Option[ShuffleDependency[_,_]], // Output shuffle if stage is a map stage + val parents: List[Stage], + val jobId: Int, + callSite: Option[String]) + extends Logging { + + val isShuffleMap = shuffleDep != None + val numPartitions = rdd.partitions.size + val outputLocs = Array.fill[List[MapStatus]](numPartitions)(Nil) + var numAvailableOutputs = 0 + + /** When first task was submitted to scheduler. */ + var submissionTime: Option[Long] = None + var completionTime: Option[Long] = None + + private var nextAttemptId = 0 + + def isAvailable: Boolean = { + if (!isShuffleMap) { + true + } else { + numAvailableOutputs == numPartitions + } + } + + def addOutputLoc(partition: Int, status: MapStatus) { + val prevList = outputLocs(partition) + outputLocs(partition) = status :: prevList + if (prevList == Nil) + numAvailableOutputs += 1 + } + + def removeOutputLoc(partition: Int, bmAddress: BlockManagerId) { + val prevList = outputLocs(partition) + val newList = prevList.filterNot(_.location == bmAddress) + outputLocs(partition) = newList + if (prevList != Nil && newList == Nil) { + numAvailableOutputs -= 1 + } + } + + def removeOutputsOnExecutor(execId: String) { + var becameUnavailable = false + for (partition <- 0 until numPartitions) { + val prevList = outputLocs(partition) + val newList = prevList.filterNot(_.location.executorId == execId) + outputLocs(partition) = newList + if (prevList != Nil && newList == Nil) { + becameUnavailable = true + numAvailableOutputs -= 1 + } + } + if (becameUnavailable) { + logInfo("%s is now unavailable on executor %s (%d/%d, %s)".format( + this, execId, numAvailableOutputs, numPartitions, isAvailable)) + } + } + + def newAttemptId(): Int = { + val id = nextAttemptId + nextAttemptId += 1 + return id + } + + val name = callSite.getOrElse(rdd.origin) + + override def toString = "Stage " + id + + override def hashCode(): Int = id +} diff --git a/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala new file mode 100644 index 0000000000..72cb1c9ce8 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler + +import org.apache.spark.scheduler.cluster.TaskInfo +import scala.collection._ +import org.apache.spark.executor.TaskMetrics + +case class StageInfo( + val stage: Stage, + val taskInfos: mutable.Buffer[(TaskInfo, TaskMetrics)] = mutable.Buffer[(TaskInfo, TaskMetrics)]() +) { + override def toString = stage.rdd.toString +} diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala new file mode 100644 index 0000000000..598d91752a --- /dev/null +++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala @@ -0,0 +1,115 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler + +import org.apache.spark.serializer.SerializerInstance +import java.io.{DataInputStream, DataOutputStream} +import java.nio.ByteBuffer +import it.unimi.dsi.fastutil.io.FastByteArrayOutputStream +import org.apache.spark.util.ByteBufferInputStream +import scala.collection.mutable.HashMap +import org.apache.spark.executor.TaskMetrics + +/** + * A task to execute on a worker node. + */ +private[spark] abstract class Task[T](val stageId: Int) extends Serializable { + def run(attemptId: Long): T + def preferredLocations: Seq[TaskLocation] = Nil + + var epoch: Long = -1 // Map output tracker epoch. Will be set by TaskScheduler. + + var metrics: Option[TaskMetrics] = None + +} + +/** + * Handles transmission of tasks and their dependencies, because this can be slightly tricky. We + * need to send the list of JARs and files added to the SparkContext with each task to ensure that + * worker nodes find out about it, but we can't make it part of the Task because the user's code in + * the task might depend on one of the JARs. Thus we serialize each task as multiple objects, by + * first writing out its dependencies. + */ +private[spark] object Task { + /** + * Serialize a task and the current app dependencies (files and JARs added to the SparkContext) + */ + def serializeWithDependencies( + task: Task[_], + currentFiles: HashMap[String, Long], + currentJars: HashMap[String, Long], + serializer: SerializerInstance) + : ByteBuffer = { + + val out = new FastByteArrayOutputStream(4096) + val dataOut = new DataOutputStream(out) + + // Write currentFiles + dataOut.writeInt(currentFiles.size) + for ((name, timestamp) <- currentFiles) { + dataOut.writeUTF(name) + dataOut.writeLong(timestamp) + } + + // Write currentJars + dataOut.writeInt(currentJars.size) + for ((name, timestamp) <- currentJars) { + dataOut.writeUTF(name) + dataOut.writeLong(timestamp) + } + + // Write the task itself and finish + dataOut.flush() + val taskBytes = serializer.serialize(task).array() + out.write(taskBytes) + out.trim() + ByteBuffer.wrap(out.array) + } + + /** + * Deserialize the list of dependencies in a task serialized with serializeWithDependencies, + * and return the task itself as a serialized ByteBuffer. The caller can then update its + * ClassLoaders and deserialize the task. + * + * @return (taskFiles, taskJars, taskBytes) + */ + def deserializeWithDependencies(serializedTask: ByteBuffer) + : (HashMap[String, Long], HashMap[String, Long], ByteBuffer) = { + + val in = new ByteBufferInputStream(serializedTask) + val dataIn = new DataInputStream(in) + + // Read task's files + val taskFiles = new HashMap[String, Long]() + val numFiles = dataIn.readInt() + for (i <- 0 until numFiles) { + taskFiles(dataIn.readUTF()) = dataIn.readLong() + } + + // Read task's JARs + val taskJars = new HashMap[String, Long]() + val numJars = dataIn.readInt() + for (i <- 0 until numJars) { + taskJars(dataIn.readUTF()) = dataIn.readLong() + } + + // Create a sub-buffer for the rest of the data, which is the serialized Task object + val subBuffer = serializedTask.slice() // ByteBufferInputStream will have read just up to task + (taskFiles, taskJars, subBuffer) + } +} diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskLocation.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskLocation.scala new file mode 100644 index 0000000000..67c9a6760b --- /dev/null +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskLocation.scala @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler + +/** + * A location where a task should run. This can either be a host or a (host, executorID) pair. + * In the latter case, we will prefer to launch the task on that executorID, but our next level + * of preference will be executors on the same host if this is not possible. + */ +private[spark] +class TaskLocation private (val host: String, val executorId: Option[String]) extends Serializable { + override def toString: String = "TaskLocation(" + host + ", " + executorId + ")" +} + +private[spark] object TaskLocation { + def apply(host: String, executorId: String) = new TaskLocation(host, Some(executorId)) + + def apply(host: String) = new TaskLocation(host, None) +} diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala new file mode 100644 index 0000000000..5c7e5bb977 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler + +import java.io._ + +import scala.collection.mutable.Map +import org.apache.spark.executor.TaskMetrics +import org.apache.spark.{SparkEnv} +import java.nio.ByteBuffer +import org.apache.spark.util.Utils + +// Task result. Also contains updates to accumulator variables. +// TODO: Use of distributed cache to return result is a hack to get around +// what seems to be a bug with messages over 60KB in libprocess; fix it +private[spark] +class TaskResult[T](var value: T, var accumUpdates: Map[Long, Any], var metrics: TaskMetrics) + extends Externalizable +{ + def this() = this(null.asInstanceOf[T], null, null) + + override def writeExternal(out: ObjectOutput) { + + val objectSer = SparkEnv.get.serializer.newInstance() + val bb = objectSer.serialize(value) + + out.writeInt(bb.remaining()) + Utils.writeByteBuffer(bb, out) + + out.writeInt(accumUpdates.size) + for ((key, value) <- accumUpdates) { + out.writeLong(key) + out.writeObject(value) + } + out.writeObject(metrics) + } + + override def readExternal(in: ObjectInput) { + + val objectSer = SparkEnv.get.serializer.newInstance() + + val blen = in.readInt() + val byteVal = new Array[Byte](blen) + in.readFully(byteVal) + value = objectSer.deserialize(ByteBuffer.wrap(byteVal)) + + val numUpdates = in.readInt + if (numUpdates == 0) { + accumUpdates = null + } else { + accumUpdates = Map() + for (i <- 0 until numUpdates) { + accumUpdates(in.readLong()) = in.readObject() + } + } + metrics = in.readObject().asInstanceOf[TaskMetrics] + } +} diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala new file mode 100644 index 0000000000..63be8ba3f5 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler + +import org.apache.spark.scheduler.cluster.Pool +import org.apache.spark.scheduler.cluster.SchedulingMode.SchedulingMode +/** + * Low-level task scheduler interface, implemented by both ClusterScheduler and LocalScheduler. + * These schedulers get sets of tasks submitted to them from the DAGScheduler for each stage, + * and are responsible for sending the tasks to the cluster, running them, retrying if there + * are failures, and mitigating stragglers. They return events to the DAGScheduler through + * the TaskSchedulerListener interface. + */ +private[spark] trait TaskScheduler { + + def rootPool: Pool + + def schedulingMode: SchedulingMode + + def start(): Unit + + // Invoked after system has successfully initialized (typically in spark context). + // Yarn uses this to bootstrap allocation of resources based on preferred locations, wait for slave registerations, etc. + def postStartHook() { } + + // Disconnect from the cluster. + def stop(): Unit + + // Submit a sequence of tasks to run. + def submitTasks(taskSet: TaskSet): Unit + + // Set a listener for upcalls. This is guaranteed to be set before submitTasks is called. + def setListener(listener: TaskSchedulerListener): Unit + + // Get the default level of parallelism to use in the cluster, as a hint for sizing jobs. + def defaultParallelism(): Int +} diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerListener.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerListener.scala new file mode 100644 index 0000000000..83be051c1a --- /dev/null +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerListener.scala @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler + +import org.apache.spark.scheduler.cluster.TaskInfo +import scala.collection.mutable.Map + +import org.apache.spark.TaskEndReason +import org.apache.spark.executor.TaskMetrics + +/** + * Interface for getting events back from the TaskScheduler. + */ +private[spark] trait TaskSchedulerListener { + // A task has started. + def taskStarted(task: Task[_], taskInfo: TaskInfo) + + // A task has finished or failed. + def taskEnded(task: Task[_], reason: TaskEndReason, result: Any, accumUpdates: Map[Long, Any], + taskInfo: TaskInfo, taskMetrics: TaskMetrics): Unit + + // A node was added to the cluster. + def executorGained(execId: String, host: String): Unit + + // A node was lost from the cluster. + def executorLost(execId: String): Unit + + // The TaskScheduler wants to abort an entire task set. + def taskSetFailed(taskSet: TaskSet, reason: String): Unit +} diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSet.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSet.scala new file mode 100644 index 0000000000..c3ad325156 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSet.scala @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler + +import java.util.Properties + +/** + * A set of tasks submitted together to the low-level TaskScheduler, usually representing + * missing partitions of a particular stage. + */ +private[spark] class TaskSet( + val tasks: Array[Task[_]], + val stageId: Int, + val attempt: Int, + val priority: Int, + val properties: Properties) { + val id: String = stageId + "." + attempt + + override def toString: String = "TaskSet " + id +} diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala new file mode 100644 index 0000000000..3196ab5022 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala @@ -0,0 +1,440 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler.cluster + +import java.lang.{Boolean => JBoolean} + +import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.HashMap +import scala.collection.mutable.HashSet + +import org.apache.spark._ +import org.apache.spark.TaskState.TaskState +import org.apache.spark.scheduler._ +import org.apache.spark.scheduler.cluster.SchedulingMode.SchedulingMode +import java.nio.ByteBuffer +import java.util.concurrent.atomic.AtomicLong +import java.util.{TimerTask, Timer} + +/** + * The main TaskScheduler implementation, for running tasks on a cluster. Clients should first call + * initialize() and start(), then submit task sets through the runTasks method. + * + * This class can work with multiple types of clusters by acting through a SchedulerBackend. + * It handles common logic, like determining a scheduling order across jobs, waking up to launch + * speculative tasks, etc. + * + * THREADING: SchedulerBackends and task-submitting clients can call this class from multiple + * threads, so it needs locks in public API methods to maintain its state. In addition, some + * SchedulerBackends sycnchronize on themselves when they want to send events here, and then + * acquire a lock on us, so we need to make sure that we don't try to lock the backend while + * we are holding a lock on ourselves. + */ +private[spark] class ClusterScheduler(val sc: SparkContext) + extends TaskScheduler + with Logging +{ + // How often to check for speculative tasks + val SPECULATION_INTERVAL = System.getProperty("spark.speculation.interval", "100").toLong + + // Threshold above which we warn user initial TaskSet may be starved + val STARVATION_TIMEOUT = System.getProperty("spark.starvation.timeout", "15000").toLong + + val activeTaskSets = new HashMap[String, TaskSetManager] + + val taskIdToTaskSetId = new HashMap[Long, String] + val taskIdToExecutorId = new HashMap[Long, String] + val taskSetTaskIds = new HashMap[String, HashSet[Long]] + + @volatile private var hasReceivedTask = false + @volatile private var hasLaunchedTask = false + private val starvationTimer = new Timer(true) + + // Incrementing Mesos task IDs + val nextTaskId = new AtomicLong(0) + + // Which executor IDs we have executors on + val activeExecutorIds = new HashSet[String] + + // The set of executors we have on each host; this is used to compute hostsAlive, which + // in turn is used to decide when we can attain data locality on a given host + private val executorsByHost = new HashMap[String, HashSet[String]] + + private val executorIdToHost = new HashMap[String, String] + + // JAR server, if any JARs were added by the user to the SparkContext + var jarServer: HttpServer = null + + // URIs of JARs to pass to executor + var jarUris: String = "" + + // Listener object to pass upcalls into + var listener: TaskSchedulerListener = null + + var backend: SchedulerBackend = null + + val mapOutputTracker = SparkEnv.get.mapOutputTracker + + var schedulableBuilder: SchedulableBuilder = null + var rootPool: Pool = null + // default scheduler is FIFO + val schedulingMode: SchedulingMode = SchedulingMode.withName( + System.getProperty("spark.cluster.schedulingmode", "FIFO")) + + override def setListener(listener: TaskSchedulerListener) { + this.listener = listener + } + + def initialize(context: SchedulerBackend) { + backend = context + // temporarily set rootPool name to empty + rootPool = new Pool("", schedulingMode, 0, 0) + schedulableBuilder = { + schedulingMode match { + case SchedulingMode.FIFO => + new FIFOSchedulableBuilder(rootPool) + case SchedulingMode.FAIR => + new FairSchedulableBuilder(rootPool) + } + } + schedulableBuilder.buildPools() + } + + def newTaskId(): Long = nextTaskId.getAndIncrement() + + override def start() { + backend.start() + + if (System.getProperty("spark.speculation", "false").toBoolean) { + new Thread("ClusterScheduler speculation check") { + setDaemon(true) + + override def run() { + logInfo("Starting speculative execution thread") + while (true) { + try { + Thread.sleep(SPECULATION_INTERVAL) + } catch { + case e: InterruptedException => {} + } + checkSpeculatableTasks() + } + } + }.start() + } + } + + override def submitTasks(taskSet: TaskSet) { + val tasks = taskSet.tasks + logInfo("Adding task set " + taskSet.id + " with " + tasks.length + " tasks") + this.synchronized { + val manager = new ClusterTaskSetManager(this, taskSet) + activeTaskSets(taskSet.id) = manager + schedulableBuilder.addTaskSetManager(manager, manager.taskSet.properties) + taskSetTaskIds(taskSet.id) = new HashSet[Long]() + + if (!hasReceivedTask) { + starvationTimer.scheduleAtFixedRate(new TimerTask() { + override def run() { + if (!hasLaunchedTask) { + logWarning("Initial job has not accepted any resources; " + + "check your cluster UI to ensure that workers are registered " + + "and have sufficient memory") + } else { + this.cancel() + } + } + }, STARVATION_TIMEOUT, STARVATION_TIMEOUT) + } + hasReceivedTask = true + } + backend.reviveOffers() + } + + def taskSetFinished(manager: TaskSetManager) { + this.synchronized { + activeTaskSets -= manager.taskSet.id + manager.parent.removeSchedulable(manager) + logInfo("Remove TaskSet %s from pool %s".format(manager.taskSet.id, manager.parent.name)) + taskIdToTaskSetId --= taskSetTaskIds(manager.taskSet.id) + taskIdToExecutorId --= taskSetTaskIds(manager.taskSet.id) + taskSetTaskIds.remove(manager.taskSet.id) + } + } + + /** + * Called by cluster manager to offer resources on slaves. We respond by asking our active task + * sets for tasks in order of priority. We fill each node with tasks in a round-robin manner so + * that tasks are balanced across the cluster. + */ + def resourceOffers(offers: Seq[WorkerOffer]): Seq[Seq[TaskDescription]] = synchronized { + SparkEnv.set(sc.env) + + // Mark each slave as alive and remember its hostname + for (o <- offers) { + executorIdToHost(o.executorId) = o.host + if (!executorsByHost.contains(o.host)) { + executorsByHost(o.host) = new HashSet[String]() + executorGained(o.executorId, o.host) + } + } + + // Build a list of tasks to assign to each worker + val tasks = offers.map(o => new ArrayBuffer[TaskDescription](o.cores)) + val availableCpus = offers.map(o => o.cores).toArray + val sortedTaskSets = rootPool.getSortedTaskSetQueue() + for (taskSet <- sortedTaskSets) { + logDebug("parentName: %s, name: %s, runningTasks: %s".format( + taskSet.parent.name, taskSet.name, taskSet.runningTasks)) + } + + // Take each TaskSet in our scheduling order, and then offer it each node in increasing order + // of locality levels so that it gets a chance to launch local tasks on all of them. + var launchedTask = false + for (taskSet <- sortedTaskSets; maxLocality <- TaskLocality.values) { + do { + launchedTask = false + for (i <- 0 until offers.size) { + val execId = offers(i).executorId + val host = offers(i).host + for (task <- taskSet.resourceOffer(execId, host, availableCpus(i), maxLocality)) { + tasks(i) += task + val tid = task.taskId + taskIdToTaskSetId(tid) = taskSet.taskSet.id + taskSetTaskIds(taskSet.taskSet.id) += tid + taskIdToExecutorId(tid) = execId + activeExecutorIds += execId + executorsByHost(host) += execId + availableCpus(i) -= 1 + launchedTask = true + } + } + } while (launchedTask) + } + + if (tasks.size > 0) { + hasLaunchedTask = true + } + return tasks + } + + def statusUpdate(tid: Long, state: TaskState, serializedData: ByteBuffer) { + var taskSetToUpdate: Option[TaskSetManager] = None + var failedExecutor: Option[String] = None + var taskFailed = false + synchronized { + try { + if (state == TaskState.LOST && taskIdToExecutorId.contains(tid)) { + // We lost this entire executor, so remember that it's gone + val execId = taskIdToExecutorId(tid) + if (activeExecutorIds.contains(execId)) { + removeExecutor(execId) + failedExecutor = Some(execId) + } + } + taskIdToTaskSetId.get(tid) match { + case Some(taskSetId) => + if (activeTaskSets.contains(taskSetId)) { + taskSetToUpdate = Some(activeTaskSets(taskSetId)) + } + if (TaskState.isFinished(state)) { + taskIdToTaskSetId.remove(tid) + if (taskSetTaskIds.contains(taskSetId)) { + taskSetTaskIds(taskSetId) -= tid + } + taskIdToExecutorId.remove(tid) + } + if (state == TaskState.FAILED) { + taskFailed = true + } + case None => + logInfo("Ignoring update from TID " + tid + " because its task set is gone") + } + } catch { + case e: Exception => logError("Exception in statusUpdate", e) + } + } + // Update the task set and DAGScheduler without holding a lock on this, since that can deadlock + if (taskSetToUpdate != None) { + taskSetToUpdate.get.statusUpdate(tid, state, serializedData) + } + if (failedExecutor != None) { + listener.executorLost(failedExecutor.get) + backend.reviveOffers() + } + if (taskFailed) { + // Also revive offers if a task had failed for some reason other than host lost + backend.reviveOffers() + } + } + + def error(message: String) { + synchronized { + if (activeTaskSets.size > 0) { + // Have each task set throw a SparkException with the error + for ((taskSetId, manager) <- activeTaskSets) { + try { + manager.error(message) + } catch { + case e: Exception => logError("Exception in error callback", e) + } + } + } else { + // No task sets are active but we still got an error. Just exit since this + // must mean the error is during registration. + // It might be good to do something smarter here in the future. + logError("Exiting due to error from cluster scheduler: " + message) + System.exit(1) + } + } + } + + override def stop() { + if (backend != null) { + backend.stop() + } + if (jarServer != null) { + jarServer.stop() + } + + // sleeping for an arbitrary 5 seconds : to ensure that messages are sent out. + // TODO: Do something better ! + Thread.sleep(5000L) + } + + override def defaultParallelism() = backend.defaultParallelism() + + + // Check for speculatable tasks in all our active jobs. + def checkSpeculatableTasks() { + var shouldRevive = false + synchronized { + shouldRevive = rootPool.checkSpeculatableTasks() + } + if (shouldRevive) { + backend.reviveOffers() + } + } + + // Check for pending tasks in all our active jobs. + def hasPendingTasks: Boolean = { + synchronized { + rootPool.hasPendingTasks() + } + } + + def executorLost(executorId: String, reason: ExecutorLossReason) { + var failedExecutor: Option[String] = None + + synchronized { + if (activeExecutorIds.contains(executorId)) { + val hostPort = executorIdToHost(executorId) + logError("Lost executor %s on %s: %s".format(executorId, hostPort, reason)) + removeExecutor(executorId) + failedExecutor = Some(executorId) + } else { + // We may get multiple executorLost() calls with different loss reasons. For example, one + // may be triggered by a dropped connection from the slave while another may be a report + // of executor termination from Mesos. We produce log messages for both so we eventually + // report the termination reason. + logError("Lost an executor " + executorId + " (already removed): " + reason) + } + } + // Call listener.executorLost without holding the lock on this to prevent deadlock + if (failedExecutor != None) { + listener.executorLost(failedExecutor.get) + backend.reviveOffers() + } + } + + /** Remove an executor from all our data structures and mark it as lost */ + private def removeExecutor(executorId: String) { + activeExecutorIds -= executorId + val host = executorIdToHost(executorId) + val execs = executorsByHost.getOrElse(host, new HashSet) + execs -= executorId + if (execs.isEmpty) { + executorsByHost -= host + } + executorIdToHost -= executorId + rootPool.executorLost(executorId, host) + } + + def executorGained(execId: String, host: String) { + listener.executorGained(execId, host) + } + + def getExecutorsAliveOnHost(host: String): Option[Set[String]] = synchronized { + executorsByHost.get(host).map(_.toSet) + } + + def hasExecutorsAliveOnHost(host: String): Boolean = synchronized { + executorsByHost.contains(host) + } + + def isExecutorAlive(execId: String): Boolean = synchronized { + activeExecutorIds.contains(execId) + } + + // By default, rack is unknown + def getRackForHost(value: String): Option[String] = None +} + + +object ClusterScheduler { + /** + * Used to balance containers across hosts. + * + * Accepts a map of hosts to resource offers for that host, and returns a prioritized list of + * resource offers representing the order in which the offers should be used. The resource + * offers are ordered such that we'll allocate one container on each host before allocating a + * second container on any host, and so on, in order to reduce the damage if a host fails. + * + * For example, given <h1, [o1, o2, o3]>, <h2, [o4]>, <h1, [o5, o6]>, returns + * [o1, o5, o4, 02, o6, o3] + */ + def prioritizeContainers[K, T] (map: HashMap[K, ArrayBuffer[T]]): List[T] = { + val _keyList = new ArrayBuffer[K](map.size) + _keyList ++= map.keys + + // order keyList based on population of value in map + val keyList = _keyList.sortWith( + (left, right) => map(left).size > map(right).size + ) + + val retval = new ArrayBuffer[T](keyList.size * 2) + var index = 0 + var found = true + + while (found) { + found = false + for (key <- keyList) { + val containerList: ArrayBuffer[T] = map.get(key).getOrElse(null) + assert(containerList != null) + // Get the index'th entry for this host - if present + if (index < containerList.size){ + retval += containerList.apply(index) + found = true + } + } + index += 1 + } + + retval.toList + } +} diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManager.scala new file mode 100644 index 0000000000..1b31c8c57e --- /dev/null +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManager.scala @@ -0,0 +1,712 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler.cluster + +import java.nio.ByteBuffer +import java.util.{Arrays, NoSuchElementException} + +import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.HashMap +import scala.collection.mutable.HashSet +import scala.math.max +import scala.math.min + +import org.apache.spark.{FetchFailed, Logging, Resubmitted, SparkEnv, Success, TaskEndReason, TaskState} +import org.apache.spark.{ExceptionFailure, SparkException, TaskResultTooBigFailure} +import org.apache.spark.TaskState.TaskState +import org.apache.spark.scheduler._ +import scala.Some +import org.apache.spark.FetchFailed +import org.apache.spark.ExceptionFailure +import org.apache.spark.TaskResultTooBigFailure +import org.apache.spark.util.{SystemClock, Clock} + + +/** + * Schedules the tasks within a single TaskSet in the ClusterScheduler. This class keeps track of + * the status of each task, retries tasks if they fail (up to a limited number of times), and + * handles locality-aware scheduling for this TaskSet via delay scheduling. The main interfaces + * to it are resourceOffer, which asks the TaskSet whether it wants to run a task on one node, + * and statusUpdate, which tells it that one of its tasks changed state (e.g. finished). + * + * THREADING: This class is designed to only be called from code with a lock on the + * ClusterScheduler (e.g. its event handlers). It should not be called from other threads. + */ +private[spark] class ClusterTaskSetManager( + sched: ClusterScheduler, + val taskSet: TaskSet, + clock: Clock = SystemClock) + extends TaskSetManager + with Logging +{ + // CPUs to request per task + val CPUS_PER_TASK = System.getProperty("spark.task.cpus", "1").toInt + + // Maximum times a task is allowed to fail before failing the job + val MAX_TASK_FAILURES = System.getProperty("spark.task.maxFailures", "4").toInt + + // Quantile of tasks at which to start speculation + val SPECULATION_QUANTILE = System.getProperty("spark.speculation.quantile", "0.75").toDouble + val SPECULATION_MULTIPLIER = System.getProperty("spark.speculation.multiplier", "1.5").toDouble + + // Serializer for closures and tasks. + val env = SparkEnv.get + val ser = env.closureSerializer.newInstance() + + val tasks = taskSet.tasks + val numTasks = tasks.length + val copiesRunning = new Array[Int](numTasks) + val finished = new Array[Boolean](numTasks) + val numFailures = new Array[Int](numTasks) + val taskAttempts = Array.fill[List[TaskInfo]](numTasks)(Nil) + var tasksFinished = 0 + + var weight = 1 + var minShare = 0 + var runningTasks = 0 + var priority = taskSet.priority + var stageId = taskSet.stageId + var name = "TaskSet_"+taskSet.stageId.toString + var parent: Schedulable = null + + // Set of pending tasks for each executor. These collections are actually + // treated as stacks, in which new tasks are added to the end of the + // ArrayBuffer and removed from the end. This makes it faster to detect + // tasks that repeatedly fail because whenever a task failed, it is put + // back at the head of the stack. They are also only cleaned up lazily; + // when a task is launched, it remains in all the pending lists except + // the one that it was launched from, but gets removed from them later. + private val pendingTasksForExecutor = new HashMap[String, ArrayBuffer[Int]] + + // Set of pending tasks for each host. Similar to pendingTasksForExecutor, + // but at host level. + private val pendingTasksForHost = new HashMap[String, ArrayBuffer[Int]] + + // Set of pending tasks for each rack -- similar to the above. + private val pendingTasksForRack = new HashMap[String, ArrayBuffer[Int]] + + // Set containing pending tasks with no locality preferences. + val pendingTasksWithNoPrefs = new ArrayBuffer[Int] + + // Set containing all pending tasks (also used as a stack, as above). + val allPendingTasks = new ArrayBuffer[Int] + + // Tasks that can be speculated. Since these will be a small fraction of total + // tasks, we'll just hold them in a HashSet. + val speculatableTasks = new HashSet[Int] + + // Task index, start and finish time for each task attempt (indexed by task ID) + val taskInfos = new HashMap[Long, TaskInfo] + + // Did the TaskSet fail? + var failed = false + var causeOfFailure = "" + + // How frequently to reprint duplicate exceptions in full, in milliseconds + val EXCEPTION_PRINT_INTERVAL = + System.getProperty("spark.logging.exceptionPrintInterval", "10000").toLong + + // Map of recent exceptions (identified by string representation and top stack frame) to + // duplicate count (how many times the same exception has appeared) and time the full exception + // was printed. This should ideally be an LRU map that can drop old exceptions automatically. + val recentExceptions = HashMap[String, (Int, Long)]() + + // Figure out the current map output tracker epoch and set it on all tasks + val epoch = sched.mapOutputTracker.getEpoch + logDebug("Epoch for " + taskSet + ": " + epoch) + for (t <- tasks) { + t.epoch = epoch + } + + // Add all our tasks to the pending lists. We do this in reverse order + // of task index so that tasks with low indices get launched first. + for (i <- (0 until numTasks).reverse) { + addPendingTask(i) + } + + // Figure out which locality levels we have in our TaskSet, so we can do delay scheduling + val myLocalityLevels = computeValidLocalityLevels() + val localityWaits = myLocalityLevels.map(getLocalityWait) // Time to wait at each level + + // Delay scheduling variables: we keep track of our current locality level and the time we + // last launched a task at that level, and move up a level when localityWaits[curLevel] expires. + // We then move down if we manage to launch a "more local" task. + var currentLocalityIndex = 0 // Index of our current locality level in validLocalityLevels + var lastLaunchTime = clock.getTime() // Time we last launched a task at this level + + /** + * Add a task to all the pending-task lists that it should be on. If readding is set, we are + * re-adding the task so only include it in each list if it's not already there. + */ + private def addPendingTask(index: Int, readding: Boolean = false) { + // Utility method that adds `index` to a list only if readding=false or it's not already there + def addTo(list: ArrayBuffer[Int]) { + if (!readding || !list.contains(index)) { + list += index + } + } + + var hadAliveLocations = false + for (loc <- tasks(index).preferredLocations) { + for (execId <- loc.executorId) { + if (sched.isExecutorAlive(execId)) { + addTo(pendingTasksForExecutor.getOrElseUpdate(execId, new ArrayBuffer)) + hadAliveLocations = true + } + } + if (sched.hasExecutorsAliveOnHost(loc.host)) { + addTo(pendingTasksForHost.getOrElseUpdate(loc.host, new ArrayBuffer)) + for (rack <- sched.getRackForHost(loc.host)) { + addTo(pendingTasksForRack.getOrElseUpdate(rack, new ArrayBuffer)) + } + hadAliveLocations = true + } + } + + if (!hadAliveLocations) { + // Even though the task might've had preferred locations, all of those hosts or executors + // are dead; put it in the no-prefs list so we can schedule it elsewhere right away. + addTo(pendingTasksWithNoPrefs) + } + + if (!readding) { + allPendingTasks += index // No point scanning this whole list to find the old task there + } + } + + /** + * Return the pending tasks list for a given executor ID, or an empty list if + * there is no map entry for that host + */ + private def getPendingTasksForExecutor(executorId: String): ArrayBuffer[Int] = { + pendingTasksForExecutor.getOrElse(executorId, ArrayBuffer()) + } + + /** + * Return the pending tasks list for a given host, or an empty list if + * there is no map entry for that host + */ + private def getPendingTasksForHost(host: String): ArrayBuffer[Int] = { + pendingTasksForHost.getOrElse(host, ArrayBuffer()) + } + + /** + * Return the pending rack-local task list for a given rack, or an empty list if + * there is no map entry for that rack + */ + private def getPendingTasksForRack(rack: String): ArrayBuffer[Int] = { + pendingTasksForRack.getOrElse(rack, ArrayBuffer()) + } + + /** + * Dequeue a pending task from the given list and return its index. + * Return None if the list is empty. + * This method also cleans up any tasks in the list that have already + * been launched, since we want that to happen lazily. + */ + private def findTaskFromList(list: ArrayBuffer[Int]): Option[Int] = { + while (!list.isEmpty) { + val index = list.last + list.trimEnd(1) + if (copiesRunning(index) == 0 && !finished(index)) { + return Some(index) + } + } + return None + } + + /** Check whether a task is currently running an attempt on a given host */ + private def hasAttemptOnHost(taskIndex: Int, host: String): Boolean = { + !taskAttempts(taskIndex).exists(_.host == host) + } + + /** + * Return a speculative task for a given executor if any are available. The task should not have + * an attempt running on this host, in case the host is slow. In addition, the task should meet + * the given locality constraint. + */ + private def findSpeculativeTask(execId: String, host: String, locality: TaskLocality.Value) + : Option[(Int, TaskLocality.Value)] = + { + speculatableTasks.retain(index => !finished(index)) // Remove finished tasks from set + + if (!speculatableTasks.isEmpty) { + // Check for process-local or preference-less tasks; note that tasks can be process-local + // on multiple nodes when we replicate cached blocks, as in Spark Streaming + for (index <- speculatableTasks if !hasAttemptOnHost(index, host)) { + val prefs = tasks(index).preferredLocations + val executors = prefs.flatMap(_.executorId) + if (prefs.size == 0 || executors.contains(execId)) { + speculatableTasks -= index + return Some((index, TaskLocality.PROCESS_LOCAL)) + } + } + + // Check for node-local tasks + if (TaskLocality.isAllowed(locality, TaskLocality.NODE_LOCAL)) { + for (index <- speculatableTasks if !hasAttemptOnHost(index, host)) { + val locations = tasks(index).preferredLocations.map(_.host) + if (locations.contains(host)) { + speculatableTasks -= index + return Some((index, TaskLocality.NODE_LOCAL)) + } + } + } + + // Check for rack-local tasks + if (TaskLocality.isAllowed(locality, TaskLocality.RACK_LOCAL)) { + for (rack <- sched.getRackForHost(host)) { + for (index <- speculatableTasks if !hasAttemptOnHost(index, host)) { + val racks = tasks(index).preferredLocations.map(_.host).map(sched.getRackForHost) + if (racks.contains(rack)) { + speculatableTasks -= index + return Some((index, TaskLocality.RACK_LOCAL)) + } + } + } + } + + // Check for non-local tasks + if (TaskLocality.isAllowed(locality, TaskLocality.ANY)) { + for (index <- speculatableTasks if !hasAttemptOnHost(index, host)) { + speculatableTasks -= index + return Some((index, TaskLocality.ANY)) + } + } + } + + return None + } + + /** + * Dequeue a pending task for a given node and return its index and locality level. + * Only search for tasks matching the given locality constraint. + */ + private def findTask(execId: String, host: String, locality: TaskLocality.Value) + : Option[(Int, TaskLocality.Value)] = + { + for (index <- findTaskFromList(getPendingTasksForExecutor(execId))) { + return Some((index, TaskLocality.PROCESS_LOCAL)) + } + + if (TaskLocality.isAllowed(locality, TaskLocality.NODE_LOCAL)) { + for (index <- findTaskFromList(getPendingTasksForHost(host))) { + return Some((index, TaskLocality.NODE_LOCAL)) + } + } + + if (TaskLocality.isAllowed(locality, TaskLocality.RACK_LOCAL)) { + for { + rack <- sched.getRackForHost(host) + index <- findTaskFromList(getPendingTasksForRack(rack)) + } { + return Some((index, TaskLocality.RACK_LOCAL)) + } + } + + // Look for no-pref tasks after rack-local tasks since they can run anywhere. + for (index <- findTaskFromList(pendingTasksWithNoPrefs)) { + return Some((index, TaskLocality.PROCESS_LOCAL)) + } + + if (TaskLocality.isAllowed(locality, TaskLocality.ANY)) { + for (index <- findTaskFromList(allPendingTasks)) { + return Some((index, TaskLocality.ANY)) + } + } + + // Finally, if all else has failed, find a speculative task + return findSpeculativeTask(execId, host, locality) + } + + /** + * Respond to an offer of a single slave from the scheduler by finding a task + */ + override def resourceOffer( + execId: String, + host: String, + availableCpus: Int, + maxLocality: TaskLocality.TaskLocality) + : Option[TaskDescription] = + { + if (tasksFinished < numTasks && availableCpus >= CPUS_PER_TASK) { + val curTime = clock.getTime() + + var allowedLocality = getAllowedLocalityLevel(curTime) + if (allowedLocality > maxLocality) { + allowedLocality = maxLocality // We're not allowed to search for farther-away tasks + } + + findTask(execId, host, allowedLocality) match { + case Some((index, taskLocality)) => { + // Found a task; do some bookkeeping and return a task description + val task = tasks(index) + val taskId = sched.newTaskId() + // Figure out whether this should count as a preferred launch + logInfo("Starting task %s:%d as TID %s on slave %s: %s (%s)".format( + taskSet.id, index, taskId, execId, host, taskLocality)) + // Do various bookkeeping + copiesRunning(index) += 1 + val info = new TaskInfo(taskId, index, curTime, execId, host, taskLocality) + taskInfos(taskId) = info + taskAttempts(index) = info :: taskAttempts(index) + // Update our locality level for delay scheduling + currentLocalityIndex = getLocalityIndex(taskLocality) + lastLaunchTime = curTime + // Serialize and return the task + val startTime = clock.getTime() + // We rely on the DAGScheduler to catch non-serializable closures and RDDs, so in here + // we assume the task can be serialized without exceptions. + val serializedTask = Task.serializeWithDependencies( + task, sched.sc.addedFiles, sched.sc.addedJars, ser) + val timeTaken = clock.getTime() - startTime + increaseRunningTasks(1) + logInfo("Serialized task %s:%d as %d bytes in %d ms".format( + taskSet.id, index, serializedTask.limit, timeTaken)) + val taskName = "task %s:%d".format(taskSet.id, index) + if (taskAttempts(index).size == 1) + taskStarted(task,info) + return Some(new TaskDescription(taskId, execId, taskName, index, serializedTask)) + } + case _ => + } + } + return None + } + + /** + * Get the level we can launch tasks according to delay scheduling, based on current wait time. + */ + private def getAllowedLocalityLevel(curTime: Long): TaskLocality.TaskLocality = { + while (curTime - lastLaunchTime >= localityWaits(currentLocalityIndex) && + currentLocalityIndex < myLocalityLevels.length - 1) + { + // Jump to the next locality level, and remove our waiting time for the current one since + // we don't want to count it again on the next one + lastLaunchTime += localityWaits(currentLocalityIndex) + currentLocalityIndex += 1 + } + myLocalityLevels(currentLocalityIndex) + } + + /** + * Find the index in myLocalityLevels for a given locality. This is also designed to work with + * localities that are not in myLocalityLevels (in case we somehow get those) by returning the + * next-biggest level we have. Uses the fact that the last value in myLocalityLevels is ANY. + */ + def getLocalityIndex(locality: TaskLocality.TaskLocality): Int = { + var index = 0 + while (locality > myLocalityLevels(index)) { + index += 1 + } + index + } + + /** Called by cluster scheduler when one of our tasks changes state */ + override def statusUpdate(tid: Long, state: TaskState, serializedData: ByteBuffer) { + SparkEnv.set(env) + state match { + case TaskState.FINISHED => + taskFinished(tid, state, serializedData) + case TaskState.LOST => + taskLost(tid, state, serializedData) + case TaskState.FAILED => + taskLost(tid, state, serializedData) + case TaskState.KILLED => + taskLost(tid, state, serializedData) + case _ => + } + } + + def taskStarted(task: Task[_], info: TaskInfo) { + sched.listener.taskStarted(task, info) + } + + def taskFinished(tid: Long, state: TaskState, serializedData: ByteBuffer) { + val info = taskInfos(tid) + if (info.failed) { + // We might get two task-lost messages for the same task in coarse-grained Mesos mode, + // or even from Mesos itself when acks get delayed. + return + } + val index = info.index + info.markSuccessful() + decreaseRunningTasks(1) + if (!finished(index)) { + tasksFinished += 1 + logInfo("Finished TID %s in %d ms on %s (progress: %d/%d)".format( + tid, info.duration, info.host, tasksFinished, numTasks)) + // Deserialize task result and pass it to the scheduler + try { + val result = ser.deserialize[TaskResult[_]](serializedData) + result.metrics.resultSize = serializedData.limit() + sched.listener.taskEnded( + tasks(index), Success, result.value, result.accumUpdates, info, result.metrics) + } catch { + case cnf: ClassNotFoundException => + val loader = Thread.currentThread().getContextClassLoader + throw new SparkException("ClassNotFound with classloader: " + loader, cnf) + case ex => throw ex + } + // Mark finished and stop if we've finished all the tasks + finished(index) = true + if (tasksFinished == numTasks) { + sched.taskSetFinished(this) + } + } else { + logInfo("Ignoring task-finished event for TID " + tid + + " because task " + index + " is already finished") + } + } + + def taskLost(tid: Long, state: TaskState, serializedData: ByteBuffer) { + val info = taskInfos(tid) + if (info.failed) { + // We might get two task-lost messages for the same task in coarse-grained Mesos mode, + // or even from Mesos itself when acks get delayed. + return + } + val index = info.index + info.markFailed() + decreaseRunningTasks(1) + if (!finished(index)) { + logInfo("Lost TID %s (task %s:%d)".format(tid, taskSet.id, index)) + copiesRunning(index) -= 1 + // Check if the problem is a map output fetch failure. In that case, this + // task will never succeed on any node, so tell the scheduler about it. + if (serializedData != null && serializedData.limit() > 0) { + val reason = ser.deserialize[TaskEndReason](serializedData, getClass.getClassLoader) + reason match { + case fetchFailed: FetchFailed => + logInfo("Loss was due to fetch failure from " + fetchFailed.bmAddress) + sched.listener.taskEnded(tasks(index), fetchFailed, null, null, info, null) + finished(index) = true + tasksFinished += 1 + sched.taskSetFinished(this) + decreaseRunningTasks(runningTasks) + return + + case taskResultTooBig: TaskResultTooBigFailure => + logInfo("Loss was due to task %s result exceeding Akka frame size; aborting job".format( + tid)) + abort("Task %s result exceeded Akka frame size".format(tid)) + return + + case ef: ExceptionFailure => + sched.listener.taskEnded(tasks(index), ef, null, null, info, ef.metrics.getOrElse(null)) + val key = ef.description + val now = clock.getTime() + val (printFull, dupCount) = { + if (recentExceptions.contains(key)) { + val (dupCount, printTime) = recentExceptions(key) + if (now - printTime > EXCEPTION_PRINT_INTERVAL) { + recentExceptions(key) = (0, now) + (true, 0) + } else { + recentExceptions(key) = (dupCount + 1, printTime) + (false, dupCount + 1) + } + } else { + recentExceptions(key) = (0, now) + (true, 0) + } + } + if (printFull) { + val locs = ef.stackTrace.map(loc => "\tat %s".format(loc.toString)) + logInfo("Loss was due to %s\n%s\n%s".format( + ef.className, ef.description, locs.mkString("\n"))) + } else { + logInfo("Loss was due to %s [duplicate %d]".format(ef.description, dupCount)) + } + + case _ => {} + } + } + // On non-fetch failures, re-enqueue the task as pending for a max number of retries + addPendingTask(index) + // Count failed attempts only on FAILED and LOST state (not on KILLED) + if (state == TaskState.FAILED || state == TaskState.LOST) { + numFailures(index) += 1 + if (numFailures(index) > MAX_TASK_FAILURES) { + logError("Task %s:%d failed more than %d times; aborting job".format( + taskSet.id, index, MAX_TASK_FAILURES)) + abort("Task %s:%d failed more than %d times".format(taskSet.id, index, MAX_TASK_FAILURES)) + } + } + } else { + logInfo("Ignoring task-lost event for TID " + tid + + " because task " + index + " is already finished") + } + } + + override def error(message: String) { + // Save the error message + abort("Error: " + message) + } + + def abort(message: String) { + failed = true + causeOfFailure = message + // TODO: Kill running tasks if we were not terminated due to a Mesos error + sched.listener.taskSetFailed(taskSet, message) + decreaseRunningTasks(runningTasks) + sched.taskSetFinished(this) + } + + override def increaseRunningTasks(taskNum: Int) { + runningTasks += taskNum + if (parent != null) { + parent.increaseRunningTasks(taskNum) + } + } + + override def decreaseRunningTasks(taskNum: Int) { + runningTasks -= taskNum + if (parent != null) { + parent.decreaseRunningTasks(taskNum) + } + } + + override def getSchedulableByName(name: String): Schedulable = { + return null + } + + override def addSchedulable(schedulable: Schedulable) {} + + override def removeSchedulable(schedulable: Schedulable) {} + + override def getSortedTaskSetQueue(): ArrayBuffer[TaskSetManager] = { + var sortedTaskSetQueue = ArrayBuffer[TaskSetManager](this) + sortedTaskSetQueue += this + return sortedTaskSetQueue + } + + /** Called by cluster scheduler when an executor is lost so we can re-enqueue our tasks */ + override def executorLost(execId: String, host: String) { + logInfo("Re-queueing tasks for " + execId + " from TaskSet " + taskSet.id) + + // Re-enqueue pending tasks for this host based on the status of the cluster -- for example, a + // task that used to have locations on only this host might now go to the no-prefs list. Note + // that it's okay if we add a task to the same queue twice (if it had multiple preferred + // locations), because findTaskFromList will skip already-running tasks. + for (index <- getPendingTasksForExecutor(execId)) { + addPendingTask(index, readding=true) + } + for (index <- getPendingTasksForHost(host)) { + addPendingTask(index, readding=true) + } + + // Re-enqueue any tasks that ran on the failed executor if this is a shuffle map stage + if (tasks(0).isInstanceOf[ShuffleMapTask]) { + for ((tid, info) <- taskInfos if info.executorId == execId) { + val index = taskInfos(tid).index + if (finished(index)) { + finished(index) = false + copiesRunning(index) -= 1 + tasksFinished -= 1 + addPendingTask(index) + // Tell the DAGScheduler that this task was resubmitted so that it doesn't think our + // stage finishes when a total of tasks.size tasks finish. + sched.listener.taskEnded(tasks(index), Resubmitted, null, null, info, null) + } + } + } + // Also re-enqueue any tasks that were running on the node + for ((tid, info) <- taskInfos if info.running && info.executorId == execId) { + taskLost(tid, TaskState.KILLED, null) + } + } + + /** + * Check for tasks to be speculated and return true if there are any. This is called periodically + * by the ClusterScheduler. + * + * TODO: To make this scale to large jobs, we need to maintain a list of running tasks, so that + * we don't scan the whole task set. It might also help to make this sorted by launch time. + */ + override def checkSpeculatableTasks(): Boolean = { + // Can't speculate if we only have one task, or if all tasks have finished. + if (numTasks == 1 || tasksFinished == numTasks) { + return false + } + var foundTasks = false + val minFinishedForSpeculation = (SPECULATION_QUANTILE * numTasks).floor.toInt + logDebug("Checking for speculative tasks: minFinished = " + minFinishedForSpeculation) + if (tasksFinished >= minFinishedForSpeculation) { + val time = clock.getTime() + val durations = taskInfos.values.filter(_.successful).map(_.duration).toArray + Arrays.sort(durations) + val medianDuration = durations(min((0.5 * numTasks).round.toInt, durations.size - 1)) + val threshold = max(SPECULATION_MULTIPLIER * medianDuration, 100) + // TODO: Threshold should also look at standard deviation of task durations and have a lower + // bound based on that. + logDebug("Task length threshold for speculation: " + threshold) + for ((tid, info) <- taskInfos) { + val index = info.index + if (!finished(index) && copiesRunning(index) == 1 && info.timeRunning(time) > threshold && + !speculatableTasks.contains(index)) { + logInfo( + "Marking task %s:%d (on %s) as speculatable because it ran more than %.0f ms".format( + taskSet.id, index, info.host, threshold)) + speculatableTasks += index + foundTasks = true + } + } + } + return foundTasks + } + + override def hasPendingTasks(): Boolean = { + numTasks > 0 && tasksFinished < numTasks + } + + private def getLocalityWait(level: TaskLocality.TaskLocality): Long = { + val defaultWait = System.getProperty("spark.locality.wait", "3000") + level match { + case TaskLocality.PROCESS_LOCAL => + System.getProperty("spark.locality.wait.process", defaultWait).toLong + case TaskLocality.NODE_LOCAL => + System.getProperty("spark.locality.wait.node", defaultWait).toLong + case TaskLocality.RACK_LOCAL => + System.getProperty("spark.locality.wait.rack", defaultWait).toLong + case TaskLocality.ANY => + 0L + } + } + + /** + * Compute the locality levels used in this TaskSet. Assumes that all tasks have already been + * added to queues using addPendingTask. + */ + private def computeValidLocalityLevels(): Array[TaskLocality.TaskLocality] = { + import TaskLocality.{PROCESS_LOCAL, NODE_LOCAL, RACK_LOCAL, ANY} + val levels = new ArrayBuffer[TaskLocality.TaskLocality] + if (!pendingTasksForExecutor.isEmpty && getLocalityWait(PROCESS_LOCAL) != 0) { + levels += PROCESS_LOCAL + } + if (!pendingTasksForHost.isEmpty && getLocalityWait(NODE_LOCAL) != 0) { + levels += NODE_LOCAL + } + if (!pendingTasksForRack.isEmpty && getLocalityWait(RACK_LOCAL) != 0) { + levels += RACK_LOCAL + } + levels += ANY + logDebug("Valid locality levels for " + taskSet + ": " + levels.mkString(", ")) + levels.toArray + } +} diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorLossReason.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorLossReason.scala new file mode 100644 index 0000000000..5077b2b48b --- /dev/null +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorLossReason.scala @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler.cluster + +import org.apache.spark.executor.ExecutorExitCode + +/** + * Represents an explanation for a executor or whole slave failing or exiting. + */ +private[spark] +class ExecutorLossReason(val message: String) { + override def toString: String = message +} + +private[spark] +case class ExecutorExited(val exitCode: Int) + extends ExecutorLossReason(ExecutorExitCode.explainExitCode(exitCode)) { +} + +private[spark] +case class SlaveLost(_message: String = "Slave lost") + extends ExecutorLossReason(_message) { +} diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/Pool.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/Pool.scala new file mode 100644 index 0000000000..35b32600da --- /dev/null +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/Pool.scala @@ -0,0 +1,121 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler.cluster + +import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.HashMap + +import org.apache.spark.Logging +import org.apache.spark.scheduler.cluster.SchedulingMode.SchedulingMode + +/** + * An Schedulable entity that represent collection of Pools or TaskSetManagers + */ + +private[spark] class Pool( + val poolName: String, + val schedulingMode: SchedulingMode, + initMinShare: Int, + initWeight: Int) + extends Schedulable + with Logging { + + var schedulableQueue = new ArrayBuffer[Schedulable] + var schedulableNameToSchedulable = new HashMap[String, Schedulable] + + var weight = initWeight + var minShare = initMinShare + var runningTasks = 0 + + var priority = 0 + var stageId = 0 + var name = poolName + var parent:Schedulable = null + + var taskSetSchedulingAlgorithm: SchedulingAlgorithm = { + schedulingMode match { + case SchedulingMode.FAIR => + new FairSchedulingAlgorithm() + case SchedulingMode.FIFO => + new FIFOSchedulingAlgorithm() + } + } + + override def addSchedulable(schedulable: Schedulable) { + schedulableQueue += schedulable + schedulableNameToSchedulable(schedulable.name) = schedulable + schedulable.parent= this + } + + override def removeSchedulable(schedulable: Schedulable) { + schedulableQueue -= schedulable + schedulableNameToSchedulable -= schedulable.name + } + + override def getSchedulableByName(schedulableName: String): Schedulable = { + if (schedulableNameToSchedulable.contains(schedulableName)) { + return schedulableNameToSchedulable(schedulableName) + } + for (schedulable <- schedulableQueue) { + var sched = schedulable.getSchedulableByName(schedulableName) + if (sched != null) { + return sched + } + } + return null + } + + override def executorLost(executorId: String, host: String) { + schedulableQueue.foreach(_.executorLost(executorId, host)) + } + + override def checkSpeculatableTasks(): Boolean = { + var shouldRevive = false + for (schedulable <- schedulableQueue) { + shouldRevive |= schedulable.checkSpeculatableTasks() + } + return shouldRevive + } + + override def getSortedTaskSetQueue(): ArrayBuffer[TaskSetManager] = { + var sortedTaskSetQueue = new ArrayBuffer[TaskSetManager] + val sortedSchedulableQueue = schedulableQueue.sortWith(taskSetSchedulingAlgorithm.comparator) + for (schedulable <- sortedSchedulableQueue) { + sortedTaskSetQueue ++= schedulable.getSortedTaskSetQueue() + } + return sortedTaskSetQueue + } + + override def increaseRunningTasks(taskNum: Int) { + runningTasks += taskNum + if (parent != null) { + parent.increaseRunningTasks(taskNum) + } + } + + override def decreaseRunningTasks(taskNum: Int) { + runningTasks -= taskNum + if (parent != null) { + parent.decreaseRunningTasks(taskNum) + } + } + + override def hasPendingTasks(): Boolean = { + schedulableQueue.exists(_.hasPendingTasks()) + } +} diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/Schedulable.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/Schedulable.scala new file mode 100644 index 0000000000..f4726450ec --- /dev/null +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/Schedulable.scala @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler.cluster + +import org.apache.spark.scheduler.cluster.SchedulingMode.SchedulingMode + +import scala.collection.mutable.ArrayBuffer +/** + * An interface for schedulable entities. + * there are two type of Schedulable entities(Pools and TaskSetManagers) + */ +private[spark] trait Schedulable { + var parent: Schedulable + // child queues + def schedulableQueue: ArrayBuffer[Schedulable] + def schedulingMode: SchedulingMode + def weight: Int + def minShare: Int + def runningTasks: Int + def priority: Int + def stageId: Int + def name: String + + def increaseRunningTasks(taskNum: Int): Unit + def decreaseRunningTasks(taskNum: Int): Unit + def addSchedulable(schedulable: Schedulable): Unit + def removeSchedulable(schedulable: Schedulable): Unit + def getSchedulableByName(name: String): Schedulable + def executorLost(executorId: String, host: String): Unit + def checkSpeculatableTasks(): Boolean + def getSortedTaskSetQueue(): ArrayBuffer[TaskSetManager] + def hasPendingTasks(): Boolean +} diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SchedulableBuilder.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SchedulableBuilder.scala new file mode 100644 index 0000000000..d04eeb6b98 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SchedulableBuilder.scala @@ -0,0 +1,137 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler.cluster + +import java.io.{File, FileInputStream, FileOutputStream, FileNotFoundException} +import java.util.Properties + +import scala.xml.XML + +import org.apache.spark.Logging +import org.apache.spark.scheduler.cluster.SchedulingMode.SchedulingMode + + +/** + * An interface to build Schedulable tree + * buildPools: build the tree nodes(pools) + * addTaskSetManager: build the leaf nodes(TaskSetManagers) + */ +private[spark] trait SchedulableBuilder { + def buildPools() + def addTaskSetManager(manager: Schedulable, properties: Properties) +} + +private[spark] class FIFOSchedulableBuilder(val rootPool: Pool) + extends SchedulableBuilder with Logging { + + override def buildPools() { + // nothing + } + + override def addTaskSetManager(manager: Schedulable, properties: Properties) { + rootPool.addSchedulable(manager) + } +} + +private[spark] class FairSchedulableBuilder(val rootPool: Pool) + extends SchedulableBuilder with Logging { + + val schedulerAllocFile = System.getProperty("spark.fairscheduler.allocation.file") + val FAIR_SCHEDULER_PROPERTIES = "spark.scheduler.cluster.fair.pool" + val DEFAULT_POOL_NAME = "default" + val MINIMUM_SHARES_PROPERTY = "minShare" + val SCHEDULING_MODE_PROPERTY = "schedulingMode" + val WEIGHT_PROPERTY = "weight" + val POOL_NAME_PROPERTY = "@name" + val POOLS_PROPERTY = "pool" + val DEFAULT_SCHEDULING_MODE = SchedulingMode.FIFO + val DEFAULT_MINIMUM_SHARE = 2 + val DEFAULT_WEIGHT = 1 + + override def buildPools() { + if (schedulerAllocFile != null) { + val file = new File(schedulerAllocFile) + if (file.exists()) { + val xml = XML.loadFile(file) + for (poolNode <- (xml \\ POOLS_PROPERTY)) { + + val poolName = (poolNode \ POOL_NAME_PROPERTY).text + var schedulingMode = DEFAULT_SCHEDULING_MODE + var minShare = DEFAULT_MINIMUM_SHARE + var weight = DEFAULT_WEIGHT + + val xmlSchedulingMode = (poolNode \ SCHEDULING_MODE_PROPERTY).text + if (xmlSchedulingMode != "") { + try { + schedulingMode = SchedulingMode.withName(xmlSchedulingMode) + } catch { + case e: Exception => logInfo("Error xml schedulingMode, using default schedulingMode") + } + } + + val xmlMinShare = (poolNode \ MINIMUM_SHARES_PROPERTY).text + if (xmlMinShare != "") { + minShare = xmlMinShare.toInt + } + + val xmlWeight = (poolNode \ WEIGHT_PROPERTY).text + if (xmlWeight != "") { + weight = xmlWeight.toInt + } + + val pool = new Pool(poolName, schedulingMode, minShare, weight) + rootPool.addSchedulable(pool) + logInfo("Created pool %s, schedulingMode: %s, minShare: %d, weight: %d".format( + poolName, schedulingMode, minShare, weight)) + } + } else { + throw new java.io.FileNotFoundException( + "Fair scheduler allocation file not found: " + schedulerAllocFile) + } + } + + // finally create "default" pool + if (rootPool.getSchedulableByName(DEFAULT_POOL_NAME) == null) { + val pool = new Pool(DEFAULT_POOL_NAME, DEFAULT_SCHEDULING_MODE, + DEFAULT_MINIMUM_SHARE, DEFAULT_WEIGHT) + rootPool.addSchedulable(pool) + logInfo("Created default pool %s, schedulingMode: %s, minShare: %d, weight: %d".format( + DEFAULT_POOL_NAME, DEFAULT_SCHEDULING_MODE, DEFAULT_MINIMUM_SHARE, DEFAULT_WEIGHT)) + } + } + + override def addTaskSetManager(manager: Schedulable, properties: Properties) { + var poolName = DEFAULT_POOL_NAME + var parentPool = rootPool.getSchedulableByName(poolName) + if (properties != null) { + poolName = properties.getProperty(FAIR_SCHEDULER_PROPERTIES, DEFAULT_POOL_NAME) + parentPool = rootPool.getSchedulableByName(poolName) + if (parentPool == null) { + // we will create a new pool that user has configured in app + // instead of being defined in xml file + parentPool = new Pool(poolName, DEFAULT_SCHEDULING_MODE, + DEFAULT_MINIMUM_SHARE, DEFAULT_WEIGHT) + rootPool.addSchedulable(parentPool) + logInfo("Created pool %s, schedulingMode: %s, minShare: %d, weight: %d".format( + poolName, DEFAULT_SCHEDULING_MODE, DEFAULT_MINIMUM_SHARE, DEFAULT_WEIGHT)) + } + } + parentPool.addSchedulable(manager) + logInfo("Added task set " + manager.name + " tasks to pool "+poolName) + } +} diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SchedulerBackend.scala new file mode 100644 index 0000000000..d57eb3276f --- /dev/null +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SchedulerBackend.scala @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler.cluster + +import org.apache.spark.{SparkContext} + +/** + * A backend interface for cluster scheduling systems that allows plugging in different ones under + * ClusterScheduler. We assume a Mesos-like model where the application gets resource offers as + * machines become available and can launch tasks on them. + */ +private[spark] trait SchedulerBackend { + def start(): Unit + def stop(): Unit + def reviveOffers(): Unit + def defaultParallelism(): Int + + // Memory used by each executor (in megabytes) + protected val executorMemory: Int = SparkContext.executorMemoryRequested + + // TODO: Probably want to add a killTask too +} diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SchedulingAlgorithm.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SchedulingAlgorithm.scala new file mode 100644 index 0000000000..cbeed4731a --- /dev/null +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SchedulingAlgorithm.scala @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler.cluster + +/** + * An interface for sort algorithm + * FIFO: FIFO algorithm between TaskSetManagers + * FS: FS algorithm between Pools, and FIFO or FS within Pools + */ +private[spark] trait SchedulingAlgorithm { + def comparator(s1: Schedulable, s2: Schedulable): Boolean +} + +private[spark] class FIFOSchedulingAlgorithm extends SchedulingAlgorithm { + override def comparator(s1: Schedulable, s2: Schedulable): Boolean = { + val priority1 = s1.priority + val priority2 = s2.priority + var res = math.signum(priority1 - priority2) + if (res == 0) { + val stageId1 = s1.stageId + val stageId2 = s2.stageId + res = math.signum(stageId1 - stageId2) + } + if (res < 0) { + return true + } else { + return false + } + } +} + +private[spark] class FairSchedulingAlgorithm extends SchedulingAlgorithm { + override def comparator(s1: Schedulable, s2: Schedulable): Boolean = { + val minShare1 = s1.minShare + val minShare2 = s2.minShare + val runningTasks1 = s1.runningTasks + val runningTasks2 = s2.runningTasks + val s1Needy = runningTasks1 < minShare1 + val s2Needy = runningTasks2 < minShare2 + val minShareRatio1 = runningTasks1.toDouble / math.max(minShare1, 1.0).toDouble + val minShareRatio2 = runningTasks2.toDouble / math.max(minShare2, 1.0).toDouble + val taskToWeightRatio1 = runningTasks1.toDouble / s1.weight.toDouble + val taskToWeightRatio2 = runningTasks2.toDouble / s2.weight.toDouble + var res:Boolean = true + var compare:Int = 0 + + if (s1Needy && !s2Needy) { + return true + } else if (!s1Needy && s2Needy) { + return false + } else if (s1Needy && s2Needy) { + compare = minShareRatio1.compareTo(minShareRatio2) + } else { + compare = taskToWeightRatio1.compareTo(taskToWeightRatio2) + } + + if (compare < 0) { + return true + } else if (compare > 0) { + return false + } else { + return s1.name < s2.name + } + } +} + diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SchedulingMode.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SchedulingMode.scala new file mode 100644 index 0000000000..34811389a0 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SchedulingMode.scala @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler.cluster + +/** + * "FAIR" and "FIFO" determines which policy is used + * to order tasks amongst a Schedulable's sub-queues + * "NONE" is used when the a Schedulable has no sub-queues. + */ +object SchedulingMode extends Enumeration("FAIR", "FIFO", "NONE") { + + type SchedulingMode = Value + val FAIR,FIFO,NONE = Value +} diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala new file mode 100644 index 0000000000..9a2cf20de7 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler.cluster + +import org.apache.spark.{Logging, SparkContext} +import org.apache.spark.deploy.client.{Client, ClientListener} +import org.apache.spark.deploy.{Command, ApplicationDescription} +import scala.collection.mutable.HashMap +import org.apache.spark.util.Utils + +private[spark] class SparkDeploySchedulerBackend( + scheduler: ClusterScheduler, + sc: SparkContext, + master: String, + appName: String) + extends StandaloneSchedulerBackend(scheduler, sc.env.actorSystem) + with ClientListener + with Logging { + + var client: Client = null + var stopping = false + var shutdownCallback : (SparkDeploySchedulerBackend) => Unit = _ + + val maxCores = System.getProperty("spark.cores.max", Int.MaxValue.toString).toInt + + override def start() { + super.start() + + // The endpoint for executors to talk to us + val driverUrl = "akka://spark@%s:%s/user/%s".format( + System.getProperty("spark.driver.host"), System.getProperty("spark.driver.port"), + StandaloneSchedulerBackend.ACTOR_NAME) + val args = Seq(driverUrl, "{{EXECUTOR_ID}}", "{{HOSTNAME}}", "{{CORES}}") + val command = Command( + "org.apache.spark.executor.StandaloneExecutorBackend", args, sc.executorEnvs) + val sparkHome = sc.getSparkHome().getOrElse(null) + val appDesc = new ApplicationDescription(appName, maxCores, executorMemory, command, sparkHome, + "http://" + sc.ui.appUIAddress) + + client = new Client(sc.env.actorSystem, master, appDesc, this) + client.start() + } + + override def stop() { + stopping = true + super.stop() + client.stop() + if (shutdownCallback != null) { + shutdownCallback(this) + } + } + + override def connected(appId: String) { + logInfo("Connected to Spark cluster with app ID " + appId) + } + + override def disconnected() { + if (!stopping) { + logError("Disconnected from Spark cluster!") + scheduler.error("Disconnected from Spark cluster") + } + } + + override def executorAdded(executorId: String, workerId: String, hostPort: String, cores: Int, memory: Int) { + logInfo("Granted executor ID %s on hostPort %s with %d cores, %s RAM".format( + executorId, hostPort, cores, Utils.megabytesToString(memory))) + } + + override def executorRemoved(executorId: String, message: String, exitStatus: Option[Int]) { + val reason: ExecutorLossReason = exitStatus match { + case Some(code) => ExecutorExited(code) + case None => SlaveLost(message) + } + logInfo("Executor %s removed: %s".format(executorId, message)) + removeExecutor(executorId, reason.toString) + } +} diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneClusterMessage.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneClusterMessage.scala new file mode 100644 index 0000000000..9c36d221f6 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneClusterMessage.scala @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler.cluster + +import java.nio.ByteBuffer + +import org.apache.spark.TaskState.TaskState +import org.apache.spark.util.{Utils, SerializableBuffer} + + +private[spark] sealed trait StandaloneClusterMessage extends Serializable + +private[spark] object StandaloneClusterMessages { + + // Driver to executors + case class LaunchTask(task: TaskDescription) extends StandaloneClusterMessage + + case class RegisteredExecutor(sparkProperties: Seq[(String, String)]) + extends StandaloneClusterMessage + + case class RegisterExecutorFailed(message: String) extends StandaloneClusterMessage + + // Executors to driver + case class RegisterExecutor(executorId: String, hostPort: String, cores: Int) + extends StandaloneClusterMessage { + Utils.checkHostPort(hostPort, "Expected host port") + } + + case class StatusUpdate(executorId: String, taskId: Long, state: TaskState, + data: SerializableBuffer) extends StandaloneClusterMessage + + object StatusUpdate { + /** Alternate factory method that takes a ByteBuffer directly for the data field */ + def apply(executorId: String, taskId: Long, state: TaskState, data: ByteBuffer) + : StatusUpdate = { + StatusUpdate(executorId, taskId, state, new SerializableBuffer(data)) + } + } + + // Internal messages in driver + case object ReviveOffers extends StandaloneClusterMessage + + case object StopDriver extends StandaloneClusterMessage + + case class RemoveExecutor(executorId: String, reason: String) extends StandaloneClusterMessage + +} diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala new file mode 100644 index 0000000000..addfa077c1 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala @@ -0,0 +1,199 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler.cluster + +import java.util.concurrent.atomic.AtomicInteger + +import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} +import scala.concurrent.Await +import scala.concurrent.duration._ + +import akka.actor._ +import akka.pattern.ask +import akka.remote.{RemoteClientShutdown, RemoteClientDisconnected, RemoteClientLifeCycleEvent} + +import org.apache.spark.{SparkException, Logging, TaskState} +import org.apache.spark.scheduler.cluster.StandaloneClusterMessages._ +import org.apache.spark.util.Utils + +/** + * A standalone scheduler backend, which waits for standalone executors to connect to it through + * Akka. These may be executed in a variety of ways, such as Mesos tasks for the coarse-grained + * Mesos mode or standalone processes for Spark's standalone deploy mode (spark.deploy.*). + */ +private[spark] +class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: ActorSystem) + extends SchedulerBackend with Logging +{ + // Use an atomic variable to track total number of cores in the cluster for simplicity and speed + var totalCoreCount = new AtomicInteger(0) + + class DriverActor(sparkProperties: Seq[(String, String)]) extends Actor { + private val executorActor = new HashMap[String, ActorRef] + private val executorAddress = new HashMap[String, Address] + private val executorHost = new HashMap[String, String] + private val freeCores = new HashMap[String, Int] + private val actorToExecutorId = new HashMap[ActorRef, String] + private val addressToExecutorId = new HashMap[Address, String] + + override def preStart() { + // Listen for remote client disconnection events, since they don't go through Akka's watch() + context.system.eventStream.subscribe(self, classOf[RemoteClientLifeCycleEvent]) + + // Periodically revive offers to allow delay scheduling to work + val reviveInterval = System.getProperty("spark.scheduler.revive.interval", "1000").toLong + import context.dispatcher + context.system.scheduler.schedule(0.millis, reviveInterval.millis, self, ReviveOffers) + } + + def receive = { + case RegisterExecutor(executorId, hostPort, cores) => + Utils.checkHostPort(hostPort, "Host port expected " + hostPort) + if (executorActor.contains(executorId)) { + sender ! RegisterExecutorFailed("Duplicate executor ID: " + executorId) + } else { + logInfo("Registered executor: " + sender + " with ID " + executorId) + sender ! RegisteredExecutor(sparkProperties) + context.watch(sender) + executorActor(executorId) = sender + executorHost(executorId) = Utils.parseHostPort(hostPort)._1 + freeCores(executorId) = cores + executorAddress(executorId) = sender.path.address + actorToExecutorId(sender) = executorId + addressToExecutorId(sender.path.address) = executorId + totalCoreCount.addAndGet(cores) + makeOffers() + } + + case StatusUpdate(executorId, taskId, state, data) => + scheduler.statusUpdate(taskId, state, data.value) + if (TaskState.isFinished(state)) { + freeCores(executorId) += 1 + makeOffers(executorId) + } + + case ReviveOffers => + makeOffers() + + case StopDriver => + sender ! true + context.stop(self) + + case RemoveExecutor(executorId, reason) => + removeExecutor(executorId, reason) + sender ! true + + case Terminated(actor) => + actorToExecutorId.get(actor).foreach(removeExecutor(_, "Akka actor terminated")) + + case RemoteClientDisconnected(transport, address) => + addressToExecutorId.get(address).foreach(removeExecutor(_, "remote Akka client disconnected")) + + case RemoteClientShutdown(transport, address) => + addressToExecutorId.get(address).foreach(removeExecutor(_, "remote Akka client shutdown")) + } + + // Make fake resource offers on all executors + def makeOffers() { + launchTasks(scheduler.resourceOffers( + executorHost.toArray.map {case (id, host) => new WorkerOffer(id, host, freeCores(id))})) + } + + // Make fake resource offers on just one executor + def makeOffers(executorId: String) { + launchTasks(scheduler.resourceOffers( + Seq(new WorkerOffer(executorId, executorHost(executorId), freeCores(executorId))))) + } + + // Launch tasks returned by a set of resource offers + def launchTasks(tasks: Seq[Seq[TaskDescription]]) { + for (task <- tasks.flatten) { + freeCores(task.executorId) -= 1 + executorActor(task.executorId) ! LaunchTask(task) + } + } + + // Remove a disconnected slave from the cluster + def removeExecutor(executorId: String, reason: String) { + if (executorActor.contains(executorId)) { + logInfo("Executor " + executorId + " disconnected, so removing it") + val numCores = freeCores(executorId) + actorToExecutorId -= executorActor(executorId) + addressToExecutorId -= executorAddress(executorId) + executorActor -= executorId + executorHost -= executorId + freeCores -= executorId + totalCoreCount.addAndGet(-numCores) + scheduler.executorLost(executorId, SlaveLost(reason)) + } + } + } + + var driverActor: ActorRef = null + val taskIdsOnSlave = new HashMap[String, HashSet[String]] + + override def start() { + val properties = new ArrayBuffer[(String, String)] + val iterator = System.getProperties.entrySet.iterator + while (iterator.hasNext) { + val entry = iterator.next + val (key, value) = (entry.getKey.toString, entry.getValue.toString) + if (key.startsWith("spark.") && !key.equals("spark.hostPort")) { + properties += ((key, value)) + } + } + driverActor = actorSystem.actorOf( + Props(new DriverActor(properties)), name = StandaloneSchedulerBackend.ACTOR_NAME) + } + + private val timeout = Duration.create(System.getProperty("spark.akka.askTimeout", "10").toLong, "seconds") + + override def stop() { + try { + if (driverActor != null) { + val future = driverActor.ask(StopDriver)(timeout) + Await.result(future, timeout) + } + } catch { + case e: Exception => + throw new SparkException("Error stopping standalone scheduler's driver actor", e) + } + } + + override def reviveOffers() { + driverActor ! ReviveOffers + } + + override def defaultParallelism() = Option(System.getProperty("spark.default.parallelism")) + .map(_.toInt).getOrElse(math.max(totalCoreCount.get(), 2)) + + // Called by subclasses when notified of a lost worker + def removeExecutor(executorId: String, reason: String) { + try { + val future = driverActor.ask(RemoveExecutor(executorId, reason))(timeout) + Await.result(future, timeout) + } catch { + case e: Exception => + throw new SparkException("Error notifying standalone scheduler's driver actor", e) + } + } +} + +private[spark] object StandaloneSchedulerBackend { + val ACTOR_NAME = "StandaloneScheduler" +} diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/TaskDescription.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/TaskDescription.scala new file mode 100644 index 0000000000..309ac2f6c9 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/TaskDescription.scala @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler.cluster + +import java.nio.ByteBuffer +import org.apache.spark.util.SerializableBuffer + +private[spark] class TaskDescription( + val taskId: Long, + val executorId: String, + val name: String, + val index: Int, // Index within this task's TaskSet + _serializedTask: ByteBuffer) + extends Serializable { + + // Because ByteBuffers are not serializable, wrap the task in a SerializableBuffer + private val buffer = new SerializableBuffer(_serializedTask) + + def serializedTask: ByteBuffer = buffer.value + + override def toString: String = "TaskDescription(TID=%d, index=%d)".format(taskId, index) +} diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/TaskInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/TaskInfo.scala new file mode 100644 index 0000000000..9685fb1a67 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/TaskInfo.scala @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler.cluster + +import org.apache.spark.util.Utils + +/** + * Information about a running task attempt inside a TaskSet. + */ +private[spark] +class TaskInfo( + val taskId: Long, + val index: Int, + val launchTime: Long, + val executorId: String, + val host: String, + val taskLocality: TaskLocality.TaskLocality) { + + var finishTime: Long = 0 + var failed = false + + def markSuccessful(time: Long = System.currentTimeMillis) { + finishTime = time + } + + def markFailed(time: Long = System.currentTimeMillis) { + finishTime = time + failed = true + } + + def finished: Boolean = finishTime != 0 + + def successful: Boolean = finished && !failed + + def running: Boolean = !finished + + def status: String = { + if (running) + "RUNNING" + else if (failed) + "FAILED" + else if (successful) + "SUCCESS" + else + "UNKNOWN" + } + + def duration: Long = { + if (!finished) { + throw new UnsupportedOperationException("duration() called on unfinished tasks") + } else { + finishTime - launchTime + } + } + + def timeRunning(currentTime: Long): Long = currentTime - launchTime +} diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/TaskLocality.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/TaskLocality.scala new file mode 100644 index 0000000000..5d4130e14a --- /dev/null +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/TaskLocality.scala @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler.cluster + + +private[spark] object TaskLocality + extends Enumeration("PROCESS_LOCAL", "NODE_LOCAL", "RACK_LOCAL", "ANY") +{ + // process local is expected to be used ONLY within tasksetmanager for now. + val PROCESS_LOCAL, NODE_LOCAL, RACK_LOCAL, ANY = Value + + type TaskLocality = Value + + def isAllowed(constraint: TaskLocality, condition: TaskLocality): Boolean = { + condition <= constraint + } +} diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/TaskSetManager.scala new file mode 100644 index 0000000000..648a3ef922 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/TaskSetManager.scala @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler.cluster + +import java.nio.ByteBuffer + +import org.apache.spark.TaskState.TaskState +import org.apache.spark.scheduler.TaskSet + +/** + * Tracks and schedules the tasks within a single TaskSet. This class keeps track of the status of + * each task and is responsible for retries on failure and locality. The main interfaces to it + * are resourceOffer, which asks the TaskSet whether it wants to run a task on one node, and + * statusUpdate, which tells it that one of its tasks changed state (e.g. finished). + * + * THREADING: This class is designed to only be called from code with a lock on the TaskScheduler + * (e.g. its event handlers). It should not be called from other threads. + */ +private[spark] trait TaskSetManager extends Schedulable { + def schedulableQueue = null + + def schedulingMode = SchedulingMode.NONE + + def taskSet: TaskSet + + def resourceOffer( + execId: String, + host: String, + availableCpus: Int, + maxLocality: TaskLocality.TaskLocality) + : Option[TaskDescription] + + def statusUpdate(tid: Long, state: TaskState, serializedData: ByteBuffer) + + def error(message: String) +} diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/WorkerOffer.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/WorkerOffer.scala new file mode 100644 index 0000000000..938f62883a --- /dev/null +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/WorkerOffer.scala @@ -0,0 +1,24 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler.cluster + +/** + * Represents free resources available on an executor. + */ +private[spark] +class WorkerOffer(val executorId: String, val host: String, val cores: Int) diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalScheduler.scala new file mode 100644 index 0000000000..e8fa5e2f17 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/scheduler/local/LocalScheduler.scala @@ -0,0 +1,273 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler.local + +import java.io.File +import java.lang.management.ManagementFactory +import java.util.concurrent.atomic.AtomicInteger +import java.nio.ByteBuffer + +import scala.collection.JavaConversions._ +import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.HashMap +import scala.collection.mutable.HashSet + +import org.apache.spark._ +import org.apache.spark.TaskState.TaskState +import org.apache.spark.executor.ExecutorURLClassLoader +import org.apache.spark.scheduler._ +import org.apache.spark.scheduler.cluster._ +import org.apache.spark.scheduler.cluster.SchedulingMode.SchedulingMode +import akka.actor._ +import org.apache.spark.util.Utils + +/** + * A FIFO or Fair TaskScheduler implementation that runs tasks locally in a thread pool. Optionally + * the scheduler also allows each task to fail up to maxFailures times, which is useful for + * testing fault recovery. + */ + +private[spark] +case class LocalReviveOffers() + +private[spark] +case class LocalStatusUpdate(taskId: Long, state: TaskState, serializedData: ByteBuffer) + +private[spark] +class LocalActor(localScheduler: LocalScheduler, var freeCores: Int) extends Actor with Logging { + + def receive = { + case LocalReviveOffers => + launchTask(localScheduler.resourceOffer(freeCores)) + case LocalStatusUpdate(taskId, state, serializeData) => + freeCores += 1 + localScheduler.statusUpdate(taskId, state, serializeData) + launchTask(localScheduler.resourceOffer(freeCores)) + } + + def launchTask(tasks : Seq[TaskDescription]) { + for (task <- tasks) { + freeCores -= 1 + localScheduler.threadPool.submit(new Runnable { + def run() { + localScheduler.runTask(task.taskId, task.serializedTask) + } + }) + } + } +} + +private[spark] class LocalScheduler(threads: Int, val maxFailures: Int, val sc: SparkContext) + extends TaskScheduler + with Logging { + + var attemptId = new AtomicInteger(0) + var threadPool = Utils.newDaemonFixedThreadPool(threads) + val env = SparkEnv.get + var listener: TaskSchedulerListener = null + + // Application dependencies (added through SparkContext) that we've fetched so far on this node. + // Each map holds the master's timestamp for the version of that file or JAR we got. + val currentFiles: HashMap[String, Long] = new HashMap[String, Long]() + val currentJars: HashMap[String, Long] = new HashMap[String, Long]() + + val classLoader = new ExecutorURLClassLoader(Array(), Thread.currentThread.getContextClassLoader) + + var schedulableBuilder: SchedulableBuilder = null + var rootPool: Pool = null + val schedulingMode: SchedulingMode = SchedulingMode.withName( + System.getProperty("spark.cluster.schedulingmode", "FIFO")) + val activeTaskSets = new HashMap[String, TaskSetManager] + val taskIdToTaskSetId = new HashMap[Long, String] + val taskSetTaskIds = new HashMap[String, HashSet[Long]] + + var localActor: ActorRef = null + + override def start() { + // temporarily set rootPool name to empty + rootPool = new Pool("", schedulingMode, 0, 0) + schedulableBuilder = { + schedulingMode match { + case SchedulingMode.FIFO => + new FIFOSchedulableBuilder(rootPool) + case SchedulingMode.FAIR => + new FairSchedulableBuilder(rootPool) + } + } + schedulableBuilder.buildPools() + + localActor = env.actorSystem.actorOf(Props(new LocalActor(this, threads)), "Test") + } + + override def setListener(listener: TaskSchedulerListener) { + this.listener = listener + } + + override def submitTasks(taskSet: TaskSet) { + synchronized { + val manager = new LocalTaskSetManager(this, taskSet) + schedulableBuilder.addTaskSetManager(manager, manager.taskSet.properties) + activeTaskSets(taskSet.id) = manager + taskSetTaskIds(taskSet.id) = new HashSet[Long]() + localActor ! LocalReviveOffers + } + } + + def resourceOffer(freeCores: Int): Seq[TaskDescription] = { + synchronized { + var freeCpuCores = freeCores + val tasks = new ArrayBuffer[TaskDescription](freeCores) + val sortedTaskSetQueue = rootPool.getSortedTaskSetQueue() + for (manager <- sortedTaskSetQueue) { + logDebug("parentName:%s,name:%s,runningTasks:%s".format( + manager.parent.name, manager.name, manager.runningTasks)) + } + + var launchTask = false + for (manager <- sortedTaskSetQueue) { + do { + launchTask = false + manager.resourceOffer(null, null, freeCpuCores, null) match { + case Some(task) => + tasks += task + taskIdToTaskSetId(task.taskId) = manager.taskSet.id + taskSetTaskIds(manager.taskSet.id) += task.taskId + freeCpuCores -= 1 + launchTask = true + case None => {} + } + } while(launchTask) + } + return tasks + } + } + + def taskSetFinished(manager: TaskSetManager) { + synchronized { + activeTaskSets -= manager.taskSet.id + manager.parent.removeSchedulable(manager) + logInfo("Remove TaskSet %s from pool %s".format(manager.taskSet.id, manager.parent.name)) + taskIdToTaskSetId --= taskSetTaskIds(manager.taskSet.id) + taskSetTaskIds -= manager.taskSet.id + } + } + + def runTask(taskId: Long, bytes: ByteBuffer) { + logInfo("Running " + taskId) + val info = new TaskInfo(taskId, 0, System.currentTimeMillis(), "local", "local:1", TaskLocality.NODE_LOCAL) + // Set the Spark execution environment for the worker thread + SparkEnv.set(env) + val ser = SparkEnv.get.closureSerializer.newInstance() + val objectSer = SparkEnv.get.serializer.newInstance() + var attemptedTask: Option[Task[_]] = None + val start = System.currentTimeMillis() + var taskStart: Long = 0 + def getTotalGCTime = ManagementFactory.getGarbageCollectorMXBeans.map(g => g.getCollectionTime).sum + val startGCTime = getTotalGCTime + + try { + Accumulators.clear() + Thread.currentThread().setContextClassLoader(classLoader) + + // Serialize and deserialize the task so that accumulators are changed to thread-local ones; + // this adds a bit of unnecessary overhead but matches how the Mesos Executor works. + val (taskFiles, taskJars, taskBytes) = Task.deserializeWithDependencies(bytes) + updateDependencies(taskFiles, taskJars) // Download any files added with addFile + val deserializedTask = ser.deserialize[Task[_]]( + taskBytes, Thread.currentThread.getContextClassLoader) + attemptedTask = Some(deserializedTask) + val deserTime = System.currentTimeMillis() - start + taskStart = System.currentTimeMillis() + + // Run it + val result: Any = deserializedTask.run(taskId) + + // Serialize and deserialize the result to emulate what the Mesos + // executor does. This is useful to catch serialization errors early + // on in development (so when users move their local Spark programs + // to the cluster, they don't get surprised by serialization errors). + val serResult = objectSer.serialize(result) + deserializedTask.metrics.get.resultSize = serResult.limit() + val resultToReturn = objectSer.deserialize[Any](serResult) + val accumUpdates = ser.deserialize[collection.mutable.Map[Long, Any]]( + ser.serialize(Accumulators.values)) + val serviceTime = System.currentTimeMillis() - taskStart + logInfo("Finished " + taskId) + deserializedTask.metrics.get.executorRunTime = serviceTime.toInt + deserializedTask.metrics.get.jvmGCTime = getTotalGCTime - startGCTime + deserializedTask.metrics.get.executorDeserializeTime = deserTime.toInt + val taskResult = new TaskResult(result, accumUpdates, deserializedTask.metrics.getOrElse(null)) + val serializedResult = ser.serialize(taskResult) + localActor ! LocalStatusUpdate(taskId, TaskState.FINISHED, serializedResult) + } catch { + case t: Throwable => { + val serviceTime = System.currentTimeMillis() - taskStart + val metrics = attemptedTask.flatMap(t => t.metrics) + for (m <- metrics) { + m.executorRunTime = serviceTime.toInt + m.jvmGCTime = getTotalGCTime - startGCTime + } + val failure = new ExceptionFailure(t.getClass.getName, t.toString, t.getStackTrace, metrics) + localActor ! LocalStatusUpdate(taskId, TaskState.FAILED, ser.serialize(failure)) + } + } + } + + /** + * Download any missing dependencies if we receive a new set of files and JARs from the + * SparkContext. Also adds any new JARs we fetched to the class loader. + */ + private def updateDependencies(newFiles: HashMap[String, Long], newJars: HashMap[String, Long]) { + synchronized { + // Fetch missing dependencies + for ((name, timestamp) <- newFiles if currentFiles.getOrElse(name, -1L) < timestamp) { + logInfo("Fetching " + name + " with timestamp " + timestamp) + Utils.fetchFile(name, new File(SparkFiles.getRootDirectory)) + currentFiles(name) = timestamp + } + + for ((name, timestamp) <- newJars if currentJars.getOrElse(name, -1L) < timestamp) { + logInfo("Fetching " + name + " with timestamp " + timestamp) + Utils.fetchFile(name, new File(SparkFiles.getRootDirectory)) + currentJars(name) = timestamp + // Add it to our class loader + val localName = name.split("/").last + val url = new File(SparkFiles.getRootDirectory, localName).toURI.toURL + if (!classLoader.getURLs.contains(url)) { + logInfo("Adding " + url + " to class loader") + classLoader.addURL(url) + } + } + } + } + + def statusUpdate(taskId :Long, state: TaskState, serializedData: ByteBuffer) { + synchronized { + val taskSetId = taskIdToTaskSetId(taskId) + val taskSetManager = activeTaskSets(taskSetId) + taskSetTaskIds(taskSetId) -= taskId + taskSetManager.statusUpdate(taskId, state, serializedData) + } + } + + override def stop() { + threadPool.shutdownNow() + } + + override def defaultParallelism() = threads +} diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalTaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalTaskSetManager.scala new file mode 100644 index 0000000000..e52cb998bd --- /dev/null +++ b/core/src/main/scala/org/apache/spark/scheduler/local/LocalTaskSetManager.scala @@ -0,0 +1,194 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler.local + +import java.nio.ByteBuffer +import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.HashMap + +import org.apache.spark.{ExceptionFailure, Logging, SparkEnv, Success, TaskState} +import org.apache.spark.TaskState.TaskState +import org.apache.spark.scheduler.{Task, TaskResult, TaskSet} +import org.apache.spark.scheduler.cluster.{Schedulable, TaskDescription, TaskInfo, TaskLocality, TaskSetManager} + + +private[spark] class LocalTaskSetManager(sched: LocalScheduler, val taskSet: TaskSet) + extends TaskSetManager with Logging { + + var parent: Schedulable = null + var weight: Int = 1 + var minShare: Int = 0 + var runningTasks: Int = 0 + var priority: Int = taskSet.priority + var stageId: Int = taskSet.stageId + var name: String = "TaskSet_" + taskSet.stageId.toString + + var failCount = new Array[Int](taskSet.tasks.size) + val taskInfos = new HashMap[Long, TaskInfo] + val numTasks = taskSet.tasks.size + var numFinished = 0 + val env = SparkEnv.get + val ser = env.closureSerializer.newInstance() + val copiesRunning = new Array[Int](numTasks) + val finished = new Array[Boolean](numTasks) + val numFailures = new Array[Int](numTasks) + val MAX_TASK_FAILURES = sched.maxFailures + + override def increaseRunningTasks(taskNum: Int): Unit = { + runningTasks += taskNum + if (parent != null) { + parent.increaseRunningTasks(taskNum) + } + } + + override def decreaseRunningTasks(taskNum: Int): Unit = { + runningTasks -= taskNum + if (parent != null) { + parent.decreaseRunningTasks(taskNum) + } + } + + override def addSchedulable(schedulable: Schedulable): Unit = { + // nothing + } + + override def removeSchedulable(schedulable: Schedulable): Unit = { + // nothing + } + + override def getSchedulableByName(name: String): Schedulable = { + return null + } + + override def executorLost(executorId: String, host: String): Unit = { + // nothing + } + + override def checkSpeculatableTasks() = true + + override def getSortedTaskSetQueue(): ArrayBuffer[TaskSetManager] = { + var sortedTaskSetQueue = new ArrayBuffer[TaskSetManager] + sortedTaskSetQueue += this + return sortedTaskSetQueue + } + + override def hasPendingTasks() = true + + def findTask(): Option[Int] = { + for (i <- 0 to numTasks-1) { + if (copiesRunning(i) == 0 && !finished(i)) { + return Some(i) + } + } + return None + } + + override def resourceOffer( + execId: String, + host: String, + availableCpus: Int, + maxLocality: TaskLocality.TaskLocality) + : Option[TaskDescription] = + { + SparkEnv.set(sched.env) + logDebug("availableCpus:%d, numFinished:%d, numTasks:%d".format( + availableCpus.toInt, numFinished, numTasks)) + if (availableCpus > 0 && numFinished < numTasks) { + findTask() match { + case Some(index) => + val taskId = sched.attemptId.getAndIncrement() + val task = taskSet.tasks(index) + val info = new TaskInfo(taskId, index, System.currentTimeMillis(), "local", "local:1", + TaskLocality.NODE_LOCAL) + taskInfos(taskId) = info + // We rely on the DAGScheduler to catch non-serializable closures and RDDs, so in here + // we assume the task can be serialized without exceptions. + val bytes = Task.serializeWithDependencies( + task, sched.sc.addedFiles, sched.sc.addedJars, ser) + logInfo("Size of task " + taskId + " is " + bytes.limit + " bytes") + val taskName = "task %s:%d".format(taskSet.id, index) + copiesRunning(index) += 1 + increaseRunningTasks(1) + taskStarted(task, info) + return Some(new TaskDescription(taskId, null, taskName, index, bytes)) + case None => {} + } + } + return None + } + + override def statusUpdate(tid: Long, state: TaskState, serializedData: ByteBuffer) { + SparkEnv.set(env) + state match { + case TaskState.FINISHED => + taskEnded(tid, state, serializedData) + case TaskState.FAILED => + taskFailed(tid, state, serializedData) + case _ => {} + } + } + + def taskStarted(task: Task[_], info: TaskInfo) { + sched.listener.taskStarted(task, info) + } + + def taskEnded(tid: Long, state: TaskState, serializedData: ByteBuffer) { + val info = taskInfos(tid) + val index = info.index + val task = taskSet.tasks(index) + info.markSuccessful() + val result = ser.deserialize[TaskResult[_]](serializedData, getClass.getClassLoader) + result.metrics.resultSize = serializedData.limit() + sched.listener.taskEnded(task, Success, result.value, result.accumUpdates, info, result.metrics) + numFinished += 1 + decreaseRunningTasks(1) + finished(index) = true + if (numFinished == numTasks) { + sched.taskSetFinished(this) + } + } + + def taskFailed(tid: Long, state: TaskState, serializedData: ByteBuffer) { + val info = taskInfos(tid) + val index = info.index + val task = taskSet.tasks(index) + info.markFailed() + decreaseRunningTasks(1) + val reason: ExceptionFailure = ser.deserialize[ExceptionFailure]( + serializedData, getClass.getClassLoader) + sched.listener.taskEnded(task, reason, null, null, info, reason.metrics.getOrElse(null)) + if (!finished(index)) { + copiesRunning(index) -= 1 + numFailures(index) += 1 + val locs = reason.stackTrace.map(loc => "\tat %s".format(loc.toString)) + logInfo("Loss was due to %s\n%s\n%s".format( + reason.className, reason.description, locs.mkString("\n"))) + if (numFailures(index) > MAX_TASK_FAILURES) { + val errorMessage = "Task %s:%d failed more than %d times; aborting job %s".format( + taskSet.id, index, 4, reason.description) + decreaseRunningTasks(runningTasks) + sched.listener.taskSetFailed(taskSet, errorMessage) + // need to delete failed Taskset from schedule queue + sched.taskSetFinished(this) + } + } + } + + override def error(message: String) { + } +} diff --git a/core/src/main/scala/org/apache/spark/scheduler/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/mesos/CoarseMesosSchedulerBackend.scala new file mode 100644 index 0000000000..3dbe61d706 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/scheduler/mesos/CoarseMesosSchedulerBackend.scala @@ -0,0 +1,286 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler.mesos + +import com.google.protobuf.ByteString + +import org.apache.mesos.{Scheduler => MScheduler} +import org.apache.mesos._ +import org.apache.mesos.Protos.{TaskInfo => MesosTaskInfo, TaskState => MesosTaskState, _} + +import org.apache.spark.{SparkException, Logging, SparkContext} +import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} +import scala.collection.JavaConversions._ +import java.io.File +import org.apache.spark.scheduler.cluster._ +import java.util.{ArrayList => JArrayList, List => JList} +import java.util.Collections +import org.apache.spark.TaskState + +/** + * A SchedulerBackend that runs tasks on Mesos, but uses "coarse-grained" tasks, where it holds + * onto each Mesos node for the duration of the Spark job instead of relinquishing cores whenever + * a task is done. It launches Spark tasks within the coarse-grained Mesos tasks using the + * StandaloneBackend mechanism. This class is useful for lower and more predictable latency. + * + * Unfortunately this has a bit of duplication from MesosSchedulerBackend, but it seems hard to + * remove this. + */ +private[spark] class CoarseMesosSchedulerBackend( + scheduler: ClusterScheduler, + sc: SparkContext, + master: String, + appName: String) + extends StandaloneSchedulerBackend(scheduler, sc.env.actorSystem) + with MScheduler + with Logging { + + val MAX_SLAVE_FAILURES = 2 // Blacklist a slave after this many failures + + // Lock used to wait for scheduler to be registered + var isRegistered = false + val registeredLock = new Object() + + // Driver for talking to Mesos + var driver: SchedulerDriver = null + + // Maximum number of cores to acquire (TODO: we'll need more flexible controls here) + val maxCores = System.getProperty("spark.cores.max", Int.MaxValue.toString).toInt + + // Cores we have acquired with each Mesos task ID + val coresByTaskId = new HashMap[Int, Int] + var totalCoresAcquired = 0 + + val slaveIdsWithExecutors = new HashSet[String] + + val taskIdToSlaveId = new HashMap[Int, String] + val failuresBySlaveId = new HashMap[String, Int] // How many times tasks on each slave failed + + val sparkHome = sc.getSparkHome().getOrElse(throw new SparkException( + "Spark home is not set; set it through the spark.home system " + + "property, the SPARK_HOME environment variable or the SparkContext constructor")) + + val extraCoresPerSlave = System.getProperty("spark.mesos.extra.cores", "0").toInt + + var nextMesosTaskId = 0 + + def newMesosTaskId(): Int = { + val id = nextMesosTaskId + nextMesosTaskId += 1 + id + } + + override def start() { + super.start() + + synchronized { + new Thread("CoarseMesosSchedulerBackend driver") { + setDaemon(true) + override def run() { + val scheduler = CoarseMesosSchedulerBackend.this + val fwInfo = FrameworkInfo.newBuilder().setUser("").setName(appName).build() + driver = new MesosSchedulerDriver(scheduler, fwInfo, master) + try { { + val ret = driver.run() + logInfo("driver.run() returned with code " + ret) + } + } catch { + case e: Exception => logError("driver.run() failed", e) + } + } + }.start() + + waitForRegister() + } + } + + def createCommand(offer: Offer, numCores: Int): CommandInfo = { + val environment = Environment.newBuilder() + sc.executorEnvs.foreach { case (key, value) => + environment.addVariables(Environment.Variable.newBuilder() + .setName(key) + .setValue(value) + .build()) + } + val command = CommandInfo.newBuilder() + .setEnvironment(environment) + val driverUrl = "akka://spark@%s:%s/user/%s".format( + System.getProperty("spark.driver.host"), + System.getProperty("spark.driver.port"), + StandaloneSchedulerBackend.ACTOR_NAME) + val uri = System.getProperty("spark.executor.uri") + if (uri == null) { + val runScript = new File(sparkHome, "spark-class").getCanonicalPath + command.setValue( + "\"%s\" org.apache.spark.executor.StandaloneExecutorBackend %s %s %s %d".format( + runScript, driverUrl, offer.getSlaveId.getValue, offer.getHostname, numCores)) + } else { + // Grab everything to the first '.'. We'll use that and '*' to + // glob the directory "correctly". + val basename = uri.split('/').last.split('.').head + command.setValue( + "cd %s*; ./spark-class org.apache.spark.executor.StandaloneExecutorBackend %s %s %s %d".format( + basename, driverUrl, offer.getSlaveId.getValue, offer.getHostname, numCores)) + command.addUris(CommandInfo.URI.newBuilder().setValue(uri)) + } + return command.build() + } + + override def offerRescinded(d: SchedulerDriver, o: OfferID) {} + + override def registered(d: SchedulerDriver, frameworkId: FrameworkID, masterInfo: MasterInfo) { + logInfo("Registered as framework ID " + frameworkId.getValue) + registeredLock.synchronized { + isRegistered = true + registeredLock.notifyAll() + } + } + + def waitForRegister() { + registeredLock.synchronized { + while (!isRegistered) { + registeredLock.wait() + } + } + } + + override def disconnected(d: SchedulerDriver) {} + + override def reregistered(d: SchedulerDriver, masterInfo: MasterInfo) {} + + /** + * Method called by Mesos to offer resources on slaves. We respond by launching an executor, + * unless we've already launched more than we wanted to. + */ + override def resourceOffers(d: SchedulerDriver, offers: JList[Offer]) { + synchronized { + val filters = Filters.newBuilder().setRefuseSeconds(-1).build() + + for (offer <- offers) { + val slaveId = offer.getSlaveId.toString + val mem = getResource(offer.getResourcesList, "mem") + val cpus = getResource(offer.getResourcesList, "cpus").toInt + if (totalCoresAcquired < maxCores && mem >= executorMemory && cpus >= 1 && + failuresBySlaveId.getOrElse(slaveId, 0) < MAX_SLAVE_FAILURES && + !slaveIdsWithExecutors.contains(slaveId)) { + // Launch an executor on the slave + val cpusToUse = math.min(cpus, maxCores - totalCoresAcquired) + val taskId = newMesosTaskId() + taskIdToSlaveId(taskId) = slaveId + slaveIdsWithExecutors += slaveId + coresByTaskId(taskId) = cpusToUse + val task = MesosTaskInfo.newBuilder() + .setTaskId(TaskID.newBuilder().setValue(taskId.toString).build()) + .setSlaveId(offer.getSlaveId) + .setCommand(createCommand(offer, cpusToUse + extraCoresPerSlave)) + .setName("Task " + taskId) + .addResources(createResource("cpus", cpusToUse)) + .addResources(createResource("mem", executorMemory)) + .build() + d.launchTasks(offer.getId, Collections.singletonList(task), filters) + } else { + // Filter it out + d.launchTasks(offer.getId, Collections.emptyList[MesosTaskInfo](), filters) + } + } + } + } + + /** Helper function to pull out a resource from a Mesos Resources protobuf */ + private def getResource(res: JList[Resource], name: String): Double = { + for (r <- res if r.getName == name) { + return r.getScalar.getValue + } + // If we reached here, no resource with the required name was present + throw new IllegalArgumentException("No resource called " + name + " in " + res) + } + + /** Build a Mesos resource protobuf object */ + private def createResource(resourceName: String, quantity: Double): Protos.Resource = { + Resource.newBuilder() + .setName(resourceName) + .setType(Value.Type.SCALAR) + .setScalar(Value.Scalar.newBuilder().setValue(quantity).build()) + .build() + } + + /** Check whether a Mesos task state represents a finished task */ + private def isFinished(state: MesosTaskState) = { + state == MesosTaskState.TASK_FINISHED || + state == MesosTaskState.TASK_FAILED || + state == MesosTaskState.TASK_KILLED || + state == MesosTaskState.TASK_LOST + } + + override def statusUpdate(d: SchedulerDriver, status: TaskStatus) { + val taskId = status.getTaskId.getValue.toInt + val state = status.getState + logInfo("Mesos task " + taskId + " is now " + state) + synchronized { + if (isFinished(state)) { + val slaveId = taskIdToSlaveId(taskId) + slaveIdsWithExecutors -= slaveId + taskIdToSlaveId -= taskId + // Remove the cores we have remembered for this task, if it's in the hashmap + for (cores <- coresByTaskId.get(taskId)) { + totalCoresAcquired -= cores + coresByTaskId -= taskId + } + // If it was a failure, mark the slave as failed for blacklisting purposes + if (state == MesosTaskState.TASK_FAILED || state == MesosTaskState.TASK_LOST) { + failuresBySlaveId(slaveId) = failuresBySlaveId.getOrElse(slaveId, 0) + 1 + if (failuresBySlaveId(slaveId) >= MAX_SLAVE_FAILURES) { + logInfo("Blacklisting Mesos slave " + slaveId + " due to too many failures; " + + "is Spark installed on it?") + } + } + driver.reviveOffers() // In case we'd rejected everything before but have now lost a node + } + } + } + + override def error(d: SchedulerDriver, message: String) { + logError("Mesos error: " + message) + scheduler.error(message) + } + + override def stop() { + super.stop() + if (driver != null) { + driver.stop() + } + } + + override def frameworkMessage(d: SchedulerDriver, e: ExecutorID, s: SlaveID, b: Array[Byte]) {} + + override def slaveLost(d: SchedulerDriver, slaveId: SlaveID) { + logInfo("Mesos slave lost: " + slaveId.getValue) + synchronized { + if (slaveIdsWithExecutors.contains(slaveId.getValue)) { + // Note that the slave ID corresponds to the executor ID on that slave + slaveIdsWithExecutors -= slaveId.getValue + removeExecutor(slaveId.getValue, "Mesos slave lost") + } + } + } + + override def executorLost(d: SchedulerDriver, e: ExecutorID, s: SlaveID, status: Int) { + logInfo("Executor lost: %s, marking slave %s as lost".format(e.getValue, s.getValue)) + slaveLost(d, s) + } +} diff --git a/core/src/main/scala/org/apache/spark/scheduler/mesos/MesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/mesos/MesosSchedulerBackend.scala new file mode 100644 index 0000000000..541f86e338 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/scheduler/mesos/MesosSchedulerBackend.scala @@ -0,0 +1,343 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler.mesos + +import com.google.protobuf.ByteString + +import org.apache.mesos.{Scheduler => MScheduler} +import org.apache.mesos._ +import org.apache.mesos.Protos.{TaskInfo => MesosTaskInfo, TaskState => MesosTaskState, _} + +import org.apache.spark.{SparkException, Logging, SparkContext} +import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} +import scala.collection.JavaConversions._ +import java.io.File +import org.apache.spark.scheduler.cluster._ +import java.util.{ArrayList => JArrayList, List => JList} +import java.util.Collections +import org.apache.spark.TaskState +import org.apache.spark.util.Utils + +/** + * A SchedulerBackend for running fine-grained tasks on Mesos. Each Spark task is mapped to a + * separate Mesos task, allowing multiple applications to share cluster nodes both in space (tasks + * from multiple apps can run on different cores) and in time (a core can switch ownership). + */ +private[spark] class MesosSchedulerBackend( + scheduler: ClusterScheduler, + sc: SparkContext, + master: String, + appName: String) + extends SchedulerBackend + with MScheduler + with Logging { + + // Lock used to wait for scheduler to be registered + var isRegistered = false + val registeredLock = new Object() + + // Driver for talking to Mesos + var driver: SchedulerDriver = null + + // Which slave IDs we have executors on + val slaveIdsWithExecutors = new HashSet[String] + val taskIdToSlaveId = new HashMap[Long, String] + + // An ExecutorInfo for our tasks + var execArgs: Array[Byte] = null + + var classLoader: ClassLoader = null + + override def start() { + synchronized { + classLoader = Thread.currentThread.getContextClassLoader + + new Thread("MesosSchedulerBackend driver") { + setDaemon(true) + override def run() { + val scheduler = MesosSchedulerBackend.this + val fwInfo = FrameworkInfo.newBuilder().setUser("").setName(appName).build() + driver = new MesosSchedulerDriver(scheduler, fwInfo, master) + try { + val ret = driver.run() + logInfo("driver.run() returned with code " + ret) + } catch { + case e: Exception => logError("driver.run() failed", e) + } + } + }.start() + + waitForRegister() + } + } + + def createExecutorInfo(execId: String): ExecutorInfo = { + val sparkHome = sc.getSparkHome().getOrElse(throw new SparkException( + "Spark home is not set; set it through the spark.home system " + + "property, the SPARK_HOME environment variable or the SparkContext constructor")) + val environment = Environment.newBuilder() + sc.executorEnvs.foreach { case (key, value) => + environment.addVariables(Environment.Variable.newBuilder() + .setName(key) + .setValue(value) + .build()) + } + val command = CommandInfo.newBuilder() + .setEnvironment(environment) + val uri = System.getProperty("spark.executor.uri") + if (uri == null) { + command.setValue(new File(sparkHome, "spark-executor").getCanonicalPath) + } else { + // Grab everything to the first '.'. We'll use that and '*' to + // glob the directory "correctly". + val basename = uri.split('/').last.split('.').head + command.setValue("cd %s*; ./spark-executor".format(basename)) + command.addUris(CommandInfo.URI.newBuilder().setValue(uri)) + } + val memory = Resource.newBuilder() + .setName("mem") + .setType(Value.Type.SCALAR) + .setScalar(Value.Scalar.newBuilder().setValue(executorMemory).build()) + .build() + ExecutorInfo.newBuilder() + .setExecutorId(ExecutorID.newBuilder().setValue(execId).build()) + .setCommand(command) + .setData(ByteString.copyFrom(createExecArg())) + .addResources(memory) + .build() + } + + /** + * Create and serialize the executor argument to pass to Mesos. Our executor arg is an array + * containing all the spark.* system properties in the form of (String, String) pairs. + */ + private def createExecArg(): Array[Byte] = { + if (execArgs == null) { + val props = new HashMap[String, String] + val iterator = System.getProperties.entrySet.iterator + while (iterator.hasNext) { + val entry = iterator.next + val (key, value) = (entry.getKey.toString, entry.getValue.toString) + if (key.startsWith("spark.")) { + props(key) = value + } + } + // Serialize the map as an array of (String, String) pairs + execArgs = Utils.serialize(props.toArray) + } + return execArgs + } + + private def setClassLoader(): ClassLoader = { + val oldClassLoader = Thread.currentThread.getContextClassLoader + Thread.currentThread.setContextClassLoader(classLoader) + return oldClassLoader + } + + private def restoreClassLoader(oldClassLoader: ClassLoader) { + Thread.currentThread.setContextClassLoader(oldClassLoader) + } + + override def offerRescinded(d: SchedulerDriver, o: OfferID) {} + + override def registered(d: SchedulerDriver, frameworkId: FrameworkID, masterInfo: MasterInfo) { + val oldClassLoader = setClassLoader() + try { + logInfo("Registered as framework ID " + frameworkId.getValue) + registeredLock.synchronized { + isRegistered = true + registeredLock.notifyAll() + } + } finally { + restoreClassLoader(oldClassLoader) + } + } + + def waitForRegister() { + registeredLock.synchronized { + while (!isRegistered) { + registeredLock.wait() + } + } + } + + override def disconnected(d: SchedulerDriver) {} + + override def reregistered(d: SchedulerDriver, masterInfo: MasterInfo) {} + + /** + * Method called by Mesos to offer resources on slaves. We resond by asking our active task sets + * for tasks in order of priority. We fill each node with tasks in a round-robin manner so that + * tasks are balanced across the cluster. + */ + override def resourceOffers(d: SchedulerDriver, offers: JList[Offer]) { + val oldClassLoader = setClassLoader() + try { + synchronized { + // Build a big list of the offerable workers, and remember their indices so that we can + // figure out which Offer to reply to for each worker + val offerableIndices = new ArrayBuffer[Int] + val offerableWorkers = new ArrayBuffer[WorkerOffer] + + def enoughMemory(o: Offer) = { + val mem = getResource(o.getResourcesList, "mem") + val slaveId = o.getSlaveId.getValue + mem >= executorMemory || slaveIdsWithExecutors.contains(slaveId) + } + + for ((offer, index) <- offers.zipWithIndex if enoughMemory(offer)) { + offerableIndices += index + offerableWorkers += new WorkerOffer( + offer.getSlaveId.getValue, + offer.getHostname, + getResource(offer.getResourcesList, "cpus").toInt) + } + + // Call into the ClusterScheduler + val taskLists = scheduler.resourceOffers(offerableWorkers) + + // Build a list of Mesos tasks for each slave + val mesosTasks = offers.map(o => Collections.emptyList[MesosTaskInfo]()) + for ((taskList, index) <- taskLists.zipWithIndex) { + if (!taskList.isEmpty) { + val offerNum = offerableIndices(index) + val slaveId = offers(offerNum).getSlaveId.getValue + slaveIdsWithExecutors += slaveId + mesosTasks(offerNum) = new JArrayList[MesosTaskInfo](taskList.size) + for (taskDesc <- taskList) { + taskIdToSlaveId(taskDesc.taskId) = slaveId + mesosTasks(offerNum).add(createMesosTask(taskDesc, slaveId)) + } + } + } + + // Reply to the offers + val filters = Filters.newBuilder().setRefuseSeconds(1).build() // TODO: lower timeout? + for (i <- 0 until offers.size) { + d.launchTasks(offers(i).getId, mesosTasks(i), filters) + } + } + } finally { + restoreClassLoader(oldClassLoader) + } + } + + /** Helper function to pull out a resource from a Mesos Resources protobuf */ + def getResource(res: JList[Resource], name: String): Double = { + for (r <- res if r.getName == name) { + return r.getScalar.getValue + } + // If we reached here, no resource with the required name was present + throw new IllegalArgumentException("No resource called " + name + " in " + res) + } + + /** Turn a Spark TaskDescription into a Mesos task */ + def createMesosTask(task: TaskDescription, slaveId: String): MesosTaskInfo = { + val taskId = TaskID.newBuilder().setValue(task.taskId.toString).build() + val cpuResource = Resource.newBuilder() + .setName("cpus") + .setType(Value.Type.SCALAR) + .setScalar(Value.Scalar.newBuilder().setValue(1).build()) + .build() + return MesosTaskInfo.newBuilder() + .setTaskId(taskId) + .setSlaveId(SlaveID.newBuilder().setValue(slaveId).build()) + .setExecutor(createExecutorInfo(slaveId)) + .setName(task.name) + .addResources(cpuResource) + .setData(ByteString.copyFrom(task.serializedTask)) + .build() + } + + /** Check whether a Mesos task state represents a finished task */ + def isFinished(state: MesosTaskState) = { + state == MesosTaskState.TASK_FINISHED || + state == MesosTaskState.TASK_FAILED || + state == MesosTaskState.TASK_KILLED || + state == MesosTaskState.TASK_LOST + } + + override def statusUpdate(d: SchedulerDriver, status: TaskStatus) { + val oldClassLoader = setClassLoader() + try { + val tid = status.getTaskId.getValue.toLong + val state = TaskState.fromMesos(status.getState) + synchronized { + if (status.getState == MesosTaskState.TASK_LOST && taskIdToSlaveId.contains(tid)) { + // We lost the executor on this slave, so remember that it's gone + slaveIdsWithExecutors -= taskIdToSlaveId(tid) + } + if (isFinished(status.getState)) { + taskIdToSlaveId.remove(tid) + } + } + scheduler.statusUpdate(tid, state, status.getData.asReadOnlyByteBuffer) + } finally { + restoreClassLoader(oldClassLoader) + } + } + + override def error(d: SchedulerDriver, message: String) { + val oldClassLoader = setClassLoader() + try { + logError("Mesos error: " + message) + scheduler.error(message) + } finally { + restoreClassLoader(oldClassLoader) + } + } + + override def stop() { + if (driver != null) { + driver.stop() + } + } + + override def reviveOffers() { + driver.reviveOffers() + } + + override def frameworkMessage(d: SchedulerDriver, e: ExecutorID, s: SlaveID, b: Array[Byte]) {} + + private def recordSlaveLost(d: SchedulerDriver, slaveId: SlaveID, reason: ExecutorLossReason) { + val oldClassLoader = setClassLoader() + try { + logInfo("Mesos slave lost: " + slaveId.getValue) + synchronized { + slaveIdsWithExecutors -= slaveId.getValue + } + scheduler.executorLost(slaveId.getValue, reason) + } finally { + restoreClassLoader(oldClassLoader) + } + } + + override def slaveLost(d: SchedulerDriver, slaveId: SlaveID) { + recordSlaveLost(d, slaveId, SlaveLost()) + } + + override def executorLost(d: SchedulerDriver, executorId: ExecutorID, + slaveId: SlaveID, status: Int) { + logInfo("Executor lost: %s, marking slave %s as lost".format(executorId.getValue, + slaveId.getValue)) + recordSlaveLost(d, slaveId, ExecutorExited(status)) + } + + // TODO: query Mesos for number of cores + override def defaultParallelism() = System.getProperty("spark.default.parallelism", "8").toInt +} diff --git a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala new file mode 100644 index 0000000000..4de81617b1 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.serializer + +import java.io._ +import java.nio.ByteBuffer + +import org.apache.spark.util.ByteBufferInputStream + +private[spark] class JavaSerializationStream(out: OutputStream) extends SerializationStream { + val objOut = new ObjectOutputStream(out) + def writeObject[T](t: T): SerializationStream = { objOut.writeObject(t); this } + def flush() { objOut.flush() } + def close() { objOut.close() } +} + +private[spark] class JavaDeserializationStream(in: InputStream, loader: ClassLoader) +extends DeserializationStream { + val objIn = new ObjectInputStream(in) { + override def resolveClass(desc: ObjectStreamClass) = + Class.forName(desc.getName, false, loader) + } + + def readObject[T](): T = objIn.readObject().asInstanceOf[T] + def close() { objIn.close() } +} + +private[spark] class JavaSerializerInstance extends SerializerInstance { + def serialize[T](t: T): ByteBuffer = { + val bos = new ByteArrayOutputStream() + val out = serializeStream(bos) + out.writeObject(t) + out.close() + ByteBuffer.wrap(bos.toByteArray) + } + + def deserialize[T](bytes: ByteBuffer): T = { + val bis = new ByteBufferInputStream(bytes) + val in = deserializeStream(bis) + in.readObject().asInstanceOf[T] + } + + def deserialize[T](bytes: ByteBuffer, loader: ClassLoader): T = { + val bis = new ByteBufferInputStream(bytes) + val in = deserializeStream(bis, loader) + in.readObject().asInstanceOf[T] + } + + def serializeStream(s: OutputStream): SerializationStream = { + new JavaSerializationStream(s) + } + + def deserializeStream(s: InputStream): DeserializationStream = { + new JavaDeserializationStream(s, Thread.currentThread.getContextClassLoader) + } + + def deserializeStream(s: InputStream, loader: ClassLoader): DeserializationStream = { + new JavaDeserializationStream(s, loader) + } +} + +/** + * A Spark serializer that uses Java's built-in serialization. + */ +class JavaSerializer extends Serializer { + def newInstance(): SerializerInstance = new JavaSerializerInstance +} diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala new file mode 100644 index 0000000000..24ef204aa1 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala @@ -0,0 +1,159 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.serializer + +import java.nio.ByteBuffer +import java.io.{EOFException, InputStream, OutputStream} + +import com.esotericsoftware.kryo.serializers.{JavaSerializer => KryoJavaSerializer} +import com.esotericsoftware.kryo.{KryoException, Kryo} +import com.esotericsoftware.kryo.io.{Input => KryoInput, Output => KryoOutput} +import com.twitter.chill.ScalaKryoInstantiator + +import org.apache.spark.{SerializableWritable, Logging} +import org.apache.spark.storage.{GetBlock, GotBlock, PutBlock, StorageLevel} + +import org.apache.spark.broadcast.HttpBroadcast + +/** + * A Spark serializer that uses the [[http://code.google.com/p/kryo/wiki/V1Documentation Kryo 1.x library]]. + */ +class KryoSerializer extends org.apache.spark.serializer.Serializer with Logging { + private val bufferSize = System.getProperty("spark.kryoserializer.buffer.mb", "2").toInt * 1024 * 1024 + + def newKryoOutput() = new KryoOutput(bufferSize) + + def newKryoInput() = new KryoInput(bufferSize) + + def newKryo(): Kryo = { + val instantiator = new ScalaKryoInstantiator + val kryo = instantiator.newKryo() + val classLoader = Thread.currentThread.getContextClassLoader + + // Register some commonly used classes + val toRegister: Seq[AnyRef] = Seq( + ByteBuffer.allocate(1), + StorageLevel.MEMORY_ONLY, + PutBlock("1", ByteBuffer.allocate(1), StorageLevel.MEMORY_ONLY), + GotBlock("1", ByteBuffer.allocate(1)), + GetBlock("1") + ) + + for (obj <- toRegister) kryo.register(obj.getClass) + + // Allow sending SerializableWritable + kryo.register(classOf[SerializableWritable[_]], new KryoJavaSerializer()) + kryo.register(classOf[HttpBroadcast[_]], new KryoJavaSerializer()) + + // Allow the user to register their own classes by setting spark.kryo.registrator + try { + Option(System.getProperty("spark.kryo.registrator")).foreach { regCls => + logDebug("Running user registrator: " + regCls) + val reg = Class.forName(regCls, true, classLoader).newInstance().asInstanceOf[KryoRegistrator] + reg.registerClasses(kryo) + } + } catch { + case _: Exception => println("Failed to register spark.kryo.registrator") + } + + kryo.setClassLoader(classLoader) + + // Allow disabling Kryo reference tracking if user knows their object graphs don't have loops + kryo.setReferences(System.getProperty("spark.kryo.referenceTracking", "true").toBoolean) + + kryo + } + + def newInstance(): SerializerInstance = { + new KryoSerializerInstance(this) + } +} + +private[spark] +class KryoSerializationStream(kryo: Kryo, outStream: OutputStream) extends SerializationStream { + val output = new KryoOutput(outStream) + + def writeObject[T](t: T): SerializationStream = { + kryo.writeClassAndObject(output, t) + this + } + + def flush() { output.flush() } + def close() { output.close() } +} + +private[spark] +class KryoDeserializationStream(kryo: Kryo, inStream: InputStream) extends DeserializationStream { + val input = new KryoInput(inStream) + + def readObject[T](): T = { + try { + kryo.readClassAndObject(input).asInstanceOf[T] + } catch { + // DeserializationStream uses the EOF exception to indicate stopping condition. + case _: KryoException => throw new EOFException + } + } + + def close() { + // Kryo's Input automatically closes the input stream it is using. + input.close() + } +} + +private[spark] class KryoSerializerInstance(ks: KryoSerializer) extends SerializerInstance { + val kryo = ks.newKryo() + val output = ks.newKryoOutput() + val input = ks.newKryoInput() + + def serialize[T](t: T): ByteBuffer = { + output.clear() + kryo.writeClassAndObject(output, t) + ByteBuffer.wrap(output.toBytes) + } + + def deserialize[T](bytes: ByteBuffer): T = { + input.setBuffer(bytes.array) + kryo.readClassAndObject(input).asInstanceOf[T] + } + + def deserialize[T](bytes: ByteBuffer, loader: ClassLoader): T = { + val oldClassLoader = kryo.getClassLoader + kryo.setClassLoader(loader) + input.setBuffer(bytes.array) + val obj = kryo.readClassAndObject(input).asInstanceOf[T] + kryo.setClassLoader(oldClassLoader) + obj + } + + def serializeStream(s: OutputStream): SerializationStream = { + new KryoSerializationStream(kryo, s) + } + + def deserializeStream(s: InputStream): DeserializationStream = { + new KryoDeserializationStream(kryo, s) + } +} + +/** + * Interface implemented by clients to register their classes with Kryo when using Kryo + * serialization. + */ +trait KryoRegistrator { + def registerClasses(kryo: Kryo) +} diff --git a/core/src/main/scala/org/apache/spark/serializer/Serializer.scala b/core/src/main/scala/org/apache/spark/serializer/Serializer.scala new file mode 100644 index 0000000000..160cca4d6c --- /dev/null +++ b/core/src/main/scala/org/apache/spark/serializer/Serializer.scala @@ -0,0 +1,112 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.serializer + +import java.io.{EOFException, InputStream, OutputStream} +import java.nio.ByteBuffer + +import it.unimi.dsi.fastutil.io.FastByteArrayOutputStream + +import org.apache.spark.util.{NextIterator, ByteBufferInputStream} + + +/** + * A serializer. Because some serialization libraries are not thread safe, this class is used to + * create [[org.apache.spark.serializer.SerializerInstance]] objects that do the actual serialization and are + * guaranteed to only be called from one thread at a time. + */ +trait Serializer { + def newInstance(): SerializerInstance +} + + +/** + * An instance of a serializer, for use by one thread at a time. + */ +trait SerializerInstance { + def serialize[T](t: T): ByteBuffer + + def deserialize[T](bytes: ByteBuffer): T + + def deserialize[T](bytes: ByteBuffer, loader: ClassLoader): T + + def serializeStream(s: OutputStream): SerializationStream + + def deserializeStream(s: InputStream): DeserializationStream + + def serializeMany[T](iterator: Iterator[T]): ByteBuffer = { + // Default implementation uses serializeStream + val stream = new FastByteArrayOutputStream() + serializeStream(stream).writeAll(iterator) + val buffer = ByteBuffer.allocate(stream.position.toInt) + buffer.put(stream.array, 0, stream.position.toInt) + buffer.flip() + buffer + } + + def deserializeMany(buffer: ByteBuffer): Iterator[Any] = { + // Default implementation uses deserializeStream + buffer.rewind() + deserializeStream(new ByteBufferInputStream(buffer)).asIterator + } +} + + +/** + * A stream for writing serialized objects. + */ +trait SerializationStream { + def writeObject[T](t: T): SerializationStream + def flush(): Unit + def close(): Unit + + def writeAll[T](iter: Iterator[T]): SerializationStream = { + while (iter.hasNext) { + writeObject(iter.next()) + } + this + } +} + + +/** + * A stream for reading serialized objects. + */ +trait DeserializationStream { + def readObject[T](): T + def close(): Unit + + /** + * Read the elements of this stream through an iterator. This can only be called once, as + * reading each element will consume data from the input source. + */ + def asIterator: Iterator[Any] = new NextIterator[Any] { + override protected def getNext() = { + try { + readObject[Any]() + } catch { + case eof: EOFException => + finished = true + } + } + + override protected def close() { + DeserializationStream.this.close() + } + } +} diff --git a/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala b/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala new file mode 100644 index 0000000000..2955986fec --- /dev/null +++ b/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.serializer + +import java.util.concurrent.ConcurrentHashMap + + +/** + * A service that returns a serializer object given the serializer's class name. If a previous + * instance of the serializer object has been created, the get method returns that instead of + * creating a new one. + */ +private[spark] class SerializerManager { + + private val serializers = new ConcurrentHashMap[String, Serializer] + private var _default: Serializer = _ + + def default = _default + + def setDefault(clsName: String): Serializer = { + _default = get(clsName) + _default + } + + def get(clsName: String): Serializer = { + if (clsName == null) { + default + } else { + var serializer = serializers.get(clsName) + if (serializer != null) { + // If the serializer has been created previously, reuse that. + serializer + } else this.synchronized { + // Otherwise, create a new one. But make sure no other thread has attempted + // to create another new one at the same time. + serializer = serializers.get(clsName) + if (serializer == null) { + val clsLoader = Thread.currentThread.getContextClassLoader + serializer = + Class.forName(clsName, true, clsLoader).newInstance().asInstanceOf[Serializer] + serializers.put(clsName, serializer) + } + serializer + } + } + } +} diff --git a/core/src/main/scala/org/apache/spark/storage/BlockException.scala b/core/src/main/scala/org/apache/spark/storage/BlockException.scala new file mode 100644 index 0000000000..290dbce4f5 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/storage/BlockException.scala @@ -0,0 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.storage + +private[spark] +case class BlockException(blockId: String, message: String) extends Exception(message) + diff --git a/core/src/main/scala/org/apache/spark/storage/BlockFetchTracker.scala b/core/src/main/scala/org/apache/spark/storage/BlockFetchTracker.scala new file mode 100644 index 0000000000..2e0b0e6eda --- /dev/null +++ b/core/src/main/scala/org/apache/spark/storage/BlockFetchTracker.scala @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.storage + +private[spark] trait BlockFetchTracker { + def totalBlocks : Int + def numLocalBlocks: Int + def numRemoteBlocks: Int + def remoteFetchTime : Long + def fetchWaitTime: Long + def remoteBytesRead : Long +} diff --git a/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala new file mode 100644 index 0000000000..3aeda3879d --- /dev/null +++ b/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala @@ -0,0 +1,348 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.storage + +import java.nio.ByteBuffer +import java.util.concurrent.LinkedBlockingQueue + +import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.HashSet +import scala.collection.mutable.Queue + +import io.netty.buffer.ByteBuf + +import org.apache.spark.Logging +import org.apache.spark.SparkException +import org.apache.spark.network.BufferMessage +import org.apache.spark.network.ConnectionManagerId +import org.apache.spark.network.netty.ShuffleCopier +import org.apache.spark.serializer.Serializer +import org.apache.spark.util.Utils + + +/** + * A block fetcher iterator interface. There are two implementations: + * + * BasicBlockFetcherIterator: uses a custom-built NIO communication layer. + * NettyBlockFetcherIterator: uses Netty (OIO) as the communication layer. + * + * Eventually we would like the two to converge and use a single NIO-based communication layer, + * but extensive tests show that under some circumstances (e.g. large shuffles with lots of cores), + * NIO would perform poorly and thus the need for the Netty OIO one. + */ + +private[storage] +trait BlockFetcherIterator extends Iterator[(String, Option[Iterator[Any]])] + with Logging with BlockFetchTracker { + def initialize() +} + + +private[storage] +object BlockFetcherIterator { + + // A request to fetch one or more blocks, complete with their sizes + class FetchRequest(val address: BlockManagerId, val blocks: Seq[(String, Long)]) { + val size = blocks.map(_._2).sum + } + + // A result of a fetch. Includes the block ID, size in bytes, and a function to deserialize + // the block (since we want all deserializaton to happen in the calling thread); can also + // represent a fetch failure if size == -1. + class FetchResult(val blockId: String, val size: Long, val deserialize: () => Iterator[Any]) { + def failed: Boolean = size == -1 + } + + class BasicBlockFetcherIterator( + private val blockManager: BlockManager, + val blocksByAddress: Seq[(BlockManagerId, Seq[(String, Long)])], + serializer: Serializer) + extends BlockFetcherIterator { + + import blockManager._ + + private var _remoteBytesRead = 0l + private var _remoteFetchTime = 0l + private var _fetchWaitTime = 0l + + if (blocksByAddress == null) { + throw new IllegalArgumentException("BlocksByAddress is null") + } + + // Total number blocks fetched (local + remote). Also number of FetchResults expected + protected var _numBlocksToFetch = 0 + + protected var startTime = System.currentTimeMillis + + // This represents the number of local blocks, also counting zero-sized blocks + private var numLocal = 0 + // BlockIds for local blocks that need to be fetched. Excludes zero-sized blocks + protected val localBlocksToFetch = new ArrayBuffer[String]() + + // This represents the number of remote blocks, also counting zero-sized blocks + private var numRemote = 0 + // BlockIds for remote blocks that need to be fetched. Excludes zero-sized blocks + protected val remoteBlocksToFetch = new HashSet[String]() + + // A queue to hold our results. + protected val results = new LinkedBlockingQueue[FetchResult] + + // Queue of fetch requests to issue; we'll pull requests off this gradually to make sure that + // the number of bytes in flight is limited to maxBytesInFlight + private val fetchRequests = new Queue[FetchRequest] + + // Current bytes in flight from our requests + private var bytesInFlight = 0L + + protected def sendRequest(req: FetchRequest) { + logDebug("Sending request for %d blocks (%s) from %s".format( + req.blocks.size, Utils.bytesToString(req.size), req.address.hostPort)) + val cmId = new ConnectionManagerId(req.address.host, req.address.port) + val blockMessageArray = new BlockMessageArray(req.blocks.map { + case (blockId, size) => BlockMessage.fromGetBlock(GetBlock(blockId)) + }) + bytesInFlight += req.size + val sizeMap = req.blocks.toMap // so we can look up the size of each blockID + val fetchStart = System.currentTimeMillis() + val future = connectionManager.sendMessageReliably(cmId, blockMessageArray.toBufferMessage) + future.onSuccess { + case Some(message) => { + val fetchDone = System.currentTimeMillis() + _remoteFetchTime += fetchDone - fetchStart + val bufferMessage = message.asInstanceOf[BufferMessage] + val blockMessageArray = BlockMessageArray.fromBufferMessage(bufferMessage) + for (blockMessage <- blockMessageArray) { + if (blockMessage.getType != BlockMessage.TYPE_GOT_BLOCK) { + throw new SparkException( + "Unexpected message " + blockMessage.getType + " received from " + cmId) + } + val blockId = blockMessage.getId + val networkSize = blockMessage.getData.limit() + results.put(new FetchResult(blockId, sizeMap(blockId), + () => dataDeserialize(blockId, blockMessage.getData, serializer))) + _remoteBytesRead += networkSize + logDebug("Got remote block " + blockId + " after " + Utils.getUsedTimeMs(startTime)) + } + } + case None => { + logError("Could not get block(s) from " + cmId) + for ((blockId, size) <- req.blocks) { + results.put(new FetchResult(blockId, -1, null)) + } + } + } + } + + protected def splitLocalRemoteBlocks(): ArrayBuffer[FetchRequest] = { + // Split local and remote blocks. Remote blocks are further split into FetchRequests of size + // at most maxBytesInFlight in order to limit the amount of data in flight. + val remoteRequests = new ArrayBuffer[FetchRequest] + for ((address, blockInfos) <- blocksByAddress) { + if (address == blockManagerId) { + numLocal = blockInfos.size + // Filter out zero-sized blocks + localBlocksToFetch ++= blockInfos.filter(_._2 != 0).map(_._1) + _numBlocksToFetch += localBlocksToFetch.size + } else { + numRemote += blockInfos.size + // Make our requests at least maxBytesInFlight / 5 in length; the reason to keep them + // smaller than maxBytesInFlight is to allow multiple, parallel fetches from up to 5 + // nodes, rather than blocking on reading output from one node. + val minRequestSize = math.max(maxBytesInFlight / 5, 1L) + logInfo("maxBytesInFlight: " + maxBytesInFlight + ", minRequest: " + minRequestSize) + val iterator = blockInfos.iterator + var curRequestSize = 0L + var curBlocks = new ArrayBuffer[(String, Long)] + while (iterator.hasNext) { + val (blockId, size) = iterator.next() + // Skip empty blocks + if (size > 0) { + curBlocks += ((blockId, size)) + remoteBlocksToFetch += blockId + _numBlocksToFetch += 1 + curRequestSize += size + } else if (size < 0) { + throw new BlockException(blockId, "Negative block size " + size) + } + if (curRequestSize >= minRequestSize) { + // Add this FetchRequest + remoteRequests += new FetchRequest(address, curBlocks) + curRequestSize = 0 + curBlocks = new ArrayBuffer[(String, Long)] + } + } + // Add in the final request + if (!curBlocks.isEmpty) { + remoteRequests += new FetchRequest(address, curBlocks) + } + } + } + logInfo("Getting " + _numBlocksToFetch + " non-zero-bytes blocks out of " + + totalBlocks + " blocks") + remoteRequests + } + + protected def getLocalBlocks() { + // Get the local blocks while remote blocks are being fetched. Note that it's okay to do + // these all at once because they will just memory-map some files, so they won't consume + // any memory that might exceed our maxBytesInFlight + for (id <- localBlocksToFetch) { + getLocalFromDisk(id, serializer) match { + case Some(iter) => { + // Pass 0 as size since it's not in flight + results.put(new FetchResult(id, 0, () => iter)) + logDebug("Got local block " + id) + } + case None => { + throw new BlockException(id, "Could not get block " + id + " from local machine") + } + } + } + } + + override def initialize() { + // Split local and remote blocks. + val remoteRequests = splitLocalRemoteBlocks() + // Add the remote requests into our queue in a random order + fetchRequests ++= Utils.randomize(remoteRequests) + + // Send out initial requests for blocks, up to our maxBytesInFlight + while (!fetchRequests.isEmpty && + (bytesInFlight == 0 || bytesInFlight + fetchRequests.front.size <= maxBytesInFlight)) { + sendRequest(fetchRequests.dequeue()) + } + + val numGets = remoteRequests.size - fetchRequests.size + logInfo("Started " + numGets + " remote gets in " + Utils.getUsedTimeMs(startTime)) + + // Get Local Blocks + startTime = System.currentTimeMillis + getLocalBlocks() + logDebug("Got local blocks in " + Utils.getUsedTimeMs(startTime) + " ms") + } + + //an iterator that will read fetched blocks off the queue as they arrive. + @volatile protected var resultsGotten = 0 + + override def hasNext: Boolean = resultsGotten < _numBlocksToFetch + + override def next(): (String, Option[Iterator[Any]]) = { + resultsGotten += 1 + val startFetchWait = System.currentTimeMillis() + val result = results.take() + val stopFetchWait = System.currentTimeMillis() + _fetchWaitTime += (stopFetchWait - startFetchWait) + if (! result.failed) bytesInFlight -= result.size + while (!fetchRequests.isEmpty && + (bytesInFlight == 0 || bytesInFlight + fetchRequests.front.size <= maxBytesInFlight)) { + sendRequest(fetchRequests.dequeue()) + } + (result.blockId, if (result.failed) None else Some(result.deserialize())) + } + + // Implementing BlockFetchTracker trait. + override def totalBlocks: Int = numLocal + numRemote + override def numLocalBlocks: Int = numLocal + override def numRemoteBlocks: Int = numRemote + override def remoteFetchTime: Long = _remoteFetchTime + override def fetchWaitTime: Long = _fetchWaitTime + override def remoteBytesRead: Long = _remoteBytesRead + } + // End of BasicBlockFetcherIterator + + class NettyBlockFetcherIterator( + blockManager: BlockManager, + blocksByAddress: Seq[(BlockManagerId, Seq[(String, Long)])], + serializer: Serializer) + extends BasicBlockFetcherIterator(blockManager, blocksByAddress, serializer) { + + import blockManager._ + + val fetchRequestsSync = new LinkedBlockingQueue[FetchRequest] + + private def startCopiers(numCopiers: Int): List[_ <: Thread] = { + (for ( i <- Range(0,numCopiers) ) yield { + val copier = new Thread { + override def run(){ + try { + while(!isInterrupted && !fetchRequestsSync.isEmpty) { + sendRequest(fetchRequestsSync.take()) + } + } catch { + case x: InterruptedException => logInfo("Copier Interrupted") + //case _ => throw new SparkException("Exception Throw in Shuffle Copier") + } + } + } + copier.start + copier + }).toList + } + + // keep this to interrupt the threads when necessary + private def stopCopiers() { + for (copier <- copiers) { + copier.interrupt() + } + } + + override protected def sendRequest(req: FetchRequest) { + + def putResult(blockId: String, blockSize: Long, blockData: ByteBuf) { + val fetchResult = new FetchResult(blockId, blockSize, + () => dataDeserialize(blockId, blockData.nioBuffer, serializer)) + results.put(fetchResult) + } + + logDebug("Sending request for %d blocks (%s) from %s".format( + req.blocks.size, Utils.bytesToString(req.size), req.address.host)) + val cmId = new ConnectionManagerId(req.address.host, req.address.nettyPort) + val cpier = new ShuffleCopier + cpier.getBlocks(cmId, req.blocks, putResult) + logDebug("Sent request for remote blocks " + req.blocks + " from " + req.address.host ) + } + + private var copiers: List[_ <: Thread] = null + + override def initialize() { + // Split Local Remote Blocks and set numBlocksToFetch + val remoteRequests = splitLocalRemoteBlocks() + // Add the remote requests into our queue in a random order + for (request <- Utils.randomize(remoteRequests)) { + fetchRequestsSync.put(request) + } + + copiers = startCopiers(System.getProperty("spark.shuffle.copier.threads", "6").toInt) + logInfo("Started " + fetchRequestsSync.size + " remote gets in " + + Utils.getUsedTimeMs(startTime)) + + // Get Local Blocks + startTime = System.currentTimeMillis + getLocalBlocks() + logDebug("Got local blocks in " + Utils.getUsedTimeMs(startTime) + " ms") + } + + override def next(): (String, Option[Iterator[Any]]) = { + resultsGotten += 1 + val result = results.take() + // If all the results has been retrieved, copiers will exit automatically + (result.blockId, if (result.failed) None else Some(result.deserialize())) + } + } + // End of NettyBlockFetcherIterator +} diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala new file mode 100644 index 0000000000..13b98a51a1 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -0,0 +1,1046 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.storage + +import java.io.{InputStream, OutputStream} +import java.nio.{ByteBuffer, MappedByteBuffer} + +import scala.collection.mutable.{HashMap, ArrayBuffer, HashSet} + +import akka.actor.{ActorSystem, Cancellable, Props} +import scala.concurrent.{Await, Future} +import scala.concurrent.duration.Duration +import scala.concurrent.duration._ + +import it.unimi.dsi.fastutil.io.FastByteArrayOutputStream + +import org.apache.spark.{Logging, SparkEnv, SparkException} +import org.apache.spark.io.CompressionCodec +import org.apache.spark.network._ +import org.apache.spark.serializer.Serializer +import org.apache.spark.util._ + +import sun.nio.ch.DirectBuffer + + +private[spark] class BlockManager( + executorId: String, + actorSystem: ActorSystem, + val master: BlockManagerMaster, + val defaultSerializer: Serializer, + maxMemory: Long) + extends Logging { + + private class BlockInfo(val level: StorageLevel, val tellMaster: Boolean) { + @volatile var pending: Boolean = true + @volatile var size: Long = -1L + @volatile var initThread: Thread = null + @volatile var failed = false + + setInitThread() + + private def setInitThread() { + // Set current thread as init thread - waitForReady will not block this thread + // (in case there is non trivial initialization which ends up calling waitForReady as part of + // initialization itself) + this.initThread = Thread.currentThread() + } + + /** + * Wait for this BlockInfo to be marked as ready (i.e. block is finished writing). + * Return true if the block is available, false otherwise. + */ + def waitForReady(): Boolean = { + if (initThread != Thread.currentThread() && pending) { + synchronized { + while (pending) this.wait() + } + } + !failed + } + + /** Mark this BlockInfo as ready (i.e. block is finished writing) */ + def markReady(sizeInBytes: Long) { + assert (pending) + size = sizeInBytes + initThread = null + failed = false + initThread = null + pending = false + synchronized { + this.notifyAll() + } + } + + /** Mark this BlockInfo as ready but failed */ + def markFailure() { + assert (pending) + size = 0 + initThread = null + failed = true + initThread = null + pending = false + synchronized { + this.notifyAll() + } + } + } + + val shuffleBlockManager = new ShuffleBlockManager(this) + + private val blockInfo = new TimeStampedHashMap[String, BlockInfo] + + private[storage] val memoryStore: BlockStore = new MemoryStore(this, maxMemory) + private[storage] val diskStore: DiskStore = + new DiskStore(this, System.getProperty("spark.local.dir", System.getProperty("java.io.tmpdir"))) + + // If we use Netty for shuffle, start a new Netty-based shuffle sender service. + private val nettyPort: Int = { + val useNetty = System.getProperty("spark.shuffle.use.netty", "false").toBoolean + val nettyPortConfig = System.getProperty("spark.shuffle.sender.port", "0").toInt + if (useNetty) diskStore.startShuffleBlockSender(nettyPortConfig) else 0 + } + + val connectionManager = new ConnectionManager(0) + implicit val futureExecContext = connectionManager.futureExecContext + + val blockManagerId = BlockManagerId( + executorId, connectionManager.id.host, connectionManager.id.port, nettyPort) + + // Max megabytes of data to keep in flight per reducer (to avoid over-allocating memory + // for receiving shuffle outputs) + val maxBytesInFlight = + System.getProperty("spark.reducer.maxMbInFlight", "48").toLong * 1024 * 1024 + + // Whether to compress broadcast variables that are stored + val compressBroadcast = System.getProperty("spark.broadcast.compress", "true").toBoolean + // Whether to compress shuffle output that are stored + val compressShuffle = System.getProperty("spark.shuffle.compress", "true").toBoolean + // Whether to compress RDD partitions that are stored serialized + val compressRdds = System.getProperty("spark.rdd.compress", "false").toBoolean + + val heartBeatFrequency = BlockManager.getHeartBeatFrequencyFromSystemProperties + + val hostPort = Utils.localHostPort() + + val slaveActor = actorSystem.actorOf(Props(new BlockManagerSlaveActor(this)), + name = "BlockManagerActor" + BlockManager.ID_GENERATOR.next) + + // Pending reregistration action being executed asynchronously or null if none + // is pending. Accesses should synchronize on asyncReregisterLock. + var asyncReregisterTask: Future[Unit] = null + val asyncReregisterLock = new Object + + private def heartBeat() { + if (!master.sendHeartBeat(blockManagerId)) { + reregister() + } + } + + var heartBeatTask: Cancellable = null + + val metadataCleaner = new MetadataCleaner("BlockManager", this.dropOldBlocks) + initialize() + + // The compression codec to use. Note that the "lazy" val is necessary because we want to delay + // the initialization of the compression codec until it is first used. The reason is that a Spark + // program could be using a user-defined codec in a third party jar, which is loaded in + // Executor.updateDependencies. When the BlockManager is initialized, user level jars hasn't been + // loaded yet. + private lazy val compressionCodec: CompressionCodec = CompressionCodec.createCodec() + + /** + * Construct a BlockManager with a memory limit set based on system properties. + */ + def this(execId: String, actorSystem: ActorSystem, master: BlockManagerMaster, + serializer: Serializer) = { + this(execId, actorSystem, master, serializer, BlockManager.getMaxMemoryFromSystemProperties) + } + + /** + * Initialize the BlockManager. Register to the BlockManagerMaster, and start the + * BlockManagerWorker actor. + */ + private def initialize() { + master.registerBlockManager(blockManagerId, maxMemory, slaveActor) + BlockManagerWorker.startBlockManagerWorker(this) + if (!BlockManager.getDisableHeartBeatsForTesting) { + heartBeatTask = actorSystem.scheduler.schedule(0.seconds, heartBeatFrequency.milliseconds) { + heartBeat() + } + } + } + + /** + * Report all blocks to the BlockManager again. This may be necessary if we are dropped + * by the BlockManager and come back or if we become capable of recovering blocks on disk after + * an executor crash. + * + * This function deliberately fails silently if the master returns false (indicating that + * the slave needs to reregister). The error condition will be detected again by the next + * heart beat attempt or new block registration and another try to reregister all blocks + * will be made then. + */ + private def reportAllBlocks() { + logInfo("Reporting " + blockInfo.size + " blocks to the master.") + for ((blockId, info) <- blockInfo) { + if (!tryToReportBlockStatus(blockId, info)) { + logError("Failed to report " + blockId + " to master; giving up.") + return + } + } + } + + /** + * Reregister with the master and report all blocks to it. This will be called by the heart beat + * thread if our heartbeat to the block amnager indicates that we were not registered. + * + * Note that this method must be called without any BlockInfo locks held. + */ + def reregister() { + // TODO: We might need to rate limit reregistering. + logInfo("BlockManager reregistering with master") + master.registerBlockManager(blockManagerId, maxMemory, slaveActor) + reportAllBlocks() + } + + /** + * Reregister with the master sometime soon. + */ + def asyncReregister() { + asyncReregisterLock.synchronized { + if (asyncReregisterTask == null) { + asyncReregisterTask = Future[Unit] { + reregister() + asyncReregisterLock.synchronized { + asyncReregisterTask = null + } + } + } + } + } + + /** + * For testing. Wait for any pending asynchronous reregistration; otherwise, do nothing. + */ + def waitForAsyncReregister() { + val task = asyncReregisterTask + if (task != null) { + Await.ready(task, Duration.Inf) + } + } + + /** + * Get storage level of local block. If no info exists for the block, then returns null. + */ + def getLevel(blockId: String): StorageLevel = blockInfo.get(blockId).map(_.level).orNull + + /** + * Tell the master about the current storage status of a block. This will send a block update + * message reflecting the current status, *not* the desired storage level in its block info. + * For example, a block with MEMORY_AND_DISK set might have fallen out to be only on disk. + * + * droppedMemorySize exists to account for when block is dropped from memory to disk (so it is still valid). + * This ensures that update in master will compensate for the increase in memory on slave. + */ + def reportBlockStatus(blockId: String, info: BlockInfo, droppedMemorySize: Long = 0L) { + val needReregister = !tryToReportBlockStatus(blockId, info, droppedMemorySize) + if (needReregister) { + logInfo("Got told to reregister updating block " + blockId) + // Reregistering will report our new block for free. + asyncReregister() + } + logDebug("Told master about block " + blockId) + } + + /** + * Actually send a UpdateBlockInfo message. Returns the mater's response, + * which will be true if the block was successfully recorded and false if + * the slave needs to re-register. + */ + private def tryToReportBlockStatus(blockId: String, info: BlockInfo, droppedMemorySize: Long = 0L): Boolean = { + val (curLevel, inMemSize, onDiskSize, tellMaster) = info.synchronized { + info.level match { + case null => + (StorageLevel.NONE, 0L, 0L, false) + case level => + val inMem = level.useMemory && memoryStore.contains(blockId) + val onDisk = level.useDisk && diskStore.contains(blockId) + val storageLevel = StorageLevel(onDisk, inMem, level.deserialized, level.replication) + val memSize = if (inMem) memoryStore.getSize(blockId) else droppedMemorySize + val diskSize = if (onDisk) diskStore.getSize(blockId) else 0L + (storageLevel, memSize, diskSize, info.tellMaster) + } + } + + if (tellMaster) { + master.updateBlockInfo(blockManagerId, blockId, curLevel, inMemSize, onDiskSize) + } else { + true + } + } + + /** + * Get locations of an array of blocks. + */ + def getLocationBlockIds(blockIds: Array[String]): Array[Seq[BlockManagerId]] = { + val startTimeMs = System.currentTimeMillis + val locations = master.getLocations(blockIds).toArray + logDebug("Got multiple block location in " + Utils.getUsedTimeMs(startTimeMs)) + locations + } + + /** + * A short-circuited method to get blocks directly from disk. This is used for getting + * shuffle blocks. It is safe to do so without a lock on block info since disk store + * never deletes (recent) items. + */ + def getLocalFromDisk(blockId: String, serializer: Serializer): Option[Iterator[Any]] = { + diskStore.getValues(blockId, serializer).orElse( + sys.error("Block " + blockId + " not found on disk, though it should be")) + } + + /** + * Get block from local block manager. + */ + def getLocal(blockId: String): Option[Iterator[Any]] = { + logDebug("Getting local block " + blockId) + val info = blockInfo.get(blockId).orNull + if (info != null) { + info.synchronized { + + // In the another thread is writing the block, wait for it to become ready. + if (!info.waitForReady()) { + // If we get here, the block write failed. + logWarning("Block " + blockId + " was marked as failure.") + return None + } + + val level = info.level + logDebug("Level for block " + blockId + " is " + level) + + // Look for the block in memory + if (level.useMemory) { + logDebug("Getting block " + blockId + " from memory") + memoryStore.getValues(blockId) match { + case Some(iterator) => + return Some(iterator) + case None => + logDebug("Block " + blockId + " not found in memory") + } + } + + // Look for block on disk, potentially loading it back into memory if required + if (level.useDisk) { + logDebug("Getting block " + blockId + " from disk") + if (level.useMemory && level.deserialized) { + diskStore.getValues(blockId) match { + case Some(iterator) => + // Put the block back in memory before returning it + // TODO: Consider creating a putValues that also takes in a iterator ? + val elements = new ArrayBuffer[Any] + elements ++= iterator + memoryStore.putValues(blockId, elements, level, true).data match { + case Left(iterator2) => + return Some(iterator2) + case _ => + throw new Exception("Memory store did not return back an iterator") + } + case None => + throw new Exception("Block " + blockId + " not found on disk, though it should be") + } + } else if (level.useMemory && !level.deserialized) { + // Read it as a byte buffer into memory first, then return it + diskStore.getBytes(blockId) match { + case Some(bytes) => + // Put a copy of the block back in memory before returning it. Note that we can't + // put the ByteBuffer returned by the disk store as that's a memory-mapped file. + // The use of rewind assumes this. + assert (0 == bytes.position()) + val copyForMemory = ByteBuffer.allocate(bytes.limit) + copyForMemory.put(bytes) + memoryStore.putBytes(blockId, copyForMemory, level) + bytes.rewind() + return Some(dataDeserialize(blockId, bytes)) + case None => + throw new Exception("Block " + blockId + " not found on disk, though it should be") + } + } else { + diskStore.getValues(blockId) match { + case Some(iterator) => + return Some(iterator) + case None => + throw new Exception("Block " + blockId + " not found on disk, though it should be") + } + } + } + } + } else { + logDebug("Block " + blockId + " not registered locally") + } + return None + } + + /** + * Get block from the local block manager as serialized bytes. + */ + def getLocalBytes(blockId: String): Option[ByteBuffer] = { + // TODO: This whole thing is very similar to getLocal; we need to refactor it somehow + logDebug("Getting local block " + blockId + " as bytes") + + // As an optimization for map output fetches, if the block is for a shuffle, return it + // without acquiring a lock; the disk store never deletes (recent) items so this should work + if (ShuffleBlockManager.isShuffle(blockId)) { + return diskStore.getBytes(blockId) match { + case Some(bytes) => + Some(bytes) + case None => + throw new Exception("Block " + blockId + " not found on disk, though it should be") + } + } + + val info = blockInfo.get(blockId).orNull + if (info != null) { + info.synchronized { + + // In the another thread is writing the block, wait for it to become ready. + if (!info.waitForReady()) { + // If we get here, the block write failed. + logWarning("Block " + blockId + " was marked as failure.") + return None + } + + val level = info.level + logDebug("Level for block " + blockId + " is " + level) + + // Look for the block in memory + if (level.useMemory) { + logDebug("Getting block " + blockId + " from memory") + memoryStore.getBytes(blockId) match { + case Some(bytes) => + return Some(bytes) + case None => + logDebug("Block " + blockId + " not found in memory") + } + } + + // Look for block on disk + if (level.useDisk) { + // Read it as a byte buffer into memory first, then return it + diskStore.getBytes(blockId) match { + case Some(bytes) => + assert (0 == bytes.position()) + if (level.useMemory) { + if (level.deserialized) { + memoryStore.putBytes(blockId, bytes, level) + } else { + // The memory store will hang onto the ByteBuffer, so give it a copy instead of + // the memory-mapped file buffer we got from the disk store + val copyForMemory = ByteBuffer.allocate(bytes.limit) + copyForMemory.put(bytes) + memoryStore.putBytes(blockId, copyForMemory, level) + } + } + bytes.rewind() + return Some(bytes) + case None => + throw new Exception("Block " + blockId + " not found on disk, though it should be") + } + } + } + } else { + logDebug("Block " + blockId + " not registered locally") + } + return None + } + + /** + * Get block from remote block managers. + */ + def getRemote(blockId: String): Option[Iterator[Any]] = { + if (blockId == null) { + throw new IllegalArgumentException("Block Id is null") + } + logDebug("Getting remote block " + blockId) + // Get locations of block + val locations = master.getLocations(blockId) + + // Get block from remote locations + for (loc <- locations) { + logDebug("Getting remote block " + blockId + " from " + loc) + val data = BlockManagerWorker.syncGetBlock( + GetBlock(blockId), ConnectionManagerId(loc.host, loc.port)) + if (data != null) { + return Some(dataDeserialize(blockId, data)) + } + logDebug("The value of block " + blockId + " is null") + } + logDebug("Block " + blockId + " not found") + return None + } + + /** + * Get a block from the block manager (either local or remote). + */ + def get(blockId: String): Option[Iterator[Any]] = { + getLocal(blockId).orElse(getRemote(blockId)) + } + + /** + * Get multiple blocks from local and remote block manager using their BlockManagerIds. Returns + * an Iterator of (block ID, value) pairs so that clients may handle blocks in a pipelined + * fashion as they're received. Expects a size in bytes to be provided for each block fetched, + * so that we can control the maxMegabytesInFlight for the fetch. + */ + def getMultiple( + blocksByAddress: Seq[(BlockManagerId, Seq[(String, Long)])], serializer: Serializer) + : BlockFetcherIterator = { + + val iter = + if (System.getProperty("spark.shuffle.use.netty", "false").toBoolean) { + new BlockFetcherIterator.NettyBlockFetcherIterator(this, blocksByAddress, serializer) + } else { + new BlockFetcherIterator.BasicBlockFetcherIterator(this, blocksByAddress, serializer) + } + + iter.initialize() + iter + } + + def put(blockId: String, values: Iterator[Any], level: StorageLevel, tellMaster: Boolean) + : Long = { + val elements = new ArrayBuffer[Any] + elements ++= values + put(blockId, elements, level, tellMaster) + } + + /** + * A short circuited method to get a block writer that can write data directly to disk. + * This is currently used for writing shuffle files out. Callers should handle error + * cases. + */ + def getDiskBlockWriter(blockId: String, serializer: Serializer, bufferSize: Int) + : BlockObjectWriter = { + val writer = diskStore.getBlockWriter(blockId, serializer, bufferSize) + writer.registerCloseEventHandler(() => { + val myInfo = new BlockInfo(StorageLevel.DISK_ONLY, false) + blockInfo.put(blockId, myInfo) + myInfo.markReady(writer.size()) + }) + writer + } + + /** + * Put a new block of values to the block manager. Returns its (estimated) size in bytes. + */ + def put(blockId: String, values: ArrayBuffer[Any], level: StorageLevel, + tellMaster: Boolean = true) : Long = { + + if (blockId == null) { + throw new IllegalArgumentException("Block Id is null") + } + if (values == null) { + throw new IllegalArgumentException("Values is null") + } + if (level == null || !level.isValid) { + throw new IllegalArgumentException("Storage level is null or invalid") + } + + // Remember the block's storage level so that we can correctly drop it to disk if it needs + // to be dropped right after it got put into memory. Note, however, that other threads will + // not be able to get() this block until we call markReady on its BlockInfo. + val myInfo = { + val tinfo = new BlockInfo(level, tellMaster) + // Do atomically ! + val oldBlockOpt = blockInfo.putIfAbsent(blockId, tinfo) + + if (oldBlockOpt.isDefined) { + if (oldBlockOpt.get.waitForReady()) { + logWarning("Block " + blockId + " already exists on this machine; not re-adding it") + return oldBlockOpt.get.size + } + + // TODO: So the block info exists - but previous attempt to load it (?) failed. What do we do now ? Retry on it ? + oldBlockOpt.get + } else { + tinfo + } + } + + val startTimeMs = System.currentTimeMillis + + // If we need to replicate the data, we'll want access to the values, but because our + // put will read the whole iterator, there will be no values left. For the case where + // the put serializes data, we'll remember the bytes, above; but for the case where it + // doesn't, such as deserialized storage, let's rely on the put returning an Iterator. + var valuesAfterPut: Iterator[Any] = null + + // Ditto for the bytes after the put + var bytesAfterPut: ByteBuffer = null + + // Size of the block in bytes (to return to caller) + var size = 0L + + myInfo.synchronized { + logTrace("Put for block " + blockId + " took " + Utils.getUsedTimeMs(startTimeMs) + + " to get into synchronized block") + + var marked = false + try { + if (level.useMemory) { + // Save it just to memory first, even if it also has useDisk set to true; we will later + // drop it to disk if the memory store can't hold it. + val res = memoryStore.putValues(blockId, values, level, true) + size = res.size + res.data match { + case Right(newBytes) => bytesAfterPut = newBytes + case Left(newIterator) => valuesAfterPut = newIterator + } + } else { + // Save directly to disk. + // Don't get back the bytes unless we replicate them. + val askForBytes = level.replication > 1 + val res = diskStore.putValues(blockId, values, level, askForBytes) + size = res.size + res.data match { + case Right(newBytes) => bytesAfterPut = newBytes + case _ => + } + } + + // Now that the block is in either the memory or disk store, let other threads read it, + // and tell the master about it. + marked = true + myInfo.markReady(size) + if (tellMaster) { + reportBlockStatus(blockId, myInfo) + } + } finally { + // If we failed at putting the block to memory/disk, notify other possible readers + // that it has failed, and then remove it from the block info map. + if (! marked) { + // Note that the remove must happen before markFailure otherwise another thread + // could've inserted a new BlockInfo before we remove it. + blockInfo.remove(blockId) + myInfo.markFailure() + logWarning("Putting block " + blockId + " failed") + } + } + } + logDebug("Put block " + blockId + " locally took " + Utils.getUsedTimeMs(startTimeMs)) + + // Replicate block if required + if (level.replication > 1) { + val remoteStartTime = System.currentTimeMillis + // Serialize the block if not already done + if (bytesAfterPut == null) { + if (valuesAfterPut == null) { + throw new SparkException( + "Underlying put returned neither an Iterator nor bytes! This shouldn't happen.") + } + bytesAfterPut = dataSerialize(blockId, valuesAfterPut) + } + replicate(blockId, bytesAfterPut, level) + logDebug("Put block " + blockId + " remotely took " + Utils.getUsedTimeMs(remoteStartTime)) + } + BlockManager.dispose(bytesAfterPut) + + return size + } + + + /** + * Put a new block of serialized bytes to the block manager. + */ + def putBytes( + blockId: String, bytes: ByteBuffer, level: StorageLevel, tellMaster: Boolean = true) { + + if (blockId == null) { + throw new IllegalArgumentException("Block Id is null") + } + if (bytes == null) { + throw new IllegalArgumentException("Bytes is null") + } + if (level == null || !level.isValid) { + throw new IllegalArgumentException("Storage level is null or invalid") + } + + // Remember the block's storage level so that we can correctly drop it to disk if it needs + // to be dropped right after it got put into memory. Note, however, that other threads will + // not be able to get() this block until we call markReady on its BlockInfo. + val myInfo = { + val tinfo = new BlockInfo(level, tellMaster) + // Do atomically ! + val oldBlockOpt = blockInfo.putIfAbsent(blockId, tinfo) + + if (oldBlockOpt.isDefined) { + if (oldBlockOpt.get.waitForReady()) { + logWarning("Block " + blockId + " already exists on this machine; not re-adding it") + return + } + + // TODO: So the block info exists - but previous attempt to load it (?) failed. What do we do now ? Retry on it ? + oldBlockOpt.get + } else { + tinfo + } + } + + val startTimeMs = System.currentTimeMillis + + // Initiate the replication before storing it locally. This is faster as + // data is already serialized and ready for sending + val replicationFuture = if (level.replication > 1) { + val bufferView = bytes.duplicate() // Doesn't copy the bytes, just creates a wrapper + Future { + replicate(blockId, bufferView, level) + } + } else { + null + } + + myInfo.synchronized { + logDebug("PutBytes for block " + blockId + " took " + Utils.getUsedTimeMs(startTimeMs) + + " to get into synchronized block") + + var marked = false + try { + if (level.useMemory) { + // Store it only in memory at first, even if useDisk is also set to true + bytes.rewind() + memoryStore.putBytes(blockId, bytes, level) + } else { + bytes.rewind() + diskStore.putBytes(blockId, bytes, level) + } + + // assert (0 == bytes.position(), "" + bytes) + + // Now that the block is in either the memory or disk store, let other threads read it, + // and tell the master about it. + marked = true + myInfo.markReady(bytes.limit) + if (tellMaster) { + reportBlockStatus(blockId, myInfo) + } + } finally { + // If we failed at putting the block to memory/disk, notify other possible readers + // that it has failed, and then remove it from the block info map. + if (! marked) { + // Note that the remove must happen before markFailure otherwise another thread + // could've inserted a new BlockInfo before we remove it. + blockInfo.remove(blockId) + myInfo.markFailure() + logWarning("Putting block " + blockId + " failed") + } + } + } + + // If replication had started, then wait for it to finish + if (level.replication > 1) { + Await.ready(replicationFuture, Duration.Inf) + } + + if (level.replication > 1) { + logDebug("PutBytes for block " + blockId + " with replication took " + + Utils.getUsedTimeMs(startTimeMs)) + } else { + logDebug("PutBytes for block " + blockId + " without replication took " + + Utils.getUsedTimeMs(startTimeMs)) + } + } + + /** + * Replicate block to another node. + */ + var cachedPeers: Seq[BlockManagerId] = null + private def replicate(blockId: String, data: ByteBuffer, level: StorageLevel) { + val tLevel = StorageLevel(level.useDisk, level.useMemory, level.deserialized, 1) + if (cachedPeers == null) { + cachedPeers = master.getPeers(blockManagerId, level.replication - 1) + } + for (peer: BlockManagerId <- cachedPeers) { + val start = System.nanoTime + data.rewind() + logDebug("Try to replicate BlockId " + blockId + " once; The size of the data is " + + data.limit() + " Bytes. To node: " + peer) + if (!BlockManagerWorker.syncPutBlock(PutBlock(blockId, data, tLevel), + new ConnectionManagerId(peer.host, peer.port))) { + logError("Failed to call syncPutBlock to " + peer) + } + logDebug("Replicated BlockId " + blockId + " once used " + + (System.nanoTime - start) / 1e6 + " s; The size of the data is " + + data.limit() + " bytes.") + } + } + + /** + * Read a block consisting of a single object. + */ + def getSingle(blockId: String): Option[Any] = { + get(blockId).map(_.next()) + } + + /** + * Write a block consisting of a single object. + */ + def putSingle(blockId: String, value: Any, level: StorageLevel, tellMaster: Boolean = true) { + put(blockId, Iterator(value), level, tellMaster) + } + + /** + * Drop a block from memory, possibly putting it on disk if applicable. Called when the memory + * store reaches its limit and needs to free up space. + */ + def dropFromMemory(blockId: String, data: Either[ArrayBuffer[Any], ByteBuffer]) { + logInfo("Dropping block " + blockId + " from memory") + val info = blockInfo.get(blockId).orNull + if (info != null) { + info.synchronized { + // required ? As of now, this will be invoked only for blocks which are ready + // But in case this changes in future, adding for consistency sake. + if (! info.waitForReady() ) { + // If we get here, the block write failed. + logWarning("Block " + blockId + " was marked as failure. Nothing to drop") + return + } + + val level = info.level + if (level.useDisk && !diskStore.contains(blockId)) { + logInfo("Writing block " + blockId + " to disk") + data match { + case Left(elements) => + diskStore.putValues(blockId, elements, level, false) + case Right(bytes) => + diskStore.putBytes(blockId, bytes, level) + } + } + val droppedMemorySize = if (memoryStore.contains(blockId)) memoryStore.getSize(blockId) else 0L + val blockWasRemoved = memoryStore.remove(blockId) + if (!blockWasRemoved) { + logWarning("Block " + blockId + " could not be dropped from memory as it does not exist") + } + if (info.tellMaster) { + reportBlockStatus(blockId, info, droppedMemorySize) + } + if (!level.useDisk) { + // The block is completely gone from this node; forget it so we can put() it again later. + blockInfo.remove(blockId) + } + } + } else { + // The block has already been dropped + } + } + + /** + * Remove all blocks belonging to the given RDD. + * @return The number of blocks removed. + */ + def removeRdd(rddId: Int): Int = { + // TODO: Instead of doing a linear scan on the blockInfo map, create another map that maps + // from RDD.id to blocks. + logInfo("Removing RDD " + rddId) + val rddPrefix = "rdd_" + rddId + "_" + val blocksToRemove = blockInfo.filter(_._1.startsWith(rddPrefix)).map(_._1) + blocksToRemove.foreach(blockId => removeBlock(blockId, false)) + blocksToRemove.size + } + + /** + * Remove a block from both memory and disk. + */ + def removeBlock(blockId: String, tellMaster: Boolean = true) { + logInfo("Removing block " + blockId) + val info = blockInfo.get(blockId).orNull + if (info != null) info.synchronized { + // Removals are idempotent in disk store and memory store. At worst, we get a warning. + val removedFromMemory = memoryStore.remove(blockId) + val removedFromDisk = diskStore.remove(blockId) + if (!removedFromMemory && !removedFromDisk) { + logWarning("Block " + blockId + " could not be removed as it was not found in either " + + "the disk or memory store") + } + blockInfo.remove(blockId) + if (tellMaster && info.tellMaster) { + reportBlockStatus(blockId, info) + } + } else { + // The block has already been removed; do nothing. + logWarning("Asked to remove block " + blockId + ", which does not exist") + } + } + + def dropOldBlocks(cleanupTime: Long) { + logInfo("Dropping blocks older than " + cleanupTime) + val iterator = blockInfo.internalMap.entrySet().iterator() + while (iterator.hasNext) { + val entry = iterator.next() + val (id, info, time) = (entry.getKey, entry.getValue._1, entry.getValue._2) + if (time < cleanupTime) { + info.synchronized { + val level = info.level + if (level.useMemory) { + memoryStore.remove(id) + } + if (level.useDisk) { + diskStore.remove(id) + } + iterator.remove() + logInfo("Dropped block " + id) + } + reportBlockStatus(id, info) + } + } + } + + def shouldCompress(blockId: String): Boolean = { + if (ShuffleBlockManager.isShuffle(blockId)) { + compressShuffle + } else if (blockId.startsWith("broadcast_")) { + compressBroadcast + } else if (blockId.startsWith("rdd_")) { + compressRdds + } else { + false // Won't happen in a real cluster, but it can in tests + } + } + + /** + * Wrap an output stream for compression if block compression is enabled for its block type + */ + def wrapForCompression(blockId: String, s: OutputStream): OutputStream = { + if (shouldCompress(blockId)) compressionCodec.compressedOutputStream(s) else s + } + + /** + * Wrap an input stream for compression if block compression is enabled for its block type + */ + def wrapForCompression(blockId: String, s: InputStream): InputStream = { + if (shouldCompress(blockId)) compressionCodec.compressedInputStream(s) else s + } + + def dataSerialize( + blockId: String, + values: Iterator[Any], + serializer: Serializer = defaultSerializer): ByteBuffer = { + val byteStream = new FastByteArrayOutputStream(4096) + val ser = serializer.newInstance() + ser.serializeStream(wrapForCompression(blockId, byteStream)).writeAll(values).close() + byteStream.trim() + ByteBuffer.wrap(byteStream.array) + } + + /** + * Deserializes a ByteBuffer into an iterator of values and disposes of it when the end of + * the iterator is reached. + */ + def dataDeserialize( + blockId: String, + bytes: ByteBuffer, + serializer: Serializer = defaultSerializer): Iterator[Any] = { + bytes.rewind() + val stream = wrapForCompression(blockId, new ByteBufferInputStream(bytes, true)) + serializer.newInstance().deserializeStream(stream).asIterator + } + + def stop() { + if (heartBeatTask != null) { + heartBeatTask.cancel() + } + connectionManager.stop() + actorSystem.stop(slaveActor) + blockInfo.clear() + memoryStore.clear() + diskStore.clear() + metadataCleaner.cancel() + logInfo("BlockManager stopped") + } +} + + +private[spark] object BlockManager extends Logging { + + val ID_GENERATOR = new IdGenerator + + def getMaxMemoryFromSystemProperties: Long = { + val memoryFraction = System.getProperty("spark.storage.memoryFraction", "0.66").toDouble + (Runtime.getRuntime.maxMemory * memoryFraction).toLong + } + + def getHeartBeatFrequencyFromSystemProperties: Long = + + System.getProperty("spark.storage.blockManagerTimeoutIntervalMs", "60000").toLong / 4 + + def getDisableHeartBeatsForTesting: Boolean = + System.getProperty("spark.test.disableBlockManagerHeartBeat", "false").toBoolean + + /** + * Attempt to clean up a ByteBuffer if it is memory-mapped. This uses an *unsafe* Sun API that + * might cause errors if one attempts to read from the unmapped buffer, but it's better than + * waiting for the GC to find it because that could lead to huge numbers of open files. There's + * unfortunately no standard API to do this. + */ + def dispose(buffer: ByteBuffer) { + if (buffer != null && buffer.isInstanceOf[MappedByteBuffer]) { + logTrace("Unmapping " + buffer) + if (buffer.asInstanceOf[DirectBuffer].cleaner() != null) { + buffer.asInstanceOf[DirectBuffer].cleaner().clean() + } + } + } + + def blockIdsToBlockManagers( + blockIds: Array[String], + env: SparkEnv, + blockManagerMaster: BlockManagerMaster = null) + : Map[String, Seq[BlockManagerId]] = + { + // env == null and blockManagerMaster != null is used in tests + assert (env != null || blockManagerMaster != null) + val blockLocations: Seq[Seq[BlockManagerId]] = if (env != null) { + env.blockManager.getLocationBlockIds(blockIds) + } else { + blockManagerMaster.getLocations(blockIds) + } + + val blockManagers = new HashMap[String, Seq[BlockManagerId]] + for (i <- 0 until blockIds.length) { + blockManagers(blockIds(i)) = blockLocations(i) + } + blockManagers.toMap + } + + def blockIdsToExecutorIds( + blockIds: Array[String], + env: SparkEnv, + blockManagerMaster: BlockManagerMaster = null) + : Map[String, Seq[String]] = + { + blockIdsToBlockManagers(blockIds, env, blockManagerMaster).mapValues(s => s.map(_.executorId)) + } + + def blockIdsToHosts( + blockIds: Array[String], + env: SparkEnv, + blockManagerMaster: BlockManagerMaster = null) + : Map[String, Seq[String]] = + { + blockIdsToBlockManagers(blockIds, env, blockManagerMaster).mapValues(s => s.map(_.host)) + } +} diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala new file mode 100644 index 0000000000..74207f59af --- /dev/null +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala @@ -0,0 +1,118 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.storage + +import java.io.{Externalizable, IOException, ObjectInput, ObjectOutput} +import java.util.concurrent.ConcurrentHashMap +import org.apache.spark.util.Utils + +/** + * This class represent an unique identifier for a BlockManager. + * The first 2 constructors of this class is made private to ensure that + * BlockManagerId objects can be created only using the apply method in + * the companion object. This allows de-duplication of ID objects. + * Also, constructor parameters are private to ensure that parameters cannot + * be modified from outside this class. + */ +private[spark] class BlockManagerId private ( + private var executorId_ : String, + private var host_ : String, + private var port_ : Int, + private var nettyPort_ : Int + ) extends Externalizable { + + private def this() = this(null, null, 0, 0) // For deserialization only + + def executorId: String = executorId_ + + if (null != host_){ + Utils.checkHost(host_, "Expected hostname") + assert (port_ > 0) + } + + def hostPort: String = { + // DEBUG code + Utils.checkHost(host) + assert (port > 0) + + host + ":" + port + } + + def host: String = host_ + + def port: Int = port_ + + def nettyPort: Int = nettyPort_ + + override def writeExternal(out: ObjectOutput) { + out.writeUTF(executorId_) + out.writeUTF(host_) + out.writeInt(port_) + out.writeInt(nettyPort_) + } + + override def readExternal(in: ObjectInput) { + executorId_ = in.readUTF() + host_ = in.readUTF() + port_ = in.readInt() + nettyPort_ = in.readInt() + } + + @throws(classOf[IOException]) + private def readResolve(): Object = BlockManagerId.getCachedBlockManagerId(this) + + override def toString = "BlockManagerId(%s, %s, %d, %d)".format(executorId, host, port, nettyPort) + + override def hashCode: Int = (executorId.hashCode * 41 + host.hashCode) * 41 + port + nettyPort + + override def equals(that: Any) = that match { + case id: BlockManagerId => + executorId == id.executorId && port == id.port && host == id.host && nettyPort == id.nettyPort + case _ => + false + } +} + + +private[spark] object BlockManagerId { + + /** + * Returns a [[org.apache.spark.storage.BlockManagerId]] for the given configuraiton. + * + * @param execId ID of the executor. + * @param host Host name of the block manager. + * @param port Port of the block manager. + * @param nettyPort Optional port for the Netty-based shuffle sender. + * @return A new [[org.apache.spark.storage.BlockManagerId]]. + */ + def apply(execId: String, host: String, port: Int, nettyPort: Int) = + getCachedBlockManagerId(new BlockManagerId(execId, host, port, nettyPort)) + + def apply(in: ObjectInput) = { + val obj = new BlockManagerId() + obj.readExternal(in) + getCachedBlockManagerId(obj) + } + + val blockManagerIdCache = new ConcurrentHashMap[BlockManagerId, BlockManagerId]() + + def getCachedBlockManagerId(id: BlockManagerId): BlockManagerId = { + blockManagerIdCache.putIfAbsent(id, id) + blockManagerIdCache.get(id) + } +} diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala new file mode 100644 index 0000000000..0c977f05d1 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala @@ -0,0 +1,187 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.storage + +import java.io._ +import java.util.{HashMap => JHashMap} + +import scala.collection.JavaConverters._ +import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} +import scala.util.Random + +import akka.actor.{Actor, ActorRef, ActorSystem, Props} +import scala.concurrent.Await +import scala.concurrent.Future +import scala.concurrent.ExecutionContext.Implicits.global + +import akka.pattern.ask +import scala.concurrent.duration._ + +import org.apache.spark.{Logging, SparkException} +import org.apache.spark.storage.BlockManagerMessages._ + +private[spark] class BlockManagerMaster(var driverActor: ActorRef) extends Logging { + + val AKKA_RETRY_ATTEMPTS: Int = System.getProperty("spark.akka.num.retries", "3").toInt + val AKKA_RETRY_INTERVAL_MS: Int = System.getProperty("spark.akka.retry.wait", "3000").toInt + + val DRIVER_AKKA_ACTOR_NAME = "BlockManagerMaster" + + val timeout = Duration.create(System.getProperty("spark.akka.askTimeout", "10").toLong, "seconds") + + /** Remove a dead executor from the driver actor. This is only called on the driver side. */ + def removeExecutor(execId: String) { + tell(RemoveExecutor(execId)) + logInfo("Removed " + execId + " successfully in removeExecutor") + } + + /** + * Send the driver actor a heart beat from the slave. Returns true if everything works out, + * false if the driver does not know about the given block manager, which means the block + * manager should re-register. + */ + def sendHeartBeat(blockManagerId: BlockManagerId): Boolean = { + askDriverWithReply[Boolean](HeartBeat(blockManagerId)) + } + + /** Register the BlockManager's id with the driver. */ + def registerBlockManager( + blockManagerId: BlockManagerId, maxMemSize: Long, slaveActor: ActorRef) { + logInfo("Trying to register BlockManager") + tell(RegisterBlockManager(blockManagerId, maxMemSize, slaveActor)) + logInfo("Registered BlockManager") + } + + def updateBlockInfo( + blockManagerId: BlockManagerId, + blockId: String, + storageLevel: StorageLevel, + memSize: Long, + diskSize: Long): Boolean = { + val res = askDriverWithReply[Boolean]( + UpdateBlockInfo(blockManagerId, blockId, storageLevel, memSize, diskSize)) + logInfo("Updated info of block " + blockId) + res + } + + /** Get locations of the blockId from the driver */ + def getLocations(blockId: String): Seq[BlockManagerId] = { + askDriverWithReply[Seq[BlockManagerId]](GetLocations(blockId)) + } + + /** Get locations of multiple blockIds from the driver */ + def getLocations(blockIds: Array[String]): Seq[Seq[BlockManagerId]] = { + askDriverWithReply[Seq[Seq[BlockManagerId]]](GetLocationsMultipleBlockIds(blockIds)) + } + + /** Get ids of other nodes in the cluster from the driver */ + def getPeers(blockManagerId: BlockManagerId, numPeers: Int): Seq[BlockManagerId] = { + val result = askDriverWithReply[Seq[BlockManagerId]](GetPeers(blockManagerId, numPeers)) + if (result.length != numPeers) { + throw new SparkException( + "Error getting peers, only got " + result.size + " instead of " + numPeers) + } + result + } + + /** + * Remove a block from the slaves that have it. This can only be used to remove + * blocks that the driver knows about. + */ + def removeBlock(blockId: String) { + askDriverWithReply(RemoveBlock(blockId)) + } + + /** + * Remove all blocks belonging to the given RDD. + */ + def removeRdd(rddId: Int, blocking: Boolean) { + val future = askDriverWithReply[Future[Seq[Int]]](RemoveRdd(rddId)) + future onFailure { + case e: Throwable => logError("Failed to remove RDD " + rddId, e) + } + if (blocking) { + Await.result(future, timeout) + } + } + + /** + * Return the memory status for each block manager, in the form of a map from + * the block manager's id to two long values. The first value is the maximum + * amount of memory allocated for the block manager, while the second is the + * amount of remaining memory. + */ + def getMemoryStatus: Map[BlockManagerId, (Long, Long)] = { + askDriverWithReply[Map[BlockManagerId, (Long, Long)]](GetMemoryStatus) + } + + def getStorageStatus: Array[StorageStatus] = { + askDriverWithReply[Array[StorageStatus]](GetStorageStatus) + } + + /** Stop the driver actor, called only on the Spark driver node */ + def stop() { + if (driverActor != null) { + tell(StopBlockManagerMaster) + driverActor = null + logInfo("BlockManagerMaster stopped") + } + } + + /** Send a one-way message to the master actor, to which we expect it to reply with true. */ + private def tell(message: Any) { + if (!askDriverWithReply[Boolean](message)) { + throw new SparkException("BlockManagerMasterActor returned false, expected true.") + } + } + + /** + * Send a message to the driver actor and get its result within a default timeout, or + * throw a SparkException if this fails. + */ + private def askDriverWithReply[T](message: Any): T = { + // TODO: Consider removing multiple attempts + if (driverActor == null) { + throw new SparkException("Error sending message to BlockManager as driverActor is null " + + "[message = " + message + "]") + } + var attempts = 0 + var lastException: Exception = null + while (attempts < AKKA_RETRY_ATTEMPTS) { + attempts += 1 + try { + val future = driverActor.ask(message)(timeout) + val result = Await.result(future, timeout) + if (result == null) { + throw new SparkException("BlockManagerMaster returned null") + } + return result.asInstanceOf[T] + } catch { + case ie: InterruptedException => throw ie + case e: Exception => + lastException = e + logWarning("Error sending message to BlockManagerMaster in " + attempts + " attempts", e) + } + Thread.sleep(AKKA_RETRY_INTERVAL_MS) + } + + throw new SparkException( + "Error sending message to BlockManagerMaster [message = " + message + "]", lastException) + } + +} diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala new file mode 100644 index 0000000000..3776951782 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala @@ -0,0 +1,406 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.storage + +import java.util.{HashMap => JHashMap} + +import scala.collection.mutable +import scala.collection.JavaConversions._ + +import akka.actor.{Actor, ActorRef, Cancellable} +import akka.pattern.ask + +import scala.concurrent.duration._ +import scala.concurrent.Future + +import org.apache.spark.{Logging, SparkException} +import org.apache.spark.storage.BlockManagerMessages._ +import org.apache.spark.util.Utils + + +/** + * BlockManagerMasterActor is an actor on the master node to track statuses of + * all slaves' block managers. + */ +private[spark] +class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging { + + // Mapping from block manager id to the block manager's information. + private val blockManagerInfo = + new mutable.HashMap[BlockManagerId, BlockManagerMasterActor.BlockManagerInfo] + + // Mapping from executor ID to block manager ID. + private val blockManagerIdByExecutor = new mutable.HashMap[String, BlockManagerId] + + // Mapping from block id to the set of block managers that have the block. + private val blockLocations = new JHashMap[String, mutable.HashSet[BlockManagerId]] + + val akkaTimeout = Duration.create( + System.getProperty("spark.akka.askTimeout", "10").toLong, "seconds") + + initLogging() + + val slaveTimeout = System.getProperty("spark.storage.blockManagerSlaveTimeoutMs", + "" + (BlockManager.getHeartBeatFrequencyFromSystemProperties * 3)).toLong + + val checkTimeoutInterval = System.getProperty("spark.storage.blockManagerTimeoutIntervalMs", + "60000").toLong + + var timeoutCheckingTask: Cancellable = null + + override def preStart() { + if (!BlockManager.getDisableHeartBeatsForTesting) { + import context.dispatcher + timeoutCheckingTask = context.system.scheduler.schedule( + 0.seconds, checkTimeoutInterval.milliseconds, self, ExpireDeadHosts) + } + super.preStart() + } + + def receive = { + case RegisterBlockManager(blockManagerId, maxMemSize, slaveActor) => + register(blockManagerId, maxMemSize, slaveActor) + sender ! true + + case UpdateBlockInfo(blockManagerId, blockId, storageLevel, deserializedSize, size) => + // TODO: Ideally we want to handle all the message replies in receive instead of in the + // individual private methods. + updateBlockInfo(blockManagerId, blockId, storageLevel, deserializedSize, size) + + case GetLocations(blockId) => + sender ! getLocations(blockId) + + case GetLocationsMultipleBlockIds(blockIds) => + sender ! getLocationsMultipleBlockIds(blockIds) + + case GetPeers(blockManagerId, size) => + sender ! getPeers(blockManagerId, size) + + case GetMemoryStatus => + sender ! memoryStatus + + case GetStorageStatus => + sender ! storageStatus + + case RemoveRdd(rddId) => + sender ! removeRdd(rddId) + + case RemoveBlock(blockId) => + removeBlockFromWorkers(blockId) + sender ! true + + case RemoveExecutor(execId) => + removeExecutor(execId) + sender ! true + + case StopBlockManagerMaster => + logInfo("Stopping BlockManagerMaster") + sender ! true + if (timeoutCheckingTask != null) { + timeoutCheckingTask.cancel() + } + context.stop(self) + + case ExpireDeadHosts => + expireDeadHosts() + + case HeartBeat(blockManagerId) => + sender ! heartBeat(blockManagerId) + + case other => + logWarning("Got unknown message: " + other) + } + + private def removeRdd(rddId: Int): Future[Seq[Int]] = { + // First remove the metadata for the given RDD, and then asynchronously remove the blocks + // from the slaves. + + val prefix = "rdd_" + rddId + "_" + // Find all blocks for the given RDD, remove the block from both blockLocations and + // the blockManagerInfo that is tracking the blocks. + val blocks = blockLocations.keySet().filter(_.startsWith(prefix)) + blocks.foreach { blockId => + val bms: mutable.HashSet[BlockManagerId] = blockLocations.get(blockId) + bms.foreach(bm => blockManagerInfo.get(bm).foreach(_.removeBlock(blockId))) + blockLocations.remove(blockId) + } + + // Ask the slaves to remove the RDD, and put the result in a sequence of Futures. + // The dispatcher is used as an implicit argument into the Future sequence construction. + import context.dispatcher + val removeMsg = RemoveRdd(rddId) + Future.sequence(blockManagerInfo.values.map { bm => + bm.slaveActor.ask(removeMsg)(akkaTimeout).mapTo[Int] + }.toSeq) + } + + private def removeBlockManager(blockManagerId: BlockManagerId) { + val info = blockManagerInfo(blockManagerId) + + // Remove the block manager from blockManagerIdByExecutor. + blockManagerIdByExecutor -= blockManagerId.executorId + + // Remove it from blockManagerInfo and remove all the blocks. + blockManagerInfo.remove(blockManagerId) + val iterator = info.blocks.keySet.iterator + while (iterator.hasNext) { + val blockId = iterator.next + val locations = blockLocations.get(blockId) + locations -= blockManagerId + if (locations.size == 0) { + blockLocations.remove(locations) + } + } + } + + private def expireDeadHosts() { + logTrace("Checking for hosts with no recent heart beats in BlockManagerMaster.") + val now = System.currentTimeMillis() + val minSeenTime = now - slaveTimeout + val toRemove = new mutable.HashSet[BlockManagerId] + for (info <- blockManagerInfo.values) { + if (info.lastSeenMs < minSeenTime) { + logWarning("Removing BlockManager " + info.blockManagerId + " with no recent heart beats: " + + (now - info.lastSeenMs) + "ms exceeds " + slaveTimeout + "ms") + toRemove += info.blockManagerId + } + } + toRemove.foreach(removeBlockManager) + } + + private def removeExecutor(execId: String) { + logInfo("Trying to remove executor " + execId + " from BlockManagerMaster.") + blockManagerIdByExecutor.get(execId).foreach(removeBlockManager) + } + + private def heartBeat(blockManagerId: BlockManagerId): Boolean = { + if (!blockManagerInfo.contains(blockManagerId)) { + blockManagerId.executorId == "<driver>" && !isLocal + } else { + blockManagerInfo(blockManagerId).updateLastSeenMs() + true + } + } + + // Remove a block from the slaves that have it. This can only be used to remove + // blocks that the master knows about. + private def removeBlockFromWorkers(blockId: String) { + val locations = blockLocations.get(blockId) + if (locations != null) { + locations.foreach { blockManagerId: BlockManagerId => + val blockManager = blockManagerInfo.get(blockManagerId) + if (blockManager.isDefined) { + // Remove the block from the slave's BlockManager. + // Doesn't actually wait for a confirmation and the message might get lost. + // If message loss becomes frequent, we should add retry logic here. + blockManager.get.slaveActor ! RemoveBlock(blockId) + } + } + } + } + + // Return a map from the block manager id to max memory and remaining memory. + private def memoryStatus: Map[BlockManagerId, (Long, Long)] = { + blockManagerInfo.map { case(blockManagerId, info) => + (blockManagerId, (info.maxMem, info.remainingMem)) + }.toMap + } + + private def storageStatus: Array[StorageStatus] = { + blockManagerInfo.map { case(blockManagerId, info) => + import collection.JavaConverters._ + StorageStatus(blockManagerId, info.maxMem, info.blocks.asScala.toMap) + }.toArray + } + + private def register(id: BlockManagerId, maxMemSize: Long, slaveActor: ActorRef) { + if (id.executorId == "<driver>" && !isLocal) { + // Got a register message from the master node; don't register it + } else if (!blockManagerInfo.contains(id)) { + blockManagerIdByExecutor.get(id.executorId) match { + case Some(manager) => + // A block manager of the same executor already exists. + // This should never happen. Let's just quit. + logError("Got two different block manager registrations on " + id.executorId) + System.exit(1) + case None => + blockManagerIdByExecutor(id.executorId) = id + } + blockManagerInfo(id) = new BlockManagerMasterActor.BlockManagerInfo( + id, System.currentTimeMillis(), maxMemSize, slaveActor) + } + } + + private def updateBlockInfo( + blockManagerId: BlockManagerId, + blockId: String, + storageLevel: StorageLevel, + memSize: Long, + diskSize: Long) { + + if (!blockManagerInfo.contains(blockManagerId)) { + if (blockManagerId.executorId == "<driver>" && !isLocal) { + // We intentionally do not register the master (except in local mode), + // so we should not indicate failure. + sender ! true + } else { + sender ! false + } + return + } + + if (blockId == null) { + blockManagerInfo(blockManagerId).updateLastSeenMs() + sender ! true + return + } + + blockManagerInfo(blockManagerId).updateBlockInfo(blockId, storageLevel, memSize, diskSize) + + var locations: mutable.HashSet[BlockManagerId] = null + if (blockLocations.containsKey(blockId)) { + locations = blockLocations.get(blockId) + } else { + locations = new mutable.HashSet[BlockManagerId] + blockLocations.put(blockId, locations) + } + + if (storageLevel.isValid) { + locations.add(blockManagerId) + } else { + locations.remove(blockManagerId) + } + + // Remove the block from master tracking if it has been removed on all slaves. + if (locations.size == 0) { + blockLocations.remove(blockId) + } + sender ! true + } + + private def getLocations(blockId: String): Seq[BlockManagerId] = { + if (blockLocations.containsKey(blockId)) blockLocations.get(blockId).toSeq else Seq.empty + } + + private def getLocationsMultipleBlockIds(blockIds: Array[String]): Seq[Seq[BlockManagerId]] = { + blockIds.map(blockId => getLocations(blockId)) + } + + private def getPeers(blockManagerId: BlockManagerId, size: Int): Seq[BlockManagerId] = { + val peers: Array[BlockManagerId] = blockManagerInfo.keySet.toArray + + val selfIndex = peers.indexOf(blockManagerId) + if (selfIndex == -1) { + throw new SparkException("Self index for " + blockManagerId + " not found") + } + + // Note that this logic will select the same node multiple times if there aren't enough peers + Array.tabulate[BlockManagerId](size) { i => peers((selfIndex + i + 1) % peers.length) }.toSeq + } +} + + +private[spark] +object BlockManagerMasterActor { + + case class BlockStatus(storageLevel: StorageLevel, memSize: Long, diskSize: Long) + + class BlockManagerInfo( + val blockManagerId: BlockManagerId, + timeMs: Long, + val maxMem: Long, + val slaveActor: ActorRef) + extends Logging { + + private var _lastSeenMs: Long = timeMs + private var _remainingMem: Long = maxMem + + // Mapping from block id to its status. + private val _blocks = new JHashMap[String, BlockStatus] + + logInfo("Registering block manager %s with %s RAM".format( + blockManagerId.hostPort, Utils.bytesToString(maxMem))) + + def updateLastSeenMs() { + _lastSeenMs = System.currentTimeMillis() + } + + def updateBlockInfo(blockId: String, storageLevel: StorageLevel, memSize: Long, + diskSize: Long) { + + updateLastSeenMs() + + if (_blocks.containsKey(blockId)) { + // The block exists on the slave already. + val originalLevel: StorageLevel = _blocks.get(blockId).storageLevel + + if (originalLevel.useMemory) { + _remainingMem += memSize + } + } + + if (storageLevel.isValid) { + // isValid means it is either stored in-memory or on-disk. + _blocks.put(blockId, BlockStatus(storageLevel, memSize, diskSize)) + if (storageLevel.useMemory) { + _remainingMem -= memSize + logInfo("Added %s in memory on %s (size: %s, free: %s)".format( + blockId, blockManagerId.hostPort, Utils.bytesToString(memSize), + Utils.bytesToString(_remainingMem))) + } + if (storageLevel.useDisk) { + logInfo("Added %s on disk on %s (size: %s)".format( + blockId, blockManagerId.hostPort, Utils.bytesToString(diskSize))) + } + } else if (_blocks.containsKey(blockId)) { + // If isValid is not true, drop the block. + val blockStatus: BlockStatus = _blocks.get(blockId) + _blocks.remove(blockId) + if (blockStatus.storageLevel.useMemory) { + _remainingMem += blockStatus.memSize + logInfo("Removed %s on %s in memory (size: %s, free: %s)".format( + blockId, blockManagerId.hostPort, Utils.bytesToString(blockStatus.memSize), + Utils.bytesToString(_remainingMem))) + } + if (blockStatus.storageLevel.useDisk) { + logInfo("Removed %s on %s on disk (size: %s)".format( + blockId, blockManagerId.hostPort, Utils.bytesToString(blockStatus.diskSize))) + } + } + } + + def removeBlock(blockId: String) { + if (_blocks.containsKey(blockId)) { + _remainingMem += _blocks.get(blockId).memSize + _blocks.remove(blockId) + } + } + + def remainingMem: Long = _remainingMem + + def lastSeenMs: Long = _lastSeenMs + + def blocks: JHashMap[String, BlockStatus] = _blocks + + override def toString: String = "BlockManagerInfo " + timeMs + " " + _remainingMem + + def clear() { + _blocks.clear() + } + } +} diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala new file mode 100644 index 0000000000..24333a179c --- /dev/null +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala @@ -0,0 +1,110 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.storage + +import java.io.{Externalizable, ObjectInput, ObjectOutput} + +import akka.actor.ActorRef + + +private[storage] object BlockManagerMessages { + ////////////////////////////////////////////////////////////////////////////////// + // Messages from the master to slaves. + ////////////////////////////////////////////////////////////////////////////////// + sealed trait ToBlockManagerSlave + + // Remove a block from the slaves that have it. This can only be used to remove + // blocks that the master knows about. + case class RemoveBlock(blockId: String) extends ToBlockManagerSlave + + // Remove all blocks belonging to a specific RDD. + case class RemoveRdd(rddId: Int) extends ToBlockManagerSlave + + + ////////////////////////////////////////////////////////////////////////////////// + // Messages from slaves to the master. + ////////////////////////////////////////////////////////////////////////////////// + sealed trait ToBlockManagerMaster + + case class RegisterBlockManager( + blockManagerId: BlockManagerId, + maxMemSize: Long, + sender: ActorRef) + extends ToBlockManagerMaster + + case class HeartBeat(blockManagerId: BlockManagerId) extends ToBlockManagerMaster + + class UpdateBlockInfo( + var blockManagerId: BlockManagerId, + var blockId: String, + var storageLevel: StorageLevel, + var memSize: Long, + var diskSize: Long) + extends ToBlockManagerMaster + with Externalizable { + + def this() = this(null, null, null, 0, 0) // For deserialization only + + override def writeExternal(out: ObjectOutput) { + blockManagerId.writeExternal(out) + out.writeUTF(blockId) + storageLevel.writeExternal(out) + out.writeLong(memSize) + out.writeLong(diskSize) + } + + override def readExternal(in: ObjectInput) { + blockManagerId = BlockManagerId(in) + blockId = in.readUTF() + storageLevel = StorageLevel(in) + memSize = in.readLong() + diskSize = in.readLong() + } + } + + object UpdateBlockInfo { + def apply(blockManagerId: BlockManagerId, + blockId: String, + storageLevel: StorageLevel, + memSize: Long, + diskSize: Long): UpdateBlockInfo = { + new UpdateBlockInfo(blockManagerId, blockId, storageLevel, memSize, diskSize) + } + + // For pattern-matching + def unapply(h: UpdateBlockInfo): Option[(BlockManagerId, String, StorageLevel, Long, Long)] = { + Some((h.blockManagerId, h.blockId, h.storageLevel, h.memSize, h.diskSize)) + } + } + + case class GetLocations(blockId: String) extends ToBlockManagerMaster + + case class GetLocationsMultipleBlockIds(blockIds: Array[String]) extends ToBlockManagerMaster + + case class GetPeers(blockManagerId: BlockManagerId, size: Int) extends ToBlockManagerMaster + + case class RemoveExecutor(execId: String) extends ToBlockManagerMaster + + case object StopBlockManagerMaster extends ToBlockManagerMaster + + case object GetMemoryStatus extends ToBlockManagerMaster + + case object ExpireDeadHosts extends ToBlockManagerMaster + + case object GetStorageStatus extends ToBlockManagerMaster +} diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala new file mode 100644 index 0000000000..951503019f --- /dev/null +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.storage + +import akka.actor.Actor + +import org.apache.spark.storage.BlockManagerMessages._ + + +/** + * An actor to take commands from the master to execute options. For example, + * this is used to remove blocks from the slave's BlockManager. + */ +class BlockManagerSlaveActor(blockManager: BlockManager) extends Actor { + override def receive = { + + case RemoveBlock(blockId) => + blockManager.removeBlock(blockId) + + case RemoveRdd(rddId) => + val numBlocksRemoved = blockManager.removeRdd(rddId) + sender ! numBlocksRemoved + } +} diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerSource.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerSource.scala new file mode 100644 index 0000000000..3d709cfde4 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerSource.scala @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.storage + +import com.codahale.metrics.{Gauge,MetricRegistry} + +import org.apache.spark.metrics.source.Source + + +private[spark] class BlockManagerSource(val blockManager: BlockManager) extends Source { + val metricRegistry = new MetricRegistry() + val sourceName = "BlockManager" + + metricRegistry.register(MetricRegistry.name("memory", "maxMem", "MBytes"), new Gauge[Long] { + override def getValue: Long = { + val storageStatusList = blockManager.master.getStorageStatus + val maxMem = storageStatusList.map(_.maxMem).reduce(_ + _) + maxMem / 1024 / 1024 + } + }) + + metricRegistry.register(MetricRegistry.name("memory", "remainingMem", "MBytes"), new Gauge[Long] { + override def getValue: Long = { + val storageStatusList = blockManager.master.getStorageStatus + val remainingMem = storageStatusList.map(_.memRemaining).reduce(_ + _) + remainingMem / 1024 / 1024 + } + }) + + metricRegistry.register(MetricRegistry.name("memory", "memUsed", "MBytes"), new Gauge[Long] { + override def getValue: Long = { + val storageStatusList = blockManager.master.getStorageStatus + val maxMem = storageStatusList.map(_.maxMem).reduce(_ + _) + val remainingMem = storageStatusList.map(_.memRemaining).reduce(_ + _) + (maxMem - remainingMem) / 1024 / 1024 + } + }) + + metricRegistry.register(MetricRegistry.name("disk", "diskSpaceUsed", "MBytes"), new Gauge[Long] { + override def getValue: Long = { + val storageStatusList = blockManager.master.getStorageStatus + val diskSpaceUsed = storageStatusList + .flatMap(_.blocks.values.map(_.diskSize)) + .reduceOption(_ + _) + .getOrElse(0L) + + diskSpaceUsed / 1024 / 1024 + } + }) +} diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerWorker.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerWorker.scala new file mode 100644 index 0000000000..678c38203c --- /dev/null +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerWorker.scala @@ -0,0 +1,140 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.storage + +import java.nio.ByteBuffer + +import org.apache.spark.{Logging} +import org.apache.spark.network._ +import org.apache.spark.util.Utils + +/** + * A network interface for BlockManager. Each slave should have one + * BlockManagerWorker. + * + * TODO: Use event model. + */ +private[spark] class BlockManagerWorker(val blockManager: BlockManager) extends Logging { + initLogging() + + blockManager.connectionManager.onReceiveMessage(onBlockMessageReceive) + + def onBlockMessageReceive(msg: Message, id: ConnectionManagerId): Option[Message] = { + logDebug("Handling message " + msg) + msg match { + case bufferMessage: BufferMessage => { + try { + logDebug("Handling as a buffer message " + bufferMessage) + val blockMessages = BlockMessageArray.fromBufferMessage(bufferMessage) + logDebug("Parsed as a block message array") + val responseMessages = blockMessages.map(processBlockMessage).filter(_ != None).map(_.get) + return Some(new BlockMessageArray(responseMessages).toBufferMessage) + } catch { + case e: Exception => logError("Exception handling buffer message", e) + return None + } + } + case otherMessage: Any => { + logError("Unknown type message received: " + otherMessage) + return None + } + } + } + + def processBlockMessage(blockMessage: BlockMessage): Option[BlockMessage] = { + blockMessage.getType match { + case BlockMessage.TYPE_PUT_BLOCK => { + val pB = PutBlock(blockMessage.getId, blockMessage.getData, blockMessage.getLevel) + logDebug("Received [" + pB + "]") + putBlock(pB.id, pB.data, pB.level) + return None + } + case BlockMessage.TYPE_GET_BLOCK => { + val gB = new GetBlock(blockMessage.getId) + logDebug("Received [" + gB + "]") + val buffer = getBlock(gB.id) + if (buffer == null) { + return None + } + return Some(BlockMessage.fromGotBlock(GotBlock(gB.id, buffer))) + } + case _ => return None + } + } + + private def putBlock(id: String, bytes: ByteBuffer, level: StorageLevel) { + val startTimeMs = System.currentTimeMillis() + logDebug("PutBlock " + id + " started from " + startTimeMs + " with data: " + bytes) + blockManager.putBytes(id, bytes, level) + logDebug("PutBlock " + id + " used " + Utils.getUsedTimeMs(startTimeMs) + + " with data size: " + bytes.limit) + } + + private def getBlock(id: String): ByteBuffer = { + val startTimeMs = System.currentTimeMillis() + logDebug("GetBlock " + id + " started from " + startTimeMs) + val buffer = blockManager.getLocalBytes(id) match { + case Some(bytes) => bytes + case None => null + } + logDebug("GetBlock " + id + " used " + Utils.getUsedTimeMs(startTimeMs) + + " and got buffer " + buffer) + return buffer + } +} + +private[spark] object BlockManagerWorker extends Logging { + private var blockManagerWorker: BlockManagerWorker = null + + initLogging() + + def startBlockManagerWorker(manager: BlockManager) { + blockManagerWorker = new BlockManagerWorker(manager) + } + + def syncPutBlock(msg: PutBlock, toConnManagerId: ConnectionManagerId): Boolean = { + val blockManager = blockManagerWorker.blockManager + val connectionManager = blockManager.connectionManager + val blockMessage = BlockMessage.fromPutBlock(msg) + val blockMessageArray = new BlockMessageArray(blockMessage) + val resultMessage = connectionManager.sendMessageReliablySync( + toConnManagerId, blockMessageArray.toBufferMessage) + return (resultMessage != None) + } + + def syncGetBlock(msg: GetBlock, toConnManagerId: ConnectionManagerId): ByteBuffer = { + val blockManager = blockManagerWorker.blockManager + val connectionManager = blockManager.connectionManager + val blockMessage = BlockMessage.fromGetBlock(msg) + val blockMessageArray = new BlockMessageArray(blockMessage) + val responseMessage = connectionManager.sendMessageReliablySync( + toConnManagerId, blockMessageArray.toBufferMessage) + responseMessage match { + case Some(message) => { + val bufferMessage = message.asInstanceOf[BufferMessage] + logDebug("Response message received " + bufferMessage) + BlockMessageArray.fromBufferMessage(bufferMessage).foreach(blockMessage => { + logDebug("Found " + blockMessage) + return blockMessage.getData + }) + } + case None => logDebug("No response message received"); return null + } + return null + } +} diff --git a/core/src/main/scala/org/apache/spark/storage/BlockMessage.scala b/core/src/main/scala/org/apache/spark/storage/BlockMessage.scala new file mode 100644 index 0000000000..d8fa6a91d1 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/storage/BlockMessage.scala @@ -0,0 +1,223 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.storage + +import java.nio.ByteBuffer + +import scala.collection.mutable.StringBuilder +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.network._ + +private[spark] case class GetBlock(id: String) +private[spark] case class GotBlock(id: String, data: ByteBuffer) +private[spark] case class PutBlock(id: String, data: ByteBuffer, level: StorageLevel) + +private[spark] class BlockMessage() { + // Un-initialized: typ = 0 + // GetBlock: typ = 1 + // GotBlock: typ = 2 + // PutBlock: typ = 3 + private var typ: Int = BlockMessage.TYPE_NON_INITIALIZED + private var id: String = null + private var data: ByteBuffer = null + private var level: StorageLevel = null + + def set(getBlock: GetBlock) { + typ = BlockMessage.TYPE_GET_BLOCK + id = getBlock.id + } + + def set(gotBlock: GotBlock) { + typ = BlockMessage.TYPE_GOT_BLOCK + id = gotBlock.id + data = gotBlock.data + } + + def set(putBlock: PutBlock) { + typ = BlockMessage.TYPE_PUT_BLOCK + id = putBlock.id + data = putBlock.data + level = putBlock.level + } + + def set(buffer: ByteBuffer) { + val startTime = System.currentTimeMillis + /* + println() + println("BlockMessage: ") + while(buffer.remaining > 0) { + print(buffer.get()) + } + buffer.rewind() + println() + println() + */ + typ = buffer.getInt() + val idLength = buffer.getInt() + val idBuilder = new StringBuilder(idLength) + for (i <- 1 to idLength) { + idBuilder += buffer.getChar() + } + id = idBuilder.toString() + + if (typ == BlockMessage.TYPE_PUT_BLOCK) { + + val booleanInt = buffer.getInt() + val replication = buffer.getInt() + level = StorageLevel(booleanInt, replication) + + val dataLength = buffer.getInt() + data = ByteBuffer.allocate(dataLength) + if (dataLength != buffer.remaining) { + throw new Exception("Error parsing buffer") + } + data.put(buffer) + data.flip() + } else if (typ == BlockMessage.TYPE_GOT_BLOCK) { + + val dataLength = buffer.getInt() + data = ByteBuffer.allocate(dataLength) + if (dataLength != buffer.remaining) { + throw new Exception("Error parsing buffer") + } + data.put(buffer) + data.flip() + } + + val finishTime = System.currentTimeMillis + } + + def set(bufferMsg: BufferMessage) { + val buffer = bufferMsg.buffers.apply(0) + buffer.clear() + set(buffer) + } + + def getType: Int = { + return typ + } + + def getId: String = { + return id + } + + def getData: ByteBuffer = { + return data + } + + def getLevel: StorageLevel = { + return level + } + + def toBufferMessage: BufferMessage = { + val startTime = System.currentTimeMillis + val buffers = new ArrayBuffer[ByteBuffer]() + var buffer = ByteBuffer.allocate(4 + 4 + id.length() * 2) + buffer.putInt(typ).putInt(id.length()) + id.foreach((x: Char) => buffer.putChar(x)) + buffer.flip() + buffers += buffer + + if (typ == BlockMessage.TYPE_PUT_BLOCK) { + buffer = ByteBuffer.allocate(8).putInt(level.toInt).putInt(level.replication) + buffer.flip() + buffers += buffer + + buffer = ByteBuffer.allocate(4).putInt(data.remaining) + buffer.flip() + buffers += buffer + + buffers += data + } else if (typ == BlockMessage.TYPE_GOT_BLOCK) { + buffer = ByteBuffer.allocate(4).putInt(data.remaining) + buffer.flip() + buffers += buffer + + buffers += data + } + + /* + println() + println("BlockMessage: ") + buffers.foreach(b => { + while(b.remaining > 0) { + print(b.get()) + } + b.rewind() + }) + println() + println() + */ + val finishTime = System.currentTimeMillis + return Message.createBufferMessage(buffers) + } + + override def toString: String = { + "BlockMessage [type = " + typ + ", id = " + id + ", level = " + level + + ", data = " + (if (data != null) data.remaining.toString else "null") + "]" + } +} + +private[spark] object BlockMessage { + val TYPE_NON_INITIALIZED: Int = 0 + val TYPE_GET_BLOCK: Int = 1 + val TYPE_GOT_BLOCK: Int = 2 + val TYPE_PUT_BLOCK: Int = 3 + + def fromBufferMessage(bufferMessage: BufferMessage): BlockMessage = { + val newBlockMessage = new BlockMessage() + newBlockMessage.set(bufferMessage) + newBlockMessage + } + + def fromByteBuffer(buffer: ByteBuffer): BlockMessage = { + val newBlockMessage = new BlockMessage() + newBlockMessage.set(buffer) + newBlockMessage + } + + def fromGetBlock(getBlock: GetBlock): BlockMessage = { + val newBlockMessage = new BlockMessage() + newBlockMessage.set(getBlock) + newBlockMessage + } + + def fromGotBlock(gotBlock: GotBlock): BlockMessage = { + val newBlockMessage = new BlockMessage() + newBlockMessage.set(gotBlock) + newBlockMessage + } + + def fromPutBlock(putBlock: PutBlock): BlockMessage = { + val newBlockMessage = new BlockMessage() + newBlockMessage.set(putBlock) + newBlockMessage + } + + def main(args: Array[String]) { + val B = new BlockMessage() + B.set(new PutBlock("ABC", ByteBuffer.allocate(10), StorageLevel.MEMORY_AND_DISK_SER_2)) + val bMsg = B.toBufferMessage + val C = new BlockMessage() + C.set(bMsg) + + println(B.getId + " " + B.getLevel) + println(C.getId + " " + C.getLevel) + } +} diff --git a/core/src/main/scala/org/apache/spark/storage/BlockMessageArray.scala b/core/src/main/scala/org/apache/spark/storage/BlockMessageArray.scala new file mode 100644 index 0000000000..0aaf846b5b --- /dev/null +++ b/core/src/main/scala/org/apache/spark/storage/BlockMessageArray.scala @@ -0,0 +1,159 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.storage + +import java.nio.ByteBuffer + +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark._ +import org.apache.spark.network._ + +private[spark] +class BlockMessageArray(var blockMessages: Seq[BlockMessage]) extends Seq[BlockMessage] with Logging { + + def this(bm: BlockMessage) = this(Array(bm)) + + def this() = this(null.asInstanceOf[Seq[BlockMessage]]) + + def apply(i: Int) = blockMessages(i) + + def iterator = blockMessages.iterator + + def length = blockMessages.length + + initLogging() + + def set(bufferMessage: BufferMessage) { + val startTime = System.currentTimeMillis + val newBlockMessages = new ArrayBuffer[BlockMessage]() + val buffer = bufferMessage.buffers(0) + buffer.clear() + /* + println() + println("BlockMessageArray: ") + while(buffer.remaining > 0) { + print(buffer.get()) + } + buffer.rewind() + println() + println() + */ + while (buffer.remaining() > 0) { + val size = buffer.getInt() + logDebug("Creating block message of size " + size + " bytes") + val newBuffer = buffer.slice() + newBuffer.clear() + newBuffer.limit(size) + logDebug("Trying to convert buffer " + newBuffer + " to block message") + val newBlockMessage = BlockMessage.fromByteBuffer(newBuffer) + logDebug("Created " + newBlockMessage) + newBlockMessages += newBlockMessage + buffer.position(buffer.position() + size) + } + val finishTime = System.currentTimeMillis + logDebug("Converted block message array from buffer message in " + (finishTime - startTime) / 1000.0 + " s") + this.blockMessages = newBlockMessages + } + + def toBufferMessage: BufferMessage = { + val buffers = new ArrayBuffer[ByteBuffer]() + + blockMessages.foreach(blockMessage => { + val bufferMessage = blockMessage.toBufferMessage + logDebug("Adding " + blockMessage) + val sizeBuffer = ByteBuffer.allocate(4).putInt(bufferMessage.size) + sizeBuffer.flip + buffers += sizeBuffer + buffers ++= bufferMessage.buffers + logDebug("Added " + bufferMessage) + }) + + logDebug("Buffer list:") + buffers.foreach((x: ByteBuffer) => logDebug("" + x)) + /* + println() + println("BlockMessageArray: ") + buffers.foreach(b => { + while(b.remaining > 0) { + print(b.get()) + } + b.rewind() + }) + println() + println() + */ + return Message.createBufferMessage(buffers) + } +} + +private[spark] object BlockMessageArray { + + def fromBufferMessage(bufferMessage: BufferMessage): BlockMessageArray = { + val newBlockMessageArray = new BlockMessageArray() + newBlockMessageArray.set(bufferMessage) + newBlockMessageArray + } + + def main(args: Array[String]) { + val blockMessages = + (0 until 10).map { i => + if (i % 2 == 0) { + val buffer = ByteBuffer.allocate(100) + buffer.clear + BlockMessage.fromPutBlock(PutBlock(i.toString, buffer, StorageLevel.MEMORY_ONLY_SER)) + } else { + BlockMessage.fromGetBlock(GetBlock(i.toString)) + } + } + val blockMessageArray = new BlockMessageArray(blockMessages) + println("Block message array created") + + val bufferMessage = blockMessageArray.toBufferMessage + println("Converted to buffer message") + + val totalSize = bufferMessage.size + val newBuffer = ByteBuffer.allocate(totalSize) + newBuffer.clear() + bufferMessage.buffers.foreach(buffer => { + assert (0 == buffer.position()) + newBuffer.put(buffer) + buffer.rewind() + }) + newBuffer.flip + val newBufferMessage = Message.createBufferMessage(newBuffer) + println("Copied to new buffer message, size = " + newBufferMessage.size) + + val newBlockMessageArray = BlockMessageArray.fromBufferMessage(newBufferMessage) + println("Converted back to block message array") + newBlockMessageArray.foreach(blockMessage => { + blockMessage.getType match { + case BlockMessage.TYPE_PUT_BLOCK => { + val pB = PutBlock(blockMessage.getId, blockMessage.getData, blockMessage.getLevel) + println(pB) + } + case BlockMessage.TYPE_GET_BLOCK => { + val gB = new GetBlock(blockMessage.getId) + println(gB) + } + } + }) + } +} + + diff --git a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala new file mode 100644 index 0000000000..39f103297f --- /dev/null +++ b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.storage + + +/** + * An interface for writing JVM objects to some underlying storage. This interface allows + * appending data to an existing block, and can guarantee atomicity in the case of faults + * as it allows the caller to revert partial writes. + * + * This interface does not support concurrent writes. + */ +abstract class BlockObjectWriter(val blockId: String) { + + var closeEventHandler: () => Unit = _ + + def open(): BlockObjectWriter + + def close() { + closeEventHandler() + } + + def isOpen: Boolean + + def registerCloseEventHandler(handler: () => Unit) { + closeEventHandler = handler + } + + /** + * Flush the partial writes and commit them as a single atomic block. Return the + * number of bytes written for this commit. + */ + def commit(): Long + + /** + * Reverts writes that haven't been flushed yet. Callers should invoke this function + * when there are runtime exceptions. + */ + def revertPartialWrites() + + /** + * Writes an object. + */ + def write(value: Any) + + /** + * Size of the valid writes, in bytes. + */ + def size(): Long +} diff --git a/core/src/main/scala/org/apache/spark/storage/BlockStore.scala b/core/src/main/scala/org/apache/spark/storage/BlockStore.scala new file mode 100644 index 0000000000..fa834371f4 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/storage/BlockStore.scala @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.storage + +import java.nio.ByteBuffer +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.Logging + +/** + * Abstract class to store blocks + */ +private[spark] +abstract class BlockStore(val blockManager: BlockManager) extends Logging { + def putBytes(blockId: String, bytes: ByteBuffer, level: StorageLevel) + + /** + * Put in a block and, possibly, also return its content as either bytes or another Iterator. + * This is used to efficiently write the values to multiple locations (e.g. for replication). + * + * @return a PutResult that contains the size of the data, as well as the values put if + * returnValues is true (if not, the result's data field can be null) + */ + def putValues(blockId: String, values: ArrayBuffer[Any], level: StorageLevel, + returnValues: Boolean) : PutResult + + /** + * Return the size of a block in bytes. + */ + def getSize(blockId: String): Long + + def getBytes(blockId: String): Option[ByteBuffer] + + def getValues(blockId: String): Option[Iterator[Any]] + + /** + * Remove a block, if it exists. + * @param blockId the block to remove. + * @return True if the block was found and removed, False otherwise. + */ + def remove(blockId: String): Boolean + + def contains(blockId: String): Boolean + + def clear() { } +} diff --git a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala new file mode 100644 index 0000000000..fc25ef0fae --- /dev/null +++ b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala @@ -0,0 +1,329 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.storage + +import java.io.{File, FileOutputStream, OutputStream, RandomAccessFile} +import java.nio.ByteBuffer +import java.nio.channels.FileChannel +import java.nio.channels.FileChannel.MapMode +import java.util.{Random, Date} +import java.text.SimpleDateFormat + +import scala.collection.mutable.ArrayBuffer + +import it.unimi.dsi.fastutil.io.FastBufferedOutputStream + +import org.apache.spark.executor.ExecutorExitCode +import org.apache.spark.serializer.{Serializer, SerializationStream} +import org.apache.spark.Logging +import org.apache.spark.network.netty.ShuffleSender +import org.apache.spark.network.netty.PathResolver +import org.apache.spark.util.Utils + + +/** + * Stores BlockManager blocks on disk. + */ +private class DiskStore(blockManager: BlockManager, rootDirs: String) + extends BlockStore(blockManager) with Logging { + + class DiskBlockObjectWriter(blockId: String, serializer: Serializer, bufferSize: Int) + extends BlockObjectWriter(blockId) { + + private val f: File = createFile(blockId /*, allowAppendExisting */) + + // The file channel, used for repositioning / truncating the file. + private var channel: FileChannel = null + private var bs: OutputStream = null + private var objOut: SerializationStream = null + private var lastValidPosition = 0L + private var initialized = false + + override def open(): DiskBlockObjectWriter = { + val fos = new FileOutputStream(f, true) + channel = fos.getChannel() + bs = blockManager.wrapForCompression(blockId, new FastBufferedOutputStream(fos, bufferSize)) + objOut = serializer.newInstance().serializeStream(bs) + initialized = true + this + } + + override def close() { + if (initialized) { + objOut.close() + channel = null + bs = null + objOut = null + } + // Invoke the close callback handler. + super.close() + } + + override def isOpen: Boolean = objOut != null + + // Flush the partial writes, and set valid length to be the length of the entire file. + // Return the number of bytes written for this commit. + override def commit(): Long = { + if (initialized) { + // NOTE: Flush the serializer first and then the compressed/buffered output stream + objOut.flush() + bs.flush() + val prevPos = lastValidPosition + lastValidPosition = channel.position() + lastValidPosition - prevPos + } else { + // lastValidPosition is zero if stream is uninitialized + lastValidPosition + } + } + + override def revertPartialWrites() { + if (initialized) { + // Discard current writes. We do this by flushing the outstanding writes and + // truncate the file to the last valid position. + objOut.flush() + bs.flush() + channel.truncate(lastValidPosition) + } + } + + override def write(value: Any) { + if (!initialized) { + open() + } + objOut.writeObject(value) + } + + override def size(): Long = lastValidPosition + } + + private val MAX_DIR_CREATION_ATTEMPTS: Int = 10 + private val subDirsPerLocalDir = System.getProperty("spark.diskStore.subDirectories", "64").toInt + + private var shuffleSender : ShuffleSender = null + // Create one local directory for each path mentioned in spark.local.dir; then, inside this + // directory, create multiple subdirectories that we will hash files into, in order to avoid + // having really large inodes at the top level. + private val localDirs: Array[File] = createLocalDirs() + private val subDirs = Array.fill(localDirs.length)(new Array[File](subDirsPerLocalDir)) + + addShutdownHook() + + def getBlockWriter(blockId: String, serializer: Serializer, bufferSize: Int) + : BlockObjectWriter = { + new DiskBlockObjectWriter(blockId, serializer, bufferSize) + } + + override def getSize(blockId: String): Long = { + getFile(blockId).length() + } + + override def putBytes(blockId: String, _bytes: ByteBuffer, level: StorageLevel) { + // So that we do not modify the input offsets ! + // duplicate does not copy buffer, so inexpensive + val bytes = _bytes.duplicate() + logDebug("Attempting to put block " + blockId) + val startTime = System.currentTimeMillis + val file = createFile(blockId) + val channel = new RandomAccessFile(file, "rw").getChannel() + while (bytes.remaining > 0) { + channel.write(bytes) + } + channel.close() + val finishTime = System.currentTimeMillis + logDebug("Block %s stored as %s file on disk in %d ms".format( + blockId, Utils.bytesToString(bytes.limit), (finishTime - startTime))) + } + + private def getFileBytes(file: File): ByteBuffer = { + val length = file.length() + val channel = new RandomAccessFile(file, "r").getChannel() + val buffer = try { + channel.map(MapMode.READ_ONLY, 0, length) + } finally { + channel.close() + } + + buffer + } + + override def putValues( + blockId: String, + values: ArrayBuffer[Any], + level: StorageLevel, + returnValues: Boolean) + : PutResult = { + + logDebug("Attempting to write values for block " + blockId) + val startTime = System.currentTimeMillis + val file = createFile(blockId) + val fileOut = blockManager.wrapForCompression(blockId, + new FastBufferedOutputStream(new FileOutputStream(file))) + val objOut = blockManager.defaultSerializer.newInstance().serializeStream(fileOut) + objOut.writeAll(values.iterator) + objOut.close() + val length = file.length() + + val timeTaken = System.currentTimeMillis - startTime + logDebug("Block %s stored as %s file on disk in %d ms".format( + blockId, Utils.bytesToString(length), timeTaken)) + + if (returnValues) { + // Return a byte buffer for the contents of the file + val buffer = getFileBytes(file) + PutResult(length, Right(buffer)) + } else { + PutResult(length, null) + } + } + + override def getBytes(blockId: String): Option[ByteBuffer] = { + val file = getFile(blockId) + val bytes = getFileBytes(file) + Some(bytes) + } + + override def getValues(blockId: String): Option[Iterator[Any]] = { + getBytes(blockId).map(bytes => blockManager.dataDeserialize(blockId, bytes)) + } + + /** + * A version of getValues that allows a custom serializer. This is used as part of the + * shuffle short-circuit code. + */ + def getValues(blockId: String, serializer: Serializer): Option[Iterator[Any]] = { + getBytes(blockId).map(bytes => blockManager.dataDeserialize(blockId, bytes, serializer)) + } + + override def remove(blockId: String): Boolean = { + val file = getFile(blockId) + if (file.exists()) { + file.delete() + } else { + false + } + } + + override def contains(blockId: String): Boolean = { + getFile(blockId).exists() + } + + private def createFile(blockId: String, allowAppendExisting: Boolean = false): File = { + val file = getFile(blockId) + if (!allowAppendExisting && file.exists()) { + // NOTE(shivaram): Delete the file if it exists. This might happen if a ShuffleMap task + // was rescheduled on the same machine as the old task. + logWarning("File for block " + blockId + " already exists on disk: " + file + ". Deleting") + file.delete() + } + file + } + + private def getFile(blockId: String): File = { + logDebug("Getting file for block " + blockId) + + // Figure out which local directory it hashes to, and which subdirectory in that + val hash = math.abs(blockId.hashCode) + val dirId = hash % localDirs.length + val subDirId = (hash / localDirs.length) % subDirsPerLocalDir + + // Create the subdirectory if it doesn't already exist + var subDir = subDirs(dirId)(subDirId) + if (subDir == null) { + subDir = subDirs(dirId).synchronized { + val old = subDirs(dirId)(subDirId) + if (old != null) { + old + } else { + val newDir = new File(localDirs(dirId), "%02x".format(subDirId)) + newDir.mkdir() + subDirs(dirId)(subDirId) = newDir + newDir + } + } + } + + new File(subDir, blockId) + } + + private def createLocalDirs(): Array[File] = { + logDebug("Creating local directories at root dirs '" + rootDirs + "'") + val dateFormat = new SimpleDateFormat("yyyyMMddHHmmss") + rootDirs.split(",").map { rootDir => + var foundLocalDir = false + var localDir: File = null + var localDirId: String = null + var tries = 0 + val rand = new Random() + while (!foundLocalDir && tries < MAX_DIR_CREATION_ATTEMPTS) { + tries += 1 + try { + localDirId = "%s-%04x".format(dateFormat.format(new Date), rand.nextInt(65536)) + localDir = new File(rootDir, "spark-local-" + localDirId) + if (!localDir.exists) { + foundLocalDir = localDir.mkdirs() + } + } catch { + case e: Exception => + logWarning("Attempt " + tries + " to create local dir " + localDir + " failed", e) + } + } + if (!foundLocalDir) { + logError("Failed " + MAX_DIR_CREATION_ATTEMPTS + + " attempts to create local dir in " + rootDir) + System.exit(ExecutorExitCode.DISK_STORE_FAILED_TO_CREATE_DIR) + } + logInfo("Created local directory at " + localDir) + localDir + } + } + + private def addShutdownHook() { + localDirs.foreach(localDir => Utils.registerShutdownDeleteDir(localDir)) + Runtime.getRuntime.addShutdownHook(new Thread("delete Spark local dirs") { + override def run() { + logDebug("Shutdown hook called") + localDirs.foreach { localDir => + try { + if (!Utils.hasRootAsShutdownDeleteDir(localDir)) Utils.deleteRecursively(localDir) + } catch { + case t: Throwable => + logError("Exception while deleting local spark dir: " + localDir, t) + } + } + if (shuffleSender != null) { + shuffleSender.stop + } + } + }) + } + + private[storage] def startShuffleBlockSender(port: Int): Int = { + val pResolver = new PathResolver { + override def getAbsolutePath(blockId: String): String = { + if (!blockId.startsWith("shuffle_")) { + return null + } + DiskStore.this.getFile(blockId).getAbsolutePath() + } + } + shuffleSender = new ShuffleSender(port, pResolver) + logInfo("Created ShuffleSender binding to port : "+ shuffleSender.port) + shuffleSender.port + } +} diff --git a/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala b/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala new file mode 100644 index 0000000000..3b3b2342fa --- /dev/null +++ b/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala @@ -0,0 +1,257 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.storage + +import java.util.LinkedHashMap +import java.util.concurrent.ArrayBlockingQueue +import java.nio.ByteBuffer +import collection.mutable.ArrayBuffer +import org.apache.spark.util.{SizeEstimator, Utils} + +/** + * Stores blocks in memory, either as ArrayBuffers of deserialized Java objects or as + * serialized ByteBuffers. + */ +private class MemoryStore(blockManager: BlockManager, maxMemory: Long) + extends BlockStore(blockManager) { + + case class Entry(value: Any, size: Long, deserialized: Boolean, var dropPending: Boolean = false) + + private val entries = new LinkedHashMap[String, Entry](32, 0.75f, true) + private var currentMemory = 0L + // Object used to ensure that only one thread is putting blocks and if necessary, dropping + // blocks from the memory store. + private val putLock = new Object() + + logInfo("MemoryStore started with capacity %s.".format(Utils.bytesToString(maxMemory))) + + def freeMemory: Long = maxMemory - currentMemory + + override def getSize(blockId: String): Long = { + entries.synchronized { + entries.get(blockId).size + } + } + + override def putBytes(blockId: String, _bytes: ByteBuffer, level: StorageLevel) { + // Work on a duplicate - since the original input might be used elsewhere. + val bytes = _bytes.duplicate() + bytes.rewind() + if (level.deserialized) { + val values = blockManager.dataDeserialize(blockId, bytes) + val elements = new ArrayBuffer[Any] + elements ++= values + val sizeEstimate = SizeEstimator.estimate(elements.asInstanceOf[AnyRef]) + tryToPut(blockId, elements, sizeEstimate, true) + } else { + tryToPut(blockId, bytes, bytes.limit, false) + } + } + + override def putValues( + blockId: String, + values: ArrayBuffer[Any], + level: StorageLevel, + returnValues: Boolean) + : PutResult = { + + if (level.deserialized) { + val sizeEstimate = SizeEstimator.estimate(values.asInstanceOf[AnyRef]) + tryToPut(blockId, values, sizeEstimate, true) + PutResult(sizeEstimate, Left(values.iterator)) + } else { + val bytes = blockManager.dataSerialize(blockId, values.iterator) + tryToPut(blockId, bytes, bytes.limit, false) + PutResult(bytes.limit(), Right(bytes.duplicate())) + } + } + + override def getBytes(blockId: String): Option[ByteBuffer] = { + val entry = entries.synchronized { + entries.get(blockId) + } + if (entry == null) { + None + } else if (entry.deserialized) { + Some(blockManager.dataSerialize(blockId, entry.value.asInstanceOf[ArrayBuffer[Any]].iterator)) + } else { + Some(entry.value.asInstanceOf[ByteBuffer].duplicate()) // Doesn't actually copy the data + } + } + + override def getValues(blockId: String): Option[Iterator[Any]] = { + val entry = entries.synchronized { + entries.get(blockId) + } + if (entry == null) { + None + } else if (entry.deserialized) { + Some(entry.value.asInstanceOf[ArrayBuffer[Any]].iterator) + } else { + val buffer = entry.value.asInstanceOf[ByteBuffer].duplicate() // Doesn't actually copy data + Some(blockManager.dataDeserialize(blockId, buffer)) + } + } + + override def remove(blockId: String): Boolean = { + entries.synchronized { + val entry = entries.get(blockId) + if (entry != null) { + entries.remove(blockId) + currentMemory -= entry.size + logInfo("Block %s of size %d dropped from memory (free %d)".format( + blockId, entry.size, freeMemory)) + true + } else { + false + } + } + } + + override def clear() { + entries.synchronized { + entries.clear() + } + logInfo("MemoryStore cleared") + } + + /** + * Return the RDD ID that a given block ID is from, or null if it is not an RDD block. + */ + private def getRddId(blockId: String): String = { + if (blockId.startsWith("rdd_")) { + blockId.split('_')(1) + } else { + null + } + } + + /** + * Try to put in a set of values, if we can free up enough space. The value should either be + * an ArrayBuffer if deserialized is true or a ByteBuffer otherwise. Its (possibly estimated) + * size must also be passed by the caller. + * + * Locks on the object putLock to ensure that all the put requests and its associated block + * dropping is done by only on thread at a time. Otherwise while one thread is dropping + * blocks to free memory for one block, another thread may use up the freed space for + * another block. + */ + private def tryToPut(blockId: String, value: Any, size: Long, deserialized: Boolean): Boolean = { + // TODO: Its possible to optimize the locking by locking entries only when selecting blocks + // to be dropped. Once the to-be-dropped blocks have been selected, and lock on entries has been + // released, it must be ensured that those to-be-dropped blocks are not double counted for + // freeing up more space for another block that needs to be put. Only then the actually dropping + // of blocks (and writing to disk if necessary) can proceed in parallel. + putLock.synchronized { + if (ensureFreeSpace(blockId, size)) { + val entry = new Entry(value, size, deserialized) + entries.synchronized { entries.put(blockId, entry) } + currentMemory += size + if (deserialized) { + logInfo("Block %s stored as values to memory (estimated size %s, free %s)".format( + blockId, Utils.bytesToString(size), Utils.bytesToString(freeMemory))) + } else { + logInfo("Block %s stored as bytes to memory (size %s, free %s)".format( + blockId, Utils.bytesToString(size), Utils.bytesToString(freeMemory))) + } + true + } else { + // Tell the block manager that we couldn't put it in memory so that it can drop it to + // disk if the block allows disk storage. + val data = if (deserialized) { + Left(value.asInstanceOf[ArrayBuffer[Any]]) + } else { + Right(value.asInstanceOf[ByteBuffer].duplicate()) + } + blockManager.dropFromMemory(blockId, data) + false + } + } + } + + /** + * Tries to free up a given amount of space to store a particular block, but can fail and return + * false if either the block is bigger than our memory or it would require replacing another + * block from the same RDD (which leads to a wasteful cyclic replacement pattern for RDDs that + * don't fit into memory that we want to avoid). + * + * Assumes that a lock is held by the caller to ensure only one thread is dropping blocks. + * Otherwise, the freed space may fill up before the caller puts in their new value. + */ + private def ensureFreeSpace(blockIdToAdd: String, space: Long): Boolean = { + + logInfo("ensureFreeSpace(%d) called with curMem=%d, maxMem=%d".format( + space, currentMemory, maxMemory)) + + if (space > maxMemory) { + logInfo("Will not store " + blockIdToAdd + " as it is larger than our memory limit") + return false + } + + if (maxMemory - currentMemory < space) { + val rddToAdd = getRddId(blockIdToAdd) + val selectedBlocks = new ArrayBuffer[String]() + var selectedMemory = 0L + + // This is synchronized to ensure that the set of entries is not changed + // (because of getValue or getBytes) while traversing the iterator, as that + // can lead to exceptions. + entries.synchronized { + val iterator = entries.entrySet().iterator() + while (maxMemory - (currentMemory - selectedMemory) < space && iterator.hasNext) { + val pair = iterator.next() + val blockId = pair.getKey + if (rddToAdd != null && rddToAdd == getRddId(blockId)) { + logInfo("Will not store " + blockIdToAdd + " as it would require dropping another " + + "block from the same RDD") + return false + } + selectedBlocks += blockId + selectedMemory += pair.getValue.size + } + } + + if (maxMemory - (currentMemory - selectedMemory) >= space) { + logInfo(selectedBlocks.size + " blocks selected for dropping") + for (blockId <- selectedBlocks) { + val entry = entries.synchronized { entries.get(blockId) } + // This should never be null as only one thread should be dropping + // blocks and removing entries. However the check is still here for + // future safety. + if (entry != null) { + val data = if (entry.deserialized) { + Left(entry.value.asInstanceOf[ArrayBuffer[Any]]) + } else { + Right(entry.value.asInstanceOf[ByteBuffer].duplicate()) + } + blockManager.dropFromMemory(blockId, data) + } + } + return true + } else { + return false + } + } + return true + } + + override def contains(blockId: String): Boolean = { + entries.synchronized { entries.containsKey(blockId) } + } +} + diff --git a/core/src/main/scala/org/apache/spark/storage/PutResult.scala b/core/src/main/scala/org/apache/spark/storage/PutResult.scala new file mode 100644 index 0000000000..2eba2f06b5 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/storage/PutResult.scala @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.storage + +import java.nio.ByteBuffer + +/** + * Result of adding a block into a BlockStore. Contains its estimated size, and possibly the + * values put if the caller asked for them to be returned (e.g. for chaining replication) + */ +private[spark] case class PutResult(size: Long, data: Either[Iterator[_], ByteBuffer]) diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala new file mode 100644 index 0000000000..9da11efb57 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.storage + +import org.apache.spark.serializer.Serializer + + +private[spark] +class ShuffleWriterGroup(val id: Int, val writers: Array[BlockObjectWriter]) + + +private[spark] +trait ShuffleBlocks { + def acquireWriters(mapId: Int): ShuffleWriterGroup + def releaseWriters(group: ShuffleWriterGroup) +} + + +private[spark] +class ShuffleBlockManager(blockManager: BlockManager) { + + def forShuffle(shuffleId: Int, numBuckets: Int, serializer: Serializer): ShuffleBlocks = { + new ShuffleBlocks { + // Get a group of writers for a map task. + override def acquireWriters(mapId: Int): ShuffleWriterGroup = { + val bufferSize = System.getProperty("spark.shuffle.file.buffer.kb", "100").toInt * 1024 + val writers = Array.tabulate[BlockObjectWriter](numBuckets) { bucketId => + val blockId = ShuffleBlockManager.blockId(shuffleId, bucketId, mapId) + blockManager.getDiskBlockWriter(blockId, serializer, bufferSize) + } + new ShuffleWriterGroup(mapId, writers) + } + + override def releaseWriters(group: ShuffleWriterGroup) = { + // Nothing really to release here. + } + } + } +} + + +private[spark] +object ShuffleBlockManager { + + // Returns the block id for a given shuffle block. + def blockId(shuffleId: Int, bucketId: Int, groupId: Int): String = { + "shuffle_" + shuffleId + "_" + groupId + "_" + bucketId + } + + // Returns true if the block is a shuffle block. + def isShuffle(blockId: String): Boolean = blockId.startsWith("shuffle_") +} diff --git a/core/src/main/scala/org/apache/spark/storage/StorageLevel.scala b/core/src/main/scala/org/apache/spark/storage/StorageLevel.scala new file mode 100644 index 0000000000..755f1a760e --- /dev/null +++ b/core/src/main/scala/org/apache/spark/storage/StorageLevel.scala @@ -0,0 +1,146 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.storage + +import java.io.{Externalizable, IOException, ObjectInput, ObjectOutput} + +/** + * Flags for controlling the storage of an RDD. Each StorageLevel records whether to use memory, + * whether to drop the RDD to disk if it falls out of memory, whether to keep the data in memory + * in a serialized format, and whether to replicate the RDD partitions on multiple nodes. + * The [[org.apache.spark.storage.StorageLevel$]] singleton object contains some static constants for + * commonly useful storage levels. To create your own storage level object, use the factor method + * of the singleton object (`StorageLevel(...)`). + */ +class StorageLevel private( + private var useDisk_ : Boolean, + private var useMemory_ : Boolean, + private var deserialized_ : Boolean, + private var replication_ : Int = 1) + extends Externalizable { + + // TODO: Also add fields for caching priority, dataset ID, and flushing. + private def this(flags: Int, replication: Int) { + this((flags & 4) != 0, (flags & 2) != 0, (flags & 1) != 0, replication) + } + + def this() = this(false, true, false) // For deserialization + + def useDisk = useDisk_ + def useMemory = useMemory_ + def deserialized = deserialized_ + def replication = replication_ + + assert(replication < 40, "Replication restricted to be less than 40 for calculating hashcodes") + + override def clone(): StorageLevel = new StorageLevel( + this.useDisk, this.useMemory, this.deserialized, this.replication) + + override def equals(other: Any): Boolean = other match { + case s: StorageLevel => + s.useDisk == useDisk && + s.useMemory == useMemory && + s.deserialized == deserialized && + s.replication == replication + case _ => + false + } + + def isValid = ((useMemory || useDisk) && (replication > 0)) + + def toInt: Int = { + var ret = 0 + if (useDisk_) { + ret |= 4 + } + if (useMemory_) { + ret |= 2 + } + if (deserialized_) { + ret |= 1 + } + return ret + } + + override def writeExternal(out: ObjectOutput) { + out.writeByte(toInt) + out.writeByte(replication_) + } + + override def readExternal(in: ObjectInput) { + val flags = in.readByte() + useDisk_ = (flags & 4) != 0 + useMemory_ = (flags & 2) != 0 + deserialized_ = (flags & 1) != 0 + replication_ = in.readByte() + } + + @throws(classOf[IOException]) + private def readResolve(): Object = StorageLevel.getCachedStorageLevel(this) + + override def toString: String = + "StorageLevel(%b, %b, %b, %d)".format(useDisk, useMemory, deserialized, replication) + + override def hashCode(): Int = toInt * 41 + replication + def description : String = { + var result = "" + result += (if (useDisk) "Disk " else "") + result += (if (useMemory) "Memory " else "") + result += (if (deserialized) "Deserialized " else "Serialized") + result += "%sx Replicated".format(replication) + result + } +} + + +object StorageLevel { + val NONE = new StorageLevel(false, false, false) + val DISK_ONLY = new StorageLevel(true, false, false) + val DISK_ONLY_2 = new StorageLevel(true, false, false, 2) + val MEMORY_ONLY = new StorageLevel(false, true, true) + val MEMORY_ONLY_2 = new StorageLevel(false, true, true, 2) + val MEMORY_ONLY_SER = new StorageLevel(false, true, false) + val MEMORY_ONLY_SER_2 = new StorageLevel(false, true, false, 2) + val MEMORY_AND_DISK = new StorageLevel(true, true, true) + val MEMORY_AND_DISK_2 = new StorageLevel(true, true, true, 2) + val MEMORY_AND_DISK_SER = new StorageLevel(true, true, false) + val MEMORY_AND_DISK_SER_2 = new StorageLevel(true, true, false, 2) + + /** Create a new StorageLevel object */ + def apply(useDisk: Boolean, useMemory: Boolean, deserialized: Boolean, replication: Int = 1) = + getCachedStorageLevel(new StorageLevel(useDisk, useMemory, deserialized, replication)) + + /** Create a new StorageLevel object from its integer representation */ + def apply(flags: Int, replication: Int) = + getCachedStorageLevel(new StorageLevel(flags, replication)) + + /** Read StorageLevel object from ObjectInput stream */ + def apply(in: ObjectInput) = { + val obj = new StorageLevel() + obj.readExternal(in) + getCachedStorageLevel(obj) + } + + private[spark] + val storageLevelCache = new java.util.concurrent.ConcurrentHashMap[StorageLevel, StorageLevel]() + + private[spark] def getCachedStorageLevel(level: StorageLevel): StorageLevel = { + storageLevelCache.putIfAbsent(level, level) + storageLevelCache.get(level) + } +} diff --git a/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala b/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala new file mode 100644 index 0000000000..2bb7715696 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala @@ -0,0 +1,116 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.storage + +import org.apache.spark.{SparkContext} +import BlockManagerMasterActor.BlockStatus +import org.apache.spark.util.Utils + +private[spark] +case class StorageStatus(blockManagerId: BlockManagerId, maxMem: Long, + blocks: Map[String, BlockStatus]) { + + def memUsed(blockPrefix: String = "") = { + blocks.filterKeys(_.startsWith(blockPrefix)).values.map(_.memSize). + reduceOption(_+_).getOrElse(0l) + } + + def diskUsed(blockPrefix: String = "") = { + blocks.filterKeys(_.startsWith(blockPrefix)).values.map(_.diskSize). + reduceOption(_+_).getOrElse(0l) + } + + def memRemaining : Long = maxMem - memUsed() + +} + +case class RDDInfo(id: Int, name: String, storageLevel: StorageLevel, + numCachedPartitions: Int, numPartitions: Int, memSize: Long, diskSize: Long) + extends Ordered[RDDInfo] { + override def toString = { + import Utils.bytesToString + "RDD \"%s\" (%d) Storage: %s; CachedPartitions: %d; TotalPartitions: %d; MemorySize: %s; DiskSize: %s".format(name, id, + storageLevel.toString, numCachedPartitions, numPartitions, bytesToString(memSize), bytesToString(diskSize)) + } + + override def compare(that: RDDInfo) = { + this.id - that.id + } +} + +/* Helper methods for storage-related objects */ +private[spark] +object StorageUtils { + + /* Returns RDD-level information, compiled from a list of StorageStatus objects */ + def rddInfoFromStorageStatus(storageStatusList: Seq[StorageStatus], + sc: SparkContext) : Array[RDDInfo] = { + rddInfoFromBlockStatusList(storageStatusList.flatMap(_.blocks).toMap, sc) + } + + /* Returns a map of blocks to their locations, compiled from a list of StorageStatus objects */ + def blockLocationsFromStorageStatus(storageStatusList: Seq[StorageStatus]) = { + val blockLocationPairs = storageStatusList + .flatMap(s => s.blocks.map(b => (b._1, s.blockManagerId.hostPort))) + blockLocationPairs.groupBy(_._1).map{case (k, v) => (k, v.unzip._2)}.toMap + } + + /* Given a list of BlockStatus objets, returns information for each RDD */ + def rddInfoFromBlockStatusList(infos: Map[String, BlockStatus], + sc: SparkContext) : Array[RDDInfo] = { + + // Group by rddId, ignore the partition name + val groupedRddBlocks = infos.filterKeys(_.startsWith("rdd_")).groupBy { case(k, v) => + k.substring(0,k.lastIndexOf('_')) + }.mapValues(_.values.toArray) + + // For each RDD, generate an RDDInfo object + val rddInfos = groupedRddBlocks.map { case (rddKey, rddBlocks) => + // Add up memory and disk sizes + val memSize = rddBlocks.map(_.memSize).reduce(_ + _) + val diskSize = rddBlocks.map(_.diskSize).reduce(_ + _) + + // Find the id of the RDD, e.g. rdd_1 => 1 + val rddId = rddKey.split("_").last.toInt + + // Get the friendly name and storage level for the RDD, if available + sc.persistentRdds.get(rddId).map { r => + val rddName = Option(r.name).getOrElse(rddKey) + val rddStorageLevel = r.getStorageLevel + RDDInfo(rddId, rddName, rddStorageLevel, rddBlocks.length, r.partitions.size, memSize, diskSize) + } + }.flatten.toArray + + scala.util.Sorting.quickSort(rddInfos) + + rddInfos + } + + /* Removes all BlockStatus object that are not part of a block prefix */ + def filterStorageStatusByPrefix(storageStatusList: Array[StorageStatus], + prefix: String) : Array[StorageStatus] = { + + storageStatusList.map { status => + val newBlocks = status.blocks.filterKeys(_.startsWith(prefix)) + //val newRemainingMem = status.maxMem - newBlocks.values.map(_.memSize).reduce(_ + _) + StorageStatus(status.blockManagerId, status.maxMem, newBlocks) + } + + } + +} diff --git a/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala b/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala new file mode 100644 index 0000000000..f2ae8dd97d --- /dev/null +++ b/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala @@ -0,0 +1,113 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.storage + +import akka.actor._ + +import java.util.concurrent.ArrayBlockingQueue +import util.Random +import org.apache.spark.serializer.KryoSerializer + +/** + * This class tests the BlockManager and MemoryStore for thread safety and + * deadlocks. It spawns a number of producer and consumer threads. Producer + * threads continuously pushes blocks into the BlockManager and consumer + * threads continuously retrieves the blocks form the BlockManager and tests + * whether the block is correct or not. + */ +private[spark] object ThreadingTest { + + val numProducers = 5 + val numBlocksPerProducer = 20000 + + private[spark] class ProducerThread(manager: BlockManager, id: Int) extends Thread { + val queue = new ArrayBlockingQueue[(String, Seq[Int])](100) + + override def run() { + for (i <- 1 to numBlocksPerProducer) { + val blockId = "b-" + id + "-" + i + val blockSize = Random.nextInt(1000) + val block = (1 to blockSize).map(_ => Random.nextInt()) + val level = randomLevel() + val startTime = System.currentTimeMillis() + manager.put(blockId, block.iterator, level, true) + println("Pushed block " + blockId + " in " + (System.currentTimeMillis - startTime) + " ms") + queue.add((blockId, block)) + } + println("Producer thread " + id + " terminated") + } + + def randomLevel(): StorageLevel = { + math.abs(Random.nextInt()) % 4 match { + case 0 => StorageLevel.MEMORY_ONLY + case 1 => StorageLevel.MEMORY_ONLY_SER + case 2 => StorageLevel.MEMORY_AND_DISK + case 3 => StorageLevel.MEMORY_AND_DISK_SER + } + } + } + + private[spark] class ConsumerThread( + manager: BlockManager, + queue: ArrayBlockingQueue[(String, Seq[Int])] + ) extends Thread { + var numBlockConsumed = 0 + + override def run() { + println("Consumer thread started") + while(numBlockConsumed < numBlocksPerProducer) { + val (blockId, block) = queue.take() + val startTime = System.currentTimeMillis() + manager.get(blockId) match { + case Some(retrievedBlock) => + assert(retrievedBlock.toList.asInstanceOf[List[Int]] == block.toList, + "Block " + blockId + " did not match") + println("Got block " + blockId + " in " + + (System.currentTimeMillis - startTime) + " ms") + case None => + assert(false, "Block " + blockId + " could not be retrieved") + } + numBlockConsumed += 1 + } + println("Consumer thread terminated") + } + } + + def main(args: Array[String]) { + System.setProperty("spark.kryoserializer.buffer.mb", "1") + val actorSystem = ActorSystem("test") + val serializer = new KryoSerializer + val blockManagerMaster = new BlockManagerMaster( + actorSystem.actorOf(Props(new BlockManagerMasterActor(true)))) + val blockManager = new BlockManager( + "<driver>", actorSystem, blockManagerMaster, serializer, 1024 * 1024) + val producers = (1 to numProducers).map(i => new ProducerThread(blockManager, i)) + val consumers = producers.map(p => new ConsumerThread(blockManager, p.queue)) + producers.foreach(_.start) + consumers.foreach(_.start) + producers.foreach(_.join) + consumers.foreach(_.join) + blockManager.stop() + blockManagerMaster.stop() + actorSystem.shutdown() + actorSystem.awaitTermination() + println("Everything stopped.") + println( + "It will take sometime for the JVM to clean all temporary files and shutdown. Sit tight.") + } +} diff --git a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala new file mode 100644 index 0000000000..7211dbc7c6 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala @@ -0,0 +1,131 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ui + +import javax.servlet.http.{HttpServletResponse, HttpServletRequest} + +import scala.annotation.tailrec +import scala.util.{Try, Success, Failure} +import scala.xml.Node + +import net.liftweb.json.{JValue, pretty, render} + +import org.eclipse.jetty.server.{Server, Request, Handler} +import org.eclipse.jetty.server.handler.{ResourceHandler, HandlerList, ContextHandler, AbstractHandler} +import org.eclipse.jetty.util.thread.QueuedThreadPool + +import org.apache.spark.Logging + + +/** Utilities for launching a web server using Jetty's HTTP Server class */ +private[spark] object JettyUtils extends Logging { + // Base type for a function that returns something based on an HTTP request. Allows for + // implicit conversion from many types of functions to jetty Handlers. + type Responder[T] = HttpServletRequest => T + + // Conversions from various types of Responder's to jetty Handlers + implicit def jsonResponderToHandler(responder: Responder[JValue]): Handler = + createHandler(responder, "text/json", (in: JValue) => pretty(render(in))) + + implicit def htmlResponderToHandler(responder: Responder[Seq[Node]]): Handler = + createHandler(responder, "text/html", (in: Seq[Node]) => "<!DOCTYPE html>" + in.toString) + + implicit def textResponderToHandler(responder: Responder[String]): Handler = + createHandler(responder, "text/plain") + + def createHandler[T <% AnyRef](responder: Responder[T], contentType: String, + extractFn: T => String = (in: Any) => in.toString): Handler = { + new AbstractHandler { + def handle(target: String, + baseRequest: Request, + request: HttpServletRequest, + response: HttpServletResponse) { + response.setContentType("%s;charset=utf-8".format(contentType)) + response.setStatus(HttpServletResponse.SC_OK) + baseRequest.setHandled(true) + val result = responder(request) + response.setHeader("Cache-Control", "no-cache, no-store, must-revalidate") + response.getWriter().println(extractFn(result)) + } + } + } + + /** Creates a handler that always redirects the user to a given path */ + def createRedirectHandler(newPath: String): Handler = { + new AbstractHandler { + def handle(target: String, + baseRequest: Request, + request: HttpServletRequest, + response: HttpServletResponse) { + response.setStatus(302) + response.setHeader("Location", baseRequest.getRootURL + newPath) + baseRequest.setHandled(true) + } + } + } + + /** Creates a handler for serving files from a static directory */ + def createStaticHandler(resourceBase: String): ResourceHandler = { + val staticHandler = new ResourceHandler + Option(getClass.getClassLoader.getResource(resourceBase)) match { + case Some(res) => + staticHandler.setResourceBase(res.toString) + case None => + throw new Exception("Could not find resource path for Web UI: " + resourceBase) + } + staticHandler + } + + /** + * Attempts to start a Jetty server at the supplied ip:port which uses the supplied handlers. + * + * If the desired port number is contented, continues incrementing ports until a free port is + * found. Returns the chosen port and the jetty Server object. + */ + def startJettyServer(ip: String, port: Int, handlers: Seq[(String, Handler)]): (Server, Int) = { + val handlersToRegister = handlers.map { case(path, handler) => + val contextHandler = new ContextHandler(path) + contextHandler.setHandler(handler) + contextHandler.asInstanceOf[org.eclipse.jetty.server.Handler] + } + + val handlerList = new HandlerList + handlerList.setHandlers(handlersToRegister.toArray) + + @tailrec + def connect(currentPort: Int): (Server, Int) = { + val server = new Server(currentPort) + val pool = new QueuedThreadPool + pool.setDaemon(true) + server.setThreadPool(pool) + server.setHandler(handlerList) + + Try { server.start() } match { + case s: Success[_] => + (server, server.getConnectors.head.getLocalPort) + case f: Failure[_] => + server.stop() + logInfo("Failed to create UI at port, %s. Trying again.".format(currentPort)) + logInfo("Error was: " + f.toString) + connect((currentPort + 1) % 65536) + } + } + + connect(port) + } +} diff --git a/core/src/main/scala/org/apache/spark/ui/Page.scala b/core/src/main/scala/org/apache/spark/ui/Page.scala new file mode 100644 index 0000000000..b2a069a375 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/ui/Page.scala @@ -0,0 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ui + +private[spark] object Page extends Enumeration { + val Stages, Storage, Environment, Executors = Value +} diff --git a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala new file mode 100644 index 0000000000..48eb096063 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ui + +import javax.servlet.http.HttpServletRequest + +import org.eclipse.jetty.server.{Handler, Server} + +import org.apache.spark.{Logging, SparkContext, SparkEnv} +import org.apache.spark.ui.env.EnvironmentUI +import org.apache.spark.ui.exec.ExecutorsUI +import org.apache.spark.ui.storage.BlockManagerUI +import org.apache.spark.ui.jobs.JobProgressUI +import org.apache.spark.ui.JettyUtils._ +import org.apache.spark.util.Utils + +/** Top level user interface for Spark */ +private[spark] class SparkUI(sc: SparkContext) extends Logging { + val host = Option(System.getenv("SPARK_PUBLIC_DNS")).getOrElse(Utils.localHostName()) + val port = Option(System.getProperty("spark.ui.port")).getOrElse(SparkUI.DEFAULT_PORT).toInt + var boundPort: Option[Int] = None + var server: Option[Server] = None + + val handlers = Seq[(String, Handler)]( + ("/static", createStaticHandler(SparkUI.STATIC_RESOURCE_DIR)), + ("/", createRedirectHandler("/stages")) + ) + val storage = new BlockManagerUI(sc) + val jobs = new JobProgressUI(sc) + val env = new EnvironmentUI(sc) + val exec = new ExecutorsUI(sc) + + // Add MetricsServlet handlers by default + val metricsServletHandlers = SparkEnv.get.metricsSystem.getServletHandlers + + val allHandlers = storage.getHandlers ++ jobs.getHandlers ++ env.getHandlers ++ + exec.getHandlers ++ metricsServletHandlers ++ handlers + + /** Bind the HTTP server which backs this web interface */ + def bind() { + try { + val (srv, usedPort) = JettyUtils.startJettyServer("0.0.0.0", port, allHandlers) + logInfo("Started Spark Web UI at http://%s:%d".format(host, usedPort)) + server = Some(srv) + boundPort = Some(usedPort) + } catch { + case e: Exception => + logError("Failed to create Spark JettyUtils", e) + System.exit(1) + } + } + + /** Initialize all components of the server */ + def start() { + // NOTE: This is decoupled from bind() because of the following dependency cycle: + // DAGScheduler() requires that the port of this server is known + // This server must register all handlers, including JobProgressUI, before binding + // JobProgressUI registers a listener with SparkContext, which requires sc to initialize + jobs.start() + exec.start() + } + + def stop() { + server.foreach(_.stop()) + } + + private[spark] def appUIAddress = host + ":" + boundPort.getOrElse("-1") + +} + +private[spark] object SparkUI { + val DEFAULT_PORT = "3030" + val STATIC_RESOURCE_DIR = "org/apache/spark/ui/static" +} diff --git a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala new file mode 100644 index 0000000000..5573b3847b --- /dev/null +++ b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala @@ -0,0 +1,138 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ui + +import scala.xml.Node + +import org.apache.spark.SparkContext + +/** Utility functions for generating XML pages with spark content. */ +private[spark] object UIUtils { + import Page._ + + // Yarn has to go through a proxy so the base uri is provided and has to be on all links + private[spark] val uiRoot : String = Option(System.getenv("APPLICATION_WEB_PROXY_BASE")). + getOrElse("") + + def prependBaseUri(resource: String = "") = uiRoot + resource + + /** Returns a spark page with correctly formatted headers */ + def headerSparkPage(content: => Seq[Node], sc: SparkContext, title: String, page: Page.Value) + : Seq[Node] = { + val jobs = page match { + case Stages => <li class="active"><a href={prependBaseUri("/stages")}>Stages</a></li> + case _ => <li><a href={prependBaseUri("/stages")}>Stages</a></li> + } + val storage = page match { + case Storage => <li class="active"><a href={prependBaseUri("/storage")}>Storage</a></li> + case _ => <li><a href={prependBaseUri("/storage")}>Storage</a></li> + } + val environment = page match { + case Environment => + <li class="active"><a href={prependBaseUri("/environment")}>Environment</a></li> + case _ => <li><a href={prependBaseUri("/environment")}>Environment</a></li> + } + val executors = page match { + case Executors => <li class="active"><a href={prependBaseUri("/executors")}>Executors</a></li> + case _ => <li><a href={prependBaseUri("/executors")}>Executors</a></li> + } + + <html> + <head> + <meta http-equiv="Content-type" content="text/html; charset=utf-8" /> + <link rel="stylesheet" href={prependBaseUri("/static/bootstrap.min.css")} type="text/css" /> + <link rel="stylesheet" href={prependBaseUri("/static/webui.css")} type="text/css" /> + <script src={prependBaseUri("/static/sorttable.js")} ></script> + <title>{sc.appName} - {title}</title> + </head> + <body> + <div class="navbar navbar-static-top"> + <div class="navbar-inner"> + <a href={prependBaseUri("/")} class="brand"><img src={prependBaseUri("/static/spark-logo-77x50px-hd.png")} /></a> + <ul class="nav"> + {jobs} + {storage} + {environment} + {executors} + </ul> + <p class="navbar-text pull-right"><strong>{sc.appName}</strong> application UI</p> + </div> + </div> + + <div class="container-fluid"> + <div class="row-fluid"> + <div class="span12"> + <h3 style="vertical-align: bottom; display: inline-block;"> + {title} + </h3> + </div> + </div> + {content} + </div> + </body> + </html> + } + + /** Returns a page with the spark css/js and a simple format. Used for scheduler UI. */ + def basicSparkPage(content: => Seq[Node], title: String): Seq[Node] = { + <html> + <head> + <meta http-equiv="Content-type" content="text/html; charset=utf-8" /> + <link rel="stylesheet" href={prependBaseUri("/static/bootstrap.min.css")} type="text/css" /> + <link rel="stylesheet" href={prependBaseUri("/static/webui.css")} type="text/css" /> + <script src={prependBaseUri("/static/sorttable.js")} ></script> + <title>{title}</title> + </head> + <body> + <div class="container-fluid"> + <div class="row-fluid"> + <div class="span12"> + <h3 style="vertical-align: middle; display: inline-block;"> + <img src={prependBaseUri("/static/spark-logo-77x50px-hd.png")} style="margin-right: 15px;" /> + {title} + </h3> + </div> + </div> + {content} + </div> + </body> + </html> + } + + /** Returns an HTML table constructed by generating a row for each object in a sequence. */ + def listingTable[T]( + headers: Seq[String], + makeRow: T => Seq[Node], + rows: Seq[T], + fixedWidth: Boolean = false): Seq[Node] = { + + val colWidth = 100.toDouble / headers.size + val colWidthAttr = if (fixedWidth) colWidth + "%" else "" + var tableClass = "table table-bordered table-striped table-condensed sortable" + if (fixedWidth) { + tableClass += " table-fixed" + } + + <table class={tableClass}> + <thead>{headers.map(h => <th width={colWidthAttr}>{h}</th>)}</thead> + <tbody> + {rows.map(r => makeRow(r))} + </tbody> + </table> + } +} diff --git a/core/src/main/scala/org/apache/spark/ui/UIWorkloadGenerator.scala b/core/src/main/scala/org/apache/spark/ui/UIWorkloadGenerator.scala new file mode 100644 index 0000000000..0ecb22d2f9 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/ui/UIWorkloadGenerator.scala @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ui + +import scala.util.Random + +import org.apache.spark.SparkContext +import org.apache.spark.SparkContext._ +import org.apache.spark.scheduler.cluster.SchedulingMode + + +/** + * Continuously generates jobs that expose various features of the WebUI (internal testing tool). + * + * Usage: ./run spark.ui.UIWorkloadGenerator [master] + */ +private[spark] object UIWorkloadGenerator { + val NUM_PARTITIONS = 100 + val INTER_JOB_WAIT_MS = 5000 + + def main(args: Array[String]) { + if (args.length < 2) { + println("usage: ./spark-class spark.ui.UIWorkloadGenerator [master] [FIFO|FAIR]") + System.exit(1) + } + val master = args(0) + val schedulingMode = SchedulingMode.withName(args(1)) + val appName = "Spark UI Tester" + + if (schedulingMode == SchedulingMode.FAIR) { + System.setProperty("spark.cluster.schedulingmode", "FAIR") + } + val sc = new SparkContext(master, appName) + + def setProperties(s: String) = { + if(schedulingMode == SchedulingMode.FAIR) { + sc.setLocalProperty("spark.scheduler.cluster.fair.pool", s) + } + sc.setLocalProperty(SparkContext.SPARK_JOB_DESCRIPTION, s) + } + + val baseData = sc.makeRDD(1 to NUM_PARTITIONS * 10, NUM_PARTITIONS) + def nextFloat() = (new Random()).nextFloat() + + val jobs = Seq[(String, () => Long)]( + ("Count", baseData.count), + ("Cache and Count", baseData.map(x => x).cache.count), + ("Single Shuffle", baseData.map(x => (x % 10, x)).reduceByKey(_ + _).count), + ("Entirely failed phase", baseData.map(x => throw new Exception).count), + ("Partially failed phase", { + baseData.map{x => + val probFailure = (4.0 / NUM_PARTITIONS) + if (nextFloat() < probFailure) { + throw new Exception("This is a task failure") + } + 1 + }.count + }), + ("Partially failed phase (longer tasks)", { + baseData.map{x => + val probFailure = (4.0 / NUM_PARTITIONS) + if (nextFloat() < probFailure) { + Thread.sleep(100) + throw new Exception("This is a task failure") + } + 1 + }.count + }), + ("Job with delays", baseData.map(x => Thread.sleep(100)).count) + ) + + while (true) { + for ((desc, job) <- jobs) { + new Thread { + override def run() { + try { + setProperties(desc) + job() + println("Job funished: " + desc) + } catch { + case e: Exception => + println("Job Failed: " + desc) + } + } + }.start + Thread.sleep(INTER_JOB_WAIT_MS) + } + } + } +} diff --git a/core/src/main/scala/org/apache/spark/ui/env/EnvironmentUI.scala b/core/src/main/scala/org/apache/spark/ui/env/EnvironmentUI.scala new file mode 100644 index 0000000000..c5bf2acc9e --- /dev/null +++ b/core/src/main/scala/org/apache/spark/ui/env/EnvironmentUI.scala @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ui.env + +import javax.servlet.http.HttpServletRequest + +import scala.collection.JavaConversions._ +import scala.util.Properties +import scala.xml.Node + +import org.eclipse.jetty.server.Handler + +import org.apache.spark.ui.JettyUtils._ +import org.apache.spark.ui.UIUtils +import org.apache.spark.ui.Page.Environment +import org.apache.spark.SparkContext + + +private[spark] class EnvironmentUI(sc: SparkContext) { + + def getHandlers = Seq[(String, Handler)]( + ("/environment", (request: HttpServletRequest) => envDetails(request)) + ) + + def envDetails(request: HttpServletRequest): Seq[Node] = { + val jvmInformation = Seq( + ("Java Version", "%s (%s)".format(Properties.javaVersion, Properties.javaVendor)), + ("Java Home", Properties.javaHome), + ("Scala Version", Properties.versionString), + ("Scala Home", Properties.scalaHome) + ).sorted + def jvmRow(kv: (String, String)) = <tr><td>{kv._1}</td><td>{kv._2}</td></tr> + def jvmTable = + UIUtils.listingTable(Seq("Name", "Value"), jvmRow, jvmInformation, fixedWidth = true) + + val properties = System.getProperties.iterator.toSeq + val classPathProperty = properties.find { case (k, v) => + k.contains("java.class.path") + }.getOrElse(("", "")) + val sparkProperties = properties.filter(_._1.startsWith("spark")).sorted + val otherProperties = properties.diff(sparkProperties :+ classPathProperty).sorted + + val propertyHeaders = Seq("Name", "Value") + def propertyRow(kv: (String, String)) = <tr><td>{kv._1}</td><td>{kv._2}</td></tr> + val sparkPropertyTable = + UIUtils.listingTable(propertyHeaders, propertyRow, sparkProperties, fixedWidth = true) + val otherPropertyTable = + UIUtils.listingTable(propertyHeaders, propertyRow, otherProperties, fixedWidth = true) + + val classPathEntries = classPathProperty._2 + .split(System.getProperty("path.separator", ":")) + .filterNot(e => e.isEmpty) + .map(e => (e, "System Classpath")) + val addedJars = sc.addedJars.iterator.toSeq.map{case (path, time) => (path, "Added By User")} + val addedFiles = sc.addedFiles.iterator.toSeq.map{case (path, time) => (path, "Added By User")} + val classPath = (addedJars ++ addedFiles ++ classPathEntries).sorted + + val classPathHeaders = Seq("Resource", "Source") + def classPathRow(data: (String, String)) = <tr><td>{data._1}</td><td>{data._2}</td></tr> + val classPathTable = + UIUtils.listingTable(classPathHeaders, classPathRow, classPath, fixedWidth = true) + + val content = + <span> + <h4>Runtime Information</h4> {jvmTable} + <h4>Spark Properties</h4> + {sparkPropertyTable} + <h4>System Properties</h4> + {otherPropertyTable} + <h4>Classpath Entries</h4> + {classPathTable} + </span> + + UIUtils.headerSparkPage(content, sc, "Environment", Environment) + } +} diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsUI.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsUI.scala new file mode 100644 index 0000000000..d1868dcf78 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsUI.scala @@ -0,0 +1,154 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ui.exec + +import javax.servlet.http.HttpServletRequest + +import scala.collection.mutable.{HashMap, HashSet} +import scala.xml.Node + +import org.eclipse.jetty.server.Handler + +import org.apache.spark.{ExceptionFailure, Logging, SparkContext} +import org.apache.spark.executor.TaskMetrics +import org.apache.spark.scheduler.cluster.TaskInfo +import org.apache.spark.scheduler.{SparkListenerTaskStart, SparkListenerTaskEnd, SparkListener} +import org.apache.spark.ui.JettyUtils._ +import org.apache.spark.ui.Page.Executors +import org.apache.spark.ui.UIUtils +import org.apache.spark.util.Utils + + +private[spark] class ExecutorsUI(val sc: SparkContext) { + + private var _listener: Option[ExecutorsListener] = None + def listener = _listener.get + + def start() { + _listener = Some(new ExecutorsListener) + sc.addSparkListener(listener) + } + + def getHandlers = Seq[(String, Handler)]( + ("/executors", (request: HttpServletRequest) => render(request)) + ) + + def render(request: HttpServletRequest): Seq[Node] = { + val storageStatusList = sc.getExecutorStorageStatus + + val maxMem = storageStatusList.map(_.maxMem).fold(0L)(_+_) + val memUsed = storageStatusList.map(_.memUsed()).fold(0L)(_+_) + val diskSpaceUsed = storageStatusList.flatMap(_.blocks.values.map(_.diskSize)).fold(0L)(_+_) + + val execHead = Seq("Executor ID", "Address", "RDD blocks", "Memory used", "Disk used", + "Active tasks", "Failed tasks", "Complete tasks", "Total tasks") + + def execRow(kv: Seq[String]) = { + <tr> + <td>{kv(0)}</td> + <td>{kv(1)}</td> + <td>{kv(2)}</td> + <td sorttable_customkey={kv(3)}> + {Utils.bytesToString(kv(3).toLong)} / {Utils.bytesToString(kv(4).toLong)} + </td> + <td sorttable_customkey={kv(5)}> + {Utils.bytesToString(kv(5).toLong)} + </td> + <td>{kv(6)}</td> + <td>{kv(7)}</td> + <td>{kv(8)}</td> + <td>{kv(9)}</td> + </tr> + } + + val execInfo = for (b <- 0 until storageStatusList.size) yield getExecInfo(b) + val execTable = UIUtils.listingTable(execHead, execRow, execInfo) + + val content = + <div class="row-fluid"> + <div class="span12"> + <ul class="unstyled"> + <li><strong>Memory:</strong> + {Utils.bytesToString(memUsed)} Used + ({Utils.bytesToString(maxMem)} Total) </li> + <li><strong>Disk:</strong> {Utils.bytesToString(diskSpaceUsed)} Used </li> + </ul> + </div> + </div> + <div class = "row"> + <div class="span12"> + {execTable} + </div> + </div>; + + UIUtils.headerSparkPage(content, sc, "Executors (" + execInfo.size + ")", Executors) + } + + def getExecInfo(a: Int): Seq[String] = { + val execId = sc.getExecutorStorageStatus(a).blockManagerId.executorId + val hostPort = sc.getExecutorStorageStatus(a).blockManagerId.hostPort + val rddBlocks = sc.getExecutorStorageStatus(a).blocks.size.toString + val memUsed = sc.getExecutorStorageStatus(a).memUsed().toString + val maxMem = sc.getExecutorStorageStatus(a).maxMem.toString + val diskUsed = sc.getExecutorStorageStatus(a).diskUsed().toString + val activeTasks = listener.executorToTasksActive.get(a.toString).map(l => l.size).getOrElse(0) + val failedTasks = listener.executorToTasksFailed.getOrElse(a.toString, 0) + val completedTasks = listener.executorToTasksComplete.getOrElse(a.toString, 0) + val totalTasks = activeTasks + failedTasks + completedTasks + + Seq( + execId, + hostPort, + rddBlocks, + memUsed, + maxMem, + diskUsed, + activeTasks.toString, + failedTasks.toString, + completedTasks.toString, + totalTasks.toString + ) + } + + private[spark] class ExecutorsListener extends SparkListener with Logging { + val executorToTasksActive = HashMap[String, HashSet[TaskInfo]]() + val executorToTasksComplete = HashMap[String, Int]() + val executorToTasksFailed = HashMap[String, Int]() + + override def onTaskStart(taskStart: SparkListenerTaskStart) { + val eid = taskStart.taskInfo.executorId + val activeTasks = executorToTasksActive.getOrElseUpdate(eid, new HashSet[TaskInfo]()) + activeTasks += taskStart.taskInfo + } + + override def onTaskEnd(taskEnd: SparkListenerTaskEnd) { + val eid = taskEnd.taskInfo.executorId + val activeTasks = executorToTasksActive.getOrElseUpdate(eid, new HashSet[TaskInfo]()) + activeTasks -= taskEnd.taskInfo + val (failureInfo, metrics): (Option[ExceptionFailure], Option[TaskMetrics]) = + taskEnd.reason match { + case e: ExceptionFailure => + executorToTasksFailed(eid) = executorToTasksFailed.getOrElse(eid, 0) + 1 + (Some(e), e.metrics) + case _ => + executorToTasksComplete(eid) = executorToTasksComplete.getOrElse(eid, 0) + 1 + (None, Option(taskEnd.taskMetrics)) + } + } + } +} diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/IndexPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/IndexPage.scala new file mode 100644 index 0000000000..3b428effaf --- /dev/null +++ b/core/src/main/scala/org/apache/spark/ui/jobs/IndexPage.scala @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ui.jobs + +import javax.servlet.http.HttpServletRequest + +import scala.xml.{NodeSeq, Node} + +import org.apache.spark.scheduler.cluster.SchedulingMode +import org.apache.spark.ui.Page._ +import org.apache.spark.ui.UIUtils._ + + +/** Page showing list of all ongoing and recently finished stages and pools*/ +private[spark] class IndexPage(parent: JobProgressUI) { + def listener = parent.listener + + def render(request: HttpServletRequest): Seq[Node] = { + listener.synchronized { + val activeStages = listener.activeStages.toSeq + val completedStages = listener.completedStages.reverse.toSeq + val failedStages = listener.failedStages.reverse.toSeq + val now = System.currentTimeMillis() + + var activeTime = 0L + for (tasks <- listener.stageToTasksActive.values; t <- tasks) { + activeTime += t.timeRunning(now) + } + + val activeStagesTable = new StageTable(activeStages.sortBy(_.submissionTime).reverse, parent) + val completedStagesTable = new StageTable(completedStages.sortBy(_.submissionTime).reverse, parent) + val failedStagesTable = new StageTable(failedStages.sortBy(_.submissionTime).reverse, parent) + + val pools = listener.sc.getAllPools + val poolTable = new PoolTable(pools, listener) + val summary: NodeSeq = + <div> + <ul class="unstyled"> + <li> + <strong>Total Duration: </strong> + {parent.formatDuration(now - listener.sc.startTime)} + </li> + <li><strong>Scheduling Mode:</strong> {parent.sc.getSchedulingMode}</li> + <li> + <a href="#active"><strong>Active Stages:</strong></a> + {activeStages.size} + </li> + <li> + <a href="#completed"><strong>Completed Stages:</strong></a> + {completedStages.size} + </li> + <li> + <a href="#failed"><strong>Failed Stages:</strong></a> + {failedStages.size} + </li> + </ul> + </div> + + val content = summary ++ + {if (listener.sc.getSchedulingMode == SchedulingMode.FAIR) { + <h4>{pools.size} Fair Scheduler Pools</h4> ++ poolTable.toNodeSeq + } else { + Seq() + }} ++ + <h4 id="active">Active Stages ({activeStages.size})</h4> ++ + activeStagesTable.toNodeSeq++ + <h4 id="completed">Completed Stages ({completedStages.size})</h4> ++ + completedStagesTable.toNodeSeq++ + <h4 id ="failed">Failed Stages ({failedStages.size})</h4> ++ + failedStagesTable.toNodeSeq + + headerSparkPage(content, parent.sc, "Spark Stages", Stages) + } + } +} diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala new file mode 100644 index 0000000000..e2bcd98545 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala @@ -0,0 +1,173 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ui.jobs + +import scala.Seq +import scala.collection.mutable.{ListBuffer, HashMap, HashSet} + +import org.apache.spark.{ExceptionFailure, SparkContext, Success} +import org.apache.spark.scheduler._ +import org.apache.spark.scheduler.cluster.TaskInfo +import org.apache.spark.executor.TaskMetrics +import collection.mutable + +/** + * Tracks task-level information to be displayed in the UI. + * + * All access to the data structures in this class must be synchronized on the + * class, since the UI thread and the DAGScheduler event loop may otherwise + * be reading/updating the internal data structures concurrently. + */ +private[spark] class JobProgressListener(val sc: SparkContext) extends SparkListener { + // How many stages to remember + val RETAINED_STAGES = System.getProperty("spark.ui.retained_stages", "1000").toInt + val DEFAULT_POOL_NAME = "default" + + val stageToPool = new HashMap[Stage, String]() + val stageToDescription = new HashMap[Stage, String]() + val poolToActiveStages = new HashMap[String, HashSet[Stage]]() + + val activeStages = HashSet[Stage]() + val completedStages = ListBuffer[Stage]() + val failedStages = ListBuffer[Stage]() + + // Total metrics reflect metrics only for completed tasks + var totalTime = 0L + var totalShuffleRead = 0L + var totalShuffleWrite = 0L + + val stageToTime = HashMap[Int, Long]() + val stageToShuffleRead = HashMap[Int, Long]() + val stageToShuffleWrite = HashMap[Int, Long]() + val stageToTasksActive = HashMap[Int, HashSet[TaskInfo]]() + val stageToTasksComplete = HashMap[Int, Int]() + val stageToTasksFailed = HashMap[Int, Int]() + val stageToTaskInfos = + HashMap[Int, HashSet[(TaskInfo, Option[TaskMetrics], Option[ExceptionFailure])]]() + + override def onJobStart(jobStart: SparkListenerJobStart) {} + + override def onStageCompleted(stageCompleted: StageCompleted) = synchronized { + val stage = stageCompleted.stageInfo.stage + poolToActiveStages(stageToPool(stage)) -= stage + activeStages -= stage + completedStages += stage + trimIfNecessary(completedStages) + } + + /** If stages is too large, remove and garbage collect old stages */ + def trimIfNecessary(stages: ListBuffer[Stage]) = synchronized { + if (stages.size > RETAINED_STAGES) { + val toRemove = RETAINED_STAGES / 10 + stages.takeRight(toRemove).foreach( s => { + stageToTaskInfos.remove(s.id) + stageToTime.remove(s.id) + stageToShuffleRead.remove(s.id) + stageToShuffleWrite.remove(s.id) + stageToTasksActive.remove(s.id) + stageToTasksComplete.remove(s.id) + stageToTasksFailed.remove(s.id) + stageToPool.remove(s) + if (stageToDescription.contains(s)) {stageToDescription.remove(s)} + }) + stages.trimEnd(toRemove) + } + } + + /** For FIFO, all stages are contained by "default" pool but "default" pool here is meaningless */ + override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted) = synchronized { + val stage = stageSubmitted.stage + activeStages += stage + + val poolName = Option(stageSubmitted.properties).map { + p => p.getProperty("spark.scheduler.cluster.fair.pool", DEFAULT_POOL_NAME) + }.getOrElse(DEFAULT_POOL_NAME) + stageToPool(stage) = poolName + + val description = Option(stageSubmitted.properties).flatMap { + p => Option(p.getProperty(SparkContext.SPARK_JOB_DESCRIPTION)) + } + description.map(d => stageToDescription(stage) = d) + + val stages = poolToActiveStages.getOrElseUpdate(poolName, new HashSet[Stage]()) + stages += stage + } + + override def onTaskStart(taskStart: SparkListenerTaskStart) = synchronized { + val sid = taskStart.task.stageId + val tasksActive = stageToTasksActive.getOrElseUpdate(sid, new HashSet[TaskInfo]()) + tasksActive += taskStart.taskInfo + val taskList = stageToTaskInfos.getOrElse( + sid, HashSet[(TaskInfo, Option[TaskMetrics], Option[ExceptionFailure])]()) + taskList += ((taskStart.taskInfo, None, None)) + stageToTaskInfos(sid) = taskList + } + + override def onTaskEnd(taskEnd: SparkListenerTaskEnd) = synchronized { + val sid = taskEnd.task.stageId + val tasksActive = stageToTasksActive.getOrElseUpdate(sid, new HashSet[TaskInfo]()) + tasksActive -= taskEnd.taskInfo + val (failureInfo, metrics): (Option[ExceptionFailure], Option[TaskMetrics]) = + taskEnd.reason match { + case e: ExceptionFailure => + stageToTasksFailed(sid) = stageToTasksFailed.getOrElse(sid, 0) + 1 + (Some(e), e.metrics) + case _ => + stageToTasksComplete(sid) = stageToTasksComplete.getOrElse(sid, 0) + 1 + (None, Option(taskEnd.taskMetrics)) + } + + stageToTime.getOrElseUpdate(sid, 0L) + val time = metrics.map(m => m.executorRunTime).getOrElse(0) + stageToTime(sid) += time + totalTime += time + + stageToShuffleRead.getOrElseUpdate(sid, 0L) + val shuffleRead = metrics.flatMap(m => m.shuffleReadMetrics).map(s => + s.remoteBytesRead).getOrElse(0L) + stageToShuffleRead(sid) += shuffleRead + totalShuffleRead += shuffleRead + + stageToShuffleWrite.getOrElseUpdate(sid, 0L) + val shuffleWrite = metrics.flatMap(m => m.shuffleWriteMetrics).map(s => + s.shuffleBytesWritten).getOrElse(0L) + stageToShuffleWrite(sid) += shuffleWrite + totalShuffleWrite += shuffleWrite + + val taskList = stageToTaskInfos.getOrElse( + sid, HashSet[(TaskInfo, Option[TaskMetrics], Option[ExceptionFailure])]()) + taskList -= ((taskEnd.taskInfo, None, None)) + taskList += ((taskEnd.taskInfo, metrics, failureInfo)) + stageToTaskInfos(sid) = taskList + } + + override def onJobEnd(jobEnd: SparkListenerJobEnd) = synchronized { + jobEnd match { + case end: SparkListenerJobEnd => + end.jobResult match { + case JobFailed(ex, Some(stage)) => + activeStages -= stage + poolToActiveStages(stageToPool(stage)) -= stage + failedStages += stage + trimIfNecessary(failedStages) + case _ => + } + case _ => + } + } +} diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressUI.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressUI.scala new file mode 100644 index 0000000000..54e273fd8b --- /dev/null +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressUI.scala @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ui.jobs + +import scala.concurrent.duration._ + +import java.text.SimpleDateFormat + +import javax.servlet.http.HttpServletRequest + +import org.eclipse.jetty.server.Handler + +import scala.Seq +import scala.collection.mutable.{HashSet, ListBuffer, HashMap, ArrayBuffer} + +import org.apache.spark.ui.JettyUtils._ +import org.apache.spark.{ExceptionFailure, SparkContext, Success} +import org.apache.spark.scheduler._ +import org.apache.spark.scheduler.cluster.SchedulingMode +import org.apache.spark.scheduler.cluster.SchedulingMode.SchedulingMode +import org.apache.spark.util.Utils + +/** Web UI showing progress status of all jobs in the given SparkContext. */ +private[spark] class JobProgressUI(val sc: SparkContext) { + private var _listener: Option[JobProgressListener] = None + def listener = _listener.get + val dateFmt = new SimpleDateFormat("yyyy/MM/dd HH:mm:ss") + + private val indexPage = new IndexPage(this) + private val stagePage = new StagePage(this) + private val poolPage = new PoolPage(this) + + def start() { + _listener = Some(new JobProgressListener(sc)) + sc.addSparkListener(listener) + } + + def formatDuration(ms: Long) = Utils.msDurationToString(ms) + + def getHandlers = Seq[(String, Handler)]( + ("/stages/stage", (request: HttpServletRequest) => stagePage.render(request)), + ("/stages/pool", (request: HttpServletRequest) => poolPage.render(request)), + ("/stages", (request: HttpServletRequest) => indexPage.render(request)) + ) +} diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala new file mode 100644 index 0000000000..89fffcb80d --- /dev/null +++ b/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ui.jobs + +import javax.servlet.http.HttpServletRequest + +import scala.xml.{NodeSeq, Node} +import scala.collection.mutable.HashSet + +import org.apache.spark.scheduler.Stage +import org.apache.spark.ui.UIUtils._ +import org.apache.spark.ui.Page._ + +/** Page showing specific pool details */ +private[spark] class PoolPage(parent: JobProgressUI) { + def listener = parent.listener + + def render(request: HttpServletRequest): Seq[Node] = { + listener.synchronized { + val poolName = request.getParameter("poolname") + val poolToActiveStages = listener.poolToActiveStages + val activeStages = poolToActiveStages.get(poolName).toSeq.flatten + val activeStagesTable = new StageTable(activeStages.sortBy(_.submissionTime).reverse, parent) + + val pool = listener.sc.getPoolForName(poolName).get + val poolTable = new PoolTable(Seq(pool), listener) + + val content = <h4>Summary </h4> ++ poolTable.toNodeSeq() ++ + <h4>{activeStages.size} Active Stages</h4> ++ activeStagesTable.toNodeSeq() + + headerSparkPage(content, parent.sc, "Fair Scheduler Pool: " + poolName, Stages) + } + } +} diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/PoolTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/PoolTable.scala new file mode 100644 index 0000000000..b3d3666944 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/ui/jobs/PoolTable.scala @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ui.jobs + +import scala.collection.mutable.HashMap +import scala.collection.mutable.HashSet +import scala.xml.Node + +import org.apache.spark.scheduler.Stage +import org.apache.spark.scheduler.cluster.Schedulable +import org.apache.spark.ui.UIUtils + +/** Table showing list of pools */ +private[spark] class PoolTable(pools: Seq[Schedulable], listener: JobProgressListener) { + + var poolToActiveStages: HashMap[String, HashSet[Stage]] = listener.poolToActiveStages + + def toNodeSeq(): Seq[Node] = { + listener.synchronized { + poolTable(poolRow, pools) + } + } + + private def poolTable(makeRow: (Schedulable, HashMap[String, HashSet[Stage]]) => Seq[Node], + rows: Seq[Schedulable] + ): Seq[Node] = { + <table class="table table-bordered table-striped table-condensed sortable table-fixed"> + <thead> + <th>Pool Name</th> + <th>Minimum Share</th> + <th>Pool Weight</th> + <th>Active Stages</th> + <th>Running Tasks</th> + <th>SchedulingMode</th> + </thead> + <tbody> + {rows.map(r => makeRow(r, poolToActiveStages))} + </tbody> + </table> + } + + private def poolRow(p: Schedulable, poolToActiveStages: HashMap[String, HashSet[Stage]]) + : Seq[Node] = { + val activeStages = poolToActiveStages.get(p.name) match { + case Some(stages) => stages.size + case None => 0 + } + <tr> + <td><a href={"%s/stages/pool?poolname=%s".format(UIUtils.prependBaseUri(),p.name)}>{p.name}</a></td> + <td>{p.minShare}</td> + <td>{p.weight}</td> + <td>{activeStages}</td> + <td>{p.runningTasks}</td> + <td>{p.schedulingMode}</td> + </tr> + } +} + diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala new file mode 100644 index 0000000000..a9969ab1c0 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -0,0 +1,183 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ui.jobs + +import java.util.Date + +import javax.servlet.http.HttpServletRequest + +import scala.xml.Node + +import org.apache.spark.ui.UIUtils._ +import org.apache.spark.ui.Page._ +import org.apache.spark.util.{Utils, Distribution} +import org.apache.spark.{ExceptionFailure} +import org.apache.spark.scheduler.cluster.TaskInfo +import org.apache.spark.executor.TaskMetrics + +/** Page showing statistics and task list for a given stage */ +private[spark] class StagePage(parent: JobProgressUI) { + def listener = parent.listener + val dateFmt = parent.dateFmt + + def render(request: HttpServletRequest): Seq[Node] = { + listener.synchronized { + val stageId = request.getParameter("id").toInt + val now = System.currentTimeMillis() + + if (!listener.stageToTaskInfos.contains(stageId)) { + val content = + <div> + <h4>Summary Metrics</h4> No tasks have started yet + <h4>Tasks</h4> No tasks have started yet + </div> + return headerSparkPage(content, parent.sc, "Details for Stage %s".format(stageId), Stages) + } + + val tasks = listener.stageToTaskInfos(stageId).toSeq.sortBy(_._1.launchTime) + + val numCompleted = tasks.count(_._1.finished) + val shuffleReadBytes = listener.stageToShuffleRead.getOrElse(stageId, 0L) + val hasShuffleRead = shuffleReadBytes > 0 + val shuffleWriteBytes = listener.stageToShuffleWrite.getOrElse(stageId, 0L) + val hasShuffleWrite = shuffleWriteBytes > 0 + + var activeTime = 0L + listener.stageToTasksActive(stageId).foreach(activeTime += _.timeRunning(now)) + + val summary = + <div> + <ul class="unstyled"> + <li> + <strong>CPU time: </strong> + {parent.formatDuration(listener.stageToTime.getOrElse(stageId, 0L) + activeTime)} + </li> + {if (hasShuffleRead) + <li> + <strong>Shuffle read: </strong> + {Utils.bytesToString(shuffleReadBytes)} + </li> + } + {if (hasShuffleWrite) + <li> + <strong>Shuffle write: </strong> + {Utils.bytesToString(shuffleWriteBytes)} + </li> + } + </ul> + </div> + + val taskHeaders: Seq[String] = + Seq("Task ID", "Status", "Locality Level", "Executor", "Launch Time", "Duration") ++ + Seq("GC Time") ++ + {if (hasShuffleRead) Seq("Shuffle Read") else Nil} ++ + {if (hasShuffleWrite) Seq("Shuffle Write") else Nil} ++ + Seq("Errors") + + val taskTable = listingTable(taskHeaders, taskRow(hasShuffleRead, hasShuffleWrite), tasks) + + // Excludes tasks which failed and have incomplete metrics + val validTasks = tasks.filter(t => t._1.status == "SUCCESS" && (t._2.isDefined)) + + val summaryTable: Option[Seq[Node]] = + if (validTasks.size == 0) { + None + } + else { + val serviceTimes = validTasks.map{case (info, metrics, exception) => + metrics.get.executorRunTime.toDouble} + val serviceQuantiles = "Duration" +: Distribution(serviceTimes).get.getQuantiles().map( + ms => parent.formatDuration(ms.toLong)) + + def getQuantileCols(data: Seq[Double]) = + Distribution(data).get.getQuantiles().map(d => Utils.bytesToString(d.toLong)) + + val shuffleReadSizes = validTasks.map { + case(info, metrics, exception) => + metrics.get.shuffleReadMetrics.map(_.remoteBytesRead).getOrElse(0L).toDouble + } + val shuffleReadQuantiles = "Shuffle Read (Remote)" +: getQuantileCols(shuffleReadSizes) + + val shuffleWriteSizes = validTasks.map { + case(info, metrics, exception) => + metrics.get.shuffleWriteMetrics.map(_.shuffleBytesWritten).getOrElse(0L).toDouble + } + val shuffleWriteQuantiles = "Shuffle Write" +: getQuantileCols(shuffleWriteSizes) + + val listings: Seq[Seq[String]] = Seq(serviceQuantiles, + if (hasShuffleRead) shuffleReadQuantiles else Nil, + if (hasShuffleWrite) shuffleWriteQuantiles else Nil) + + val quantileHeaders = Seq("Metric", "Min", "25th percentile", + "Median", "75th percentile", "Max") + def quantileRow(data: Seq[String]): Seq[Node] = <tr> {data.map(d => <td>{d}</td>)} </tr> + Some(listingTable(quantileHeaders, quantileRow, listings, fixedWidth = true)) + } + + val content = + summary ++ + <h4>Summary Metrics for {numCompleted} Completed Tasks</h4> ++ + <div>{summaryTable.getOrElse("No tasks have reported metrics yet.")}</div> ++ + <h4>Tasks</h4> ++ taskTable; + + headerSparkPage(content, parent.sc, "Details for Stage %d".format(stageId), Stages) + } + } + + + def taskRow(shuffleRead: Boolean, shuffleWrite: Boolean) + (taskData: (TaskInfo, Option[TaskMetrics], Option[ExceptionFailure])): Seq[Node] = { + def fmtStackTrace(trace: Seq[StackTraceElement]): Seq[Node] = + trace.map(e => <span style="display:block;">{e.toString}</span>) + val (info, metrics, exception) = taskData + + val duration = if (info.status == "RUNNING") info.timeRunning(System.currentTimeMillis()) + else metrics.map(m => m.executorRunTime).getOrElse(1) + val formatDuration = if (info.status == "RUNNING") parent.formatDuration(duration) + else metrics.map(m => parent.formatDuration(m.executorRunTime)).getOrElse("") + val gcTime = metrics.map(m => m.jvmGCTime).getOrElse(0L) + + <tr> + <td>{info.taskId}</td> + <td>{info.status}</td> + <td>{info.taskLocality}</td> + <td>{info.host}</td> + <td>{dateFmt.format(new Date(info.launchTime))}</td> + <td sorttable_customkey={duration.toString}> + {formatDuration} + </td> + <td sorttable_customkey={gcTime.toString}> + {if (gcTime > 0) parent.formatDuration(gcTime) else ""} + </td> + {if (shuffleRead) { + <td>{metrics.flatMap{m => m.shuffleReadMetrics}.map{s => + Utils.bytesToString(s.remoteBytesRead)}.getOrElse("")}</td> + }} + {if (shuffleWrite) { + <td>{metrics.flatMap{m => m.shuffleWriteMetrics}.map{s => + Utils.bytesToString(s.shuffleBytesWritten)}.getOrElse("")}</td> + }} + <td>{exception.map(e => + <span> + {e.className} ({e.description})<br/> + {fmtStackTrace(e.stackTrace)} + </span>).getOrElse("")} + </td> + </tr> + } +} diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala new file mode 100644 index 0000000000..32776eaa25 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala @@ -0,0 +1,127 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ui.jobs + +import java.util.Date + +import scala.xml.Node +import scala.collection.mutable.HashSet + +import org.apache.spark.scheduler.cluster.{SchedulingMode, TaskInfo} +import org.apache.spark.scheduler.Stage +import org.apache.spark.ui.UIUtils +import org.apache.spark.util.Utils + + +/** Page showing list of all ongoing and recently finished stages */ +private[spark] class StageTable(val stages: Seq[Stage], val parent: JobProgressUI) { + + val listener = parent.listener + val dateFmt = parent.dateFmt + val isFairScheduler = listener.sc.getSchedulingMode == SchedulingMode.FAIR + + def toNodeSeq(): Seq[Node] = { + listener.synchronized { + stageTable(stageRow, stages) + } + } + + /** Special table which merges two header cells. */ + private def stageTable[T](makeRow: T => Seq[Node], rows: Seq[T]): Seq[Node] = { + <table class="table table-bordered table-striped table-condensed sortable"> + <thead> + <th>Stage Id</th> + {if (isFairScheduler) {<th>Pool Name</th>} else {}} + <th>Description</th> + <th>Submitted</th> + <th>Duration</th> + <th>Tasks: Succeeded/Total</th> + <th>Shuffle Read</th> + <th>Shuffle Write</th> + </thead> + <tbody> + {rows.map(r => makeRow(r))} + </tbody> + </table> + } + + private def makeProgressBar(started: Int, completed: Int, failed: String, total: Int): Seq[Node] = { + val completeWidth = "width: %s%%".format((completed.toDouble/total)*100) + val startWidth = "width: %s%%".format((started.toDouble/total)*100) + + <div class="progress"> + <span style="text-align:center; position:absolute; width:100%;"> + {completed}/{total} {failed} + </span> + <div class="bar bar-completed" style={completeWidth}></div> + <div class="bar bar-running" style={startWidth}></div> + </div> + } + + + private def stageRow(s: Stage): Seq[Node] = { + val submissionTime = s.submissionTime match { + case Some(t) => dateFmt.format(new Date(t)) + case None => "Unknown" + } + + val shuffleRead = listener.stageToShuffleRead.getOrElse(s.id, 0L) match { + case 0 => "" + case b => Utils.bytesToString(b) + } + val shuffleWrite = listener.stageToShuffleWrite.getOrElse(s.id, 0L) match { + case 0 => "" + case b => Utils.bytesToString(b) + } + + val startedTasks = listener.stageToTasksActive.getOrElse(s.id, HashSet[TaskInfo]()).size + val completedTasks = listener.stageToTasksComplete.getOrElse(s.id, 0) + val failedTasks = listener.stageToTasksFailed.getOrElse(s.id, 0) match { + case f if f > 0 => "(%s failed)".format(f) + case _ => "" + } + val totalTasks = s.numPartitions + + val poolName = listener.stageToPool.get(s) + + val nameLink = + <a href={"%s/stages/stage?id=%s".format(UIUtils.prependBaseUri(),s.id)}>{s.name}</a> + val description = listener.stageToDescription.get(s) + .map(d => <div><em>{d}</em></div><div>{nameLink}</div>).getOrElse(nameLink) + val finishTime = s.completionTime.getOrElse(System.currentTimeMillis()) + val duration = s.submissionTime.map(t => finishTime - t) + + <tr> + <td>{s.id}</td> + {if (isFairScheduler) { + <td><a href={"%s/stages/pool?poolname=%s".format(UIUtils.prependBaseUri(),poolName.get)}> + {poolName.get}</a></td>} + } + <td>{description}</td> + <td valign="middle">{submissionTime}</td> + <td sorttable_customkey={duration.getOrElse(-1).toString}> + {duration.map(d => parent.formatDuration(d)).getOrElse("Unknown")} + </td> + <td class="progress-cell"> + {makeProgressBar(startedTasks, completedTasks, failedTasks, totalTasks)} + </td> + <td>{shuffleRead}</td> + <td>{shuffleWrite}</td> + </tr> + } +} diff --git a/core/src/main/scala/org/apache/spark/ui/storage/BlockManagerUI.scala b/core/src/main/scala/org/apache/spark/ui/storage/BlockManagerUI.scala new file mode 100644 index 0000000000..a5446b3fc3 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/ui/storage/BlockManagerUI.scala @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ui.storage + +import scala.concurrent.duration._ + +import javax.servlet.http.HttpServletRequest + +import org.eclipse.jetty.server.Handler + +import org.apache.spark.{Logging, SparkContext} +import org.apache.spark.ui.JettyUtils._ + +/** Web UI showing storage status of all RDD's in the given SparkContext. */ +private[spark] class BlockManagerUI(val sc: SparkContext) extends Logging { + implicit val timeout = Duration.create( + System.getProperty("spark.akka.askTimeout", "10").toLong, "seconds") + + val indexPage = new IndexPage(this) + val rddPage = new RDDPage(this) + + def getHandlers = Seq[(String, Handler)]( + ("/storage/rdd", (request: HttpServletRequest) => rddPage.render(request)), + ("/storage", (request: HttpServletRequest) => indexPage.render(request)) + ) +} diff --git a/core/src/main/scala/org/apache/spark/ui/storage/IndexPage.scala b/core/src/main/scala/org/apache/spark/ui/storage/IndexPage.scala new file mode 100644 index 0000000000..109a7d4094 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/ui/storage/IndexPage.scala @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ui.storage + +import javax.servlet.http.HttpServletRequest + +import scala.xml.Node + +import org.apache.spark.storage.{RDDInfo, StorageUtils} +import org.apache.spark.ui.UIUtils._ +import org.apache.spark.ui.Page._ +import org.apache.spark.util.Utils + +/** Page showing list of RDD's currently stored in the cluster */ +private[spark] class IndexPage(parent: BlockManagerUI) { + val sc = parent.sc + + def render(request: HttpServletRequest): Seq[Node] = { + val storageStatusList = sc.getExecutorStorageStatus + // Calculate macro-level statistics + + val rddHeaders = Seq( + "RDD Name", + "Storage Level", + "Cached Partitions", + "Fraction Cached", + "Size in Memory", + "Size on Disk") + val rdds = StorageUtils.rddInfoFromStorageStatus(storageStatusList, sc) + val content = listingTable(rddHeaders, rddRow, rdds) + + headerSparkPage(content, parent.sc, "Storage ", Storage) + } + + def rddRow(rdd: RDDInfo): Seq[Node] = { + <tr> + <td> + <a href={"%s/storage/rdd?id=%s".format(prependBaseUri(),rdd.id)}> + {rdd.name} + </a> + </td> + <td>{rdd.storageLevel.description} + </td> + <td>{rdd.numCachedPartitions}</td> + <td>{"%.0f%%".format(rdd.numCachedPartitions * 100.0 / rdd.numPartitions)}</td> + <td>{Utils.bytesToString(rdd.memSize)}</td> + <td>{Utils.bytesToString(rdd.diskSize)}</td> + </tr> + } +} diff --git a/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala b/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala new file mode 100644 index 0000000000..43c1257677 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala @@ -0,0 +1,132 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ui.storage + +import javax.servlet.http.HttpServletRequest + +import scala.xml.Node + +import org.apache.spark.storage.{StorageStatus, StorageUtils} +import org.apache.spark.storage.BlockManagerMasterActor.BlockStatus +import org.apache.spark.ui.UIUtils._ +import org.apache.spark.ui.Page._ +import org.apache.spark.util.Utils + + +/** Page showing storage details for a given RDD */ +private[spark] class RDDPage(parent: BlockManagerUI) { + val sc = parent.sc + + def render(request: HttpServletRequest): Seq[Node] = { + val id = request.getParameter("id") + val prefix = "rdd_" + id.toString + val storageStatusList = sc.getExecutorStorageStatus + val filteredStorageStatusList = StorageUtils. + filterStorageStatusByPrefix(storageStatusList, prefix) + val rddInfo = StorageUtils.rddInfoFromStorageStatus(filteredStorageStatusList, sc).head + + val workerHeaders = Seq("Host", "Memory Usage", "Disk Usage") + val workers = filteredStorageStatusList.map((prefix, _)) + val workerTable = listingTable(workerHeaders, workerRow, workers) + + val blockHeaders = Seq("Block Name", "Storage Level", "Size in Memory", "Size on Disk", + "Executors") + + val blockStatuses = filteredStorageStatusList.flatMap(_.blocks).toArray.sortWith(_._1 < _._1) + val blockLocations = StorageUtils.blockLocationsFromStorageStatus(filteredStorageStatusList) + val blocks = blockStatuses.map { + case(id, status) => (id, status, blockLocations.get(id).getOrElse(Seq("UNKNOWN"))) + } + val blockTable = listingTable(blockHeaders, blockRow, blocks) + + val content = + <div class="row-fluid"> + <div class="span12"> + <ul class="unstyled"> + <li> + <strong>Storage Level:</strong> + {rddInfo.storageLevel.description} + </li> + <li> + <strong>Cached Partitions:</strong> + {rddInfo.numCachedPartitions} + </li> + <li> + <strong>Total Partitions:</strong> + {rddInfo.numPartitions} + </li> + <li> + <strong>Memory Size:</strong> + {Utils.bytesToString(rddInfo.memSize)} + </li> + <li> + <strong>Disk Size:</strong> + {Utils.bytesToString(rddInfo.diskSize)} + </li> + </ul> + </div> + </div> + + <div class="row-fluid"> + <div class="span12"> + <h4> Data Distribution on {workers.size} Executors </h4> + {workerTable} + </div> + </div> + + <div class="row-fluid"> + <div class="span12"> + <h4> {blocks.size} Partitions </h4> + {blockTable} + </div> + </div>; + + headerSparkPage(content, parent.sc, "RDD Storage Info for " + rddInfo.name, Storage) + } + + def blockRow(row: (String, BlockStatus, Seq[String])): Seq[Node] = { + val (id, block, locations) = row + <tr> + <td>{id}</td> + <td> + {block.storageLevel.description} + </td> + <td sorttable_customkey={block.memSize.toString}> + {Utils.bytesToString(block.memSize)} + </td> + <td sorttable_customkey={block.diskSize.toString}> + {Utils.bytesToString(block.diskSize)} + </td> + <td> + {locations.map(l => <span>{l}<br/></span>)} + </td> + </tr> + } + + def workerRow(worker: (String, StorageStatus)): Seq[Node] = { + val (prefix, status) = worker + <tr> + <td>{status.blockManagerId.host + ":" + status.blockManagerId.port}</td> + <td> + {Utils.bytesToString(status.memUsed(prefix))} + ({Utils.bytesToString(status.memRemaining)} Remaining) + </td> + <td>{Utils.bytesToString(status.diskUsed(prefix))}</td> + </tr> + } +} diff --git a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala new file mode 100644 index 0000000000..e674d120ea --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +import akka.actor.{ActorSystem, ExtendedActorSystem} +import com.typesafe.config.ConfigFactory +import scala.concurrent.duration._ +import scala.concurrent.Await +import akka.remote.RemoteActorRefProvider + +/** + * Various utility classes for working with Akka. + */ +private[spark] object AkkaUtils { + + /** + * Creates an ActorSystem ready for remoting, with various Spark features. Returns both the + * ActorSystem itself and its port (which is hard to get from Akka). + * + * Note: the `name` parameter is important, as even if a client sends a message to right + * host + port, if the system name is incorrect, Akka will drop the message. + */ + def createActorSystem(name: String, host: String, port: Int): (ActorSystem, Int) = { + val akkaThreads = System.getProperty("spark.akka.threads", "4").toInt + val akkaBatchSize = System.getProperty("spark.akka.batchSize", "15").toInt + + val akkaTimeout = System.getProperty("spark.akka.timeout", "60").toInt + + val akkaFrameSize = System.getProperty("spark.akka.frameSize", "10").toInt + val lifecycleEvents = if (System.getProperty("spark.akka.logLifecycleEvents", "false").toBoolean) "on" else "off" + // 10 seconds is the default akka timeout, but in a cluster, we need higher by default. + val akkaWriteTimeout = System.getProperty("spark.akka.writeTimeout", "30").toInt + + val akkaConf = ConfigFactory.parseString(""" + akka.daemonic = on + akka.event-handlers = ["akka.event.slf4j.Slf4jEventHandler"] + akka.stdout-loglevel = "ERROR" + akka.actor.provider = "akka.remote.RemoteActorRefProvider" + akka.remote.transport = "akka.remote.netty.NettyRemoteTransport" + akka.remote.netty.hostname = "%s" + akka.remote.netty.port = %d + akka.remote.netty.connection-timeout = %ds + akka.remote.netty.message-frame-size = %d MiB + akka.remote.netty.execution-pool-size = %d + akka.actor.default-dispatcher.throughput = %d + akka.remote.log-remote-lifecycle-events = %s + akka.remote.netty.write-timeout = %ds + """.format(host, port, akkaTimeout, akkaFrameSize, akkaThreads, akkaBatchSize, + lifecycleEvents, akkaWriteTimeout)) + + val actorSystem = ActorSystem(name, akkaConf) + + // Figure out the port number we bound to, in case port was passed as 0. This is a bit of a + // hack because Akka doesn't let you figure out the port through the public API yet. + val provider = actorSystem.asInstanceOf[ExtendedActorSystem].provider + val boundPort = provider.asInstanceOf[RemoteActorRefProvider].transport.address.port.get + return (actorSystem, boundPort) + } + +} diff --git a/core/src/main/scala/org/apache/spark/util/BoundedPriorityQueue.scala b/core/src/main/scala/org/apache/spark/util/BoundedPriorityQueue.scala new file mode 100644 index 0000000000..0b51c23f7b --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/BoundedPriorityQueue.scala @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +import java.io.Serializable +import java.util.{PriorityQueue => JPriorityQueue} +import scala.collection.generic.Growable +import scala.collection.JavaConverters._ + +/** + * Bounded priority queue. This class wraps the original PriorityQueue + * class and modifies it such that only the top K elements are retained. + * The top K elements are defined by an implicit Ordering[A]. + */ +class BoundedPriorityQueue[A](maxSize: Int)(implicit ord: Ordering[A]) + extends Iterable[A] with Growable[A] with Serializable { + + private val underlying = new JPriorityQueue[A](maxSize, ord) + + override def iterator: Iterator[A] = underlying.iterator.asScala + + override def ++=(xs: TraversableOnce[A]): this.type = { + xs.foreach { this += _ } + this + } + + override def +=(elem: A): this.type = { + if (size < maxSize) underlying.offer(elem) + else maybeReplaceLowest(elem) + this + } + + override def +=(elem1: A, elem2: A, elems: A*): this.type = { + this += elem1 += elem2 ++= elems + } + + override def clear() { underlying.clear() } + + private def maybeReplaceLowest(a: A): Boolean = { + val head = underlying.peek() + if (head != null && ord.gt(a, head)) { + underlying.poll() + underlying.offer(a) + } else false + } +} + diff --git a/core/src/main/scala/org/apache/spark/util/ByteBufferInputStream.scala b/core/src/main/scala/org/apache/spark/util/ByteBufferInputStream.scala new file mode 100644 index 0000000000..e214d2a519 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/ByteBufferInputStream.scala @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +import java.io.InputStream +import java.nio.ByteBuffer +import org.apache.spark.storage.BlockManager + +/** + * Reads data from a ByteBuffer, and optionally cleans it up using BlockManager.dispose() + * at the end of the stream (e.g. to close a memory-mapped file). + */ +private[spark] +class ByteBufferInputStream(private var buffer: ByteBuffer, dispose: Boolean = false) + extends InputStream { + + override def read(): Int = { + if (buffer == null || buffer.remaining() == 0) { + cleanUp() + -1 + } else { + buffer.get() & 0xFF + } + } + + override def read(dest: Array[Byte]): Int = { + read(dest, 0, dest.length) + } + + override def read(dest: Array[Byte], offset: Int, length: Int): Int = { + if (buffer == null || buffer.remaining() == 0) { + cleanUp() + -1 + } else { + val amountToGet = math.min(buffer.remaining(), length) + buffer.get(dest, offset, amountToGet) + amountToGet + } + } + + override def skip(bytes: Long): Long = { + if (buffer != null) { + val amountToSkip = math.min(bytes, buffer.remaining).toInt + buffer.position(buffer.position + amountToSkip) + if (buffer.remaining() == 0) { + cleanUp() + } + amountToSkip + } else { + 0L + } + } + + /** + * Clean up the buffer, and potentially dispose of it using BlockManager.dispose(). + */ + private def cleanUp() { + if (buffer != null) { + if (dispose) { + BlockManager.dispose(buffer) + } + buffer = null + } + } +} diff --git a/core/src/main/scala/org/apache/spark/util/Clock.scala b/core/src/main/scala/org/apache/spark/util/Clock.scala new file mode 100644 index 0000000000..97c2b45aab --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/Clock.scala @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +/** + * An interface to represent clocks, so that they can be mocked out in unit tests. + */ +private[spark] trait Clock { + def getTime(): Long +} + +private[spark] object SystemClock extends Clock { + def getTime(): Long = System.currentTimeMillis() +} diff --git a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala new file mode 100644 index 0000000000..7108595e3e --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala @@ -0,0 +1,232 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +import java.lang.reflect.Field + +import scala.collection.mutable.Map +import scala.collection.mutable.Set + +import org.objectweb.asm.{ClassReader, ClassVisitor, MethodVisitor, Type} +import org.objectweb.asm.Opcodes._ +import java.io.{InputStream, IOException, ByteArrayOutputStream, ByteArrayInputStream, BufferedInputStream} +import org.apache.spark.Logging + +private[spark] object ClosureCleaner extends Logging { + // Get an ASM class reader for a given class from the JAR that loaded it + private def getClassReader(cls: Class[_]): ClassReader = { + // Copy data over, before delegating to ClassReader - else we can run out of open file handles. + val className = cls.getName.replaceFirst("^.*\\.", "") + ".class" + val resourceStream = cls.getResourceAsStream(className) + // todo: Fixme - continuing with earlier behavior ... + if (resourceStream == null) return new ClassReader(resourceStream) + + val baos = new ByteArrayOutputStream(128) + Utils.copyStream(resourceStream, baos, true) + new ClassReader(new ByteArrayInputStream(baos.toByteArray)) + } + + // Check whether a class represents a Scala closure + private def isClosure(cls: Class[_]): Boolean = { + cls.getName.contains("$anonfun$") + } + + // Get a list of the classes of the outer objects of a given closure object, obj; + // the outer objects are defined as any closures that obj is nested within, plus + // possibly the class that the outermost closure is in, if any. We stop searching + // for outer objects beyond that because cloning the user's object is probably + // not a good idea (whereas we can clone closure objects just fine since we + // understand how all their fields are used). + private def getOuterClasses(obj: AnyRef): List[Class[_]] = { + for (f <- obj.getClass.getDeclaredFields if f.getName == "$outer") { + f.setAccessible(true) + if (isClosure(f.getType)) { + return f.getType :: getOuterClasses(f.get(obj)) + } else { + return f.getType :: Nil // Stop at the first $outer that is not a closure + } + } + return Nil + } + + // Get a list of the outer objects for a given closure object. + private def getOuterObjects(obj: AnyRef): List[AnyRef] = { + for (f <- obj.getClass.getDeclaredFields if f.getName == "$outer") { + f.setAccessible(true) + if (isClosure(f.getType)) { + return f.get(obj) :: getOuterObjects(f.get(obj)) + } else { + return f.get(obj) :: Nil // Stop at the first $outer that is not a closure + } + } + return Nil + } + + private def getInnerClasses(obj: AnyRef): List[Class[_]] = { + val seen = Set[Class[_]](obj.getClass) + var stack = List[Class[_]](obj.getClass) + while (!stack.isEmpty) { + val cr = getClassReader(stack.head) + stack = stack.tail + val set = Set[Class[_]]() + cr.accept(new InnerClosureFinder(set), 0) + for (cls <- set -- seen) { + seen += cls + stack = cls :: stack + } + } + return (seen - obj.getClass).toList + } + + private def createNullValue(cls: Class[_]): AnyRef = { + if (cls.isPrimitive) { + new java.lang.Byte(0: Byte) // Should be convertible to any primitive type + } else { + null + } + } + + def clean(func: AnyRef) { + // TODO: cache outerClasses / innerClasses / accessedFields + val outerClasses = getOuterClasses(func) + val innerClasses = getInnerClasses(func) + val outerObjects = getOuterObjects(func) + + val accessedFields = Map[Class[_], Set[String]]() + for (cls <- outerClasses) + accessedFields(cls) = Set[String]() + for (cls <- func.getClass :: innerClasses) + getClassReader(cls).accept(new FieldAccessFinder(accessedFields), 0) + //logInfo("accessedFields: " + accessedFields) + + val inInterpreter = { + try { + val interpClass = Class.forName("spark.repl.Main") + interpClass.getMethod("interp").invoke(null) != null + } catch { + case _: ClassNotFoundException => true + } + } + + var outerPairs: List[(Class[_], AnyRef)] = (outerClasses zip outerObjects).reverse + var outer: AnyRef = null + if (outerPairs.size > 0 && !isClosure(outerPairs.head._1)) { + // The closure is ultimately nested inside a class; keep the object of that + // class without cloning it since we don't want to clone the user's objects. + outer = outerPairs.head._2 + outerPairs = outerPairs.tail + } + // Clone the closure objects themselves, nulling out any fields that are not + // used in the closure we're working on or any of its inner closures. + for ((cls, obj) <- outerPairs) { + outer = instantiateClass(cls, outer, inInterpreter) + for (fieldName <- accessedFields(cls)) { + val field = cls.getDeclaredField(fieldName) + field.setAccessible(true) + val value = field.get(obj) + //logInfo("1: Setting " + fieldName + " on " + cls + " to " + value); + field.set(outer, value) + } + } + + if (outer != null) { + //logInfo("2: Setting $outer on " + func.getClass + " to " + outer); + val field = func.getClass.getDeclaredField("$outer") + field.setAccessible(true) + field.set(func, outer) + } + } + + private def instantiateClass(cls: Class[_], outer: AnyRef, inInterpreter: Boolean): AnyRef = { + //logInfo("Creating a " + cls + " with outer = " + outer) + if (!inInterpreter) { + // This is a bona fide closure class, whose constructor has no effects + // other than to set its fields, so use its constructor + val cons = cls.getConstructors()(0) + val params = cons.getParameterTypes.map(createNullValue).toArray + if (outer != null) + params(0) = outer // First param is always outer object + return cons.newInstance(params: _*).asInstanceOf[AnyRef] + } else { + // Use reflection to instantiate object without calling constructor + val rf = sun.reflect.ReflectionFactory.getReflectionFactory() + val parentCtor = classOf[java.lang.Object].getDeclaredConstructor() + val newCtor = rf.newConstructorForSerialization(cls, parentCtor) + val obj = newCtor.newInstance().asInstanceOf[AnyRef] + if (outer != null) { + //logInfo("3: Setting $outer on " + cls + " to " + outer); + val field = cls.getDeclaredField("$outer") + field.setAccessible(true) + field.set(obj, outer) + } + return obj + } + } +} + +private[spark] class FieldAccessFinder(output: Map[Class[_], Set[String]]) extends ClassVisitor(ASM4) { + override def visitMethod(access: Int, name: String, desc: String, + sig: String, exceptions: Array[String]): MethodVisitor = { + return new MethodVisitor(ASM4) { + override def visitFieldInsn(op: Int, owner: String, name: String, desc: String) { + if (op == GETFIELD) { + for (cl <- output.keys if cl.getName == owner.replace('/', '.')) { + output(cl) += name + } + } + } + + override def visitMethodInsn(op: Int, owner: String, name: String, + desc: String) { + // Check for calls a getter method for a variable in an interpreter wrapper object. + // This means that the corresponding field will be accessed, so we should save it. + if (op == INVOKEVIRTUAL && owner.endsWith("$iwC") && !name.endsWith("$outer")) { + for (cl <- output.keys if cl.getName == owner.replace('/', '.')) { + output(cl) += name + } + } + } + } + } +} + +private[spark] class InnerClosureFinder(output: Set[Class[_]]) extends ClassVisitor(ASM4) { + var myName: String = null + + override def visit(version: Int, access: Int, name: String, sig: String, + superName: String, interfaces: Array[String]) { + myName = name + } + + override def visitMethod(access: Int, name: String, desc: String, + sig: String, exceptions: Array[String]): MethodVisitor = { + return new MethodVisitor(ASM4) { + override def visitMethodInsn(op: Int, owner: String, name: String, + desc: String) { + val argTypes = Type.getArgumentTypes(desc) + if (op == INVOKESPECIAL && name == "<init>" && argTypes.length > 0 + && argTypes(0).toString.startsWith("L") // is it an object? + && argTypes(0).getInternalName == myName) + output += Class.forName( + owner.replace('/', '.'), + false, + Thread.currentThread.getContextClassLoader) + } + } + } +} diff --git a/core/src/main/scala/org/apache/spark/util/CompletionIterator.scala b/core/src/main/scala/org/apache/spark/util/CompletionIterator.scala new file mode 100644 index 0000000000..dc15a38b29 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/CompletionIterator.scala @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +/** + * Wrapper around an iterator which calls a completion method after it successfully iterates through all the elements + */ +abstract class CompletionIterator[+A, +I <: Iterator[A]](sub: I) extends Iterator[A]{ + def next = sub.next + def hasNext = { + val r = sub.hasNext + if (!r) { + completion + } + r + } + + def completion() +} + +object CompletionIterator { + def apply[A, I <: Iterator[A]](sub: I, completionFunction: => Unit) : CompletionIterator[A,I] = { + new CompletionIterator[A,I](sub) { + def completion() = completionFunction + } + } +} diff --git a/core/src/main/scala/org/apache/spark/util/Distribution.scala b/core/src/main/scala/org/apache/spark/util/Distribution.scala new file mode 100644 index 0000000000..33bf3562fe --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/Distribution.scala @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +import java.io.PrintStream + +/** + * Util for getting some stats from a small sample of numeric values, with some handy summary functions. + * + * Entirely in memory, not intended as a good way to compute stats over large data sets. + * + * Assumes you are giving it a non-empty set of data + */ +class Distribution(val data: Array[Double], val startIdx: Int, val endIdx: Int) { + require(startIdx < endIdx) + def this(data: Traversable[Double]) = this(data.toArray, 0, data.size) + java.util.Arrays.sort(data, startIdx, endIdx) + val length = endIdx - startIdx + + val defaultProbabilities = Array(0,0.25,0.5,0.75,1.0) + + /** + * Get the value of the distribution at the given probabilities. Probabilities should be + * given from 0 to 1 + * @param probabilities + */ + def getQuantiles(probabilities: Traversable[Double] = defaultProbabilities) = { + probabilities.toIndexedSeq.map{p:Double => data(closestIndex(p))} + } + + private def closestIndex(p: Double) = { + math.min((p * length).toInt + startIdx, endIdx - 1) + } + + def showQuantiles(out: PrintStream = System.out) = { + out.println("min\t25%\t50%\t75%\tmax") + getQuantiles(defaultProbabilities).foreach{q => out.print(q + "\t")} + out.println + } + + def statCounter = StatCounter(data.slice(startIdx, endIdx)) + + /** + * print a summary of this distribution to the given PrintStream. + * @param out + */ + def summary(out: PrintStream = System.out) { + out.println(statCounter) + showQuantiles(out) + } +} + +object Distribution { + + def apply(data: Traversable[Double]): Option[Distribution] = { + if (data.size > 0) + Some(new Distribution(data)) + else + None + } + + def showQuantiles(out: PrintStream = System.out, quantiles: Traversable[Double]) { + out.println("min\t25%\t50%\t75%\tmax") + quantiles.foreach{q => out.print(q + "\t")} + out.println + } +} diff --git a/core/src/main/scala/org/apache/spark/util/IdGenerator.scala b/core/src/main/scala/org/apache/spark/util/IdGenerator.scala new file mode 100644 index 0000000000..17e55f7996 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/IdGenerator.scala @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +import java.util.concurrent.atomic.AtomicInteger + +/** + * A util used to get a unique generation ID. This is a wrapper around Java's + * AtomicInteger. An example usage is in BlockManager, where each BlockManager + * instance would start an Akka actor and we use this utility to assign the Akka + * actors unique names. + */ +private[spark] class IdGenerator { + private var id = new AtomicInteger + def next: Int = id.incrementAndGet +} diff --git a/core/src/main/scala/org/apache/spark/util/IntParam.scala b/core/src/main/scala/org/apache/spark/util/IntParam.scala new file mode 100644 index 0000000000..626bb49eea --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/IntParam.scala @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +/** + * An extractor object for parsing strings into integers. + */ +private[spark] object IntParam { + def unapply(str: String): Option[Int] = { + try { + Some(str.toInt) + } catch { + case e: NumberFormatException => None + } + } +} diff --git a/core/src/main/scala/org/apache/spark/util/MemoryParam.scala b/core/src/main/scala/org/apache/spark/util/MemoryParam.scala new file mode 100644 index 0000000000..4869c9897a --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/MemoryParam.scala @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +/** + * An extractor object for parsing JVM memory strings, such as "10g", into an Int representing + * the number of megabytes. Supports the same formats as Utils.memoryStringToMb. + */ +private[spark] object MemoryParam { + def unapply(str: String): Option[Int] = { + try { + Some(Utils.memoryStringToMb(str)) + } catch { + case e: NumberFormatException => None + } + } +} diff --git a/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala b/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala new file mode 100644 index 0000000000..a430a75451 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +import java.util.concurrent.{TimeUnit, ScheduledFuture, Executors} +import java.util.{TimerTask, Timer} +import org.apache.spark.Logging + + +/** + * Runs a timer task to periodically clean up metadata (e.g. old files or hashtable entries) + */ +class MetadataCleaner(name: String, cleanupFunc: (Long) => Unit) extends Logging { + private val delaySeconds = MetadataCleaner.getDelaySeconds + private val periodSeconds = math.max(10, delaySeconds / 10) + private val timer = new Timer(name + " cleanup timer", true) + + private val task = new TimerTask { + override def run() { + try { + cleanupFunc(System.currentTimeMillis() - (delaySeconds * 1000)) + logInfo("Ran metadata cleaner for " + name) + } catch { + case e: Exception => logError("Error running cleanup task for " + name, e) + } + } + } + + if (delaySeconds > 0) { + logDebug( + "Starting metadata cleaner for " + name + " with delay of " + delaySeconds + " seconds " + + "and period of " + periodSeconds + " secs") + timer.schedule(task, periodSeconds * 1000, periodSeconds * 1000) + } + + def cancel() { + timer.cancel() + } +} + + +object MetadataCleaner { + def getDelaySeconds = System.getProperty("spark.cleaner.ttl", "-1").toInt + def setDelaySeconds(delay: Int) { System.setProperty("spark.cleaner.ttl", delay.toString) } +} + diff --git a/core/src/main/scala/org/apache/spark/util/MutablePair.scala b/core/src/main/scala/org/apache/spark/util/MutablePair.scala new file mode 100644 index 0000000000..34f1f6606f --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/MutablePair.scala @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + + +/** + * A tuple of 2 elements. This can be used as an alternative to Scala's Tuple2 when we want to + * minimize object allocation. + * + * @param _1 Element 1 of this MutablePair + * @param _2 Element 2 of this MutablePair + */ +case class MutablePair[@specialized(Int, Long, Double, Char, Boolean/*, AnyRef*/) T1, + @specialized(Int, Long, Double, Char, Boolean/*, AnyRef*/) T2] + (var _1: T1, var _2: T2) + extends Product2[T1, T2] +{ + override def toString = "(" + _1 + "," + _2 + ")" + + override def canEqual(that: Any): Boolean = that.isInstanceOf[MutablePair[_,_]] +} diff --git a/core/src/main/scala/org/apache/spark/util/NextIterator.scala b/core/src/main/scala/org/apache/spark/util/NextIterator.scala new file mode 100644 index 0000000000..8266e5e495 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/NextIterator.scala @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +/** Provides a basic/boilerplate Iterator implementation. */ +private[spark] abstract class NextIterator[U] extends Iterator[U] { + + private var gotNext = false + private var nextValue: U = _ + private var closed = false + protected var finished = false + + /** + * Method for subclasses to implement to provide the next element. + * + * If no next element is available, the subclass should set `finished` + * to `true` and may return any value (it will be ignored). + * + * This convention is required because `null` may be a valid value, + * and using `Option` seems like it might create unnecessary Some/None + * instances, given some iterators might be called in a tight loop. + * + * @return U, or set 'finished' when done + */ + protected def getNext(): U + + /** + * Method for subclasses to implement when all elements have been successfully + * iterated, and the iteration is done. + * + * <b>Note:</b> `NextIterator` cannot guarantee that `close` will be + * called because it has no control over what happens when an exception + * happens in the user code that is calling hasNext/next. + * + * Ideally you should have another try/catch, as in HadoopRDD, that + * ensures any resources are closed should iteration fail. + */ + protected def close() + + /** + * Calls the subclass-defined close method, but only once. + * + * Usually calling `close` multiple times should be fine, but historically + * there have been issues with some InputFormats throwing exceptions. + */ + def closeIfNeeded() { + if (!closed) { + close() + closed = true + } + } + + override def hasNext: Boolean = { + if (!finished) { + if (!gotNext) { + nextValue = getNext() + if (finished) { + closeIfNeeded() + } + gotNext = true + } + } + !finished + } + + override def next(): U = { + if (!hasNext) { + throw new NoSuchElementException("End of stream") + } + gotNext = false + nextValue + } +} diff --git a/core/src/main/scala/org/apache/spark/util/RateLimitedOutputStream.scala b/core/src/main/scala/org/apache/spark/util/RateLimitedOutputStream.scala new file mode 100644 index 0000000000..47e1b45004 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/RateLimitedOutputStream.scala @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +import scala.annotation.tailrec + +import java.io.OutputStream +import java.util.concurrent.TimeUnit._ + +class RateLimitedOutputStream(out: OutputStream, bytesPerSec: Int) extends OutputStream { + val SYNC_INTERVAL = NANOSECONDS.convert(10, SECONDS) + val CHUNK_SIZE = 8192 + var lastSyncTime = System.nanoTime + var bytesWrittenSinceSync: Long = 0 + + override def write(b: Int) { + waitToWrite(1) + out.write(b) + } + + override def write(bytes: Array[Byte]) { + write(bytes, 0, bytes.length) + } + + @tailrec + override final def write(bytes: Array[Byte], offset: Int, length: Int) { + val writeSize = math.min(length - offset, CHUNK_SIZE) + if (writeSize > 0) { + waitToWrite(writeSize) + out.write(bytes, offset, writeSize) + write(bytes, offset + writeSize, length) + } + } + + override def flush() { + out.flush() + } + + override def close() { + out.close() + } + + @tailrec + private def waitToWrite(numBytes: Int) { + val now = System.nanoTime + val elapsedSecs = SECONDS.convert(math.max(now - lastSyncTime, 1), NANOSECONDS) + val rate = bytesWrittenSinceSync.toDouble / elapsedSecs + if (rate < bytesPerSec) { + // It's okay to write; just update some variables and return + bytesWrittenSinceSync += numBytes + if (now > lastSyncTime + SYNC_INTERVAL) { + // Sync interval has passed; let's resync + lastSyncTime = now + bytesWrittenSinceSync = numBytes + } + } else { + // Calculate how much time we should sleep to bring ourselves to the desired rate. + // Based on throttler in Kafka (https://github.com/kafka-dev/kafka/blob/master/core/src/main/scala/kafka/utils/Throttler.scala) + val sleepTime = MILLISECONDS.convert((bytesWrittenSinceSync / bytesPerSec - elapsedSecs), SECONDS) + if (sleepTime > 0) Thread.sleep(sleepTime) + waitToWrite(numBytes) + } + } +} diff --git a/core/src/main/scala/org/apache/spark/util/SerializableBuffer.scala b/core/src/main/scala/org/apache/spark/util/SerializableBuffer.scala new file mode 100644 index 0000000000..f2b1ad7d0e --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/SerializableBuffer.scala @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +import java.nio.ByteBuffer +import java.io.{IOException, ObjectOutputStream, EOFException, ObjectInputStream} +import java.nio.channels.Channels + +/** + * A wrapper around a java.nio.ByteBuffer that is serializable through Java serialization, to make + * it easier to pass ByteBuffers in case class messages. + */ +private[spark] +class SerializableBuffer(@transient var buffer: ByteBuffer) extends Serializable { + def value = buffer + + private def readObject(in: ObjectInputStream) { + val length = in.readInt() + buffer = ByteBuffer.allocate(length) + var amountRead = 0 + val channel = Channels.newChannel(in) + while (amountRead < length) { + val ret = channel.read(buffer) + if (ret == -1) { + throw new EOFException("End of file before fully reading buffer") + } + amountRead += ret + } + buffer.rewind() // Allow us to read it later + } + + private def writeObject(out: ObjectOutputStream) { + out.writeInt(buffer.limit()) + if (Channels.newChannel(out).write(buffer) != buffer.limit()) { + throw new IOException("Could not fully write buffer to output stream") + } + buffer.rewind() // Allow us to write it again later + } +} diff --git a/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala b/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala new file mode 100644 index 0000000000..a25b37a2a9 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala @@ -0,0 +1,284 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +import java.lang.reflect.Field +import java.lang.reflect.Modifier +import java.lang.reflect.{Array => JArray} +import java.util.IdentityHashMap +import java.util.concurrent.ConcurrentHashMap +import java.util.Random + +import javax.management.MBeanServer +import java.lang.management.ManagementFactory + +import scala.collection.mutable.ArrayBuffer + +import it.unimi.dsi.fastutil.ints.IntOpenHashSet +import org.apache.spark.Logging + +/** + * Estimates the sizes of Java objects (number of bytes of memory they occupy), for use in + * memory-aware caches. + * + * Based on the following JavaWorld article: + * http://www.javaworld.com/javaworld/javaqa/2003-12/02-qa-1226-sizeof.html + */ +private[spark] object SizeEstimator extends Logging { + + // Sizes of primitive types + private val BYTE_SIZE = 1 + private val BOOLEAN_SIZE = 1 + private val CHAR_SIZE = 2 + private val SHORT_SIZE = 2 + private val INT_SIZE = 4 + private val LONG_SIZE = 8 + private val FLOAT_SIZE = 4 + private val DOUBLE_SIZE = 8 + + // Alignment boundary for objects + // TODO: Is this arch dependent ? + private val ALIGN_SIZE = 8 + + // A cache of ClassInfo objects for each class + private val classInfos = new ConcurrentHashMap[Class[_], ClassInfo] + + // Object and pointer sizes are arch dependent + private var is64bit = false + + // Size of an object reference + // Based on https://wikis.oracle.com/display/HotSpotInternals/CompressedOops + private var isCompressedOops = false + private var pointerSize = 4 + + // Minimum size of a java.lang.Object + private var objectSize = 8 + + initialize() + + // Sets object size, pointer size based on architecture and CompressedOops settings + // from the JVM. + private def initialize() { + is64bit = System.getProperty("os.arch").contains("64") + isCompressedOops = getIsCompressedOops + + objectSize = if (!is64bit) 8 else { + if(!isCompressedOops) { + 16 + } else { + 12 + } + } + pointerSize = if (is64bit && !isCompressedOops) 8 else 4 + classInfos.clear() + classInfos.put(classOf[Object], new ClassInfo(objectSize, Nil)) + } + + private def getIsCompressedOops : Boolean = { + if (System.getProperty("spark.test.useCompressedOops") != null) { + return System.getProperty("spark.test.useCompressedOops").toBoolean + } + + try { + val hotSpotMBeanName = "com.sun.management:type=HotSpotDiagnostic" + val server = ManagementFactory.getPlatformMBeanServer() + + // NOTE: This should throw an exception in non-Sun JVMs + val hotSpotMBeanClass = Class.forName("com.sun.management.HotSpotDiagnosticMXBean") + val getVMMethod = hotSpotMBeanClass.getDeclaredMethod("getVMOption", + Class.forName("java.lang.String")) + + val bean = ManagementFactory.newPlatformMXBeanProxy(server, + hotSpotMBeanName, hotSpotMBeanClass) + // TODO: We could use reflection on the VMOption returned ? + return getVMMethod.invoke(bean, "UseCompressedOops").toString.contains("true") + } catch { + case e: Exception => { + // Guess whether they've enabled UseCompressedOops based on whether maxMemory < 32 GB + val guess = Runtime.getRuntime.maxMemory < (32L*1024*1024*1024) + val guessInWords = if (guess) "yes" else "not" + logWarning("Failed to check whether UseCompressedOops is set; assuming " + guessInWords) + return guess + } + } + } + + /** + * The state of an ongoing size estimation. Contains a stack of objects to visit as well as an + * IdentityHashMap of visited objects, and provides utility methods for enqueueing new objects + * to visit. + */ + private class SearchState(val visited: IdentityHashMap[AnyRef, AnyRef]) { + val stack = new ArrayBuffer[AnyRef] + var size = 0L + + def enqueue(obj: AnyRef) { + if (obj != null && !visited.containsKey(obj)) { + visited.put(obj, null) + stack += obj + } + } + + def isFinished(): Boolean = stack.isEmpty + + def dequeue(): AnyRef = { + val elem = stack.last + stack.trimEnd(1) + return elem + } + } + + /** + * Cached information about each class. We remember two things: the "shell size" of the class + * (size of all non-static fields plus the java.lang.Object size), and any fields that are + * pointers to objects. + */ + private class ClassInfo( + val shellSize: Long, + val pointerFields: List[Field]) {} + + def estimate(obj: AnyRef): Long = estimate(obj, new IdentityHashMap[AnyRef, AnyRef]) + + private def estimate(obj: AnyRef, visited: IdentityHashMap[AnyRef, AnyRef]): Long = { + val state = new SearchState(visited) + state.enqueue(obj) + while (!state.isFinished) { + visitSingleObject(state.dequeue(), state) + } + return state.size + } + + private def visitSingleObject(obj: AnyRef, state: SearchState) { + val cls = obj.getClass + if (cls.isArray) { + visitArray(obj, cls, state) + } else if (obj.isInstanceOf[ClassLoader] || obj.isInstanceOf[Class[_]]) { + // Hadoop JobConfs created in the interpreter have a ClassLoader, which greatly confuses + // the size estimator since it references the whole REPL. Do nothing in this case. In + // general all ClassLoaders and Classes will be shared between objects anyway. + } else { + val classInfo = getClassInfo(cls) + state.size += classInfo.shellSize + for (field <- classInfo.pointerFields) { + state.enqueue(field.get(obj)) + } + } + } + + // Estimat the size of arrays larger than ARRAY_SIZE_FOR_SAMPLING by sampling. + private val ARRAY_SIZE_FOR_SAMPLING = 200 + private val ARRAY_SAMPLE_SIZE = 100 // should be lower than ARRAY_SIZE_FOR_SAMPLING + + private def visitArray(array: AnyRef, cls: Class[_], state: SearchState) { + val length = JArray.getLength(array) + val elementClass = cls.getComponentType + + // Arrays have object header and length field which is an integer + var arrSize: Long = alignSize(objectSize + INT_SIZE) + + if (elementClass.isPrimitive) { + arrSize += alignSize(length * primitiveSize(elementClass)) + state.size += arrSize + } else { + arrSize += alignSize(length * pointerSize) + state.size += arrSize + + if (length <= ARRAY_SIZE_FOR_SAMPLING) { + for (i <- 0 until length) { + state.enqueue(JArray.get(array, i)) + } + } else { + // Estimate the size of a large array by sampling elements without replacement. + var size = 0.0 + val rand = new Random(42) + val drawn = new IntOpenHashSet(ARRAY_SAMPLE_SIZE) + for (i <- 0 until ARRAY_SAMPLE_SIZE) { + var index = 0 + do { + index = rand.nextInt(length) + } while (drawn.contains(index)) + drawn.add(index) + val elem = JArray.get(array, index) + size += SizeEstimator.estimate(elem, state.visited) + } + state.size += ((length / (ARRAY_SAMPLE_SIZE * 1.0)) * size).toLong + } + } + } + + private def primitiveSize(cls: Class[_]): Long = { + if (cls == classOf[Byte]) + BYTE_SIZE + else if (cls == classOf[Boolean]) + BOOLEAN_SIZE + else if (cls == classOf[Char]) + CHAR_SIZE + else if (cls == classOf[Short]) + SHORT_SIZE + else if (cls == classOf[Int]) + INT_SIZE + else if (cls == classOf[Long]) + LONG_SIZE + else if (cls == classOf[Float]) + FLOAT_SIZE + else if (cls == classOf[Double]) + DOUBLE_SIZE + else throw new IllegalArgumentException( + "Non-primitive class " + cls + " passed to primitiveSize()") + } + + /** + * Get or compute the ClassInfo for a given class. + */ + private def getClassInfo(cls: Class[_]): ClassInfo = { + // Check whether we've already cached a ClassInfo for this class + val info = classInfos.get(cls) + if (info != null) { + return info + } + + val parent = getClassInfo(cls.getSuperclass) + var shellSize = parent.shellSize + var pointerFields = parent.pointerFields + + for (field <- cls.getDeclaredFields) { + if (!Modifier.isStatic(field.getModifiers)) { + val fieldClass = field.getType + if (fieldClass.isPrimitive) { + shellSize += primitiveSize(fieldClass) + } else { + field.setAccessible(true) // Enable future get()'s on this field + shellSize += pointerSize + pointerFields = field :: pointerFields + } + } + } + + shellSize = alignSize(shellSize) + + // Create and cache a new ClassInfo + val newInfo = new ClassInfo(shellSize, pointerFields) + classInfos.put(cls, newInfo) + return newInfo + } + + private def alignSize(size: Long): Long = { + val rem = size % ALIGN_SIZE + return if (rem == 0) size else (size + ALIGN_SIZE - rem) + } +} diff --git a/core/src/main/scala/org/apache/spark/util/StatCounter.scala b/core/src/main/scala/org/apache/spark/util/StatCounter.scala new file mode 100644 index 0000000000..020d5edba9 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/StatCounter.scala @@ -0,0 +1,131 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +/** + * A class for tracking the statistics of a set of numbers (count, mean and variance) in a + * numerically robust way. Includes support for merging two StatCounters. Based on + * [[http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance Welford and Chan's algorithms for running variance]]. + * + * @constructor Initialize the StatCounter with the given values. + */ +class StatCounter(values: TraversableOnce[Double]) extends Serializable { + private var n: Long = 0 // Running count of our values + private var mu: Double = 0 // Running mean of our values + private var m2: Double = 0 // Running variance numerator (sum of (x - mean)^2) + + merge(values) + + /** Initialize the StatCounter with no values. */ + def this() = this(Nil) + + /** Add a value into this StatCounter, updating the internal statistics. */ + def merge(value: Double): StatCounter = { + val delta = value - mu + n += 1 + mu += delta / n + m2 += delta * (value - mu) + this + } + + /** Add multiple values into this StatCounter, updating the internal statistics. */ + def merge(values: TraversableOnce[Double]): StatCounter = { + values.foreach(v => merge(v)) + this + } + + /** Merge another StatCounter into this one, adding up the internal statistics. */ + def merge(other: StatCounter): StatCounter = { + if (other == this) { + merge(other.copy()) // Avoid overwriting fields in a weird order + } else { + if (n == 0) { + mu = other.mu + m2 = other.m2 + n = other.n + } else if (other.n != 0) { + val delta = other.mu - mu + if (other.n * 10 < n) { + mu = mu + (delta * other.n) / (n + other.n) + } else if (n * 10 < other.n) { + mu = other.mu - (delta * n) / (n + other.n) + } else { + mu = (mu * n + other.mu * other.n) / (n + other.n) + } + m2 += other.m2 + (delta * delta * n * other.n) / (n + other.n) + n += other.n + } + this + } + } + + /** Clone this StatCounter */ + def copy(): StatCounter = { + val other = new StatCounter + other.n = n + other.mu = mu + other.m2 = m2 + other + } + + def count: Long = n + + def mean: Double = mu + + def sum: Double = n * mu + + /** Return the variance of the values. */ + def variance: Double = { + if (n == 0) + Double.NaN + else + m2 / n + } + + /** + * Return the sample variance, which corrects for bias in estimating the variance by dividing + * by N-1 instead of N. + */ + def sampleVariance: Double = { + if (n <= 1) + Double.NaN + else + m2 / (n - 1) + } + + /** Return the standard deviation of the values. */ + def stdev: Double = math.sqrt(variance) + + /** + * Return the sample standard deviation of the values, which corrects for bias in estimating the + * variance by dividing by N-1 instead of N. + */ + def sampleStdev: Double = math.sqrt(sampleVariance) + + override def toString: String = { + "(count: %d, mean: %f, stdev: %f)".format(count, mean, stdev) + } +} + +object StatCounter { + /** Build a StatCounter from a list of values. */ + def apply(values: TraversableOnce[Double]) = new StatCounter(values) + + /** Build a StatCounter from a list of values passed as variable-length arguments. */ + def apply(values: Double*) = new StatCounter(values) +} diff --git a/core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala b/core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala new file mode 100644 index 0000000000..dbff571de9 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +import java.util.concurrent.ConcurrentHashMap +import scala.collection.JavaConversions +import scala.collection.mutable.Map +import scala.collection.immutable +import org.apache.spark.scheduler.MapStatus +import org.apache.spark.Logging + +/** + * This is a custom implementation of scala.collection.mutable.Map which stores the insertion + * time stamp along with each key-value pair. Key-value pairs that are older than a particular + * threshold time can them be removed using the clearOldValues method. This is intended to be a drop-in + * replacement of scala.collection.mutable.HashMap. + */ +class TimeStampedHashMap[A, B] extends Map[A, B]() with Logging { + val internalMap = new ConcurrentHashMap[A, (B, Long)]() + + def get(key: A): Option[B] = { + val value = internalMap.get(key) + if (value != null) Some(value._1) else None + } + + def iterator: Iterator[(A, B)] = { + val jIterator = internalMap.entrySet().iterator() + JavaConversions.asScalaIterator(jIterator).map(kv => (kv.getKey, kv.getValue._1)) + } + + override def + [B1 >: B](kv: (A, B1)): Map[A, B1] = { + val newMap = new TimeStampedHashMap[A, B1] + newMap.internalMap.putAll(this.internalMap) + newMap.internalMap.put(kv._1, (kv._2, currentTime)) + newMap + } + + override def - (key: A): Map[A, B] = { + val newMap = new TimeStampedHashMap[A, B] + newMap.internalMap.putAll(this.internalMap) + newMap.internalMap.remove(key) + newMap + } + + override def += (kv: (A, B)): this.type = { + internalMap.put(kv._1, (kv._2, currentTime)) + this + } + + // Should we return previous value directly or as Option ? + def putIfAbsent(key: A, value: B): Option[B] = { + val prev = internalMap.putIfAbsent(key, (value, currentTime)) + if (prev != null) Some(prev._1) else None + } + + + override def -= (key: A): this.type = { + internalMap.remove(key) + this + } + + override def update(key: A, value: B) { + this += ((key, value)) + } + + override def apply(key: A): B = { + val value = internalMap.get(key) + if (value == null) throw new NoSuchElementException() + value._1 + } + + override def filter(p: ((A, B)) => Boolean): Map[A, B] = { + JavaConversions.mapAsScalaConcurrentMap(internalMap).map(kv => (kv._1, kv._2._1)).filter(p) + } + + override def empty: Map[A, B] = new TimeStampedHashMap[A, B]() + + override def size: Int = internalMap.size + + override def foreach[U](f: ((A, B)) => U) { + val iterator = internalMap.entrySet().iterator() + while(iterator.hasNext) { + val entry = iterator.next() + val kv = (entry.getKey, entry.getValue._1) + f(kv) + } + } + + def toMap: immutable.Map[A, B] = iterator.toMap + + /** + * Removes old key-value pairs that have timestamp earlier than `threshTime` + */ + def clearOldValues(threshTime: Long) { + val iterator = internalMap.entrySet().iterator() + while(iterator.hasNext) { + val entry = iterator.next() + if (entry.getValue._2 < threshTime) { + logDebug("Removing key " + entry.getKey) + iterator.remove() + } + } + } + + private def currentTime: Long = System.currentTimeMillis() + +} diff --git a/core/src/main/scala/org/apache/spark/util/TimeStampedHashSet.scala b/core/src/main/scala/org/apache/spark/util/TimeStampedHashSet.scala new file mode 100644 index 0000000000..26983138ff --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/TimeStampedHashSet.scala @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +import scala.collection.mutable.Set +import scala.collection.JavaConversions +import java.util.concurrent.ConcurrentHashMap + + +class TimeStampedHashSet[A] extends Set[A] { + val internalMap = new ConcurrentHashMap[A, Long]() + + def contains(key: A): Boolean = { + internalMap.contains(key) + } + + def iterator: Iterator[A] = { + val jIterator = internalMap.entrySet().iterator() + JavaConversions.asScalaIterator(jIterator).map(_.getKey) + } + + override def + (elem: A): Set[A] = { + val newSet = new TimeStampedHashSet[A] + newSet ++= this + newSet += elem + newSet + } + + override def - (elem: A): Set[A] = { + val newSet = new TimeStampedHashSet[A] + newSet ++= this + newSet -= elem + newSet + } + + override def += (key: A): this.type = { + internalMap.put(key, currentTime) + this + } + + override def -= (key: A): this.type = { + internalMap.remove(key) + this + } + + override def empty: Set[A] = new TimeStampedHashSet[A]() + + override def size(): Int = internalMap.size() + + override def foreach[U](f: (A) => U): Unit = { + val iterator = internalMap.entrySet().iterator() + while(iterator.hasNext) { + f(iterator.next.getKey) + } + } + + /** + * Removes old values that have timestamp earlier than `threshTime` + */ + def clearOldValues(threshTime: Long) { + val iterator = internalMap.entrySet().iterator() + while(iterator.hasNext) { + val entry = iterator.next() + if (entry.getValue < threshTime) { + iterator.remove() + } + } + } + + private def currentTime: Long = System.currentTimeMillis() +} diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala new file mode 100644 index 0000000000..d8d014de7d --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -0,0 +1,788 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +import java.io._ +import java.net.{InetAddress, URL, URI, NetworkInterface, Inet4Address, ServerSocket} +import java.util.{Locale, Random, UUID} + +import java.util.concurrent.{ConcurrentHashMap, Executors, ThreadFactory, ThreadPoolExecutor} +import java.util.regex.Pattern + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{Path, FileSystem, FileUtil} + +import scala.collection.mutable.{ArrayBuffer, HashMap} +import scala.collection.JavaConversions._ +import scala.collection.Map +import scala.io.Source +import scala.reflect.ClassTag +import scala.Some + + +import com.google.common.io.Files +import com.google.common.util.concurrent.ThreadFactoryBuilder + +import org.apache.hadoop.fs.{Path, FileSystem, FileUtil} + +import org.apache.spark.serializer.{DeserializationStream, SerializationStream, SerializerInstance} +import org.apache.spark.deploy.SparkHadoopUtil +import java.nio.ByteBuffer +import org.apache.spark.{SparkEnv, SparkException, Logging} + + +/** + * Various utility methods used by Spark. + */ +private[spark] object Utils extends Logging { + + /** Serialize an object using Java serialization */ + def serialize[T](o: T): Array[Byte] = { + val bos = new ByteArrayOutputStream() + val oos = new ObjectOutputStream(bos) + oos.writeObject(o) + oos.close() + return bos.toByteArray + } + + /** Deserialize an object using Java serialization */ + def deserialize[T](bytes: Array[Byte]): T = { + val bis = new ByteArrayInputStream(bytes) + val ois = new ObjectInputStream(bis) + return ois.readObject.asInstanceOf[T] + } + + /** Deserialize an object using Java serialization and the given ClassLoader */ + def deserialize[T](bytes: Array[Byte], loader: ClassLoader): T = { + val bis = new ByteArrayInputStream(bytes) + val ois = new ObjectInputStream(bis) { + override def resolveClass(desc: ObjectStreamClass) = + Class.forName(desc.getName, false, loader) + } + return ois.readObject.asInstanceOf[T] + } + + /** Serialize via nested stream using specific serializer */ + def serializeViaNestedStream(os: OutputStream, ser: SerializerInstance)(f: SerializationStream => Unit) = { + val osWrapper = ser.serializeStream(new OutputStream { + def write(b: Int) = os.write(b) + + override def write(b: Array[Byte], off: Int, len: Int) = os.write(b, off, len) + }) + try { + f(osWrapper) + } finally { + osWrapper.close() + } + } + + /** Deserialize via nested stream using specific serializer */ + def deserializeViaNestedStream(is: InputStream, ser: SerializerInstance)(f: DeserializationStream => Unit) = { + val isWrapper = ser.deserializeStream(new InputStream { + def read(): Int = is.read() + + override def read(b: Array[Byte], off: Int, len: Int): Int = is.read(b, off, len) + }) + try { + f(isWrapper) + } finally { + isWrapper.close() + } + } + + /** + * Primitive often used when writing {@link java.nio.ByteBuffer} to {@link java.io.DataOutput}. + */ + def writeByteBuffer(bb: ByteBuffer, out: ObjectOutput) = { + if (bb.hasArray) { + out.write(bb.array(), bb.arrayOffset() + bb.position(), bb.remaining()) + } else { + val bbval = new Array[Byte](bb.remaining()) + bb.get(bbval) + out.write(bbval) + } + } + + def isAlpha(c: Char): Boolean = { + (c >= 'A' && c <= 'Z') || (c >= 'a' && c <= 'z') + } + + /** Split a string into words at non-alphabetic characters */ + def splitWords(s: String): Seq[String] = { + val buf = new ArrayBuffer[String] + var i = 0 + while (i < s.length) { + var j = i + while (j < s.length && isAlpha(s.charAt(j))) { + j += 1 + } + if (j > i) { + buf += s.substring(i, j) + } + i = j + while (i < s.length && !isAlpha(s.charAt(i))) { + i += 1 + } + } + return buf + } + + private val shutdownDeletePaths = new collection.mutable.HashSet[String]() + + // Register the path to be deleted via shutdown hook + def registerShutdownDeleteDir(file: File) { + val absolutePath = file.getAbsolutePath() + shutdownDeletePaths.synchronized { + shutdownDeletePaths += absolutePath + } + } + + // Is the path already registered to be deleted via a shutdown hook ? + def hasShutdownDeleteDir(file: File): Boolean = { + val absolutePath = file.getAbsolutePath() + shutdownDeletePaths.synchronized { + shutdownDeletePaths.contains(absolutePath) + } + } + + // Note: if file is child of some registered path, while not equal to it, then return true; + // else false. This is to ensure that two shutdown hooks do not try to delete each others + // paths - resulting in IOException and incomplete cleanup. + def hasRootAsShutdownDeleteDir(file: File): Boolean = { + val absolutePath = file.getAbsolutePath() + val retval = shutdownDeletePaths.synchronized { + shutdownDeletePaths.find { path => + !absolutePath.equals(path) && absolutePath.startsWith(path) + }.isDefined + } + if (retval) { + logInfo("path = " + file + ", already present as root for deletion.") + } + retval + } + + /** Create a temporary directory inside the given parent directory */ + def createTempDir(root: String = System.getProperty("java.io.tmpdir")): File = { + var attempts = 0 + val maxAttempts = 10 + var dir: File = null + while (dir == null) { + attempts += 1 + if (attempts > maxAttempts) { + throw new IOException("Failed to create a temp directory (under " + root + ") after " + + maxAttempts + " attempts!") + } + try { + dir = new File(root, "spark-" + UUID.randomUUID.toString) + if (dir.exists() || !dir.mkdirs()) { + dir = null + } + } catch { case e: IOException => ; } + } + + registerShutdownDeleteDir(dir) + + // Add a shutdown hook to delete the temp dir when the JVM exits + Runtime.getRuntime.addShutdownHook(new Thread("delete Spark temp dir " + dir) { + override def run() { + // Attempt to delete if some patch which is parent of this is not already registered. + if (! hasRootAsShutdownDeleteDir(dir)) Utils.deleteRecursively(dir) + } + }) + dir + } + + /** Copy all data from an InputStream to an OutputStream */ + def copyStream(in: InputStream, + out: OutputStream, + closeStreams: Boolean = false) + { + val buf = new Array[Byte](8192) + var n = 0 + while (n != -1) { + n = in.read(buf) + if (n != -1) { + out.write(buf, 0, n) + } + } + if (closeStreams) { + in.close() + out.close() + } + } + + /** + * Download a file requested by the executor. Supports fetching the file in a variety of ways, + * including HTTP, HDFS and files on a standard filesystem, based on the URL parameter. + * + * Throws SparkException if the target file already exists and has different contents than + * the requested file. + */ + def fetchFile(url: String, targetDir: File) { + val filename = url.split("/").last + val tempDir = getLocalDir + val tempFile = File.createTempFile("fetchFileTemp", null, new File(tempDir)) + val targetFile = new File(targetDir, filename) + val uri = new URI(url) + uri.getScheme match { + case "http" | "https" | "ftp" => + logInfo("Fetching " + url + " to " + tempFile) + val in = new URL(url).openStream() + val out = new FileOutputStream(tempFile) + Utils.copyStream(in, out, true) + if (targetFile.exists && !Files.equal(tempFile, targetFile)) { + tempFile.delete() + throw new SparkException( + "File " + targetFile + " exists and does not match contents of" + " " + url) + } else { + Files.move(tempFile, targetFile) + } + case "file" | null => + // In the case of a local file, copy the local file to the target directory. + // Note the difference between uri vs url. + val sourceFile = if (uri.isAbsolute) new File(uri) else new File(url) + if (targetFile.exists) { + // If the target file already exists, warn the user if + if (!Files.equal(sourceFile, targetFile)) { + throw new SparkException( + "File " + targetFile + " exists and does not match contents of" + " " + url) + } else { + // Do nothing if the file contents are the same, i.e. this file has been copied + // previously. + logInfo(sourceFile.getAbsolutePath + " has been previously copied to " + + targetFile.getAbsolutePath) + } + } else { + // The file does not exist in the target directory. Copy it there. + logInfo("Copying " + sourceFile.getAbsolutePath + " to " + targetFile.getAbsolutePath) + Files.copy(sourceFile, targetFile) + } + case _ => + // Use the Hadoop filesystem library, which supports file://, hdfs://, s3://, and others + val env = SparkEnv.get + val uri = new URI(url) + val conf = env.hadoop.newConfiguration() + val fs = FileSystem.get(uri, conf) + val in = fs.open(new Path(uri)) + val out = new FileOutputStream(tempFile) + Utils.copyStream(in, out, true) + if (targetFile.exists && !Files.equal(tempFile, targetFile)) { + tempFile.delete() + throw new SparkException("File " + targetFile + " exists and does not match contents of" + + " " + url) + } else { + Files.move(tempFile, targetFile) + } + } + // Decompress the file if it's a .tar or .tar.gz + if (filename.endsWith(".tar.gz") || filename.endsWith(".tgz")) { + logInfo("Untarring " + filename) + Utils.execute(Seq("tar", "-xzf", filename), targetDir) + } else if (filename.endsWith(".tar")) { + logInfo("Untarring " + filename) + Utils.execute(Seq("tar", "-xf", filename), targetDir) + } + // Make the file executable - That's necessary for scripts + FileUtil.chmod(targetFile.getAbsolutePath, "a+x") + } + + /** + * Get a temporary directory using Spark's spark.local.dir property, if set. This will always + * return a single directory, even though the spark.local.dir property might be a list of + * multiple paths. + */ + def getLocalDir: String = { + System.getProperty("spark.local.dir", System.getProperty("java.io.tmpdir")).split(',')(0) + } + + /** + * Shuffle the elements of a collection into a random order, returning the + * result in a new collection. Unlike scala.util.Random.shuffle, this method + * uses a local random number generator, avoiding inter-thread contention. + */ + def randomize[T: ClassTag](seq: TraversableOnce[T]): Seq[T] = { + randomizeInPlace(seq.toArray) + } + + /** + * Shuffle the elements of an array into a random order, modifying the + * original array. Returns the original array. + */ + def randomizeInPlace[T](arr: Array[T], rand: Random = new Random): Array[T] = { + for (i <- (arr.length - 1) to 1 by -1) { + val j = rand.nextInt(i) + val tmp = arr(j) + arr(j) = arr(i) + arr(i) = tmp + } + arr + } + + /** + * Get the local host's IP address in dotted-quad format (e.g. 1.2.3.4). + * Note, this is typically not used from within core spark. + */ + lazy val localIpAddress: String = findLocalIpAddress() + lazy val localIpAddressHostname: String = getAddressHostName(localIpAddress) + + private def findLocalIpAddress(): String = { + val defaultIpOverride = System.getenv("SPARK_LOCAL_IP") + if (defaultIpOverride != null) { + defaultIpOverride + } else { + val address = InetAddress.getLocalHost + if (address.isLoopbackAddress) { + // Address resolves to something like 127.0.1.1, which happens on Debian; try to find + // a better address using the local network interfaces + for (ni <- NetworkInterface.getNetworkInterfaces) { + for (addr <- ni.getInetAddresses if !addr.isLinkLocalAddress && + !addr.isLoopbackAddress && addr.isInstanceOf[Inet4Address]) { + // We've found an address that looks reasonable! + logWarning("Your hostname, " + InetAddress.getLocalHost.getHostName + " resolves to" + + " a loopback address: " + address.getHostAddress + "; using " + addr.getHostAddress + + " instead (on interface " + ni.getName + ")") + logWarning("Set SPARK_LOCAL_IP if you need to bind to another address") + return addr.getHostAddress + } + } + logWarning("Your hostname, " + InetAddress.getLocalHost.getHostName + " resolves to" + + " a loopback address: " + address.getHostAddress + ", but we couldn't find any" + + " external IP address!") + logWarning("Set SPARK_LOCAL_IP if you need to bind to another address") + } + address.getHostAddress + } + } + + private var customHostname: Option[String] = None + + /** + * Allow setting a custom host name because when we run on Mesos we need to use the same + * hostname it reports to the master. + */ + def setCustomHostname(hostname: String) { + // DEBUG code + Utils.checkHost(hostname) + customHostname = Some(hostname) + } + + /** + * Get the local machine's hostname. + */ + def localHostName(): String = { + customHostname.getOrElse(localIpAddressHostname) + } + + def getAddressHostName(address: String): String = { + InetAddress.getByName(address).getHostName + } + + def localHostPort(): String = { + val retval = System.getProperty("spark.hostPort", null) + if (retval == null) { + logErrorWithStack("spark.hostPort not set but invoking localHostPort") + return localHostName() + } + + retval + } + + def checkHost(host: String, message: String = "") { + assert(host.indexOf(':') == -1, message) + } + + def checkHostPort(hostPort: String, message: String = "") { + assert(hostPort.indexOf(':') != -1, message) + } + + // Used by DEBUG code : remove when all testing done + def logErrorWithStack(msg: String) { + try { throw new Exception } catch { case ex: Exception => { logError(msg, ex) } } + } + + // Typically, this will be of order of number of nodes in cluster + // If not, we should change it to LRUCache or something. + private val hostPortParseResults = new ConcurrentHashMap[String, (String, Int)]() + + def parseHostPort(hostPort: String): (String, Int) = { + { + // Check cache first. + var cached = hostPortParseResults.get(hostPort) + if (cached != null) return cached + } + + val indx: Int = hostPort.lastIndexOf(':') + // This is potentially broken - when dealing with ipv6 addresses for example, sigh ... + // but then hadoop does not support ipv6 right now. + // For now, we assume that if port exists, then it is valid - not check if it is an int > 0 + if (-1 == indx) { + val retval = (hostPort, 0) + hostPortParseResults.put(hostPort, retval) + return retval + } + + val retval = (hostPort.substring(0, indx).trim(), hostPort.substring(indx + 1).trim().toInt) + hostPortParseResults.putIfAbsent(hostPort, retval) + hostPortParseResults.get(hostPort) + } + + private[spark] val daemonThreadFactory: ThreadFactory = + new ThreadFactoryBuilder().setDaemon(true).build() + + /** + * Wrapper over newCachedThreadPool. + */ + def newDaemonCachedThreadPool(): ThreadPoolExecutor = + Executors.newCachedThreadPool(daemonThreadFactory).asInstanceOf[ThreadPoolExecutor] + + /** + * Return the string to tell how long has passed in seconds. The passing parameter should be in + * millisecond. + */ + def getUsedTimeMs(startTimeMs: Long): String = { + return " " + (System.currentTimeMillis - startTimeMs) + " ms" + } + + /** + * Wrapper over newFixedThreadPool. + */ + def newDaemonFixedThreadPool(nThreads: Int): ThreadPoolExecutor = + Executors.newFixedThreadPool(nThreads, daemonThreadFactory).asInstanceOf[ThreadPoolExecutor] + + /** + * Delete a file or directory and its contents recursively. + */ + def deleteRecursively(file: File) { + if (file.isDirectory) { + for (child <- file.listFiles()) { + deleteRecursively(child) + } + } + if (!file.delete()) { + throw new IOException("Failed to delete: " + file) + } + } + + /** + * Convert a Java memory parameter passed to -Xmx (such as 300m or 1g) to a number of megabytes. + * This is used to figure out how much memory to claim from Mesos based on the SPARK_MEM + * environment variable. + */ + def memoryStringToMb(str: String): Int = { + val lower = str.toLowerCase + if (lower.endsWith("k")) { + (lower.substring(0, lower.length-1).toLong / 1024).toInt + } else if (lower.endsWith("m")) { + lower.substring(0, lower.length-1).toInt + } else if (lower.endsWith("g")) { + lower.substring(0, lower.length-1).toInt * 1024 + } else if (lower.endsWith("t")) { + lower.substring(0, lower.length-1).toInt * 1024 * 1024 + } else {// no suffix, so it's just a number in bytes + (lower.toLong / 1024 / 1024).toInt + } + } + + /** + * Convert a quantity in bytes to a human-readable string such as "4.0 MB". + */ + def bytesToString(size: Long): String = { + val TB = 1L << 40 + val GB = 1L << 30 + val MB = 1L << 20 + val KB = 1L << 10 + + val (value, unit) = { + if (size >= 2*TB) { + (size.asInstanceOf[Double] / TB, "TB") + } else if (size >= 2*GB) { + (size.asInstanceOf[Double] / GB, "GB") + } else if (size >= 2*MB) { + (size.asInstanceOf[Double] / MB, "MB") + } else if (size >= 2*KB) { + (size.asInstanceOf[Double] / KB, "KB") + } else { + (size.asInstanceOf[Double], "B") + } + } + "%.1f %s".formatLocal(Locale.US, value, unit) + } + + /** + * Returns a human-readable string representing a duration such as "35ms" + */ + def msDurationToString(ms: Long): String = { + val second = 1000 + val minute = 60 * second + val hour = 60 * minute + + ms match { + case t if t < second => + "%d ms".format(t) + case t if t < minute => + "%.1f s".format(t.toFloat / second) + case t if t < hour => + "%.1f m".format(t.toFloat / minute) + case t => + "%.2f h".format(t.toFloat / hour) + } + } + + /** + * Convert a quantity in megabytes to a human-readable string such as "4.0 MB". + */ + def megabytesToString(megabytes: Long): String = { + bytesToString(megabytes * 1024L * 1024L) + } + + /** + * Execute a command in the given working directory, throwing an exception if it completes + * with an exit code other than 0. + */ + def execute(command: Seq[String], workingDir: File) { + val process = new ProcessBuilder(command: _*) + .directory(workingDir) + .redirectErrorStream(true) + .start() + new Thread("read stdout for " + command(0)) { + override def run() { + for (line <- Source.fromInputStream(process.getInputStream).getLines) { + System.err.println(line) + } + } + }.start() + val exitCode = process.waitFor() + if (exitCode != 0) { + throw new SparkException("Process " + command + " exited with code " + exitCode) + } + } + + /** + * Execute a command in the current working directory, throwing an exception if it completes + * with an exit code other than 0. + */ + def execute(command: Seq[String]) { + execute(command, new File(".")) + } + + /** + * Execute a command and get its output, throwing an exception if it yields a code other than 0. + */ + def executeAndGetOutput(command: Seq[String], workingDir: File = new File("."), + extraEnvironment: Map[String, String] = Map.empty): String = { + val builder = new ProcessBuilder(command: _*) + .directory(workingDir) + val environment = builder.environment() + for ((key, value) <- extraEnvironment) { + environment.put(key, value) + } + val process = builder.start() + new Thread("read stderr for " + command(0)) { + override def run() { + for (line <- Source.fromInputStream(process.getErrorStream).getLines) { + System.err.println(line) + } + } + }.start() + val output = new StringBuffer + val stdoutThread = new Thread("read stdout for " + command(0)) { + override def run() { + for (line <- Source.fromInputStream(process.getInputStream).getLines) { + output.append(line) + } + } + } + stdoutThread.start() + val exitCode = process.waitFor() + stdoutThread.join() // Wait for it to finish reading output + if (exitCode != 0) { + throw new SparkException("Process " + command + " exited with code " + exitCode) + } + output.toString + } + + /** + * A regular expression to match classes of the "core" Spark API that we want to skip when + * finding the call site of a method. + */ + private val SPARK_CLASS_REGEX = """^org\.apache\.spark(\.api\.java)?(\.util)?(\.rdd)?\.[A-Z]""".r + + private[spark] class CallSiteInfo(val lastSparkMethod: String, val firstUserFile: String, + val firstUserLine: Int, val firstUserClass: String) + + /** + * When called inside a class in the spark package, returns the name of the user code class + * (outside the spark package) that called into Spark, as well as which Spark method they called. + * This is used, for example, to tell users where in their code each RDD got created. + */ + def getCallSiteInfo: CallSiteInfo = { + val trace = Thread.currentThread.getStackTrace().filter( el => + (!el.getMethodName.contains("getStackTrace"))) + + // Keep crawling up the stack trace until we find the first function not inside of the spark + // package. We track the last (shallowest) contiguous Spark method. This might be an RDD + // transformation, a SparkContext function (such as parallelize), or anything else that leads + // to instantiation of an RDD. We also track the first (deepest) user method, file, and line. + var lastSparkMethod = "<unknown>" + var firstUserFile = "<unknown>" + var firstUserLine = 0 + var finished = false + var firstUserClass = "<unknown>" + + for (el <- trace) { + if (!finished) { + if (SPARK_CLASS_REGEX.findFirstIn(el.getClassName) != None) { + lastSparkMethod = if (el.getMethodName == "<init>") { + // Spark method is a constructor; get its class name + el.getClassName.substring(el.getClassName.lastIndexOf('.') + 1) + } else { + el.getMethodName + } + } + else { + firstUserLine = el.getLineNumber + firstUserFile = el.getFileName + firstUserClass = el.getClassName + finished = true + } + } + } + new CallSiteInfo(lastSparkMethod, firstUserFile, firstUserLine, firstUserClass) + } + + def formatSparkCallSite = { + val callSiteInfo = getCallSiteInfo + "%s at %s:%s".format(callSiteInfo.lastSparkMethod, callSiteInfo.firstUserFile, + callSiteInfo.firstUserLine) + } + + /** Return a string containing part of a file from byte 'start' to 'end'. */ + def offsetBytes(path: String, start: Long, end: Long): String = { + val file = new File(path) + val length = file.length() + val effectiveEnd = math.min(length, end) + val effectiveStart = math.max(0, start) + val buff = new Array[Byte]((effectiveEnd-effectiveStart).toInt) + val stream = new FileInputStream(file) + + stream.skip(effectiveStart) + stream.read(buff) + stream.close() + Source.fromBytes(buff).mkString + } + + /** + * Clone an object using a Spark serializer. + */ + def clone[T](value: T, serializer: SerializerInstance): T = { + serializer.deserialize[T](serializer.serialize(value)) + } + + /** + * Detect whether this thread might be executing a shutdown hook. Will always return true if + * the current thread is a running a shutdown hook but may spuriously return true otherwise (e.g. + * if System.exit was just called by a concurrent thread). + * + * Currently, this detects whether the JVM is shutting down by Runtime#addShutdownHook throwing + * an IllegalStateException. + */ + def inShutdown(): Boolean = { + try { + val hook = new Thread { + override def run() {} + } + Runtime.getRuntime.addShutdownHook(hook) + Runtime.getRuntime.removeShutdownHook(hook) + } catch { + case ise: IllegalStateException => return true + } + return false + } + + def isSpace(c: Char): Boolean = { + " \t\r\n".indexOf(c) != -1 + } + + /** + * Split a string of potentially quoted arguments from the command line the way that a shell + * would do it to determine arguments to a command. For example, if the string is 'a "b c" d', + * then it would be parsed as three arguments: 'a', 'b c' and 'd'. + */ + def splitCommandString(s: String): Seq[String] = { + val buf = new ArrayBuffer[String] + var inWord = false + var inSingleQuote = false + var inDoubleQuote = false + var curWord = new StringBuilder + def endWord() { + buf += curWord.toString + curWord.clear() + } + var i = 0 + while (i < s.length) { + var nextChar = s.charAt(i) + if (inDoubleQuote) { + if (nextChar == '"') { + inDoubleQuote = false + } else if (nextChar == '\\') { + if (i < s.length - 1) { + // Append the next character directly, because only " and \ may be escaped in + // double quotes after the shell's own expansion + curWord.append(s.charAt(i + 1)) + i += 1 + } + } else { + curWord.append(nextChar) + } + } else if (inSingleQuote) { + if (nextChar == '\'') { + inSingleQuote = false + } else { + curWord.append(nextChar) + } + // Backslashes are not treated specially in single quotes + } else if (nextChar == '"') { + inWord = true + inDoubleQuote = true + } else if (nextChar == '\'') { + inWord = true + inSingleQuote = true + } else if (!isSpace(nextChar)) { + curWord.append(nextChar) + inWord = true + } else if (inWord && isSpace(nextChar)) { + endWord() + inWord = false + } + i += 1 + } + if (inWord || inDoubleQuote || inSingleQuote) { + endWord() + } + return buf + } + + /* Calculates 'x' modulo 'mod', takes to consideration sign of x, + * i.e. if 'x' is negative, than 'x' % 'mod' is negative too + * so function return (x % mod) + mod in that case. + */ + def nonNegativeMod(x: Int, mod: Int): Int = { + val rawMod = x % mod + rawMod + (if (rawMod < 0) mod else 0) + } +} diff --git a/core/src/main/scala/org/apache/spark/util/Vector.scala b/core/src/main/scala/org/apache/spark/util/Vector.scala new file mode 100644 index 0000000000..fe710c58ac --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/Vector.scala @@ -0,0 +1,139 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +class Vector(val elements: Array[Double]) extends Serializable { + def length = elements.length + + def apply(index: Int) = elements(index) + + def + (other: Vector): Vector = { + if (length != other.length) + throw new IllegalArgumentException("Vectors of different length") + return Vector(length, i => this(i) + other(i)) + } + + def add(other: Vector) = this + other + + def - (other: Vector): Vector = { + if (length != other.length) + throw new IllegalArgumentException("Vectors of different length") + return Vector(length, i => this(i) - other(i)) + } + + def subtract(other: Vector) = this - other + + def dot(other: Vector): Double = { + if (length != other.length) + throw new IllegalArgumentException("Vectors of different length") + var ans = 0.0 + var i = 0 + while (i < length) { + ans += this(i) * other(i) + i += 1 + } + return ans + } + + /** + * return (this + plus) dot other, but without creating any intermediate storage + * @param plus + * @param other + * @return + */ + def plusDot(plus: Vector, other: Vector): Double = { + if (length != other.length) + throw new IllegalArgumentException("Vectors of different length") + if (length != plus.length) + throw new IllegalArgumentException("Vectors of different length") + var ans = 0.0 + var i = 0 + while (i < length) { + ans += (this(i) + plus(i)) * other(i) + i += 1 + } + return ans + } + + def += (other: Vector): Vector = { + if (length != other.length) + throw new IllegalArgumentException("Vectors of different length") + var i = 0 + while (i < length) { + elements(i) += other(i) + i += 1 + } + this + } + + def addInPlace(other: Vector) = this +=other + + def * (scale: Double): Vector = Vector(length, i => this(i) * scale) + + def multiply (d: Double) = this * d + + def / (d: Double): Vector = this * (1 / d) + + def divide (d: Double) = this / d + + def unary_- = this * -1 + + def sum = elements.reduceLeft(_ + _) + + def squaredDist(other: Vector): Double = { + var ans = 0.0 + var i = 0 + while (i < length) { + ans += (this(i) - other(i)) * (this(i) - other(i)) + i += 1 + } + return ans + } + + def dist(other: Vector): Double = math.sqrt(squaredDist(other)) + + override def toString = elements.mkString("(", ", ", ")") +} + +object Vector { + def apply(elements: Array[Double]) = new Vector(elements) + + def apply(elements: Double*) = new Vector(elements.toArray) + + def apply(length: Int, initializer: Int => Double): Vector = { + val elements: Array[Double] = Array.tabulate(length)(initializer) + return new Vector(elements) + } + + def zeros(length: Int) = new Vector(new Array[Double](length)) + + def ones(length: Int) = Vector(length, _ => 1) + + class Multiplier(num: Double) { + def * (vec: Vector) = vec * num + } + + implicit def doubleToMultiplier(num: Double) = new Multiplier(num) + + implicit object VectorAccumParam extends org.apache.spark.AccumulatorParam[Vector] { + def addInPlace(t1: Vector, t2: Vector) = t1 + t2 + + def zero(initialValue: Vector) = Vector.zeros(initialValue.length) + } + +} |