From 4aa1205202f26663f59347f25a7d1f03c755545d Mon Sep 17 00:00:00 2001 From: seanm Date: Mon, 11 Mar 2013 12:37:29 -0600 Subject: adding typesafe repo to streaming resolvers so that akka-zeromq is found --- project/SparkBuild.scala | 3 +++ 1 file changed, 3 insertions(+) diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index b0b6e21681..44c8058e9d 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -162,6 +162,9 @@ object SparkBuild extends Build { def streamingSettings = sharedSettings ++ Seq( name := "spark-streaming", + resolvers ++= Seq( + "Typesafe Repository" at "http://repo.typesafe.com/typesafe/releases/" + ), libraryDependencies ++= Seq( "org.apache.flume" % "flume-ng-sdk" % "1.2.0" % "compile", "com.github.sgroschupf" % "zkclient" % "0.1", -- cgit v1.2.3 From c1c3682c984c83f75352fc22dcadd3e46058cfaf Mon Sep 17 00:00:00 2001 From: seanm Date: Mon, 11 Mar 2013 12:40:44 -0600 Subject: adding checkpoint dir to .gitignore --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index 155e785b01..6c9ffa5426 100644 --- a/.gitignore +++ b/.gitignore @@ -36,3 +36,4 @@ streaming-tests.log dependency-reduced-pom.xml .ensime .ensime_lucene +checkpoint -- cgit v1.2.3 From 42822cf95de71039988e22d8690ba6a4bd639227 Mon Sep 17 00:00:00 2001 From: seanm Date: Wed, 13 Mar 2013 11:40:42 -0600 Subject: changing streaming resolver for akka --- project/SparkBuild.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 44c8058e9d..7e65979a5d 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -163,7 +163,7 @@ object SparkBuild extends Build { def streamingSettings = sharedSettings ++ Seq( name := "spark-streaming", resolvers ++= Seq( - "Typesafe Repository" at "http://repo.typesafe.com/typesafe/releases/" + "Akka Repository" at "http://repo.akka.io/releases/" ), libraryDependencies ++= Seq( "org.apache.flume" % "flume-ng-sdk" % "1.2.0" % "compile", -- cgit v1.2.3 From cfa8e769a86664722f47182fa572179e8beadcb7 Mon Sep 17 00:00:00 2001 From: seanm Date: Mon, 11 Mar 2013 17:16:15 -0600 Subject: KafkaInputDStream improvements. Allows more Kafka configurability --- .../scala/spark/streaming/StreamingContext.scala | 22 +++++++++- .../streaming/dstream/KafkaInputDStream.scala | 48 ++++++++++++++-------- 2 files changed, 51 insertions(+), 19 deletions(-) diff --git a/streaming/src/main/scala/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/spark/streaming/StreamingContext.scala index 25c67b279b..4e1732adf5 100644 --- a/streaming/src/main/scala/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/spark/streaming/StreamingContext.scala @@ -199,7 +199,7 @@ class StreamingContext private ( } /** - * Create an input stream that pulls messages form a Kafka Broker. + * Create an input stream that pulls messages from a Kafka Broker. * @param zkQuorum Zookeper quorum (hostname:port,hostname:port,..). * @param groupId The group id for this consumer. * @param topics Map of (topic_name -> numPartitions) to consume. Each partition is consumed @@ -216,7 +216,25 @@ class StreamingContext private ( initialOffsets: Map[KafkaPartitionKey, Long] = Map[KafkaPartitionKey, Long](), storageLevel: StorageLevel = StorageLevel.MEMORY_ONLY_SER_2 ): DStream[T] = { - val inputStream = new KafkaInputDStream[T](this, zkQuorum, groupId, topics, initialOffsets, storageLevel) + val kafkaParams = Map[String, String]("zk.connect" -> zkQuorum, "groupid" -> groupId, "zk.connectiontimeout.ms" -> "10000"); + kafkaStream[T](kafkaParams, topics, initialOffsets, storageLevel) + } + + /** + * Create an input stream that pulls messages from a Kafka Broker. + * @param kafkaParams Map of kafka configuration paramaters. See: http://kafka.apache.org/configuration.html + * @param topics Map of (topic_name -> numPartitions) to consume. Each partition is consumed + * in its own thread. + * @param initialOffsets Optional initial offsets for each of the partitions to consume. + * @param storageLevel Storage level to use for storing the received objects + */ + def kafkaStream[T: ClassManifest]( + kafkaParams: Map[String, String], + topics: Map[String, Int], + initialOffsets: Map[KafkaPartitionKey, Long], + storageLevel: StorageLevel + ): DStream[T] = { + val inputStream = new KafkaInputDStream[T](this, kafkaParams, topics, initialOffsets, storageLevel) registerInputStream(inputStream) inputStream } diff --git a/streaming/src/main/scala/spark/streaming/dstream/KafkaInputDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/KafkaInputDStream.scala index dc7139cc27..f769fc1cc3 100644 --- a/streaming/src/main/scala/spark/streaming/dstream/KafkaInputDStream.scala +++ b/streaming/src/main/scala/spark/streaming/dstream/KafkaInputDStream.scala @@ -12,6 +12,8 @@ import kafka.message.{Message, MessageSet, MessageAndMetadata} import kafka.serializer.StringDecoder import kafka.utils.{Utils, ZKGroupTopicDirs} import kafka.utils.ZkUtils._ +import kafka.utils.ZKStringSerializer +import org.I0Itec.zkclient._ import scala.collection.mutable.HashMap import scala.collection.JavaConversions._ @@ -23,8 +25,7 @@ case class KafkaPartitionKey(brokerId: Int, topic: String, groupId: String, part /** * Input stream that pulls messages from a Kafka Broker. * - * @param zkQuorum Zookeper quorum (hostname:port,hostname:port,..). - * @param groupId The group id for this consumer. + * @param kafkaParams Map of kafka configuration paramaters. See: http://kafka.apache.org/configuration.html * @param topics Map of (topic_name -> numPartitions) to consume. Each partition is consumed * in its own thread. * @param initialOffsets Optional initial offsets for each of the partitions to consume. @@ -34,8 +35,7 @@ case class KafkaPartitionKey(brokerId: Int, topic: String, groupId: String, part private[streaming] class KafkaInputDStream[T: ClassManifest]( @transient ssc_ : StreamingContext, - zkQuorum: String, - groupId: String, + kafkaParams: Map[String, String], topics: Map[String, Int], initialOffsets: Map[KafkaPartitionKey, Long], storageLevel: StorageLevel @@ -43,19 +43,16 @@ class KafkaInputDStream[T: ClassManifest]( def getReceiver(): NetworkReceiver[T] = { - new KafkaReceiver(zkQuorum, groupId, topics, initialOffsets, storageLevel) + new KafkaReceiver(kafkaParams, topics, initialOffsets, storageLevel) .asInstanceOf[NetworkReceiver[T]] } } private[streaming] -class KafkaReceiver(zkQuorum: String, groupId: String, +class KafkaReceiver(kafkaParams: Map[String, String], topics: Map[String, Int], initialOffsets: Map[KafkaPartitionKey, Long], storageLevel: StorageLevel) extends NetworkReceiver[Any] { - // Timeout for establishing a connection to Zookeper in ms. - val ZK_TIMEOUT = 10000 - // Handles pushing data into the BlockManager lazy protected val blockGenerator = new BlockGenerator(storageLevel) // Connection to Kafka @@ -72,20 +69,24 @@ class KafkaReceiver(zkQuorum: String, groupId: String, // In case we are using multiple Threads to handle Kafka Messages val executorPool = Executors.newFixedThreadPool(topics.values.reduce(_ + _)) - logInfo("Starting Kafka Consumer Stream with group: " + groupId) + logInfo("Starting Kafka Consumer Stream with group: " + kafkaParams("groupid")) logInfo("Initial offsets: " + initialOffsets.toString) - // Zookeper connection properties + // Kafka connection properties val props = new Properties() - props.put("zk.connect", zkQuorum) - props.put("zk.connectiontimeout.ms", ZK_TIMEOUT.toString) - props.put("groupid", groupId) + kafkaParams.foreach(param => props.put(param._1, param._2)) // Create the connection to the cluster - logInfo("Connecting to Zookeper: " + zkQuorum) + logInfo("Connecting to Zookeper: " + kafkaParams("zk.connect")) val consumerConfig = new ConsumerConfig(props) consumerConnector = Consumer.create(consumerConfig).asInstanceOf[ZookeeperConsumerConnector] - logInfo("Connected to " + zkQuorum) + logInfo("Connected to " + kafkaParams("zk.connect")) + + // When autooffset.reset is 'smallest', it is our responsibility to try and whack the + // consumer group zk node. + if (kafkaParams.get("autooffset.reset").exists(_ == "smallest")) { + tryZookeeperConsumerGroupCleanup(kafkaParams("zk.connect"), kafkaParams("groupid")) + } // If specified, set the topic offset setOffsets(initialOffsets) @@ -97,7 +98,6 @@ class KafkaReceiver(zkQuorum: String, groupId: String, topicMessageStreams.values.foreach { streams => streams.foreach { stream => executorPool.submit(new MessageHandler(stream)) } } - } // Overwrites the offets in Zookeper. @@ -122,4 +122,18 @@ class KafkaReceiver(zkQuorum: String, groupId: String, } } } + + // Handles cleanup of consumer group znode. Lifted with love from Kafka's + // ConsumerConsole.scala tryCleanupZookeeper() + private def tryZookeeperConsumerGroupCleanup(zkUrl: String, groupId: String) { + try { + val dir = "/consumers/" + groupId + logInfo("Cleaning up temporary zookeeper data under " + dir + ".") + val zk = new ZkClient(zkUrl, 30*1000, 30*1000, ZKStringSerializer) + zk.deleteRecursive(dir) + zk.close() + } catch { + case _ => // swallow + } + } } -- cgit v1.2.3 From d06928321194b11e082986cd2bb2737d9bc3b698 Mon Sep 17 00:00:00 2001 From: seanm Date: Thu, 14 Mar 2013 23:25:35 -0600 Subject: fixing memory leak in kafka MessageHandler --- .../src/main/scala/spark/streaming/dstream/KafkaInputDStream.scala | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/streaming/src/main/scala/spark/streaming/dstream/KafkaInputDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/KafkaInputDStream.scala index f769fc1cc3..d674b6ee87 100644 --- a/streaming/src/main/scala/spark/streaming/dstream/KafkaInputDStream.scala +++ b/streaming/src/main/scala/spark/streaming/dstream/KafkaInputDStream.scala @@ -114,11 +114,8 @@ class KafkaReceiver(kafkaParams: Map[String, String], private class MessageHandler(stream: KafkaStream[String]) extends Runnable { def run() { logInfo("Starting MessageHandler.") - stream.takeWhile { msgAndMetadata => + for (msgAndMetadata <- stream) { blockGenerator += msgAndMetadata.message - // Keep on handling messages - - true } } } -- cgit v1.2.3 From 33fa1e7e4aca4d9e0edf65d2b768b569305fd044 Mon Sep 17 00:00:00 2001 From: seanm Date: Thu, 14 Mar 2013 23:32:52 -0600 Subject: removing dependency on ZookeeperConsumerConnector + purging last relic of kafka reliability that never solidified (ie- setOffsets) --- .../scala/spark/streaming/StreamingContext.scala | 9 ++----- .../streaming/api/java/JavaStreamingContext.scala | 28 ---------------------- .../streaming/dstream/KafkaInputDStream.scala | 28 ++++------------------ 3 files changed, 6 insertions(+), 59 deletions(-) diff --git a/streaming/src/main/scala/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/spark/streaming/StreamingContext.scala index 4e1732adf5..bb7f216ca7 100644 --- a/streaming/src/main/scala/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/spark/streaming/StreamingContext.scala @@ -204,8 +204,6 @@ class StreamingContext private ( * @param groupId The group id for this consumer. * @param topics Map of (topic_name -> numPartitions) to consume. Each partition is consumed * in its own thread. - * @param initialOffsets Optional initial offsets for each of the partitions to consume. - * By default the value is pulled from zookeper. * @param storageLevel Storage level to use for storing the received objects * (default: StorageLevel.MEMORY_AND_DISK_SER_2) */ @@ -213,11 +211,10 @@ class StreamingContext private ( zkQuorum: String, groupId: String, topics: Map[String, Int], - initialOffsets: Map[KafkaPartitionKey, Long] = Map[KafkaPartitionKey, Long](), storageLevel: StorageLevel = StorageLevel.MEMORY_ONLY_SER_2 ): DStream[T] = { val kafkaParams = Map[String, String]("zk.connect" -> zkQuorum, "groupid" -> groupId, "zk.connectiontimeout.ms" -> "10000"); - kafkaStream[T](kafkaParams, topics, initialOffsets, storageLevel) + kafkaStream[T](kafkaParams, topics, storageLevel) } /** @@ -225,16 +222,14 @@ class StreamingContext private ( * @param kafkaParams Map of kafka configuration paramaters. See: http://kafka.apache.org/configuration.html * @param topics Map of (topic_name -> numPartitions) to consume. Each partition is consumed * in its own thread. - * @param initialOffsets Optional initial offsets for each of the partitions to consume. * @param storageLevel Storage level to use for storing the received objects */ def kafkaStream[T: ClassManifest]( kafkaParams: Map[String, String], topics: Map[String, Int], - initialOffsets: Map[KafkaPartitionKey, Long], storageLevel: StorageLevel ): DStream[T] = { - val inputStream = new KafkaInputDStream[T](this, kafkaParams, topics, initialOffsets, storageLevel) + val inputStream = new KafkaInputDStream[T](this, kafkaParams, topics, storageLevel) registerInputStream(inputStream) inputStream } diff --git a/streaming/src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala b/streaming/src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala index f3b40b5b88..2373f4824a 100644 --- a/streaming/src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala +++ b/streaming/src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala @@ -84,39 +84,12 @@ class JavaStreamingContext(val ssc: StreamingContext) { * @param groupId The group id for this consumer. * @param topics Map of (topic_name -> numPartitions) to consume. Each partition is consumed * in its own thread. - * @param initialOffsets Optional initial offsets for each of the partitions to consume. - * By default the value is pulled from zookeper. - */ - def kafkaStream[T]( - zkQuorum: String, - groupId: String, - topics: JMap[String, JInt], - initialOffsets: JMap[KafkaPartitionKey, JLong]) - : JavaDStream[T] = { - implicit val cmt: ClassManifest[T] = - implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]] - ssc.kafkaStream[T]( - zkQuorum, - groupId, - Map(topics.mapValues(_.intValue()).toSeq: _*), - Map(initialOffsets.mapValues(_.longValue()).toSeq: _*)) - } - - /** - * Create an input stream that pulls messages form a Kafka Broker. - * @param zkQuorum Zookeper quorum (hostname:port,hostname:port,..). - * @param groupId The group id for this consumer. - * @param topics Map of (topic_name -> numPartitions) to consume. Each partition is consumed - * in its own thread. - * @param initialOffsets Optional initial offsets for each of the partitions to consume. - * By default the value is pulled from zookeper. * @param storageLevel RDD storage level. Defaults to memory-only */ def kafkaStream[T]( zkQuorum: String, groupId: String, topics: JMap[String, JInt], - initialOffsets: JMap[KafkaPartitionKey, JLong], storageLevel: StorageLevel) : JavaDStream[T] = { implicit val cmt: ClassManifest[T] = @@ -125,7 +98,6 @@ class JavaStreamingContext(val ssc: StreamingContext) { zkQuorum, groupId, Map(topics.mapValues(_.intValue()).toSeq: _*), - Map(initialOffsets.mapValues(_.longValue()).toSeq: _*), storageLevel) } diff --git a/streaming/src/main/scala/spark/streaming/dstream/KafkaInputDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/KafkaInputDStream.scala index d674b6ee87..c6da1a7f70 100644 --- a/streaming/src/main/scala/spark/streaming/dstream/KafkaInputDStream.scala +++ b/streaming/src/main/scala/spark/streaming/dstream/KafkaInputDStream.scala @@ -19,17 +19,12 @@ import scala.collection.mutable.HashMap import scala.collection.JavaConversions._ -// Key for a specific Kafka Partition: (broker, topic, group, part) -case class KafkaPartitionKey(brokerId: Int, topic: String, groupId: String, partId: Int) - /** * Input stream that pulls messages from a Kafka Broker. * * @param kafkaParams Map of kafka configuration paramaters. See: http://kafka.apache.org/configuration.html * @param topics Map of (topic_name -> numPartitions) to consume. Each partition is consumed * in its own thread. - * @param initialOffsets Optional initial offsets for each of the partitions to consume. - * By default the value is pulled from zookeper. * @param storageLevel RDD storage level. */ private[streaming] @@ -37,26 +32,25 @@ class KafkaInputDStream[T: ClassManifest]( @transient ssc_ : StreamingContext, kafkaParams: Map[String, String], topics: Map[String, Int], - initialOffsets: Map[KafkaPartitionKey, Long], storageLevel: StorageLevel ) extends NetworkInputDStream[T](ssc_ ) with Logging { def getReceiver(): NetworkReceiver[T] = { - new KafkaReceiver(kafkaParams, topics, initialOffsets, storageLevel) + new KafkaReceiver(kafkaParams, topics, storageLevel) .asInstanceOf[NetworkReceiver[T]] } } private[streaming] class KafkaReceiver(kafkaParams: Map[String, String], - topics: Map[String, Int], initialOffsets: Map[KafkaPartitionKey, Long], + topics: Map[String, Int], storageLevel: StorageLevel) extends NetworkReceiver[Any] { // Handles pushing data into the BlockManager lazy protected val blockGenerator = new BlockGenerator(storageLevel) // Connection to Kafka - var consumerConnector : ZookeeperConsumerConnector = null + var consumerConnector : ConsumerConnector = null def onStop() { blockGenerator.stop() @@ -70,7 +64,6 @@ class KafkaReceiver(kafkaParams: Map[String, String], val executorPool = Executors.newFixedThreadPool(topics.values.reduce(_ + _)) logInfo("Starting Kafka Consumer Stream with group: " + kafkaParams("groupid")) - logInfo("Initial offsets: " + initialOffsets.toString) // Kafka connection properties val props = new Properties() @@ -79,7 +72,7 @@ class KafkaReceiver(kafkaParams: Map[String, String], // Create the connection to the cluster logInfo("Connecting to Zookeper: " + kafkaParams("zk.connect")) val consumerConfig = new ConsumerConfig(props) - consumerConnector = Consumer.create(consumerConfig).asInstanceOf[ZookeeperConsumerConnector] + consumerConnector = Consumer.create(consumerConfig) logInfo("Connected to " + kafkaParams("zk.connect")) // When autooffset.reset is 'smallest', it is our responsibility to try and whack the @@ -88,9 +81,6 @@ class KafkaReceiver(kafkaParams: Map[String, String], tryZookeeperConsumerGroupCleanup(kafkaParams("zk.connect"), kafkaParams("groupid")) } - // If specified, set the topic offset - setOffsets(initialOffsets) - // Create Threads for each Topic/Message Stream we are listening val topicMessageStreams = consumerConnector.createMessageStreams(topics, new StringDecoder()) @@ -100,16 +90,6 @@ class KafkaReceiver(kafkaParams: Map[String, String], } } - // Overwrites the offets in Zookeper. - private def setOffsets(offsets: Map[KafkaPartitionKey, Long]) { - offsets.foreach { case(key, offset) => - val topicDirs = new ZKGroupTopicDirs(key.groupId, key.topic) - val partitionName = key.brokerId + "-" + key.partId - updatePersistentPath(consumerConnector.zkClient, - topicDirs.consumerOffsetDir + "/" + partitionName, offset.toString) - } - } - // Handles Kafka Messages private class MessageHandler(stream: KafkaStream[String]) extends Runnable { def run() { -- cgit v1.2.3 From d61978d0abad30a148680c8a63df33e40e469525 Mon Sep 17 00:00:00 2001 From: seanm Date: Fri, 15 Mar 2013 23:36:52 -0600 Subject: keeping JavaStreamingContext in sync with StreamingContext + adding comments for better clarity --- .../main/scala/spark/streaming/api/java/JavaStreamingContext.scala | 7 +++---- .../src/main/scala/spark/streaming/dstream/KafkaInputDStream.scala | 6 ++++-- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/streaming/src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala b/streaming/src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala index 2373f4824a..7a8864614c 100644 --- a/streaming/src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala +++ b/streaming/src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala @@ -80,6 +80,7 @@ class JavaStreamingContext(val ssc: StreamingContext) { /** * Create an input stream that pulls messages form a Kafka Broker. + * @param kafkaParams Map of kafka configuration paramaters. See: http://kafka.apache.org/configuration.html * @param zkQuorum Zookeper quorum (hostname:port,hostname:port,..). * @param groupId The group id for this consumer. * @param topics Map of (topic_name -> numPartitions) to consume. Each partition is consumed @@ -87,16 +88,14 @@ class JavaStreamingContext(val ssc: StreamingContext) { * @param storageLevel RDD storage level. Defaults to memory-only */ def kafkaStream[T]( - zkQuorum: String, - groupId: String, + kafkaParams: JMap[String, String], topics: JMap[String, JInt], storageLevel: StorageLevel) : JavaDStream[T] = { implicit val cmt: ClassManifest[T] = implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]] ssc.kafkaStream[T]( - zkQuorum, - groupId, + kafkaParams.toMap, Map(topics.mapValues(_.intValue()).toSeq: _*), storageLevel) } diff --git a/streaming/src/main/scala/spark/streaming/dstream/KafkaInputDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/KafkaInputDStream.scala index c6da1a7f70..85693808d1 100644 --- a/streaming/src/main/scala/spark/streaming/dstream/KafkaInputDStream.scala +++ b/streaming/src/main/scala/spark/streaming/dstream/KafkaInputDStream.scala @@ -100,8 +100,10 @@ class KafkaReceiver(kafkaParams: Map[String, String], } } - // Handles cleanup of consumer group znode. Lifted with love from Kafka's - // ConsumerConsole.scala tryCleanupZookeeper() + // Delete consumer group from zookeeper. This effectivly resets the group so we can consume from the beginning again. + // The kafka high level consumer doesn't expose setting offsets currently, this is a trick copied from Kafkas' + // ConsoleConsumer. See code related to 'autooffset.reset' when it is set to 'smallest': + // https://github.com/apache/kafka/blob/0.7.2/core/src/main/scala/kafka/consumer/ConsoleConsumer.scala private def tryZookeeperConsumerGroupCleanup(zkUrl: String, groupId: String) { try { val dir = "/consumers/" + groupId -- cgit v1.2.3 From 329ef34c2e04d28c2ad150cf6674d6e86d7511ce Mon Sep 17 00:00:00 2001 From: seanm Date: Tue, 26 Mar 2013 23:56:15 -0600 Subject: fixing autooffset.reset behavior when set to 'largest' --- .../main/scala/spark/streaming/dstream/KafkaInputDStream.scala | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/streaming/src/main/scala/spark/streaming/dstream/KafkaInputDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/KafkaInputDStream.scala index 85693808d1..17a5be3420 100644 --- a/streaming/src/main/scala/spark/streaming/dstream/KafkaInputDStream.scala +++ b/streaming/src/main/scala/spark/streaming/dstream/KafkaInputDStream.scala @@ -75,9 +75,9 @@ class KafkaReceiver(kafkaParams: Map[String, String], consumerConnector = Consumer.create(consumerConfig) logInfo("Connected to " + kafkaParams("zk.connect")) - // When autooffset.reset is 'smallest', it is our responsibility to try and whack the + // When autooffset.reset is defined, it is our responsibility to try and whack the // consumer group zk node. - if (kafkaParams.get("autooffset.reset").exists(_ == "smallest")) { + if (kafkaParams.contains("autooffset.reset")) { tryZookeeperConsumerGroupCleanup(kafkaParams("zk.connect"), kafkaParams("groupid")) } @@ -100,9 +100,11 @@ class KafkaReceiver(kafkaParams: Map[String, String], } } - // Delete consumer group from zookeeper. This effectivly resets the group so we can consume from the beginning again. + // It is our responsibility to delete the consumer group when specifying autooffset.reset. This is because + // Kafka 0.7.2 only honors this param when the group is not in zookeeper. + // // The kafka high level consumer doesn't expose setting offsets currently, this is a trick copied from Kafkas' - // ConsoleConsumer. See code related to 'autooffset.reset' when it is set to 'smallest': + // ConsoleConsumer. See code related to 'autooffset.reset' when it is set to 'smallest'/'largest': // https://github.com/apache/kafka/blob/0.7.2/core/src/main/scala/kafka/consumer/ConsoleConsumer.scala private def tryZookeeperConsumerGroupCleanup(zkUrl: String, groupId: String) { try { -- cgit v1.2.3 From b42d68c8ce9f63513969297b65f4b5a2b06e6078 Mon Sep 17 00:00:00 2001 From: seanm Date: Mon, 15 Apr 2013 12:54:55 -0600 Subject: fixing Spark Streaming count() so that 0 will be emitted when there is nothing to count --- streaming/src/main/scala/spark/streaming/DStream.scala | 5 ++++- streaming/src/test/scala/spark/streaming/BasicOperationsSuite.scala | 4 ++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/streaming/src/main/scala/spark/streaming/DStream.scala b/streaming/src/main/scala/spark/streaming/DStream.scala index e1be5ef51c..e3a9247924 100644 --- a/streaming/src/main/scala/spark/streaming/DStream.scala +++ b/streaming/src/main/scala/spark/streaming/DStream.scala @@ -441,7 +441,10 @@ abstract class DStream[T: ClassManifest] ( * Return a new DStream in which each RDD has a single element generated by counting each RDD * of this DStream. */ - def count(): DStream[Long] = this.map(_ => 1L).reduce(_ + _) + def count(): DStream[Long] = { + val zero = new ConstantInputDStream(context, context.sparkContext.makeRDD(Seq((null, 0L)), 1)) + this.map(_ => (null, 1L)).union(zero).reduceByKey(_ + _).map(_._2) + } /** * Return a new DStream in which each RDD contains the counts of each distinct value in diff --git a/streaming/src/test/scala/spark/streaming/BasicOperationsSuite.scala b/streaming/src/test/scala/spark/streaming/BasicOperationsSuite.scala index 8fce91853c..168e1b7a55 100644 --- a/streaming/src/test/scala/spark/streaming/BasicOperationsSuite.scala +++ b/streaming/src/test/scala/spark/streaming/BasicOperationsSuite.scala @@ -90,9 +90,9 @@ class BasicOperationsSuite extends TestSuiteBase { test("count") { testOperation( - Seq(1 to 1, 1 to 2, 1 to 3, 1 to 4), + Seq(Seq(), 1 to 1, 1 to 2, 1 to 3, 1 to 4), (s: DStream[Int]) => s.count(), - Seq(Seq(1L), Seq(2L), Seq(3L), Seq(4L)) + Seq(Seq(0L), Seq(1L), Seq(2L), Seq(3L), Seq(4L)) ) } -- cgit v1.2.3 From ab0f834dbb509d323577572691293b74368a9d86 Mon Sep 17 00:00:00 2001 From: seanm Date: Tue, 16 Apr 2013 11:57:05 -0600 Subject: adding spark.streaming.blockInterval property --- docs/configuration.md | 7 +++++++ .../main/scala/spark/streaming/dstream/NetworkInputDStream.scala | 2 +- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/docs/configuration.md b/docs/configuration.md index 04eb6daaa5..55f1962b18 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -253,6 +253,13 @@ Apart from these, the following properties are also available, and may be useful applications). Note that any RDD that persists in memory for more than this duration will be cleared as well. + + spark.streaming.blockInterval + 200 + + Duration (milliseconds) of how long to batch new objects coming from network receivers. + + diff --git a/streaming/src/main/scala/spark/streaming/dstream/NetworkInputDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/NetworkInputDStream.scala index 7385474963..26805e9621 100644 --- a/streaming/src/main/scala/spark/streaming/dstream/NetworkInputDStream.scala +++ b/streaming/src/main/scala/spark/streaming/dstream/NetworkInputDStream.scala @@ -198,7 +198,7 @@ abstract class NetworkReceiver[T: ClassManifest]() extends Serializable with Log case class Block(id: String, iterator: Iterator[T], metadata: Any = null) val clock = new SystemClock() - val blockInterval = 200L + val blockInterval = System.getProperty("spark.streaming.blockInterval", "200").toLong val blockIntervalTimer = new RecurringTimer(clock, blockInterval, updateCurrentBuffer) val blockStorageLevel = storageLevel val blocksForPushing = new ArrayBlockingQueue[Block](1000) -- cgit v1.2.3 From 7e56e99573b4cf161293e648aeb159375c9c0fcb Mon Sep 17 00:00:00 2001 From: seanm Date: Sun, 24 Mar 2013 13:40:19 -0600 Subject: Surfacing decoders on KafkaInputDStream --- .../spark/streaming/examples/KafkaWordCount.scala | 2 +- .../scala/spark/streaming/StreamingContext.scala | 11 ++++---- .../streaming/api/java/JavaStreamingContext.scala | 33 ++++++++++++++++------ .../streaming/dstream/KafkaInputDStream.scala | 17 ++++++----- .../test/java/spark/streaming/JavaAPISuite.java | 6 ++-- 5 files changed, 44 insertions(+), 25 deletions(-) diff --git a/examples/src/main/scala/spark/streaming/examples/KafkaWordCount.scala b/examples/src/main/scala/spark/streaming/examples/KafkaWordCount.scala index 9b135a5c54..e0c3555f21 100644 --- a/examples/src/main/scala/spark/streaming/examples/KafkaWordCount.scala +++ b/examples/src/main/scala/spark/streaming/examples/KafkaWordCount.scala @@ -37,7 +37,7 @@ object KafkaWordCount { ssc.checkpoint("checkpoint") val topicpMap = topics.split(",").map((_,numThreads.toInt)).toMap - val lines = ssc.kafkaStream[String](zkQuorum, group, topicpMap) + val lines = ssc.kafkaStream(zkQuorum, group, topicpMap) val words = lines.flatMap(_.split(" ")) val wordCounts = words.map(x => (x, 1l)).reduceByKeyAndWindow(add _, subtract _, Minutes(10), Seconds(2), 2) wordCounts.print() diff --git a/streaming/src/main/scala/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/spark/streaming/StreamingContext.scala index bb7f216ca7..2c6326943d 100644 --- a/streaming/src/main/scala/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/spark/streaming/StreamingContext.scala @@ -28,6 +28,7 @@ import org.apache.hadoop.fs.Path import java.util.UUID import twitter4j.Status + /** * A StreamingContext is the main entry point for Spark Streaming functionality. Besides the basic * information (such as, cluster URL and job name) to internally create a SparkContext, it provides @@ -207,14 +208,14 @@ class StreamingContext private ( * @param storageLevel Storage level to use for storing the received objects * (default: StorageLevel.MEMORY_AND_DISK_SER_2) */ - def kafkaStream[T: ClassManifest]( + def kafkaStream( zkQuorum: String, groupId: String, topics: Map[String, Int], storageLevel: StorageLevel = StorageLevel.MEMORY_ONLY_SER_2 - ): DStream[T] = { + ): DStream[String] = { val kafkaParams = Map[String, String]("zk.connect" -> zkQuorum, "groupid" -> groupId, "zk.connectiontimeout.ms" -> "10000"); - kafkaStream[T](kafkaParams, topics, storageLevel) + kafkaStream[String, kafka.serializer.StringDecoder](kafkaParams, topics, storageLevel) } /** @@ -224,12 +225,12 @@ class StreamingContext private ( * in its own thread. * @param storageLevel Storage level to use for storing the received objects */ - def kafkaStream[T: ClassManifest]( + def kafkaStream[T: ClassManifest, D <: kafka.serializer.Decoder[_]: Manifest]( kafkaParams: Map[String, String], topics: Map[String, Int], storageLevel: StorageLevel ): DStream[T] = { - val inputStream = new KafkaInputDStream[T](this, kafkaParams, topics, storageLevel) + val inputStream = new KafkaInputDStream[T, D](this, kafkaParams, topics, storageLevel) registerInputStream(inputStream) inputStream } diff --git a/streaming/src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala b/streaming/src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala index 7a8864614c..13427873ff 100644 --- a/streaming/src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala +++ b/streaming/src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala @@ -68,33 +68,50 @@ class JavaStreamingContext(val ssc: StreamingContext) { * @param topics Map of (topic_name -> numPartitions) to consume. Each partition is consumed * in its own thread. */ - def kafkaStream[T]( + def kafkaStream( zkQuorum: String, groupId: String, topics: JMap[String, JInt]) - : JavaDStream[T] = { - implicit val cmt: ClassManifest[T] = - implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]] - ssc.kafkaStream[T](zkQuorum, groupId, Map(topics.mapValues(_.intValue()).toSeq: _*)) + : JavaDStream[String] = { + implicit val cmt: ClassManifest[String] = + implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[String]] + ssc.kafkaStream(zkQuorum, groupId, Map(topics.mapValues(_.intValue()).toSeq: _*), StorageLevel.MEMORY_ONLY_SER_2) } /** * Create an input stream that pulls messages form a Kafka Broker. - * @param kafkaParams Map of kafka configuration paramaters. See: http://kafka.apache.org/configuration.html * @param zkQuorum Zookeper quorum (hostname:port,hostname:port,..). * @param groupId The group id for this consumer. * @param topics Map of (topic_name -> numPartitions) to consume. Each partition is consumed + * @param storageLevel RDD storage level. Defaults to memory-only + * in its own thread. + */ + def kafkaStream( + zkQuorum: String, + groupId: String, + topics: JMap[String, JInt], + storageLevel: StorageLevel) + : JavaDStream[String] = { + implicit val cmt: ClassManifest[String] = + implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[String]] + ssc.kafkaStream(zkQuorum, groupId, Map(topics.mapValues(_.intValue()).toSeq: _*), storageLevel) + } + + /** + * Create an input stream that pulls messages form a Kafka Broker. + * @param kafkaParams Map of kafka configuration paramaters. See: http://kafka.apache.org/configuration.html + * @param topics Map of (topic_name -> numPartitions) to consume. Each partition is consumed * in its own thread. * @param storageLevel RDD storage level. Defaults to memory-only */ - def kafkaStream[T]( + def kafkaStream[T, D <: kafka.serializer.Decoder[_]: Manifest]( kafkaParams: JMap[String, String], topics: JMap[String, JInt], storageLevel: StorageLevel) : JavaDStream[T] = { implicit val cmt: ClassManifest[T] = implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]] - ssc.kafkaStream[T]( + ssc.kafkaStream[T, D]( kafkaParams.toMap, Map(topics.mapValues(_.intValue()).toSeq: _*), storageLevel) diff --git a/streaming/src/main/scala/spark/streaming/dstream/KafkaInputDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/KafkaInputDStream.scala index 17a5be3420..7bd53fb6dd 100644 --- a/streaming/src/main/scala/spark/streaming/dstream/KafkaInputDStream.scala +++ b/streaming/src/main/scala/spark/streaming/dstream/KafkaInputDStream.scala @@ -9,7 +9,7 @@ import java.util.concurrent.Executors import kafka.consumer._ import kafka.message.{Message, MessageSet, MessageAndMetadata} -import kafka.serializer.StringDecoder +import kafka.serializer.Decoder import kafka.utils.{Utils, ZKGroupTopicDirs} import kafka.utils.ZkUtils._ import kafka.utils.ZKStringSerializer @@ -28,7 +28,7 @@ import scala.collection.JavaConversions._ * @param storageLevel RDD storage level. */ private[streaming] -class KafkaInputDStream[T: ClassManifest]( +class KafkaInputDStream[T: ClassManifest, D <: Decoder[_]: Manifest]( @transient ssc_ : StreamingContext, kafkaParams: Map[String, String], topics: Map[String, Int], @@ -37,15 +37,17 @@ class KafkaInputDStream[T: ClassManifest]( def getReceiver(): NetworkReceiver[T] = { - new KafkaReceiver(kafkaParams, topics, storageLevel) + new KafkaReceiver[T, D](kafkaParams, topics, storageLevel) .asInstanceOf[NetworkReceiver[T]] } } private[streaming] -class KafkaReceiver(kafkaParams: Map[String, String], +class KafkaReceiver[T: ClassManifest, D <: Decoder[_]: Manifest]( + kafkaParams: Map[String, String], topics: Map[String, Int], - storageLevel: StorageLevel) extends NetworkReceiver[Any] { + storageLevel: StorageLevel + ) extends NetworkReceiver[Any] { // Handles pushing data into the BlockManager lazy protected val blockGenerator = new BlockGenerator(storageLevel) @@ -82,7 +84,8 @@ class KafkaReceiver(kafkaParams: Map[String, String], } // Create Threads for each Topic/Message Stream we are listening - val topicMessageStreams = consumerConnector.createMessageStreams(topics, new StringDecoder()) + val decoder = manifest[D].erasure.newInstance.asInstanceOf[Decoder[T]] + val topicMessageStreams = consumerConnector.createMessageStreams(topics, decoder) // Start the messages handler for each partition topicMessageStreams.values.foreach { streams => @@ -91,7 +94,7 @@ class KafkaReceiver(kafkaParams: Map[String, String], } // Handles Kafka Messages - private class MessageHandler(stream: KafkaStream[String]) extends Runnable { + private class MessageHandler[T: ClassManifest](stream: KafkaStream[T]) extends Runnable { def run() { logInfo("Starting MessageHandler.") for (msgAndMetadata <- stream) { diff --git a/streaming/src/test/java/spark/streaming/JavaAPISuite.java b/streaming/src/test/java/spark/streaming/JavaAPISuite.java index 3bed500f73..61e4c0a207 100644 --- a/streaming/src/test/java/spark/streaming/JavaAPISuite.java +++ b/streaming/src/test/java/spark/streaming/JavaAPISuite.java @@ -23,7 +23,6 @@ import spark.streaming.api.java.JavaPairDStream; import spark.streaming.api.java.JavaStreamingContext; import spark.streaming.JavaTestUtils; import spark.streaming.JavaCheckpointTestUtils; -import spark.streaming.dstream.KafkaPartitionKey; import spark.streaming.InputStreamsSuite; import java.io.*; @@ -1203,10 +1202,9 @@ public class JavaAPISuite implements Serializable { @Test public void testKafkaStream() { HashMap topics = Maps.newHashMap(); - HashMap offsets = Maps.newHashMap(); JavaDStream test1 = ssc.kafkaStream("localhost:12345", "group", topics); - JavaDStream test2 = ssc.kafkaStream("localhost:12345", "group", topics, offsets); - JavaDStream test3 = ssc.kafkaStream("localhost:12345", "group", topics, offsets, + JavaDStream test2 = ssc.kafkaStream("localhost:12345", "group", topics); + JavaDStream test3 = ssc.kafkaStream("localhost:12345", "group", topics, StorageLevel.MEMORY_AND_DISK()); } -- cgit v1.2.3 From ee6f6aa6cd028e6a3938dcd5334661c27f493bc6 Mon Sep 17 00:00:00 2001 From: Ethan Jewett Date: Thu, 9 May 2013 18:33:38 -0500 Subject: Add hBase example --- .../src/main/scala/spark/examples/HBaseTest.scala | 35 ++++++++++++++++++++++ project/SparkBuild.scala | 6 +++- 2 files changed, 40 insertions(+), 1 deletion(-) create mode 100644 examples/src/main/scala/spark/examples/HBaseTest.scala diff --git a/examples/src/main/scala/spark/examples/HBaseTest.scala b/examples/src/main/scala/spark/examples/HBaseTest.scala new file mode 100644 index 0000000000..6e910154d4 --- /dev/null +++ b/examples/src/main/scala/spark/examples/HBaseTest.scala @@ -0,0 +1,35 @@ +package spark.examples + +import spark._ +import spark.rdd.NewHadoopRDD +import org.apache.hadoop.hbase.{HBaseConfiguration, HTableDescriptor} +import org.apache.hadoop.hbase.client.HBaseAdmin +import org.apache.hadoop.hbase.mapreduce.TableInputFormat + +object HBaseTest { + def main(args: Array[String]) { + val sc = new SparkContext(args(0), "HBaseTest", + System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR"))) + + val conf = HBaseConfiguration.create() + + // Other options for configuring scan behavior are available. More information available at + // http://hbase.apache.org/apidocs/org/apache/hadoop/hbase/mapreduce/TableInputFormat.html + conf.set(TableInputFormat.INPUT_TABLE, args(1)) + + // Initialize hBase table if necessary + val admin = new HBaseAdmin(conf) + if(!admin.isTableAvailable(args(1))) { + val tableDesc = new HTableDescriptor(args(1)) + admin.createTable(tableDesc) + } + + val hBaseRDD = sc.newAPIHadoopRDD(conf, classOf[TableInputFormat], + classOf[org.apache.hadoop.hbase.io.ImmutableBytesWritable], + classOf[org.apache.hadoop.hbase.client.Result]) + + hBaseRDD.count() + + System.exit(0) + } +} \ No newline at end of file diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 190d723435..57fe04ea2d 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -200,7 +200,11 @@ object SparkBuild extends Build { def examplesSettings = sharedSettings ++ Seq( name := "spark-examples", - libraryDependencies ++= Seq("com.twitter" % "algebird-core_2.9.2" % "0.1.11") + resolvers ++= Seq("Apache HBase" at "https://repository.apache.org/content/repositories/releases"), + libraryDependencies ++= Seq( + "com.twitter" % "algebird-core_2.9.2" % "0.1.11", + "org.apache.hbase" % "hbase" % "0.94.6" excludeAll(excludeNetty) + ) ) def bagelSettings = sharedSettings ++ Seq(name := "spark-bagel") -- cgit v1.2.3 From d761e7359deb7ca864d33b8f2e4380b57448630b Mon Sep 17 00:00:00 2001 From: seanm Date: Fri, 10 May 2013 12:05:10 -0600 Subject: adding kafkaStream API tests --- streaming/src/test/java/spark/streaming/JavaAPISuite.java | 4 ++-- .../src/test/scala/spark/streaming/InputStreamsSuite.scala | 11 +++++++++++ 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/streaming/src/test/java/spark/streaming/JavaAPISuite.java b/streaming/src/test/java/spark/streaming/JavaAPISuite.java index 61e4c0a207..350d0888a3 100644 --- a/streaming/src/test/java/spark/streaming/JavaAPISuite.java +++ b/streaming/src/test/java/spark/streaming/JavaAPISuite.java @@ -4,6 +4,7 @@ import com.google.common.base.Optional; import com.google.common.collect.Lists; import com.google.common.collect.Maps; import com.google.common.io.Files; +import kafka.serializer.StringDecoder; import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat; import org.junit.After; import org.junit.Assert; @@ -1203,8 +1204,7 @@ public class JavaAPISuite implements Serializable { public void testKafkaStream() { HashMap topics = Maps.newHashMap(); JavaDStream test1 = ssc.kafkaStream("localhost:12345", "group", topics); - JavaDStream test2 = ssc.kafkaStream("localhost:12345", "group", topics); - JavaDStream test3 = ssc.kafkaStream("localhost:12345", "group", topics, + JavaDStream test2 = ssc.kafkaStream("localhost:12345", "group", topics, StorageLevel.MEMORY_AND_DISK()); } diff --git a/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala b/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala index 1024d3ac97..595c766a21 100644 --- a/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala +++ b/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala @@ -240,6 +240,17 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { assert(output(i) === expectedOutput(i)) } } + + test("kafka input stream") { + val ssc = new StreamingContext(master, framework, batchDuration) + val topics = Map("my-topic" -> 1) + val test1 = ssc.kafkaStream("localhost:12345", "group", topics) + val test2 = ssc.kafkaStream("localhost:12345", "group", topics, StorageLevel.MEMORY_AND_DISK) + + // Test specifying decoder + val kafkaParams = Map("zk.connect"->"localhost:12345","groupid"->"consumer-group") + val test3 = ssc.kafkaStream[String, kafka.serializer.StringDecoder](kafkaParams, topics, StorageLevel.MEMORY_AND_DISK) + } } -- cgit v1.2.3 From b95c1bdbbaeea86152e24b394a03bbbad95989d5 Mon Sep 17 00:00:00 2001 From: seanm Date: Fri, 10 May 2013 12:47:24 -0600 Subject: count() now uses a transform instead of ConstantInputDStream --- streaming/src/main/scala/spark/streaming/DStream.scala | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/streaming/src/main/scala/spark/streaming/DStream.scala b/streaming/src/main/scala/spark/streaming/DStream.scala index e3a9247924..e125310861 100644 --- a/streaming/src/main/scala/spark/streaming/DStream.scala +++ b/streaming/src/main/scala/spark/streaming/DStream.scala @@ -441,10 +441,7 @@ abstract class DStream[T: ClassManifest] ( * Return a new DStream in which each RDD has a single element generated by counting each RDD * of this DStream. */ - def count(): DStream[Long] = { - val zero = new ConstantInputDStream(context, context.sparkContext.makeRDD(Seq((null, 0L)), 1)) - this.map(_ => (null, 1L)).union(zero).reduceByKey(_ + _).map(_._2) - } + def count(): DStream[Long] = this.map(_ => (null, 1L)).transform(_.union(context.sparkContext.makeRDD(Seq((null, 0L)), 1))).reduceByKey(_ + _).map(_._2) /** * Return a new DStream in which each RDD contains the counts of each distinct value in -- cgit v1.2.3 From 3632980b1b61dbb9ab9a3ab3d92fb415cb7173b9 Mon Sep 17 00:00:00 2001 From: seanm Date: Fri, 10 May 2013 15:54:26 -0600 Subject: fixing indentation --- .../src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/streaming/src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala b/streaming/src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala index 13427873ff..4ad2bdf8a8 100644 --- a/streaming/src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala +++ b/streaming/src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala @@ -105,7 +105,7 @@ class JavaStreamingContext(val ssc: StreamingContext) { * @param storageLevel RDD storage level. Defaults to memory-only */ def kafkaStream[T, D <: kafka.serializer.Decoder[_]: Manifest]( - kafkaParams: JMap[String, String], + kafkaParams: JMap[String, String], topics: JMap[String, JInt], storageLevel: StorageLevel) : JavaDStream[T] = { -- cgit v1.2.3 From f25282def5826fab6caabff28c82c57a7f3fdcb8 Mon Sep 17 00:00:00 2001 From: seanm Date: Fri, 10 May 2013 17:34:28 -0600 Subject: fixing kafkaStream Java API and adding test --- .../scala/spark/streaming/api/java/JavaStreamingContext.scala | 10 +++++++--- streaming/src/test/java/spark/streaming/JavaAPISuite.java | 6 ++++++ 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/streaming/src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala b/streaming/src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala index 4ad2bdf8a8..b35d9032f1 100644 --- a/streaming/src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala +++ b/streaming/src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala @@ -99,18 +99,22 @@ class JavaStreamingContext(val ssc: StreamingContext) { /** * Create an input stream that pulls messages form a Kafka Broker. + * @param typeClass Type of RDD + * @param decoderClass Type of kafka decoder * @param kafkaParams Map of kafka configuration paramaters. See: http://kafka.apache.org/configuration.html * @param topics Map of (topic_name -> numPartitions) to consume. Each partition is consumed * in its own thread. * @param storageLevel RDD storage level. Defaults to memory-only */ - def kafkaStream[T, D <: kafka.serializer.Decoder[_]: Manifest]( + def kafkaStream[T, D <: kafka.serializer.Decoder[_]]( + typeClass: Class[T], + decoderClass: Class[D], kafkaParams: JMap[String, String], topics: JMap[String, JInt], storageLevel: StorageLevel) : JavaDStream[T] = { - implicit val cmt: ClassManifest[T] = - implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]] + implicit val cmt: ClassManifest[T] = implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]] + implicit val cmd: Manifest[D] = implicitly[Manifest[AnyRef]].asInstanceOf[Manifest[D]] ssc.kafkaStream[T, D]( kafkaParams.toMap, Map(topics.mapValues(_.intValue()).toSeq: _*), diff --git a/streaming/src/test/java/spark/streaming/JavaAPISuite.java b/streaming/src/test/java/spark/streaming/JavaAPISuite.java index 350d0888a3..e5fdbe1b7a 100644 --- a/streaming/src/test/java/spark/streaming/JavaAPISuite.java +++ b/streaming/src/test/java/spark/streaming/JavaAPISuite.java @@ -1206,6 +1206,12 @@ public class JavaAPISuite implements Serializable { JavaDStream test1 = ssc.kafkaStream("localhost:12345", "group", topics); JavaDStream test2 = ssc.kafkaStream("localhost:12345", "group", topics, StorageLevel.MEMORY_AND_DISK()); + + HashMap kafkaParams = Maps.newHashMap(); + kafkaParams.put("zk.connect","localhost:12345"); + kafkaParams.put("groupid","consumer-group"); + JavaDStream test3 = ssc.kafkaStream(String.class, StringDecoder.class, kafkaParams, topics, + StorageLevel.MEMORY_AND_DISK()); } @Test -- cgit v1.2.3 From 3217d486f7fdd590250f2efee567e4779e130d34 Mon Sep 17 00:00:00 2001 From: Ethan Jewett Date: Mon, 20 May 2013 19:41:38 -0500 Subject: Add hBase dependency to examples POM --- examples/pom.xml | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/examples/pom.xml b/examples/pom.xml index c42d2bcdb9..0fbb5a3d5d 100644 --- a/examples/pom.xml +++ b/examples/pom.xml @@ -67,6 +67,11 @@ hadoop-core provided + + org.apache.hbase + hbase + 0.94.6 + @@ -105,6 +110,11 @@ hadoop-client provided + + org.apache.hbase + hbase + 0.94.6 + -- cgit v1.2.3 From a674d67c0aebb940e3b816e2307206115baec175 Mon Sep 17 00:00:00 2001 From: Stephen Haberman Date: Tue, 28 May 2013 16:24:05 -0500 Subject: Fix start-slave not passing instance number to spark-daemon. --- bin/start-slave.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bin/start-slave.sh b/bin/start-slave.sh index 26b5b9d462..dfcbc6981b 100755 --- a/bin/start-slave.sh +++ b/bin/start-slave.sh @@ -12,4 +12,4 @@ if [ "$SPARK_PUBLIC_DNS" = "" ]; then fi fi -"$bin"/spark-daemon.sh start spark.deploy.worker.Worker "$@" +"$bin"/spark-daemon.sh start spark.deploy.worker.Worker 1 "$@" -- cgit v1.2.3 From ecceb101d3019ef511c42a8a8a3bb0e46520ffef Mon Sep 17 00:00:00 2001 From: Andrew xia Date: Thu, 30 May 2013 10:43:01 +0800 Subject: implement FIFO and fair scheduler for spark local mode --- .../spark/scheduler/cluster/ClusterScheduler.scala | 2 +- .../scheduler/cluster/ClusterTaskSetManager.scala | 734 +++++++++++++++++++++ .../spark/scheduler/cluster/TaskSetManager.scala | 733 +------------------- .../spark/scheduler/local/LocalScheduler.scala | 386 +++++++++-- .../spark/scheduler/ClusterSchedulerSuite.scala | 2 +- 5 files changed, 1057 insertions(+), 800 deletions(-) create mode 100644 core/src/main/scala/spark/scheduler/cluster/ClusterTaskSetManager.scala diff --git a/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala b/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala index 053d4b8e4a..3a0c29b27f 100644 --- a/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala +++ b/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala @@ -177,7 +177,7 @@ private[spark] class ClusterScheduler(val sc: SparkContext) val tasks = taskSet.tasks logInfo("Adding task set " + taskSet.id + " with " + tasks.length + " tasks") this.synchronized { - val manager = new TaskSetManager(this, taskSet) + val manager = new ClusterTaskSetManager(this, taskSet) activeTaskSets(taskSet.id) = manager schedulableBuilder.addTaskSetManager(manager, manager.taskSet.properties) taskSetTaskIds(taskSet.id) = new HashSet[Long]() diff --git a/core/src/main/scala/spark/scheduler/cluster/ClusterTaskSetManager.scala b/core/src/main/scala/spark/scheduler/cluster/ClusterTaskSetManager.scala new file mode 100644 index 0000000000..ec4041ab86 --- /dev/null +++ b/core/src/main/scala/spark/scheduler/cluster/ClusterTaskSetManager.scala @@ -0,0 +1,734 @@ +package spark.scheduler.cluster + +import java.util.{HashMap => JHashMap, NoSuchElementException, Arrays} + +import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.HashMap +import scala.collection.mutable.HashSet +import scala.math.max +import scala.math.min + +import spark._ +import spark.scheduler._ +import spark.TaskState.TaskState +import java.nio.ByteBuffer + +private[spark] object TaskLocality extends Enumeration("PROCESS_LOCAL", "NODE_LOCAL", "RACK_LOCAL", "ANY") with Logging { + + // process local is expected to be used ONLY within tasksetmanager for now. + val PROCESS_LOCAL, NODE_LOCAL, RACK_LOCAL, ANY = Value + + type TaskLocality = Value + + def isAllowed(constraint: TaskLocality, condition: TaskLocality): Boolean = { + + // Must not be the constraint. + assert (constraint != TaskLocality.PROCESS_LOCAL) + + constraint match { + case TaskLocality.NODE_LOCAL => condition == TaskLocality.NODE_LOCAL + case TaskLocality.RACK_LOCAL => condition == TaskLocality.NODE_LOCAL || condition == TaskLocality.RACK_LOCAL + // For anything else, allow + case _ => true + } + } + + def parse(str: String): TaskLocality = { + // better way to do this ? + try { + val retval = TaskLocality.withName(str) + // Must not specify PROCESS_LOCAL ! + assert (retval != TaskLocality.PROCESS_LOCAL) + + retval + } catch { + case nEx: NoSuchElementException => { + logWarning("Invalid task locality specified '" + str + "', defaulting to NODE_LOCAL"); + // default to preserve earlier behavior + NODE_LOCAL + } + } + } +} + +/** + * Schedules the tasks within a single TaskSet in the ClusterScheduler. + */ +private[spark] class ClusterTaskSetManager( + sched: ClusterScheduler, + val taskSet: TaskSet) + extends TaskSetManager + with Logging { + + // Maximum time to wait to run a task in a preferred location (in ms) + val LOCALITY_WAIT = System.getProperty("spark.locality.wait", "3000").toLong + + // CPUs to request per task + val CPUS_PER_TASK = System.getProperty("spark.task.cpus", "1").toDouble + + // Maximum times a task is allowed to fail before failing the job + val MAX_TASK_FAILURES = 4 + + // Quantile of tasks at which to start speculation + val SPECULATION_QUANTILE = System.getProperty("spark.speculation.quantile", "0.75").toDouble + val SPECULATION_MULTIPLIER = System.getProperty("spark.speculation.multiplier", "1.5").toDouble + + // Serializer for closures and tasks. + val ser = SparkEnv.get.closureSerializer.newInstance() + + val tasks = taskSet.tasks + val numTasks = tasks.length + val copiesRunning = new Array[Int](numTasks) + val finished = new Array[Boolean](numTasks) + val numFailures = new Array[Int](numTasks) + val taskAttempts = Array.fill[List[TaskInfo]](numTasks)(Nil) + var tasksFinished = 0 + + var weight = 1 + var minShare = 0 + var runningTasks = 0 + var priority = taskSet.priority + var stageId = taskSet.stageId + var name = "TaskSet_"+taskSet.stageId.toString + var parent:Schedulable = null + + // Last time when we launched a preferred task (for delay scheduling) + var lastPreferredLaunchTime = System.currentTimeMillis + + // List of pending tasks for each node (process local to container). These collections are actually + // treated as stacks, in which new tasks are added to the end of the + // ArrayBuffer and removed from the end. This makes it faster to detect + // tasks that repeatedly fail because whenever a task failed, it is put + // back at the head of the stack. They are also only cleaned up lazily; + // when a task is launched, it remains in all the pending lists except + // the one that it was launched from, but gets removed from them later. + private val pendingTasksForHostPort = new HashMap[String, ArrayBuffer[Int]] + + // List of pending tasks for each node. + // Essentially, similar to pendingTasksForHostPort, except at host level + private val pendingTasksForHost = new HashMap[String, ArrayBuffer[Int]] + + // List of pending tasks for each node based on rack locality. + // Essentially, similar to pendingTasksForHost, except at rack level + private val pendingRackLocalTasksForHost = new HashMap[String, ArrayBuffer[Int]] + + // List containing pending tasks with no locality preferences + val pendingTasksWithNoPrefs = new ArrayBuffer[Int] + + // List containing all pending tasks (also used as a stack, as above) + val allPendingTasks = new ArrayBuffer[Int] + + // Tasks that can be speculated. Since these will be a small fraction of total + // tasks, we'll just hold them in a HashSet. + val speculatableTasks = new HashSet[Int] + + // Task index, start and finish time for each task attempt (indexed by task ID) + val taskInfos = new HashMap[Long, TaskInfo] + + // Did the job fail? + var failed = false + var causeOfFailure = "" + + // How frequently to reprint duplicate exceptions in full, in milliseconds + val EXCEPTION_PRINT_INTERVAL = + System.getProperty("spark.logging.exceptionPrintInterval", "10000").toLong + // Map of recent exceptions (identified by string representation and + // top stack frame) to duplicate count (how many times the same + // exception has appeared) and time the full exception was + // printed. This should ideally be an LRU map that can drop old + // exceptions automatically. + val recentExceptions = HashMap[String, (Int, Long)]() + + // Figure out the current map output tracker generation and set it on all tasks + val generation = sched.mapOutputTracker.getGeneration + logDebug("Generation for " + taskSet.id + ": " + generation) + for (t <- tasks) { + t.generation = generation + } + + // Add all our tasks to the pending lists. We do this in reverse order + // of task index so that tasks with low indices get launched first. + for (i <- (0 until numTasks).reverse) { + addPendingTask(i) + } + + // Note that it follows the hierarchy. + // if we search for NODE_LOCAL, the output will include PROCESS_LOCAL and + // if we search for RACK_LOCAL, it will include PROCESS_LOCAL & NODE_LOCAL + private def findPreferredLocations(_taskPreferredLocations: Seq[String], scheduler: ClusterScheduler, + taskLocality: TaskLocality.TaskLocality): HashSet[String] = { + + if (TaskLocality.PROCESS_LOCAL == taskLocality) { + // straight forward comparison ! Special case it. + val retval = new HashSet[String]() + scheduler.synchronized { + for (location <- _taskPreferredLocations) { + if (scheduler.isExecutorAliveOnHostPort(location)) { + retval += location + } + } + } + + return retval + } + + val taskPreferredLocations = + if (TaskLocality.NODE_LOCAL == taskLocality) { + _taskPreferredLocations + } else { + assert (TaskLocality.RACK_LOCAL == taskLocality) + // Expand set to include all 'seen' rack local hosts. + // This works since container allocation/management happens within master - so any rack locality information is updated in msater. + // Best case effort, and maybe sort of kludge for now ... rework it later ? + val hosts = new HashSet[String] + _taskPreferredLocations.foreach(h => { + val rackOpt = scheduler.getRackForHost(h) + if (rackOpt.isDefined) { + val hostsOpt = scheduler.getCachedHostsForRack(rackOpt.get) + if (hostsOpt.isDefined) { + hosts ++= hostsOpt.get + } + } + + // Ensure that irrespective of what scheduler says, host is always added ! + hosts += h + }) + + hosts + } + + val retval = new HashSet[String] + scheduler.synchronized { + for (prefLocation <- taskPreferredLocations) { + val aliveLocationsOpt = scheduler.getExecutorsAliveOnHost(Utils.parseHostPort(prefLocation)._1) + if (aliveLocationsOpt.isDefined) { + retval ++= aliveLocationsOpt.get + } + } + } + + retval + } + + // Add a task to all the pending-task lists that it should be on. + private def addPendingTask(index: Int) { + // We can infer hostLocalLocations from rackLocalLocations by joining it against tasks(index).preferredLocations (with appropriate + // hostPort <-> host conversion). But not doing it for simplicity sake. If this becomes a performance issue, modify it. + val processLocalLocations = findPreferredLocations(tasks(index).preferredLocations, sched, TaskLocality.PROCESS_LOCAL) + val hostLocalLocations = findPreferredLocations(tasks(index).preferredLocations, sched, TaskLocality.NODE_LOCAL) + val rackLocalLocations = findPreferredLocations(tasks(index).preferredLocations, sched, TaskLocality.RACK_LOCAL) + + if (rackLocalLocations.size == 0) { + // Current impl ensures this. + assert (processLocalLocations.size == 0) + assert (hostLocalLocations.size == 0) + pendingTasksWithNoPrefs += index + } else { + + // process local locality + for (hostPort <- processLocalLocations) { + // DEBUG Code + Utils.checkHostPort(hostPort) + + val hostPortList = pendingTasksForHostPort.getOrElseUpdate(hostPort, ArrayBuffer()) + hostPortList += index + } + + // host locality (includes process local) + for (hostPort <- hostLocalLocations) { + // DEBUG Code + Utils.checkHostPort(hostPort) + + val host = Utils.parseHostPort(hostPort)._1 + val hostList = pendingTasksForHost.getOrElseUpdate(host, ArrayBuffer()) + hostList += index + } + + // rack locality (includes process local and host local) + for (rackLocalHostPort <- rackLocalLocations) { + // DEBUG Code + Utils.checkHostPort(rackLocalHostPort) + + val rackLocalHost = Utils.parseHostPort(rackLocalHostPort)._1 + val list = pendingRackLocalTasksForHost.getOrElseUpdate(rackLocalHost, ArrayBuffer()) + list += index + } + } + + allPendingTasks += index + } + + // Return the pending tasks list for a given host port (process local), or an empty list if + // there is no map entry for that host + private def getPendingTasksForHostPort(hostPort: String): ArrayBuffer[Int] = { + // DEBUG Code + Utils.checkHostPort(hostPort) + pendingTasksForHostPort.getOrElse(hostPort, ArrayBuffer()) + } + + // Return the pending tasks list for a given host, or an empty list if + // there is no map entry for that host + private def getPendingTasksForHost(hostPort: String): ArrayBuffer[Int] = { + val host = Utils.parseHostPort(hostPort)._1 + pendingTasksForHost.getOrElse(host, ArrayBuffer()) + } + + // Return the pending tasks (rack level) list for a given host, or an empty list if + // there is no map entry for that host + private def getRackLocalPendingTasksForHost(hostPort: String): ArrayBuffer[Int] = { + val host = Utils.parseHostPort(hostPort)._1 + pendingRackLocalTasksForHost.getOrElse(host, ArrayBuffer()) + } + + // Number of pending tasks for a given host Port (which would be process local) + def numPendingTasksForHostPort(hostPort: String): Int = { + getPendingTasksForHostPort(hostPort).count( index => copiesRunning(index) == 0 && !finished(index) ) + } + + // Number of pending tasks for a given host (which would be data local) + def numPendingTasksForHost(hostPort: String): Int = { + getPendingTasksForHost(hostPort).count( index => copiesRunning(index) == 0 && !finished(index) ) + } + + // Number of pending rack local tasks for a given host + def numRackLocalPendingTasksForHost(hostPort: String): Int = { + getRackLocalPendingTasksForHost(hostPort).count( index => copiesRunning(index) == 0 && !finished(index) ) + } + + + // Dequeue a pending task from the given list and return its index. + // Return None if the list is empty. + // This method also cleans up any tasks in the list that have already + // been launched, since we want that to happen lazily. + private def findTaskFromList(list: ArrayBuffer[Int]): Option[Int] = { + while (!list.isEmpty) { + val index = list.last + list.trimEnd(1) + if (copiesRunning(index) == 0 && !finished(index)) { + return Some(index) + } + } + return None + } + + // Return a speculative task for a given host if any are available. The task should not have an + // attempt running on this host, in case the host is slow. In addition, if locality is set, the + // task must have a preference for this host/rack/no preferred locations at all. + private def findSpeculativeTask(hostPort: String, locality: TaskLocality.TaskLocality): Option[Int] = { + + assert (TaskLocality.isAllowed(locality, TaskLocality.NODE_LOCAL)) + speculatableTasks.retain(index => !finished(index)) // Remove finished tasks from set + + if (speculatableTasks.size > 0) { + val localTask = speculatableTasks.find { + index => + val locations = findPreferredLocations(tasks(index).preferredLocations, sched, TaskLocality.NODE_LOCAL) + val attemptLocs = taskAttempts(index).map(_.hostPort) + (locations.size == 0 || locations.contains(hostPort)) && !attemptLocs.contains(hostPort) + } + + if (localTask != None) { + speculatableTasks -= localTask.get + return localTask + } + + // check for rack locality + if (TaskLocality.isAllowed(locality, TaskLocality.RACK_LOCAL)) { + val rackTask = speculatableTasks.find { + index => + val locations = findPreferredLocations(tasks(index).preferredLocations, sched, TaskLocality.RACK_LOCAL) + val attemptLocs = taskAttempts(index).map(_.hostPort) + locations.contains(hostPort) && !attemptLocs.contains(hostPort) + } + + if (rackTask != None) { + speculatableTasks -= rackTask.get + return rackTask + } + } + + // Any task ... + if (TaskLocality.isAllowed(locality, TaskLocality.ANY)) { + // Check for attemptLocs also ? + val nonLocalTask = speculatableTasks.find(i => !taskAttempts(i).map(_.hostPort).contains(hostPort)) + if (nonLocalTask != None) { + speculatableTasks -= nonLocalTask.get + return nonLocalTask + } + } + } + return None + } + + // Dequeue a pending task for a given node and return its index. + // If localOnly is set to false, allow non-local tasks as well. + private def findTask(hostPort: String, locality: TaskLocality.TaskLocality): Option[Int] = { + val processLocalTask = findTaskFromList(getPendingTasksForHostPort(hostPort)) + if (processLocalTask != None) { + return processLocalTask + } + + val localTask = findTaskFromList(getPendingTasksForHost(hostPort)) + if (localTask != None) { + return localTask + } + + if (TaskLocality.isAllowed(locality, TaskLocality.RACK_LOCAL)) { + val rackLocalTask = findTaskFromList(getRackLocalPendingTasksForHost(hostPort)) + if (rackLocalTask != None) { + return rackLocalTask + } + } + + // Look for no pref tasks AFTER rack local tasks - this has side effect that we will get to failed tasks later rather than sooner. + // TODO: That code path needs to be revisited (adding to no prefs list when host:port goes down). + val noPrefTask = findTaskFromList(pendingTasksWithNoPrefs) + if (noPrefTask != None) { + return noPrefTask + } + + if (TaskLocality.isAllowed(locality, TaskLocality.ANY)) { + val nonLocalTask = findTaskFromList(allPendingTasks) + if (nonLocalTask != None) { + return nonLocalTask + } + } + + // Finally, if all else has failed, find a speculative task + return findSpeculativeTask(hostPort, locality) + } + + private def isProcessLocalLocation(task: Task[_], hostPort: String): Boolean = { + Utils.checkHostPort(hostPort) + + val locs = task.preferredLocations + + locs.contains(hostPort) + } + + private def isHostLocalLocation(task: Task[_], hostPort: String): Boolean = { + val locs = task.preferredLocations + + // If no preference, consider it as host local + if (locs.isEmpty) return true + + val host = Utils.parseHostPort(hostPort)._1 + locs.find(h => Utils.parseHostPort(h)._1 == host).isDefined + } + + // Does a host count as a rack local preferred location for a task? (assumes host is NOT preferred location). + // This is true if either the task has preferred locations and this host is one, or it has + // no preferred locations (in which we still count the launch as preferred). + private def isRackLocalLocation(task: Task[_], hostPort: String): Boolean = { + + val locs = task.preferredLocations + + val preferredRacks = new HashSet[String]() + for (preferredHost <- locs) { + val rack = sched.getRackForHost(preferredHost) + if (None != rack) preferredRacks += rack.get + } + + if (preferredRacks.isEmpty) return false + + val hostRack = sched.getRackForHost(hostPort) + + return None != hostRack && preferredRacks.contains(hostRack.get) + } + + // Respond to an offer of a single slave from the scheduler by finding a task + def slaveOffer(execId: String, hostPort: String, availableCpus: Double, overrideLocality: TaskLocality.TaskLocality = null): Option[TaskDescription] = { + + if (tasksFinished < numTasks && availableCpus >= CPUS_PER_TASK) { + // If explicitly specified, use that + val locality = if (overrideLocality != null) overrideLocality else { + // expand only if we have waited for more than LOCALITY_WAIT for a host local task ... + val time = System.currentTimeMillis + if (time - lastPreferredLaunchTime < LOCALITY_WAIT) TaskLocality.NODE_LOCAL else TaskLocality.ANY + } + + findTask(hostPort, locality) match { + case Some(index) => { + // Found a task; do some bookkeeping and return a Mesos task for it + val task = tasks(index) + val taskId = sched.newTaskId() + // Figure out whether this should count as a preferred launch + val taskLocality = + if (isProcessLocalLocation(task, hostPort)) TaskLocality.PROCESS_LOCAL else + if (isHostLocalLocation(task, hostPort)) TaskLocality.NODE_LOCAL else + if (isRackLocalLocation(task, hostPort)) TaskLocality.RACK_LOCAL else + TaskLocality.ANY + val prefStr = taskLocality.toString + logInfo("Starting task %s:%d as TID %s on slave %s: %s (%s)".format( + taskSet.id, index, taskId, execId, hostPort, prefStr)) + // Do various bookkeeping + copiesRunning(index) += 1 + val time = System.currentTimeMillis + val info = new TaskInfo(taskId, index, time, execId, hostPort, taskLocality) + taskInfos(taskId) = info + taskAttempts(index) = info :: taskAttempts(index) + if (TaskLocality.NODE_LOCAL == taskLocality) { + lastPreferredLaunchTime = time + } + // Serialize and return the task + val startTime = System.currentTimeMillis + val serializedTask = Task.serializeWithDependencies( + task, sched.sc.addedFiles, sched.sc.addedJars, ser) + val timeTaken = System.currentTimeMillis - startTime + increaseRunningTasks(1) + logInfo("Serialized task %s:%d as %d bytes in %d ms".format( + taskSet.id, index, serializedTask.limit, timeTaken)) + val taskName = "task %s:%d".format(taskSet.id, index) + return Some(new TaskDescription(taskId, execId, taskName, serializedTask)) + } + case _ => + } + } + return None + } + + def statusUpdate(tid: Long, state: TaskState, serializedData: ByteBuffer) { + state match { + case TaskState.FINISHED => + taskFinished(tid, state, serializedData) + case TaskState.LOST => + taskLost(tid, state, serializedData) + case TaskState.FAILED => + taskLost(tid, state, serializedData) + case TaskState.KILLED => + taskLost(tid, state, serializedData) + case _ => + } + } + + def taskFinished(tid: Long, state: TaskState, serializedData: ByteBuffer) { + val info = taskInfos(tid) + if (info.failed) { + // We might get two task-lost messages for the same task in coarse-grained Mesos mode, + // or even from Mesos itself when acks get delayed. + return + } + val index = info.index + info.markSuccessful() + decreaseRunningTasks(1) + if (!finished(index)) { + tasksFinished += 1 + logInfo("Finished TID %s in %d ms (progress: %d/%d)".format( + tid, info.duration, tasksFinished, numTasks)) + // Deserialize task result and pass it to the scheduler + val result = ser.deserialize[TaskResult[_]](serializedData, getClass.getClassLoader) + result.metrics.resultSize = serializedData.limit() + sched.listener.taskEnded(tasks(index), Success, result.value, result.accumUpdates, info, result.metrics) + // Mark finished and stop if we've finished all the tasks + finished(index) = true + if (tasksFinished == numTasks) { + sched.taskSetFinished(this) + } + } else { + logInfo("Ignoring task-finished event for TID " + tid + + " because task " + index + " is already finished") + } + } + + def taskLost(tid: Long, state: TaskState, serializedData: ByteBuffer) { + val info = taskInfos(tid) + if (info.failed) { + // We might get two task-lost messages for the same task in coarse-grained Mesos mode, + // or even from Mesos itself when acks get delayed. + return + } + val index = info.index + info.markFailed() + decreaseRunningTasks(1) + if (!finished(index)) { + logInfo("Lost TID %s (task %s:%d)".format(tid, taskSet.id, index)) + copiesRunning(index) -= 1 + // Check if the problem is a map output fetch failure. In that case, this + // task will never succeed on any node, so tell the scheduler about it. + if (serializedData != null && serializedData.limit() > 0) { + val reason = ser.deserialize[TaskEndReason](serializedData, getClass.getClassLoader) + reason match { + case fetchFailed: FetchFailed => + logInfo("Loss was due to fetch failure from " + fetchFailed.bmAddress) + sched.listener.taskEnded(tasks(index), fetchFailed, null, null, info, null) + finished(index) = true + tasksFinished += 1 + sched.taskSetFinished(this) + decreaseRunningTasks(runningTasks) + return + + case ef: ExceptionFailure => + val key = ef.description + val now = System.currentTimeMillis + val (printFull, dupCount) = { + if (recentExceptions.contains(key)) { + val (dupCount, printTime) = recentExceptions(key) + if (now - printTime > EXCEPTION_PRINT_INTERVAL) { + recentExceptions(key) = (0, now) + (true, 0) + } else { + recentExceptions(key) = (dupCount + 1, printTime) + (false, dupCount + 1) + } + } else { + recentExceptions(key) = (0, now) + (true, 0) + } + } + if (printFull) { + val locs = ef.stackTrace.map(loc => "\tat %s".format(loc.toString)) + logInfo("Loss was due to %s\n%s\n%s".format( + ef.className, ef.description, locs.mkString("\n"))) + } else { + logInfo("Loss was due to %s [duplicate %d]".format(ef.description, dupCount)) + } + + case _ => {} + } + } + // On non-fetch failures, re-enqueue the task as pending for a max number of retries + addPendingTask(index) + // Count failed attempts only on FAILED and LOST state (not on KILLED) + if (state == TaskState.FAILED || state == TaskState.LOST) { + numFailures(index) += 1 + if (numFailures(index) > MAX_TASK_FAILURES) { + logError("Task %s:%d failed more than %d times; aborting job".format( + taskSet.id, index, MAX_TASK_FAILURES)) + abort("Task %s:%d failed more than %d times".format(taskSet.id, index, MAX_TASK_FAILURES)) + } + } + } else { + logInfo("Ignoring task-lost event for TID " + tid + + " because task " + index + " is already finished") + } + } + + def error(message: String) { + // Save the error message + abort("Error: " + message) + } + + def abort(message: String) { + failed = true + causeOfFailure = message + // TODO: Kill running tasks if we were not terminated due to a Mesos error + sched.listener.taskSetFailed(taskSet, message) + decreaseRunningTasks(runningTasks) + sched.taskSetFinished(this) + } + + override def increaseRunningTasks(taskNum: Int) { + runningTasks += taskNum + if (parent != null) { + parent.increaseRunningTasks(taskNum) + } + } + + override def decreaseRunningTasks(taskNum: Int) { + runningTasks -= taskNum + if (parent != null) { + parent.decreaseRunningTasks(taskNum) + } + } + + //TODO: for now we just find Pool not TaskSetManager, we can extend this function in future if needed + override def getSchedulableByName(name: String): Schedulable = { + return null + } + + override def addSchedulable(schedulable:Schedulable) { + //nothing + } + + override def removeSchedulable(schedulable:Schedulable) { + //nothing + } + + override def getSortedTaskSetQueue(): ArrayBuffer[TaskSetManager] = { + var sortedTaskSetQueue = new ArrayBuffer[TaskSetManager] + sortedTaskSetQueue += this + return sortedTaskSetQueue + } + + override def executorLost(execId: String, hostPort: String) { + logInfo("Re-queueing tasks for " + execId + " from TaskSet " + taskSet.id) + + // If some task has preferred locations only on hostname, and there are no more executors there, + // put it in the no-prefs list to avoid the wait from delay scheduling + + // host local tasks - should we push this to rack local or no pref list ? For now, preserving behavior and moving to + // no prefs list. Note, this was done due to impliations related to 'waiting' for data local tasks, etc. + // Note: NOT checking process local list - since host local list is super set of that. We need to ad to no prefs only if + // there is no host local node for the task (not if there is no process local node for the task) + for (index <- getPendingTasksForHost(Utils.parseHostPort(hostPort)._1)) { + // val newLocs = findPreferredLocations(tasks(index).preferredLocations, sched, TaskLocality.RACK_LOCAL) + val newLocs = findPreferredLocations(tasks(index).preferredLocations, sched, TaskLocality.NODE_LOCAL) + if (newLocs.isEmpty) { + pendingTasksWithNoPrefs += index + } + } + + // Re-enqueue any tasks that ran on the failed executor if this is a shuffle map stage + if (tasks(0).isInstanceOf[ShuffleMapTask]) { + for ((tid, info) <- taskInfos if info.executorId == execId) { + val index = taskInfos(tid).index + if (finished(index)) { + finished(index) = false + copiesRunning(index) -= 1 + tasksFinished -= 1 + addPendingTask(index) + // Tell the DAGScheduler that this task was resubmitted so that it doesn't think our + // stage finishes when a total of tasks.size tasks finish. + sched.listener.taskEnded(tasks(index), Resubmitted, null, null, info, null) + } + } + } + // Also re-enqueue any tasks that were running on the node + for ((tid, info) <- taskInfos if info.running && info.executorId == execId) { + taskLost(tid, TaskState.KILLED, null) + } + } + + /** + * Check for tasks to be speculated and return true if there are any. This is called periodically + * by the ClusterScheduler. + * + * TODO: To make this scale to large jobs, we need to maintain a list of running tasks, so that + * we don't scan the whole task set. It might also help to make this sorted by launch time. + */ + override def checkSpeculatableTasks(): Boolean = { + // Can't speculate if we only have one task, or if all tasks have finished. + if (numTasks == 1 || tasksFinished == numTasks) { + return false + } + var foundTasks = false + val minFinishedForSpeculation = (SPECULATION_QUANTILE * numTasks).floor.toInt + logDebug("Checking for speculative tasks: minFinished = " + minFinishedForSpeculation) + if (tasksFinished >= minFinishedForSpeculation) { + val time = System.currentTimeMillis() + val durations = taskInfos.values.filter(_.successful).map(_.duration).toArray + Arrays.sort(durations) + val medianDuration = durations(min((0.5 * numTasks).round.toInt, durations.size - 1)) + val threshold = max(SPECULATION_MULTIPLIER * medianDuration, 100) + // TODO: Threshold should also look at standard deviation of task durations and have a lower + // bound based on that. + logDebug("Task length threshold for speculation: " + threshold) + for ((tid, info) <- taskInfos) { + val index = info.index + if (!finished(index) && copiesRunning(index) == 1 && info.timeRunning(time) > threshold && + !speculatableTasks.contains(index)) { + logInfo( + "Marking task %s:%d (on %s) as speculatable because it ran more than %.0f ms".format( + taskSet.id, index, info.hostPort, threshold)) + speculatableTasks += index + foundTasks = true + } + } + } + return foundTasks + } + + override def hasPendingTasks(): Boolean = { + numTasks > 0 && tasksFinished < numTasks + } +} diff --git a/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala b/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala index 1c403ef323..2b5a74d4e5 100644 --- a/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala +++ b/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala @@ -1,734 +1,17 @@ package spark.scheduler.cluster -import java.util.{HashMap => JHashMap, NoSuchElementException, Arrays} - import scala.collection.mutable.ArrayBuffer -import scala.collection.mutable.HashMap -import scala.collection.mutable.HashSet -import scala.math.max -import scala.math.min - -import spark._ import spark.scheduler._ import spark.TaskState.TaskState import java.nio.ByteBuffer - -private[spark] object TaskLocality extends Enumeration("PROCESS_LOCAL", "NODE_LOCAL", "RACK_LOCAL", "ANY") with Logging { - - // process local is expected to be used ONLY within tasksetmanager for now. - val PROCESS_LOCAL, NODE_LOCAL, RACK_LOCAL, ANY = Value - - type TaskLocality = Value - - def isAllowed(constraint: TaskLocality, condition: TaskLocality): Boolean = { - - // Must not be the constraint. - assert (constraint != TaskLocality.PROCESS_LOCAL) - - constraint match { - case TaskLocality.NODE_LOCAL => condition == TaskLocality.NODE_LOCAL - case TaskLocality.RACK_LOCAL => condition == TaskLocality.NODE_LOCAL || condition == TaskLocality.RACK_LOCAL - // For anything else, allow - case _ => true - } - } - - def parse(str: String): TaskLocality = { - // better way to do this ? - try { - val retval = TaskLocality.withName(str) - // Must not specify PROCESS_LOCAL ! - assert (retval != TaskLocality.PROCESS_LOCAL) - - retval - } catch { - case nEx: NoSuchElementException => { - logWarning("Invalid task locality specified '" + str + "', defaulting to NODE_LOCAL"); - // default to preserve earlier behavior - NODE_LOCAL - } - } - } -} - /** - * Schedules the tasks within a single TaskSet in the ClusterScheduler. */ -private[spark] class TaskSetManager( - sched: ClusterScheduler, - val taskSet: TaskSet) - extends Schedulable - with Logging { - - // Maximum time to wait to run a task in a preferred location (in ms) - val LOCALITY_WAIT = System.getProperty("spark.locality.wait", "3000").toLong - - // CPUs to request per task - val CPUS_PER_TASK = System.getProperty("spark.task.cpus", "1").toDouble - - // Maximum times a task is allowed to fail before failing the job - val MAX_TASK_FAILURES = 4 - - // Quantile of tasks at which to start speculation - val SPECULATION_QUANTILE = System.getProperty("spark.speculation.quantile", "0.75").toDouble - val SPECULATION_MULTIPLIER = System.getProperty("spark.speculation.multiplier", "1.5").toDouble - - // Serializer for closures and tasks. - val ser = SparkEnv.get.closureSerializer.newInstance() - - val tasks = taskSet.tasks - val numTasks = tasks.length - val copiesRunning = new Array[Int](numTasks) - val finished = new Array[Boolean](numTasks) - val numFailures = new Array[Int](numTasks) - val taskAttempts = Array.fill[List[TaskInfo]](numTasks)(Nil) - var tasksFinished = 0 - - var weight = 1 - var minShare = 0 - var runningTasks = 0 - var priority = taskSet.priority - var stageId = taskSet.stageId - var name = "TaskSet_"+taskSet.stageId.toString - var parent:Schedulable = null - - // Last time when we launched a preferred task (for delay scheduling) - var lastPreferredLaunchTime = System.currentTimeMillis - - // List of pending tasks for each node (process local to container). These collections are actually - // treated as stacks, in which new tasks are added to the end of the - // ArrayBuffer and removed from the end. This makes it faster to detect - // tasks that repeatedly fail because whenever a task failed, it is put - // back at the head of the stack. They are also only cleaned up lazily; - // when a task is launched, it remains in all the pending lists except - // the one that it was launched from, but gets removed from them later. - private val pendingTasksForHostPort = new HashMap[String, ArrayBuffer[Int]] - - // List of pending tasks for each node. - // Essentially, similar to pendingTasksForHostPort, except at host level - private val pendingTasksForHost = new HashMap[String, ArrayBuffer[Int]] - - // List of pending tasks for each node based on rack locality. - // Essentially, similar to pendingTasksForHost, except at rack level - private val pendingRackLocalTasksForHost = new HashMap[String, ArrayBuffer[Int]] - - // List containing pending tasks with no locality preferences - val pendingTasksWithNoPrefs = new ArrayBuffer[Int] - - // List containing all pending tasks (also used as a stack, as above) - val allPendingTasks = new ArrayBuffer[Int] - - // Tasks that can be speculated. Since these will be a small fraction of total - // tasks, we'll just hold them in a HashSet. - val speculatableTasks = new HashSet[Int] - - // Task index, start and finish time for each task attempt (indexed by task ID) - val taskInfos = new HashMap[Long, TaskInfo] - - // Did the job fail? - var failed = false - var causeOfFailure = "" - - // How frequently to reprint duplicate exceptions in full, in milliseconds - val EXCEPTION_PRINT_INTERVAL = - System.getProperty("spark.logging.exceptionPrintInterval", "10000").toLong - // Map of recent exceptions (identified by string representation and - // top stack frame) to duplicate count (how many times the same - // exception has appeared) and time the full exception was - // printed. This should ideally be an LRU map that can drop old - // exceptions automatically. - val recentExceptions = HashMap[String, (Int, Long)]() - - // Figure out the current map output tracker generation and set it on all tasks - val generation = sched.mapOutputTracker.getGeneration - logDebug("Generation for " + taskSet.id + ": " + generation) - for (t <- tasks) { - t.generation = generation - } - - // Add all our tasks to the pending lists. We do this in reverse order - // of task index so that tasks with low indices get launched first. - for (i <- (0 until numTasks).reverse) { - addPendingTask(i) - } - - // Note that it follows the hierarchy. - // if we search for NODE_LOCAL, the output will include PROCESS_LOCAL and - // if we search for RACK_LOCAL, it will include PROCESS_LOCAL & NODE_LOCAL - private def findPreferredLocations(_taskPreferredLocations: Seq[String], scheduler: ClusterScheduler, - taskLocality: TaskLocality.TaskLocality): HashSet[String] = { - - if (TaskLocality.PROCESS_LOCAL == taskLocality) { - // straight forward comparison ! Special case it. - val retval = new HashSet[String]() - scheduler.synchronized { - for (location <- _taskPreferredLocations) { - if (scheduler.isExecutorAliveOnHostPort(location)) { - retval += location - } - } - } - - return retval - } - - val taskPreferredLocations = - if (TaskLocality.NODE_LOCAL == taskLocality) { - _taskPreferredLocations - } else { - assert (TaskLocality.RACK_LOCAL == taskLocality) - // Expand set to include all 'seen' rack local hosts. - // This works since container allocation/management happens within master - so any rack locality information is updated in msater. - // Best case effort, and maybe sort of kludge for now ... rework it later ? - val hosts = new HashSet[String] - _taskPreferredLocations.foreach(h => { - val rackOpt = scheduler.getRackForHost(h) - if (rackOpt.isDefined) { - val hostsOpt = scheduler.getCachedHostsForRack(rackOpt.get) - if (hostsOpt.isDefined) { - hosts ++= hostsOpt.get - } - } - - // Ensure that irrespective of what scheduler says, host is always added ! - hosts += h - }) - - hosts - } - - val retval = new HashSet[String] - scheduler.synchronized { - for (prefLocation <- taskPreferredLocations) { - val aliveLocationsOpt = scheduler.getExecutorsAliveOnHost(Utils.parseHostPort(prefLocation)._1) - if (aliveLocationsOpt.isDefined) { - retval ++= aliveLocationsOpt.get - } - } - } - - retval - } - - // Add a task to all the pending-task lists that it should be on. - private def addPendingTask(index: Int) { - // We can infer hostLocalLocations from rackLocalLocations by joining it against tasks(index).preferredLocations (with appropriate - // hostPort <-> host conversion). But not doing it for simplicity sake. If this becomes a performance issue, modify it. - val processLocalLocations = findPreferredLocations(tasks(index).preferredLocations, sched, TaskLocality.PROCESS_LOCAL) - val hostLocalLocations = findPreferredLocations(tasks(index).preferredLocations, sched, TaskLocality.NODE_LOCAL) - val rackLocalLocations = findPreferredLocations(tasks(index).preferredLocations, sched, TaskLocality.RACK_LOCAL) - - if (rackLocalLocations.size == 0) { - // Current impl ensures this. - assert (processLocalLocations.size == 0) - assert (hostLocalLocations.size == 0) - pendingTasksWithNoPrefs += index - } else { - - // process local locality - for (hostPort <- processLocalLocations) { - // DEBUG Code - Utils.checkHostPort(hostPort) - - val hostPortList = pendingTasksForHostPort.getOrElseUpdate(hostPort, ArrayBuffer()) - hostPortList += index - } - - // host locality (includes process local) - for (hostPort <- hostLocalLocations) { - // DEBUG Code - Utils.checkHostPort(hostPort) - - val host = Utils.parseHostPort(hostPort)._1 - val hostList = pendingTasksForHost.getOrElseUpdate(host, ArrayBuffer()) - hostList += index - } - - // rack locality (includes process local and host local) - for (rackLocalHostPort <- rackLocalLocations) { - // DEBUG Code - Utils.checkHostPort(rackLocalHostPort) - - val rackLocalHost = Utils.parseHostPort(rackLocalHostPort)._1 - val list = pendingRackLocalTasksForHost.getOrElseUpdate(rackLocalHost, ArrayBuffer()) - list += index - } - } - - allPendingTasks += index - } - - // Return the pending tasks list for a given host port (process local), or an empty list if - // there is no map entry for that host - private def getPendingTasksForHostPort(hostPort: String): ArrayBuffer[Int] = { - // DEBUG Code - Utils.checkHostPort(hostPort) - pendingTasksForHostPort.getOrElse(hostPort, ArrayBuffer()) - } - - // Return the pending tasks list for a given host, or an empty list if - // there is no map entry for that host - private def getPendingTasksForHost(hostPort: String): ArrayBuffer[Int] = { - val host = Utils.parseHostPort(hostPort)._1 - pendingTasksForHost.getOrElse(host, ArrayBuffer()) - } - - // Return the pending tasks (rack level) list for a given host, or an empty list if - // there is no map entry for that host - private def getRackLocalPendingTasksForHost(hostPort: String): ArrayBuffer[Int] = { - val host = Utils.parseHostPort(hostPort)._1 - pendingRackLocalTasksForHost.getOrElse(host, ArrayBuffer()) - } - - // Number of pending tasks for a given host Port (which would be process local) - def numPendingTasksForHostPort(hostPort: String): Int = { - getPendingTasksForHostPort(hostPort).count( index => copiesRunning(index) == 0 && !finished(index) ) - } - - // Number of pending tasks for a given host (which would be data local) - def numPendingTasksForHost(hostPort: String): Int = { - getPendingTasksForHost(hostPort).count( index => copiesRunning(index) == 0 && !finished(index) ) - } - - // Number of pending rack local tasks for a given host - def numRackLocalPendingTasksForHost(hostPort: String): Int = { - getRackLocalPendingTasksForHost(hostPort).count( index => copiesRunning(index) == 0 && !finished(index) ) - } - - - // Dequeue a pending task from the given list and return its index. - // Return None if the list is empty. - // This method also cleans up any tasks in the list that have already - // been launched, since we want that to happen lazily. - private def findTaskFromList(list: ArrayBuffer[Int]): Option[Int] = { - while (!list.isEmpty) { - val index = list.last - list.trimEnd(1) - if (copiesRunning(index) == 0 && !finished(index)) { - return Some(index) - } - } - return None - } - - // Return a speculative task for a given host if any are available. The task should not have an - // attempt running on this host, in case the host is slow. In addition, if locality is set, the - // task must have a preference for this host/rack/no preferred locations at all. - private def findSpeculativeTask(hostPort: String, locality: TaskLocality.TaskLocality): Option[Int] = { - - assert (TaskLocality.isAllowed(locality, TaskLocality.NODE_LOCAL)) - speculatableTasks.retain(index => !finished(index)) // Remove finished tasks from set - - if (speculatableTasks.size > 0) { - val localTask = speculatableTasks.find { - index => - val locations = findPreferredLocations(tasks(index).preferredLocations, sched, TaskLocality.NODE_LOCAL) - val attemptLocs = taskAttempts(index).map(_.hostPort) - (locations.size == 0 || locations.contains(hostPort)) && !attemptLocs.contains(hostPort) - } - - if (localTask != None) { - speculatableTasks -= localTask.get - return localTask - } - - // check for rack locality - if (TaskLocality.isAllowed(locality, TaskLocality.RACK_LOCAL)) { - val rackTask = speculatableTasks.find { - index => - val locations = findPreferredLocations(tasks(index).preferredLocations, sched, TaskLocality.RACK_LOCAL) - val attemptLocs = taskAttempts(index).map(_.hostPort) - locations.contains(hostPort) && !attemptLocs.contains(hostPort) - } - - if (rackTask != None) { - speculatableTasks -= rackTask.get - return rackTask - } - } - - // Any task ... - if (TaskLocality.isAllowed(locality, TaskLocality.ANY)) { - // Check for attemptLocs also ? - val nonLocalTask = speculatableTasks.find(i => !taskAttempts(i).map(_.hostPort).contains(hostPort)) - if (nonLocalTask != None) { - speculatableTasks -= nonLocalTask.get - return nonLocalTask - } - } - } - return None - } - - // Dequeue a pending task for a given node and return its index. - // If localOnly is set to false, allow non-local tasks as well. - private def findTask(hostPort: String, locality: TaskLocality.TaskLocality): Option[Int] = { - val processLocalTask = findTaskFromList(getPendingTasksForHostPort(hostPort)) - if (processLocalTask != None) { - return processLocalTask - } - - val localTask = findTaskFromList(getPendingTasksForHost(hostPort)) - if (localTask != None) { - return localTask - } - - if (TaskLocality.isAllowed(locality, TaskLocality.RACK_LOCAL)) { - val rackLocalTask = findTaskFromList(getRackLocalPendingTasksForHost(hostPort)) - if (rackLocalTask != None) { - return rackLocalTask - } - } - - // Look for no pref tasks AFTER rack local tasks - this has side effect that we will get to failed tasks later rather than sooner. - // TODO: That code path needs to be revisited (adding to no prefs list when host:port goes down). - val noPrefTask = findTaskFromList(pendingTasksWithNoPrefs) - if (noPrefTask != None) { - return noPrefTask - } - - if (TaskLocality.isAllowed(locality, TaskLocality.ANY)) { - val nonLocalTask = findTaskFromList(allPendingTasks) - if (nonLocalTask != None) { - return nonLocalTask - } - } - - // Finally, if all else has failed, find a speculative task - return findSpeculativeTask(hostPort, locality) - } - - private def isProcessLocalLocation(task: Task[_], hostPort: String): Boolean = { - Utils.checkHostPort(hostPort) - - val locs = task.preferredLocations - - locs.contains(hostPort) - } - - private def isHostLocalLocation(task: Task[_], hostPort: String): Boolean = { - val locs = task.preferredLocations - - // If no preference, consider it as host local - if (locs.isEmpty) return true - - val host = Utils.parseHostPort(hostPort)._1 - locs.find(h => Utils.parseHostPort(h)._1 == host).isDefined - } - - // Does a host count as a rack local preferred location for a task? (assumes host is NOT preferred location). - // This is true if either the task has preferred locations and this host is one, or it has - // no preferred locations (in which we still count the launch as preferred). - private def isRackLocalLocation(task: Task[_], hostPort: String): Boolean = { - - val locs = task.preferredLocations - - val preferredRacks = new HashSet[String]() - for (preferredHost <- locs) { - val rack = sched.getRackForHost(preferredHost) - if (None != rack) preferredRacks += rack.get - } - - if (preferredRacks.isEmpty) return false - - val hostRack = sched.getRackForHost(hostPort) - - return None != hostRack && preferredRacks.contains(hostRack.get) - } - - // Respond to an offer of a single slave from the scheduler by finding a task - def slaveOffer(execId: String, hostPort: String, availableCpus: Double, overrideLocality: TaskLocality.TaskLocality = null): Option[TaskDescription] = { - - if (tasksFinished < numTasks && availableCpus >= CPUS_PER_TASK) { - // If explicitly specified, use that - val locality = if (overrideLocality != null) overrideLocality else { - // expand only if we have waited for more than LOCALITY_WAIT for a host local task ... - val time = System.currentTimeMillis - if (time - lastPreferredLaunchTime < LOCALITY_WAIT) TaskLocality.NODE_LOCAL else TaskLocality.ANY - } - - findTask(hostPort, locality) match { - case Some(index) => { - // Found a task; do some bookkeeping and return a Mesos task for it - val task = tasks(index) - val taskId = sched.newTaskId() - // Figure out whether this should count as a preferred launch - val taskLocality = - if (isProcessLocalLocation(task, hostPort)) TaskLocality.PROCESS_LOCAL else - if (isHostLocalLocation(task, hostPort)) TaskLocality.NODE_LOCAL else - if (isRackLocalLocation(task, hostPort)) TaskLocality.RACK_LOCAL else - TaskLocality.ANY - val prefStr = taskLocality.toString - logInfo("Starting task %s:%d as TID %s on slave %s: %s (%s)".format( - taskSet.id, index, taskId, execId, hostPort, prefStr)) - // Do various bookkeeping - copiesRunning(index) += 1 - val time = System.currentTimeMillis - val info = new TaskInfo(taskId, index, time, execId, hostPort, taskLocality) - taskInfos(taskId) = info - taskAttempts(index) = info :: taskAttempts(index) - if (TaskLocality.NODE_LOCAL == taskLocality) { - lastPreferredLaunchTime = time - } - // Serialize and return the task - val startTime = System.currentTimeMillis - val serializedTask = Task.serializeWithDependencies( - task, sched.sc.addedFiles, sched.sc.addedJars, ser) - val timeTaken = System.currentTimeMillis - startTime - increaseRunningTasks(1) - logInfo("Serialized task %s:%d as %d bytes in %d ms".format( - taskSet.id, index, serializedTask.limit, timeTaken)) - val taskName = "task %s:%d".format(taskSet.id, index) - return Some(new TaskDescription(taskId, execId, taskName, serializedTask)) - } - case _ => - } - } - return None - } - - def statusUpdate(tid: Long, state: TaskState, serializedData: ByteBuffer) { - state match { - case TaskState.FINISHED => - taskFinished(tid, state, serializedData) - case TaskState.LOST => - taskLost(tid, state, serializedData) - case TaskState.FAILED => - taskLost(tid, state, serializedData) - case TaskState.KILLED => - taskLost(tid, state, serializedData) - case _ => - } - } - - def taskFinished(tid: Long, state: TaskState, serializedData: ByteBuffer) { - val info = taskInfos(tid) - if (info.failed) { - // We might get two task-lost messages for the same task in coarse-grained Mesos mode, - // or even from Mesos itself when acks get delayed. - return - } - val index = info.index - info.markSuccessful() - decreaseRunningTasks(1) - if (!finished(index)) { - tasksFinished += 1 - logInfo("Finished TID %s in %d ms (progress: %d/%d)".format( - tid, info.duration, tasksFinished, numTasks)) - // Deserialize task result and pass it to the scheduler - val result = ser.deserialize[TaskResult[_]](serializedData, getClass.getClassLoader) - result.metrics.resultSize = serializedData.limit() - sched.listener.taskEnded(tasks(index), Success, result.value, result.accumUpdates, info, result.metrics) - // Mark finished and stop if we've finished all the tasks - finished(index) = true - if (tasksFinished == numTasks) { - sched.taskSetFinished(this) - } - } else { - logInfo("Ignoring task-finished event for TID " + tid + - " because task " + index + " is already finished") - } - } - - def taskLost(tid: Long, state: TaskState, serializedData: ByteBuffer) { - val info = taskInfos(tid) - if (info.failed) { - // We might get two task-lost messages for the same task in coarse-grained Mesos mode, - // or even from Mesos itself when acks get delayed. - return - } - val index = info.index - info.markFailed() - decreaseRunningTasks(1) - if (!finished(index)) { - logInfo("Lost TID %s (task %s:%d)".format(tid, taskSet.id, index)) - copiesRunning(index) -= 1 - // Check if the problem is a map output fetch failure. In that case, this - // task will never succeed on any node, so tell the scheduler about it. - if (serializedData != null && serializedData.limit() > 0) { - val reason = ser.deserialize[TaskEndReason](serializedData, getClass.getClassLoader) - reason match { - case fetchFailed: FetchFailed => - logInfo("Loss was due to fetch failure from " + fetchFailed.bmAddress) - sched.listener.taskEnded(tasks(index), fetchFailed, null, null, info, null) - finished(index) = true - tasksFinished += 1 - sched.taskSetFinished(this) - decreaseRunningTasks(runningTasks) - return - - case ef: ExceptionFailure => - val key = ef.description - val now = System.currentTimeMillis - val (printFull, dupCount) = { - if (recentExceptions.contains(key)) { - val (dupCount, printTime) = recentExceptions(key) - if (now - printTime > EXCEPTION_PRINT_INTERVAL) { - recentExceptions(key) = (0, now) - (true, 0) - } else { - recentExceptions(key) = (dupCount + 1, printTime) - (false, dupCount + 1) - } - } else { - recentExceptions(key) = (0, now) - (true, 0) - } - } - if (printFull) { - val locs = ef.stackTrace.map(loc => "\tat %s".format(loc.toString)) - logInfo("Loss was due to %s\n%s\n%s".format( - ef.className, ef.description, locs.mkString("\n"))) - } else { - logInfo("Loss was due to %s [duplicate %d]".format(ef.description, dupCount)) - } - - case _ => {} - } - } - // On non-fetch failures, re-enqueue the task as pending for a max number of retries - addPendingTask(index) - // Count failed attempts only on FAILED and LOST state (not on KILLED) - if (state == TaskState.FAILED || state == TaskState.LOST) { - numFailures(index) += 1 - if (numFailures(index) > MAX_TASK_FAILURES) { - logError("Task %s:%d failed more than %d times; aborting job".format( - taskSet.id, index, MAX_TASK_FAILURES)) - abort("Task %s:%d failed more than %d times".format(taskSet.id, index, MAX_TASK_FAILURES)) - } - } - } else { - logInfo("Ignoring task-lost event for TID " + tid + - " because task " + index + " is already finished") - } - } - - def error(message: String) { - // Save the error message - abort("Error: " + message) - } - - def abort(message: String) { - failed = true - causeOfFailure = message - // TODO: Kill running tasks if we were not terminated due to a Mesos error - sched.listener.taskSetFailed(taskSet, message) - decreaseRunningTasks(runningTasks) - sched.taskSetFinished(this) - } - - override def increaseRunningTasks(taskNum: Int) { - runningTasks += taskNum - if (parent != null) { - parent.increaseRunningTasks(taskNum) - } - } - - override def decreaseRunningTasks(taskNum: Int) { - runningTasks -= taskNum - if (parent != null) { - parent.decreaseRunningTasks(taskNum) - } - } - - //TODO: for now we just find Pool not TaskSetManager, we can extend this function in future if needed - override def getSchedulableByName(name: String): Schedulable = { - return null - } - - override def addSchedulable(schedulable:Schedulable) { - //nothing - } - - override def removeSchedulable(schedulable:Schedulable) { - //nothing - } - - override def getSortedTaskSetQueue(): ArrayBuffer[TaskSetManager] = { - var sortedTaskSetQueue = new ArrayBuffer[TaskSetManager] - sortedTaskSetQueue += this - return sortedTaskSetQueue - } - - override def executorLost(execId: String, hostPort: String) { - logInfo("Re-queueing tasks for " + execId + " from TaskSet " + taskSet.id) - - // If some task has preferred locations only on hostname, and there are no more executors there, - // put it in the no-prefs list to avoid the wait from delay scheduling - - // host local tasks - should we push this to rack local or no pref list ? For now, preserving behavior and moving to - // no prefs list. Note, this was done due to impliations related to 'waiting' for data local tasks, etc. - // Note: NOT checking process local list - since host local list is super set of that. We need to ad to no prefs only if - // there is no host local node for the task (not if there is no process local node for the task) - for (index <- getPendingTasksForHost(Utils.parseHostPort(hostPort)._1)) { - // val newLocs = findPreferredLocations(tasks(index).preferredLocations, sched, TaskLocality.RACK_LOCAL) - val newLocs = findPreferredLocations(tasks(index).preferredLocations, sched, TaskLocality.NODE_LOCAL) - if (newLocs.isEmpty) { - pendingTasksWithNoPrefs += index - } - } - - // Re-enqueue any tasks that ran on the failed executor if this is a shuffle map stage - if (tasks(0).isInstanceOf[ShuffleMapTask]) { - for ((tid, info) <- taskInfos if info.executorId == execId) { - val index = taskInfos(tid).index - if (finished(index)) { - finished(index) = false - copiesRunning(index) -= 1 - tasksFinished -= 1 - addPendingTask(index) - // Tell the DAGScheduler that this task was resubmitted so that it doesn't think our - // stage finishes when a total of tasks.size tasks finish. - sched.listener.taskEnded(tasks(index), Resubmitted, null, null, info, null) - } - } - } - // Also re-enqueue any tasks that were running on the node - for ((tid, info) <- taskInfos if info.running && info.executorId == execId) { - taskLost(tid, TaskState.KILLED, null) - } - } - - /** - * Check for tasks to be speculated and return true if there are any. This is called periodically - * by the ClusterScheduler. - * - * TODO: To make this scale to large jobs, we need to maintain a list of running tasks, so that - * we don't scan the whole task set. It might also help to make this sorted by launch time. - */ - override def checkSpeculatableTasks(): Boolean = { - // Can't speculate if we only have one task, or if all tasks have finished. - if (numTasks == 1 || tasksFinished == numTasks) { - return false - } - var foundTasks = false - val minFinishedForSpeculation = (SPECULATION_QUANTILE * numTasks).floor.toInt - logDebug("Checking for speculative tasks: minFinished = " + minFinishedForSpeculation) - if (tasksFinished >= minFinishedForSpeculation) { - val time = System.currentTimeMillis() - val durations = taskInfos.values.filter(_.successful).map(_.duration).toArray - Arrays.sort(durations) - val medianDuration = durations(min((0.5 * numTasks).round.toInt, durations.size - 1)) - val threshold = max(SPECULATION_MULTIPLIER * medianDuration, 100) - // TODO: Threshold should also look at standard deviation of task durations and have a lower - // bound based on that. - logDebug("Task length threshold for speculation: " + threshold) - for ((tid, info) <- taskInfos) { - val index = info.index - if (!finished(index) && copiesRunning(index) == 1 && info.timeRunning(time) > threshold && - !speculatableTasks.contains(index)) { - logInfo( - "Marking task %s:%d (on %s) as speculatable because it ran more than %.0f ms".format( - taskSet.id, index, info.hostPort, threshold)) - speculatableTasks += index - foundTasks = true - } - } - } - return foundTasks - } - - override def hasPendingTasks(): Boolean = { - numTasks > 0 && tasksFinished < numTasks - } +private[spark] trait TaskSetManager extends Schedulable { + def taskSet: TaskSet + def slaveOffer(execId: String, hostPort: String, availableCpus: Double, overrideLocality: TaskLocality.TaskLocality = null): Option[TaskDescription] + def numPendingTasksForHostPort(hostPort: String): Int + def numRackLocalPendingTasksForHost(hostPort :String): Int + def numPendingTasksForHost(hostPort: String): Int + def statusUpdate(tid: Long, state: TaskState, serializedData: ByteBuffer) + def error(message: String) } diff --git a/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala b/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala index 37a67f9b1b..664dc9e886 100644 --- a/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala +++ b/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala @@ -2,19 +2,215 @@ package spark.scheduler.local import java.io.File import java.util.concurrent.atomic.AtomicInteger +import java.nio.ByteBuffer +import scala.collection.mutable.ArrayBuffer import scala.collection.mutable.HashMap +import scala.collection.mutable.HashSet import spark._ +import spark.TaskState.TaskState import spark.executor.ExecutorURLClassLoader import spark.scheduler._ -import spark.scheduler.cluster.{TaskLocality, TaskInfo} +import spark.scheduler.cluster._ +import akka.actor._ /** * A simple TaskScheduler implementation that runs tasks locally in a thread pool. Optionally * the scheduler also allows each task to fail up to maxFailures times, which is useful for * testing fault recovery. */ -private[spark] class LocalScheduler(threads: Int, maxFailures: Int, sc: SparkContext) + +private[spark] case class LocalReviveOffers() +private[spark] case class LocalStatusUpdate(taskId: Long, state: TaskState, serializedData: ByteBuffer) + +private[spark] class LocalActor(localScheduler: LocalScheduler, var freeCores: Int) extends Actor with Logging { + def receive = { + case LocalReviveOffers => + logInfo("LocalReviveOffers") + launchTask(localScheduler.resourceOffer(freeCores)) + case LocalStatusUpdate(taskId, state, serializeData) => + logInfo("LocalStatusUpdate") + freeCores += 1 + localScheduler.statusUpdate(taskId, state, serializeData) + launchTask(localScheduler.resourceOffer(freeCores)) + } + + def launchTask(tasks : Seq[TaskDescription]) { + for (task <- tasks) + { + freeCores -= 1 + localScheduler.threadPool.submit(new Runnable { + def run() { + localScheduler.runTask(task.taskId,task.serializedTask) + } + }) + } + } +} + +private[spark] class LocalTaskSetManager(sched: LocalScheduler, val taskSet: TaskSet) extends TaskSetManager with Logging { + var parent: Schedulable = null + var weight: Int = 1 + var minShare: Int = 0 + var runningTasks: Int = 0 + var priority: Int = taskSet.priority + var stageId: Int = taskSet.stageId + var name: String = "TaskSet_"+taskSet.stageId.toString + + + var failCount = new Array[Int](taskSet.tasks.size) + val taskInfos = new HashMap[Long, TaskInfo] + val numTasks = taskSet.tasks.size + var numFinished = 0 + val ser = SparkEnv.get.closureSerializer.newInstance() + val copiesRunning = new Array[Int](numTasks) + val finished = new Array[Boolean](numTasks) + val numFailures = new Array[Int](numTasks) + + def increaseRunningTasks(taskNum: Int): Unit = { + runningTasks += taskNum + if (parent != null) { + parent.increaseRunningTasks(taskNum) + } + } + + def decreaseRunningTasks(taskNum: Int): Unit = { + runningTasks -= taskNum + if (parent != null) { + parent.decreaseRunningTasks(taskNum) + } + } + + def addSchedulable(schedulable: Schedulable): Unit = { + //nothing + } + + def removeSchedulable(schedulable: Schedulable): Unit = { + //nothing + } + + def getSchedulableByName(name: String): Schedulable = { + return null + } + + def executorLost(executorId: String, host: String): Unit = { + //nothing + } + + def checkSpeculatableTasks(): Boolean = { + return true + } + + def getSortedTaskSetQueue(): ArrayBuffer[TaskSetManager] = { + var sortedTaskSetQueue = new ArrayBuffer[TaskSetManager] + sortedTaskSetQueue += this + return sortedTaskSetQueue + } + + def hasPendingTasks(): Boolean = { + return true + } + + def findTask(): Option[Int] = { + for (i <- 0 to numTasks-1) { + if (copiesRunning(i) == 0 && !finished(i)) { + return Some(i) + } + } + return None + } + + def slaveOffer(execId: String, hostPort: String, availableCpus: Double, overrideLocality: TaskLocality.TaskLocality = null): Option[TaskDescription] = { + Thread.currentThread().setContextClassLoader(sched.classLoader) + SparkEnv.set(sched.env) + logInfo("availableCpus:%d,numFinished:%d,numTasks:%d".format(availableCpus.toInt, numFinished, numTasks)) + if (availableCpus > 0 && numFinished < numTasks) { + findTask() match { + case Some(index) => + logInfo(taskSet.tasks(index).toString) + val taskId = sched.attemptId.getAndIncrement() + val task = taskSet.tasks(index) + logInfo("taskId:%d,task:%s".format(index,task)) + val info = new TaskInfo(taskId, index, System.currentTimeMillis(), "local", "local:1", TaskLocality.NODE_LOCAL) + taskInfos(taskId) = info + val bytes = Task.serializeWithDependencies(task, sched.sc.addedFiles, sched.sc.addedJars, ser) + logInfo("Size of task " + taskId + " is " + bytes.limit + " bytes") + val taskName = "task %s:%d".format(taskSet.id, index) + copiesRunning(index) += 1 + increaseRunningTasks(1) + return Some(new TaskDescription(taskId, null, taskName, bytes)) + case None => {} + } + } + return None + } + + def numPendingTasksForHostPort(hostPort: String): Int = { + return 0 + } + + def numRackLocalPendingTasksForHost(hostPort :String): Int = { + return 0 + } + + def numPendingTasksForHost(hostPort: String): Int = { + return 0 + } + + def statusUpdate(tid: Long, state: TaskState, serializedData: ByteBuffer) { + state match { + case TaskState.FINISHED => + taskEnded(tid, state, serializedData) + case TaskState.FAILED => + taskFailed(tid, state, serializedData) + case _ => {} + } + } + + def taskEnded(tid: Long, state: TaskState, serializedData: ByteBuffer) { + val info = taskInfos(tid) + val index = info.index + val task = taskSet.tasks(index) + info.markSuccessful() + val result = ser.deserialize[TaskResult[_]](serializedData, getClass.getClassLoader) + result.metrics.resultSize = serializedData.limit() + sched.listener.taskEnded(task, Success, result.value, result.accumUpdates, info, result.metrics) + numFinished += 1 + decreaseRunningTasks(1) + finished(index) = true + if (numFinished == numTasks) { + sched.taskSetFinished(this) + } + } + + def taskFailed(tid: Long, state: TaskState, serializedData: ByteBuffer) { + val info = taskInfos(tid) + val index = info.index + val task = taskSet.tasks(index) + info.markFailed() + decreaseRunningTasks(1) + val reason: ExceptionFailure = ser.deserialize[ExceptionFailure](serializedData, getClass.getClassLoader) + if (!finished(index)) { + copiesRunning(index) -= 1 + numFailures(index) += 1 + val locs = reason.stackTrace.map(loc => "\tat %s".format(loc.toString)) + logInfo("Loss was due to %s\n%s\n%s".format(reason.className, reason.description, locs.mkString("\n"))) + if (numFailures(index) > 4) { + val errorMessage = "Task %s:%d failed more than %d times; aborting job %s".format(taskSet.id, index, 4, reason.description) + //logError(errorMessage) + //sched.listener.taskEnded(task, reason, null, null, info, null) + sched.listener.taskSetFailed(taskSet, errorMessage) + sched.taskSetFinished(this) + decreaseRunningTasks(runningTasks) + } + } + } + + def error(message: String) { + } +} + +private[spark] class LocalScheduler(threads: Int, val maxFailures: Int, val sc: SparkContext) extends TaskScheduler with Logging { @@ -30,90 +226,126 @@ private[spark] class LocalScheduler(threads: Int, maxFailures: Int, sc: SparkCon val classLoader = new ExecutorURLClassLoader(Array(), Thread.currentThread.getContextClassLoader) + var schedulableBuilder: SchedulableBuilder = null + var rootPool: Pool = null + val activeTaskSets = new HashMap[String, TaskSetManager] + val taskIdToTaskSetId = new HashMap[Long, String] + val taskSetTaskIds = new HashMap[String, HashSet[Long]] + + var localActor: ActorRef = null // TODO: Need to take into account stage priority in scheduling - override def start() { } + override def start() { + //default scheduler is FIFO + val schedulingMode = System.getProperty("spark.cluster.schedulingmode", "FIFO") + //temporarily set rootPool name to empty + rootPool = new Pool("", SchedulingMode.withName(schedulingMode), 0, 0) + schedulableBuilder = { + schedulingMode match { + case "FIFO" => + new FIFOSchedulableBuilder(rootPool) + case "FAIR" => + new FairSchedulableBuilder(rootPool) + } + } + schedulableBuilder.buildPools() + + //val properties = new ArrayBuffer[(String, String)] + localActor = env.actorSystem.actorOf( + Props(new LocalActor(this, threads)), "Test") + } override def setListener(listener: TaskSchedulerListener) { this.listener = listener } override def submitTasks(taskSet: TaskSet) { - val tasks = taskSet.tasks - val failCount = new Array[Int](tasks.size) + var manager = new LocalTaskSetManager(this, taskSet) + schedulableBuilder.addTaskSetManager(manager, manager.taskSet.properties) + activeTaskSets(taskSet.id) = manager + taskSetTaskIds(taskSet.id) = new HashSet[Long]() + localActor ! LocalReviveOffers + } - def submitTask(task: Task[_], idInJob: Int) { - val myAttemptId = attemptId.getAndIncrement() - threadPool.submit(new Runnable { - def run() { - runTask(task, idInJob, myAttemptId) - } - }) + def resourceOffer(freeCores: Int): Seq[TaskDescription] = { + var freeCpuCores = freeCores + val tasks = new ArrayBuffer[TaskDescription](freeCores) + val sortedTaskSetQueue = rootPool.getSortedTaskSetQueue() + for (manager <- sortedTaskSetQueue) { + logInfo("parentName:%s,name:%s,runningTasks:%s".format(manager.parent.name, manager.name, manager.runningTasks)) } - def runTask(task: Task[_], idInJob: Int, attemptId: Int) { - logInfo("Running " + task) - val info = new TaskInfo(attemptId, idInJob, System.currentTimeMillis(), "local", "local:1", TaskLocality.NODE_LOCAL) - // Set the Spark execution environment for the worker thread - SparkEnv.set(env) - try { - Accumulators.clear() - Thread.currentThread().setContextClassLoader(classLoader) - - // Serialize and deserialize the task so that accumulators are changed to thread-local ones; - // this adds a bit of unnecessary overhead but matches how the Mesos Executor works. - val ser = SparkEnv.get.closureSerializer.newInstance() - val bytes = Task.serializeWithDependencies(task, sc.addedFiles, sc.addedJars, ser) - logInfo("Size of task " + idInJob + " is " + bytes.limit + " bytes") - val (taskFiles, taskJars, taskBytes) = Task.deserializeWithDependencies(bytes) - updateDependencies(taskFiles, taskJars) // Download any files added with addFile - val deserStart = System.currentTimeMillis() - val deserializedTask = ser.deserialize[Task[_]]( - taskBytes, Thread.currentThread.getContextClassLoader) - val deserTime = System.currentTimeMillis() - deserStart - - // Run it - val result: Any = deserializedTask.run(attemptId) - - // Serialize and deserialize the result to emulate what the Mesos - // executor does. This is useful to catch serialization errors early - // on in development (so when users move their local Spark programs - // to the cluster, they don't get surprised by serialization errors). - val serResult = ser.serialize(result) - deserializedTask.metrics.get.resultSize = serResult.limit() - val resultToReturn = ser.deserialize[Any](serResult) - val accumUpdates = ser.deserialize[collection.mutable.Map[Long, Any]]( - ser.serialize(Accumulators.values)) - logInfo("Finished " + task) - info.markSuccessful() - deserializedTask.metrics.get.executorRunTime = info.duration.toInt //close enough - deserializedTask.metrics.get.executorDeserializeTime = deserTime.toInt - - // If the threadpool has not already been shutdown, notify DAGScheduler - if (!Thread.currentThread().isInterrupted) - listener.taskEnded(task, Success, resultToReturn, accumUpdates, info, deserializedTask.metrics.getOrElse(null)) - } catch { - case t: Throwable => { - logError("Exception in task " + idInJob, t) - failCount.synchronized { - failCount(idInJob) += 1 - if (failCount(idInJob) <= maxFailures) { - submitTask(task, idInJob) - } else { - // TODO: Do something nicer here to return all the way to the user - if (!Thread.currentThread().isInterrupted) { - val failure = new ExceptionFailure(t.getClass.getName, t.toString, t.getStackTrace) - listener.taskEnded(task, failure, null, null, info, null) - } - } + var launchTask = false + for (manager <- sortedTaskSetQueue) { + do { + launchTask = false + logInfo("freeCores is" + freeCpuCores) + manager.slaveOffer(null,null,freeCpuCores) match { + case Some(task) => + tasks += task + taskIdToTaskSetId(task.taskId) = manager.taskSet.id + taskSetTaskIds(manager.taskSet.id) += task.taskId + freeCpuCores -= 1 + launchTask = true + case None => {} } - } - } + } while(launchTask) } + return tasks + } - for ((task, i) <- tasks.zipWithIndex) { - submitTask(task, i) - } + def taskSetFinished(manager: TaskSetManager) { + activeTaskSets -= manager.taskSet.id + manager.parent.removeSchedulable(manager) + logInfo("Remove TaskSet %s from pool %s".format(manager.taskSet.id, manager.parent.name)) + taskIdToTaskSetId --= taskSetTaskIds(manager.taskSet.id) + taskSetTaskIds -= manager.taskSet.id + } + + def runTask(taskId: Long, bytes: ByteBuffer) { + logInfo("Running " + taskId) + val info = new TaskInfo(taskId, 0 , System.currentTimeMillis(), "local", "local:1", TaskLocality.NODE_LOCAL) + // Set the Spark execution environment for the worker thread + SparkEnv.set(env) + val ser = SparkEnv.get.closureSerializer.newInstance() + try { + Accumulators.clear() + Thread.currentThread().setContextClassLoader(classLoader) + + // Serialize and deserialize the task so that accumulators are changed to thread-local ones; + // this adds a bit of unnecessary overhead but matches how the Mesos Executor works. + val (taskFiles, taskJars, taskBytes) = Task.deserializeWithDependencies(bytes) + updateDependencies(taskFiles, taskJars) // Download any files added with addFile + val deserStart = System.currentTimeMillis() + val deserializedTask = ser.deserialize[Task[_]]( + taskBytes, Thread.currentThread.getContextClassLoader) + val deserTime = System.currentTimeMillis() - deserStart + + // Run it + val result: Any = deserializedTask.run(taskId) + + // Serialize and deserialize the result to emulate what the Mesos + // executor does. This is useful to catch serialization errors early + // on in development (so when users move their local Spark programs + // to the cluster, they don't get surprised by serialization errors). + val serResult = ser.serialize(result) + deserializedTask.metrics.get.resultSize = serResult.limit() + val resultToReturn = ser.deserialize[Any](serResult) + val accumUpdates = ser.deserialize[collection.mutable.Map[Long, Any]]( + ser.serialize(Accumulators.values)) + logInfo("Finished " + taskId) + deserializedTask.metrics.get.executorRunTime = deserTime.toInt//info.duration.toInt //close enough + deserializedTask.metrics.get.executorDeserializeTime = deserTime.toInt + + val taskResult = new TaskResult(result, accumUpdates, deserializedTask.metrics.getOrElse(null)) + val serializedResult = ser.serialize(taskResult) + localActor ! LocalStatusUpdate(taskId, TaskState.FINISHED, serializedResult) + } catch { + case t: Throwable => { + val failure = new ExceptionFailure(t.getClass.getName, t.toString, t.getStackTrace) + localActor ! LocalStatusUpdate(taskId, TaskState.FAILED, ser.serialize(failure)) + } + } } /** @@ -128,6 +360,7 @@ private[spark] class LocalScheduler(threads: Int, maxFailures: Int, sc: SparkCon Utils.fetchFile(name, new File(SparkFiles.getRootDirectory)) currentFiles(name) = timestamp } + for ((name, timestamp) <- newJars if currentJars.getOrElse(name, -1L) < timestamp) { logInfo("Fetching " + name + " with timestamp " + timestamp) Utils.fetchFile(name, new File(SparkFiles.getRootDirectory)) @@ -143,7 +376,14 @@ private[spark] class LocalScheduler(threads: Int, maxFailures: Int, sc: SparkCon } } - override def stop() { + def statusUpdate(taskId :Long, state: TaskState, serializedData: ByteBuffer) + { + val taskSetId = taskIdToTaskSetId(taskId) + val taskSetManager = activeTaskSets(taskSetId) + taskSetManager.statusUpdate(taskId, state, serializedData) + } + + override def stop() { threadPool.shutdownNow() } diff --git a/core/src/test/scala/spark/scheduler/ClusterSchedulerSuite.scala b/core/src/test/scala/spark/scheduler/ClusterSchedulerSuite.scala index a39418b716..e6ad90192e 100644 --- a/core/src/test/scala/spark/scheduler/ClusterSchedulerSuite.scala +++ b/core/src/test/scala/spark/scheduler/ClusterSchedulerSuite.scala @@ -16,7 +16,7 @@ class DummyTaskSetManager( initNumTasks: Int, clusterScheduler: ClusterScheduler, taskSet: TaskSet) - extends TaskSetManager(clusterScheduler,taskSet) { + extends ClusterTaskSetManager(clusterScheduler,taskSet) { parent = null weight = 1 -- cgit v1.2.3 From c3db3ea55467c3fb053453c8c567db357d939640 Mon Sep 17 00:00:00 2001 From: Andrew xia Date: Thu, 30 May 2013 20:49:40 +0800 Subject: 1. Add unit test for local scheduler 2. Move localTaskSetManager to a new file --- .../spark/scheduler/local/LocalScheduler.scala | 241 ++++----------------- .../scheduler/local/LocalTaskSetManager.scala | 173 +++++++++++++++ .../spark/scheduler/LocalSchedulerSuite.scala | 171 +++++++++++++++ 3 files changed, 385 insertions(+), 200 deletions(-) create mode 100644 core/src/main/scala/spark/scheduler/local/LocalTaskSetManager.scala create mode 100644 core/src/test/scala/spark/scheduler/LocalSchedulerSuite.scala diff --git a/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala b/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala index 664dc9e886..69dacfc2bd 100644 --- a/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala +++ b/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala @@ -15,7 +15,7 @@ import spark.scheduler.cluster._ import akka.actor._ /** - * A simple TaskScheduler implementation that runs tasks locally in a thread pool. Optionally + * A FIFO or Fair TaskScheduler implementation that runs tasks locally in a thread pool. Optionally * the scheduler also allows each task to fail up to maxFailures times, which is useful for * testing fault recovery. */ @@ -26,10 +26,8 @@ private[spark] case class LocalStatusUpdate(taskId: Long, state: TaskState, seri private[spark] class LocalActor(localScheduler: LocalScheduler, var freeCores: Int) extends Actor with Logging { def receive = { case LocalReviveOffers => - logInfo("LocalReviveOffers") launchTask(localScheduler.resourceOffer(freeCores)) case LocalStatusUpdate(taskId, state, serializeData) => - logInfo("LocalStatusUpdate") freeCores += 1 localScheduler.statusUpdate(taskId, state, serializeData) launchTask(localScheduler.resourceOffer(freeCores)) @@ -48,168 +46,6 @@ private[spark] class LocalActor(localScheduler: LocalScheduler, var freeCores: I } } -private[spark] class LocalTaskSetManager(sched: LocalScheduler, val taskSet: TaskSet) extends TaskSetManager with Logging { - var parent: Schedulable = null - var weight: Int = 1 - var minShare: Int = 0 - var runningTasks: Int = 0 - var priority: Int = taskSet.priority - var stageId: Int = taskSet.stageId - var name: String = "TaskSet_"+taskSet.stageId.toString - - - var failCount = new Array[Int](taskSet.tasks.size) - val taskInfos = new HashMap[Long, TaskInfo] - val numTasks = taskSet.tasks.size - var numFinished = 0 - val ser = SparkEnv.get.closureSerializer.newInstance() - val copiesRunning = new Array[Int](numTasks) - val finished = new Array[Boolean](numTasks) - val numFailures = new Array[Int](numTasks) - - def increaseRunningTasks(taskNum: Int): Unit = { - runningTasks += taskNum - if (parent != null) { - parent.increaseRunningTasks(taskNum) - } - } - - def decreaseRunningTasks(taskNum: Int): Unit = { - runningTasks -= taskNum - if (parent != null) { - parent.decreaseRunningTasks(taskNum) - } - } - - def addSchedulable(schedulable: Schedulable): Unit = { - //nothing - } - - def removeSchedulable(schedulable: Schedulable): Unit = { - //nothing - } - - def getSchedulableByName(name: String): Schedulable = { - return null - } - - def executorLost(executorId: String, host: String): Unit = { - //nothing - } - - def checkSpeculatableTasks(): Boolean = { - return true - } - - def getSortedTaskSetQueue(): ArrayBuffer[TaskSetManager] = { - var sortedTaskSetQueue = new ArrayBuffer[TaskSetManager] - sortedTaskSetQueue += this - return sortedTaskSetQueue - } - - def hasPendingTasks(): Boolean = { - return true - } - - def findTask(): Option[Int] = { - for (i <- 0 to numTasks-1) { - if (copiesRunning(i) == 0 && !finished(i)) { - return Some(i) - } - } - return None - } - - def slaveOffer(execId: String, hostPort: String, availableCpus: Double, overrideLocality: TaskLocality.TaskLocality = null): Option[TaskDescription] = { - Thread.currentThread().setContextClassLoader(sched.classLoader) - SparkEnv.set(sched.env) - logInfo("availableCpus:%d,numFinished:%d,numTasks:%d".format(availableCpus.toInt, numFinished, numTasks)) - if (availableCpus > 0 && numFinished < numTasks) { - findTask() match { - case Some(index) => - logInfo(taskSet.tasks(index).toString) - val taskId = sched.attemptId.getAndIncrement() - val task = taskSet.tasks(index) - logInfo("taskId:%d,task:%s".format(index,task)) - val info = new TaskInfo(taskId, index, System.currentTimeMillis(), "local", "local:1", TaskLocality.NODE_LOCAL) - taskInfos(taskId) = info - val bytes = Task.serializeWithDependencies(task, sched.sc.addedFiles, sched.sc.addedJars, ser) - logInfo("Size of task " + taskId + " is " + bytes.limit + " bytes") - val taskName = "task %s:%d".format(taskSet.id, index) - copiesRunning(index) += 1 - increaseRunningTasks(1) - return Some(new TaskDescription(taskId, null, taskName, bytes)) - case None => {} - } - } - return None - } - - def numPendingTasksForHostPort(hostPort: String): Int = { - return 0 - } - - def numRackLocalPendingTasksForHost(hostPort :String): Int = { - return 0 - } - - def numPendingTasksForHost(hostPort: String): Int = { - return 0 - } - - def statusUpdate(tid: Long, state: TaskState, serializedData: ByteBuffer) { - state match { - case TaskState.FINISHED => - taskEnded(tid, state, serializedData) - case TaskState.FAILED => - taskFailed(tid, state, serializedData) - case _ => {} - } - } - - def taskEnded(tid: Long, state: TaskState, serializedData: ByteBuffer) { - val info = taskInfos(tid) - val index = info.index - val task = taskSet.tasks(index) - info.markSuccessful() - val result = ser.deserialize[TaskResult[_]](serializedData, getClass.getClassLoader) - result.metrics.resultSize = serializedData.limit() - sched.listener.taskEnded(task, Success, result.value, result.accumUpdates, info, result.metrics) - numFinished += 1 - decreaseRunningTasks(1) - finished(index) = true - if (numFinished == numTasks) { - sched.taskSetFinished(this) - } - } - - def taskFailed(tid: Long, state: TaskState, serializedData: ByteBuffer) { - val info = taskInfos(tid) - val index = info.index - val task = taskSet.tasks(index) - info.markFailed() - decreaseRunningTasks(1) - val reason: ExceptionFailure = ser.deserialize[ExceptionFailure](serializedData, getClass.getClassLoader) - if (!finished(index)) { - copiesRunning(index) -= 1 - numFailures(index) += 1 - val locs = reason.stackTrace.map(loc => "\tat %s".format(loc.toString)) - logInfo("Loss was due to %s\n%s\n%s".format(reason.className, reason.description, locs.mkString("\n"))) - if (numFailures(index) > 4) { - val errorMessage = "Task %s:%d failed more than %d times; aborting job %s".format(taskSet.id, index, 4, reason.description) - //logError(errorMessage) - //sched.listener.taskEnded(task, reason, null, null, info, null) - sched.listener.taskSetFailed(taskSet, errorMessage) - sched.taskSetFinished(this) - decreaseRunningTasks(runningTasks) - } - } - } - - def error(message: String) { - } -} - private[spark] class LocalScheduler(threads: Int, val maxFailures: Int, val sc: SparkContext) extends TaskScheduler with Logging { @@ -233,7 +69,6 @@ private[spark] class LocalScheduler(threads: Int, val maxFailures: Int, val sc: val taskSetTaskIds = new HashMap[String, HashSet[Long]] var localActor: ActorRef = null - // TODO: Need to take into account stage priority in scheduling override def start() { //default scheduler is FIFO @@ -250,7 +85,6 @@ private[spark] class LocalScheduler(threads: Int, val maxFailures: Int, val sc: } schedulableBuilder.buildPools() - //val properties = new ArrayBuffer[(String, String)] localActor = env.actorSystem.actorOf( Props(new LocalActor(this, threads)), "Test") } @@ -260,51 +94,56 @@ private[spark] class LocalScheduler(threads: Int, val maxFailures: Int, val sc: } override def submitTasks(taskSet: TaskSet) { - var manager = new LocalTaskSetManager(this, taskSet) - schedulableBuilder.addTaskSetManager(manager, manager.taskSet.properties) - activeTaskSets(taskSet.id) = manager - taskSetTaskIds(taskSet.id) = new HashSet[Long]() - localActor ! LocalReviveOffers + synchronized { + var manager = new LocalTaskSetManager(this, taskSet) + schedulableBuilder.addTaskSetManager(manager, manager.taskSet.properties) + activeTaskSets(taskSet.id) = manager + taskSetTaskIds(taskSet.id) = new HashSet[Long]() + localActor ! LocalReviveOffers + } } def resourceOffer(freeCores: Int): Seq[TaskDescription] = { - var freeCpuCores = freeCores - val tasks = new ArrayBuffer[TaskDescription](freeCores) - val sortedTaskSetQueue = rootPool.getSortedTaskSetQueue() - for (manager <- sortedTaskSetQueue) { - logInfo("parentName:%s,name:%s,runningTasks:%s".format(manager.parent.name, manager.name, manager.runningTasks)) - } + synchronized { + var freeCpuCores = freeCores + val tasks = new ArrayBuffer[TaskDescription](freeCores) + val sortedTaskSetQueue = rootPool.getSortedTaskSetQueue() + for (manager <- sortedTaskSetQueue) { + logInfo("parentName:%s,name:%s,runningTasks:%s".format(manager.parent.name, manager.name, manager.runningTasks)) + } - var launchTask = false - for (manager <- sortedTaskSetQueue) { + var launchTask = false + for (manager <- sortedTaskSetQueue) { do { launchTask = false - logInfo("freeCores is" + freeCpuCores) manager.slaveOffer(null,null,freeCpuCores) match { case Some(task) => - tasks += task - taskIdToTaskSetId(task.taskId) = manager.taskSet.id - taskSetTaskIds(manager.taskSet.id) += task.taskId - freeCpuCores -= 1 - launchTask = true + tasks += task + taskIdToTaskSetId(task.taskId) = manager.taskSet.id + taskSetTaskIds(manager.taskSet.id) += task.taskId + freeCpuCores -= 1 + launchTask = true case None => {} - } + } } while(launchTask) + } + return tasks } - return tasks } def taskSetFinished(manager: TaskSetManager) { - activeTaskSets -= manager.taskSet.id - manager.parent.removeSchedulable(manager) - logInfo("Remove TaskSet %s from pool %s".format(manager.taskSet.id, manager.parent.name)) - taskIdToTaskSetId --= taskSetTaskIds(manager.taskSet.id) - taskSetTaskIds -= manager.taskSet.id + synchronized { + activeTaskSets -= manager.taskSet.id + manager.parent.removeSchedulable(manager) + logInfo("Remove TaskSet %s from pool %s".format(manager.taskSet.id, manager.parent.name)) + taskIdToTaskSetId --= taskSetTaskIds(manager.taskSet.id) + taskSetTaskIds -= manager.taskSet.id + } } def runTask(taskId: Long, bytes: ByteBuffer) { logInfo("Running " + taskId) - val info = new TaskInfo(taskId, 0 , System.currentTimeMillis(), "local", "local:1", TaskLocality.NODE_LOCAL) + val info = new TaskInfo(taskId, 0, System.currentTimeMillis(), "local", "local:1", TaskLocality.NODE_LOCAL) // Set the Spark execution environment for the worker thread SparkEnv.set(env) val ser = SparkEnv.get.closureSerializer.newInstance() @@ -344,8 +183,8 @@ private[spark] class LocalScheduler(threads: Int, val maxFailures: Int, val sc: case t: Throwable => { val failure = new ExceptionFailure(t.getClass.getName, t.toString, t.getStackTrace) localActor ! LocalStatusUpdate(taskId, TaskState.FAILED, ser.serialize(failure)) - } } + } } /** @@ -376,11 +215,13 @@ private[spark] class LocalScheduler(threads: Int, val maxFailures: Int, val sc: } } - def statusUpdate(taskId :Long, state: TaskState, serializedData: ByteBuffer) - { - val taskSetId = taskIdToTaskSetId(taskId) - val taskSetManager = activeTaskSets(taskSetId) - taskSetManager.statusUpdate(taskId, state, serializedData) + def statusUpdate(taskId :Long, state: TaskState, serializedData: ByteBuffer) { + synchronized { + val taskSetId = taskIdToTaskSetId(taskId) + val taskSetManager = activeTaskSets(taskSetId) + taskSetTaskIds(taskSetId) -= taskId + taskSetManager.statusUpdate(taskId, state, serializedData) + } } override def stop() { diff --git a/core/src/main/scala/spark/scheduler/local/LocalTaskSetManager.scala b/core/src/main/scala/spark/scheduler/local/LocalTaskSetManager.scala new file mode 100644 index 0000000000..f2e07d162a --- /dev/null +++ b/core/src/main/scala/spark/scheduler/local/LocalTaskSetManager.scala @@ -0,0 +1,173 @@ +package spark.scheduler.local + +import java.io.File +import java.util.concurrent.atomic.AtomicInteger +import java.nio.ByteBuffer +import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.HashMap +import scala.collection.mutable.HashSet + +import spark._ +import spark.TaskState.TaskState +import spark.scheduler._ +import spark.scheduler.cluster._ + +private[spark] class LocalTaskSetManager(sched: LocalScheduler, val taskSet: TaskSet) extends TaskSetManager with Logging { + var parent: Schedulable = null + var weight: Int = 1 + var minShare: Int = 0 + var runningTasks: Int = 0 + var priority: Int = taskSet.priority + var stageId: Int = taskSet.stageId + var name: String = "TaskSet_"+taskSet.stageId.toString + + + var failCount = new Array[Int](taskSet.tasks.size) + val taskInfos = new HashMap[Long, TaskInfo] + val numTasks = taskSet.tasks.size + var numFinished = 0 + val ser = SparkEnv.get.closureSerializer.newInstance() + val copiesRunning = new Array[Int](numTasks) + val finished = new Array[Boolean](numTasks) + val numFailures = new Array[Int](numTasks) + val MAX_TASK_FAILURES = sched.maxFailures + + def increaseRunningTasks(taskNum: Int): Unit = { + runningTasks += taskNum + if (parent != null) { + parent.increaseRunningTasks(taskNum) + } + } + + def decreaseRunningTasks(taskNum: Int): Unit = { + runningTasks -= taskNum + if (parent != null) { + parent.decreaseRunningTasks(taskNum) + } + } + + def addSchedulable(schedulable: Schedulable): Unit = { + //nothing + } + + def removeSchedulable(schedulable: Schedulable): Unit = { + //nothing + } + + def getSchedulableByName(name: String): Schedulable = { + return null + } + + def executorLost(executorId: String, host: String): Unit = { + //nothing + } + + def checkSpeculatableTasks(): Boolean = { + return true + } + + def getSortedTaskSetQueue(): ArrayBuffer[TaskSetManager] = { + var sortedTaskSetQueue = new ArrayBuffer[TaskSetManager] + sortedTaskSetQueue += this + return sortedTaskSetQueue + } + + def hasPendingTasks(): Boolean = { + return true + } + + def findTask(): Option[Int] = { + for (i <- 0 to numTasks-1) { + if (copiesRunning(i) == 0 && !finished(i)) { + return Some(i) + } + } + return None + } + + def slaveOffer(execId: String, hostPort: String, availableCpus: Double, overrideLocality: TaskLocality.TaskLocality = null): Option[TaskDescription] = { + SparkEnv.set(sched.env) + logDebug("availableCpus:%d,numFinished:%d,numTasks:%d".format(availableCpus.toInt, numFinished, numTasks)) + if (availableCpus > 0 && numFinished < numTasks) { + findTask() match { + case Some(index) => + logInfo(taskSet.tasks(index).toString) + val taskId = sched.attemptId.getAndIncrement() + val task = taskSet.tasks(index) + val info = new TaskInfo(taskId, index, System.currentTimeMillis(), "local", "local:1", TaskLocality.NODE_LOCAL) + taskInfos(taskId) = info + val bytes = Task.serializeWithDependencies(task, sched.sc.addedFiles, sched.sc.addedJars, ser) + logInfo("Size of task " + taskId + " is " + bytes.limit + " bytes") + val taskName = "task %s:%d".format(taskSet.id, index) + copiesRunning(index) += 1 + increaseRunningTasks(1) + return Some(new TaskDescription(taskId, null, taskName, bytes)) + case None => {} + } + } + return None + } + + def numPendingTasksForHostPort(hostPort: String): Int = { + return 0 + } + + def numRackLocalPendingTasksForHost(hostPort :String): Int = { + return 0 + } + + def numPendingTasksForHost(hostPort: String): Int = { + return 0 + } + + def statusUpdate(tid: Long, state: TaskState, serializedData: ByteBuffer) { + state match { + case TaskState.FINISHED => + taskEnded(tid, state, serializedData) + case TaskState.FAILED => + taskFailed(tid, state, serializedData) + case _ => {} + } + } + + def taskEnded(tid: Long, state: TaskState, serializedData: ByteBuffer) { + val info = taskInfos(tid) + val index = info.index + val task = taskSet.tasks(index) + info.markSuccessful() + val result = ser.deserialize[TaskResult[_]](serializedData, getClass.getClassLoader) + result.metrics.resultSize = serializedData.limit() + sched.listener.taskEnded(task, Success, result.value, result.accumUpdates, info, result.metrics) + numFinished += 1 + decreaseRunningTasks(1) + finished(index) = true + if (numFinished == numTasks) { + sched.taskSetFinished(this) + } + } + + def taskFailed(tid: Long, state: TaskState, serializedData: ByteBuffer) { + val info = taskInfos(tid) + val index = info.index + val task = taskSet.tasks(index) + info.markFailed() + decreaseRunningTasks(1) + val reason: ExceptionFailure = ser.deserialize[ExceptionFailure](serializedData, getClass.getClassLoader) + if (!finished(index)) { + copiesRunning(index) -= 1 + numFailures(index) += 1 + val locs = reason.stackTrace.map(loc => "\tat %s".format(loc.toString)) + logInfo("Loss was due to %s\n%s\n%s".format(reason.className, reason.description, locs.mkString("\n"))) + if (numFailures(index) > MAX_TASK_FAILURES) { + val errorMessage = "Task %s:%d failed more than %d times; aborting job %s".format(taskSet.id, index, 4, reason.description) + decreaseRunningTasks(runningTasks) + sched.listener.taskSetFailed(taskSet, errorMessage) + // need to delete failed Taskset from schedule queue + sched.taskSetFinished(this) + } + } + } + + def error(message: String) { + } +} diff --git a/core/src/test/scala/spark/scheduler/LocalSchedulerSuite.scala b/core/src/test/scala/spark/scheduler/LocalSchedulerSuite.scala new file mode 100644 index 0000000000..37d14ed113 --- /dev/null +++ b/core/src/test/scala/spark/scheduler/LocalSchedulerSuite.scala @@ -0,0 +1,171 @@ +package spark.scheduler + +import org.scalatest.FunSuite +import org.scalatest.BeforeAndAfter + +import spark._ +import spark.scheduler._ +import spark.scheduler.cluster._ +import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.{ConcurrentMap, HashMap} +import java.util.concurrent.Semaphore +import java.util.concurrent.atomic.AtomicBoolean +import java.util.concurrent.atomic.AtomicInteger + +import java.util.Properties + +class Lock() { + var finished = false + def jobWait() = { + synchronized { + while(!finished) { + this.wait() + } + } + } + + def jobFinished() = { + synchronized { + finished = true + this.notifyAll() + } + } +} + +object TaskThreadInfo { + val threadToLock = HashMap[Int, Lock]() + val threadToRunning = HashMap[Int, Boolean]() +} + + +class LocalSchedulerSuite extends FunSuite with LocalSparkContext { + + def createThread(threadIndex: Int, poolName: String, sc: SparkContext, sem: Semaphore) { + + TaskThreadInfo.threadToRunning(threadIndex) = false + val nums = sc.parallelize(threadIndex to threadIndex, 1) + TaskThreadInfo.threadToLock(threadIndex) = new Lock() + new Thread { + if (poolName != null) { + sc.addLocalProperties("spark.scheduler.cluster.fair.pool",poolName) + } + override def run() { + val ans = nums.map(number => { + TaskThreadInfo.threadToRunning(number) = true + TaskThreadInfo.threadToLock(number).jobWait() + number + }).collect() + assert(ans.toList === List(threadIndex)) + sem.release() + TaskThreadInfo.threadToRunning(threadIndex) = false + } + }.start() + Thread.sleep(2000) + } + + test("Local FIFO scheduler end-to-end test") { + System.setProperty("spark.cluster.schedulingmode", "FIFO") + sc = new SparkContext("local[4]", "test") + val sem = new Semaphore(0) + + createThread(1,null,sc,sem) + createThread(2,null,sc,sem) + createThread(3,null,sc,sem) + createThread(4,null,sc,sem) + createThread(5,null,sc,sem) + createThread(6,null,sc,sem) + assert(TaskThreadInfo.threadToRunning(1) === true) + assert(TaskThreadInfo.threadToRunning(2) === true) + assert(TaskThreadInfo.threadToRunning(3) === true) + assert(TaskThreadInfo.threadToRunning(4) === true) + assert(TaskThreadInfo.threadToRunning(5) === false) + assert(TaskThreadInfo.threadToRunning(6) === false) + + TaskThreadInfo.threadToLock(1).jobFinished() + Thread.sleep(1000) + + assert(TaskThreadInfo.threadToRunning(1) === false) + assert(TaskThreadInfo.threadToRunning(2) === true) + assert(TaskThreadInfo.threadToRunning(3) === true) + assert(TaskThreadInfo.threadToRunning(4) === true) + assert(TaskThreadInfo.threadToRunning(5) === true) + assert(TaskThreadInfo.threadToRunning(6) === false) + + TaskThreadInfo.threadToLock(3).jobFinished() + Thread.sleep(1000) + + assert(TaskThreadInfo.threadToRunning(1) === false) + assert(TaskThreadInfo.threadToRunning(2) === true) + assert(TaskThreadInfo.threadToRunning(3) === false) + assert(TaskThreadInfo.threadToRunning(4) === true) + assert(TaskThreadInfo.threadToRunning(5) === true) + assert(TaskThreadInfo.threadToRunning(6) === true) + + TaskThreadInfo.threadToLock(2).jobFinished() + TaskThreadInfo.threadToLock(4).jobFinished() + TaskThreadInfo.threadToLock(5).jobFinished() + TaskThreadInfo.threadToLock(6).jobFinished() + sem.acquire(6) + } + + test("Local fair scheduler end-to-end test") { + sc = new SparkContext("local[8]", "LocalSchedulerSuite") + val sem = new Semaphore(0) + System.setProperty("spark.cluster.schedulingmode", "FAIR") + val xmlPath = getClass.getClassLoader.getResource("fairscheduler.xml").getFile() + System.setProperty("spark.fairscheduler.allocation.file", xmlPath) + + createThread(10,"1",sc,sem) + createThread(20,"2",sc,sem) + createThread(30,"3",sc,sem) + + assert(TaskThreadInfo.threadToRunning(10) === true) + assert(TaskThreadInfo.threadToRunning(20) === true) + assert(TaskThreadInfo.threadToRunning(30) === true) + + createThread(11,"1",sc,sem) + createThread(21,"2",sc,sem) + createThread(31,"3",sc,sem) + + assert(TaskThreadInfo.threadToRunning(11) === true) + assert(TaskThreadInfo.threadToRunning(21) === true) + assert(TaskThreadInfo.threadToRunning(31) === true) + + createThread(12,"1",sc,sem) + createThread(22,"2",sc,sem) + createThread(32,"3",sc,sem) + + assert(TaskThreadInfo.threadToRunning(12) === true) + assert(TaskThreadInfo.threadToRunning(22) === true) + assert(TaskThreadInfo.threadToRunning(32) === false) + + TaskThreadInfo.threadToLock(10).jobFinished() + Thread.sleep(1000) + assert(TaskThreadInfo.threadToRunning(32) === true) + + createThread(23,"2",sc,sem) + createThread(33,"3",sc,sem) + + TaskThreadInfo.threadToLock(11).jobFinished() + Thread.sleep(1000) + + assert(TaskThreadInfo.threadToRunning(23) === true) + assert(TaskThreadInfo.threadToRunning(33) === false) + + TaskThreadInfo.threadToLock(12).jobFinished() + Thread.sleep(1000) + + assert(TaskThreadInfo.threadToRunning(33) === true) + + TaskThreadInfo.threadToLock(20).jobFinished() + TaskThreadInfo.threadToLock(21).jobFinished() + TaskThreadInfo.threadToLock(22).jobFinished() + TaskThreadInfo.threadToLock(23).jobFinished() + TaskThreadInfo.threadToLock(30).jobFinished() + TaskThreadInfo.threadToLock(31).jobFinished() + TaskThreadInfo.threadToLock(32).jobFinished() + TaskThreadInfo.threadToLock(33).jobFinished() + + sem.acquire(11) + } +} -- cgit v1.2.3 From 926f41cc522def181c167b71dc919a0759c5d3f6 Mon Sep 17 00:00:00 2001 From: jerryshao Date: Thu, 30 May 2013 17:55:11 +0800 Subject: fix block manager UI display issue when enable spark.cleaner.ttl --- core/src/main/scala/spark/storage/StorageUtils.scala | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/core/src/main/scala/spark/storage/StorageUtils.scala b/core/src/main/scala/spark/storage/StorageUtils.scala index 8f52168c24..81e607868d 100644 --- a/core/src/main/scala/spark/storage/StorageUtils.scala +++ b/core/src/main/scala/spark/storage/StorageUtils.scala @@ -64,12 +64,12 @@ object StorageUtils { // Find the id of the RDD, e.g. rdd_1 => 1 val rddId = rddKey.split("_").last.toInt // Get the friendly name for the rdd, if available. - val rdd = sc.persistentRdds(rddId) - val rddName = Option(rdd.name).getOrElse(rddKey) - val rddStorageLevel = rdd.getStorageLevel - - RDDInfo(rddId, rddName, rddStorageLevel, rddBlocks.length, rdd.partitions.size, memSize, diskSize) - }.toArray + sc.persistentRdds.get(rddId).map { r => + val rddName = Option(r.name).getOrElse(rddKey) + val rddStorageLevel = r.getStorageLevel + RDDInfo(rddId, rddName, rddStorageLevel, rddBlocks.length, r.partitions.size, memSize, diskSize) + } + }.flatMap(x => x).toArray scala.util.Sorting.quickSort(rddInfos) -- cgit v1.2.3 From 9f84315c055d7a53da8787eb26b336726fc33e8a Mon Sep 17 00:00:00 2001 From: Gavin Li Date: Sat, 1 Jun 2013 00:26:10 +0000 Subject: enhance pipe to support what we can do in hadoop streaming --- core/src/main/scala/spark/RDD.scala | 18 ++++++++++++++++++ core/src/main/scala/spark/rdd/PipedRDD.scala | 25 +++++++++++++++++++++++-- 2 files changed, 41 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala index dde131696f..5a41db23c2 100644 --- a/core/src/main/scala/spark/RDD.scala +++ b/core/src/main/scala/spark/RDD.scala @@ -348,17 +348,35 @@ abstract class RDD[T: ClassManifest]( */ def pipe(command: String): RDD[String] = new PipedRDD(this, command) + /** + * Return an RDD created by piping elements to a forked external process. + */ + def pipe(command: String, transform: (T,String => Unit) => Any, arguments: Seq[String]): RDD[String] = + new PipedRDD(this, command, transform, arguments) + /** * Return an RDD created by piping elements to a forked external process. */ def pipe(command: Seq[String]): RDD[String] = new PipedRDD(this, command) + /** + * Return an RDD created by piping elements to a forked external process. + */ + def pipe(command: Seq[String], transform: (T,String => Unit) => Any, arguments: Seq[String]): RDD[String] = + new PipedRDD(this, command, transform, arguments) + /** * Return an RDD created by piping elements to a forked external process. */ def pipe(command: Seq[String], env: Map[String, String]): RDD[String] = new PipedRDD(this, command, env) + /** + * Return an RDD created by piping elements to a forked external process. + */ + def pipe(command: Seq[String], env: Map[String, String], transform: (T,String => Unit) => Any, arguments: Seq[String]): RDD[String] = + new PipedRDD(this, command, env, transform, arguments) + /** * Return a new RDD by applying a function to each partition of this RDD. */ diff --git a/core/src/main/scala/spark/rdd/PipedRDD.scala b/core/src/main/scala/spark/rdd/PipedRDD.scala index 962a1b21ad..969404c95f 100644 --- a/core/src/main/scala/spark/rdd/PipedRDD.scala +++ b/core/src/main/scala/spark/rdd/PipedRDD.scala @@ -18,14 +18,21 @@ import spark.{RDD, SparkEnv, Partition, TaskContext} class PipedRDD[T: ClassManifest]( prev: RDD[T], command: Seq[String], - envVars: Map[String, String]) + envVars: Map[String, String], + transform: (T, String => Unit) => Any, + arguments: Seq[String] + ) extends RDD[String](prev) { + def this(prev: RDD[T], command: Seq[String], envVars : Map[String, String]) = this(prev, command, envVars, null, null) def this(prev: RDD[T], command: Seq[String]) = this(prev, command, Map()) + def this(prev: RDD[T], command: Seq[String], transform: (T,String => Unit) => Any, arguments: Seq[String]) = this(prev, command, Map(), transform, arguments) // Similar to Runtime.exec(), if we are given a single string, split it into words // using a standard StringTokenizer (i.e. by spaces) def this(prev: RDD[T], command: String) = this(prev, PipedRDD.tokenize(command)) + def this(prev: RDD[T], command: String, transform: (T,String => Unit) => Any, arguments: Seq[String]) = this(prev, PipedRDD.tokenize(command), Map(), transform, arguments) + override def getPartitions: Array[Partition] = firstParent[T].partitions @@ -52,8 +59,22 @@ class PipedRDD[T: ClassManifest]( override def run() { SparkEnv.set(env) val out = new PrintWriter(proc.getOutputStream) + + // input the arguments firstly + if ( arguments != null) { + for (elem <- arguments) { + out.println(elem) + } + // ^A \n as the marker of the end of the arguments + out.println("\u0001") + } for (elem <- firstParent[T].iterator(split, context)) { - out.println(elem) + if (transform != null) { + transform(elem, out.println(_)) + } + else { + out.println(elem) + } } out.close() } -- cgit v1.2.3 From 91aca9224936da84b16ea789cb81914579a0db03 Mon Sep 17 00:00:00 2001 From: Shivaram Venkataraman Date: Fri, 31 May 2013 23:21:38 -0700 Subject: Another round of Netty fixes. 1. Avoid race condition between stop and copier completion 2. Handle socket exceptions by reporting them and filling in a failed FetchResult --- .../main/java/spark/network/netty/FileClient.java | 24 +++------ .../spark/network/netty/FileClientHandler.java | 8 +++ .../scala/spark/network/netty/ShuffleCopier.scala | 62 ++++++++++++++-------- .../scala/spark/storage/BlockFetcherIterator.scala | 9 ++-- 4 files changed, 58 insertions(+), 45 deletions(-) diff --git a/core/src/main/java/spark/network/netty/FileClient.java b/core/src/main/java/spark/network/netty/FileClient.java index 3a62dacbc8..9c9b976ebe 100644 --- a/core/src/main/java/spark/network/netty/FileClient.java +++ b/core/src/main/java/spark/network/netty/FileClient.java @@ -8,9 +8,12 @@ import io.netty.channel.ChannelOption; import io.netty.channel.oio.OioEventLoopGroup; import io.netty.channel.socket.oio.OioSocketChannel; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; class FileClient { + private Logger LOG = LoggerFactory.getLogger(this.getClass().getName()); private FileClientHandler handler = null; private Channel channel = null; private Bootstrap bootstrap = null; @@ -25,25 +28,10 @@ class FileClient { .channel(OioSocketChannel.class) .option(ChannelOption.SO_KEEPALIVE, true) .option(ChannelOption.TCP_NODELAY, true) + .option(ChannelOption.CONNECT_TIMEOUT_MILLIS, 0) // Disable connect timeout .handler(new FileClientChannelInitializer(handler)); } - public static final class ChannelCloseListener implements ChannelFutureListener { - private FileClient fc = null; - - public ChannelCloseListener(FileClient fc){ - this.fc = fc; - } - - @Override - public void operationComplete(ChannelFuture future) { - if (fc.bootstrap!=null){ - fc.bootstrap.shutdown(); - fc.bootstrap = null; - } - } - } - public void connect(String host, int port) { try { // Start the connection attempt. @@ -58,8 +46,8 @@ class FileClient { public void waitForClose() { try { channel.closeFuture().sync(); - } catch (InterruptedException e){ - e.printStackTrace(); + } catch (InterruptedException e) { + LOG.warn("FileClient interrupted", e); } } diff --git a/core/src/main/java/spark/network/netty/FileClientHandler.java b/core/src/main/java/spark/network/netty/FileClientHandler.java index 2069dee5ca..9fc9449827 100644 --- a/core/src/main/java/spark/network/netty/FileClientHandler.java +++ b/core/src/main/java/spark/network/netty/FileClientHandler.java @@ -9,7 +9,14 @@ abstract class FileClientHandler extends ChannelInboundByteHandlerAdapter { private FileHeader currentHeader = null; + private volatile boolean handlerCalled = false; + + public boolean isComplete() { + return handlerCalled; + } + public abstract void handle(ChannelHandlerContext ctx, ByteBuf in, FileHeader header); + public abstract void handleError(String blockId); @Override public ByteBuf newInboundBuffer(ChannelHandlerContext ctx) { @@ -26,6 +33,7 @@ abstract class FileClientHandler extends ChannelInboundByteHandlerAdapter { // get file if(in.readableBytes() >= currentHeader.fileLen()) { handle(ctx, in, currentHeader); + handlerCalled = true; currentHeader = null; ctx.close(); } diff --git a/core/src/main/scala/spark/network/netty/ShuffleCopier.scala b/core/src/main/scala/spark/network/netty/ShuffleCopier.scala index a91f5a886d..8ec46d42fa 100644 --- a/core/src/main/scala/spark/network/netty/ShuffleCopier.scala +++ b/core/src/main/scala/spark/network/netty/ShuffleCopier.scala @@ -9,19 +9,35 @@ import io.netty.util.CharsetUtil import spark.Logging import spark.network.ConnectionManagerId +import scala.collection.JavaConverters._ + private[spark] class ShuffleCopier extends Logging { - def getBlock(cmId: ConnectionManagerId, blockId: String, + def getBlock(host: String, port: Int, blockId: String, resultCollectCallback: (String, Long, ByteBuf) => Unit) { val handler = new ShuffleCopier.ShuffleClientHandler(resultCollectCallback) val fc = new FileClient(handler) - fc.init() - fc.connect(cmId.host, cmId.port) - fc.sendRequest(blockId) - fc.waitForClose() - fc.close() + try { + fc.init() + fc.connect(host, port) + fc.sendRequest(blockId) + fc.waitForClose() + fc.close() + } catch { + // Handle any socket-related exceptions in FileClient + case e: Exception => { + logError("Shuffle copy of block " + blockId + " from " + host + ":" + port + + " failed", e) + handler.handleError(blockId) + } + } + } + + def getBlock(cmId: ConnectionManagerId, blockId: String, + resultCollectCallback: (String, Long, ByteBuf) => Unit) { + getBlock(cmId.host, cmId.port, blockId, resultCollectCallback) } def getBlocks(cmId: ConnectionManagerId, @@ -44,20 +60,18 @@ private[spark] object ShuffleCopier extends Logging { logDebug("Received Block: " + header.blockId + " (" + header.fileLen + "B)"); resultCollectCallBack(header.blockId, header.fileLen.toLong, in.readBytes(header.fileLen)) } - } - def echoResultCollectCallBack(blockId: String, size: Long, content: ByteBuf) { - logInfo("File: " + blockId + " content is : \" " + content.toString(CharsetUtil.UTF_8) + "\"") + override def handleError(blockId: String) { + if (!isComplete) { + resultCollectCallBack(blockId, -1, null) + } + } } - def runGetBlock(host:String, port:Int, file:String){ - val handler = new ShuffleClientHandler(echoResultCollectCallBack) - val fc = new FileClient(handler) - fc.init(); - fc.connect(host, port) - fc.sendRequest(file) - fc.waitForClose(); - fc.close() + def echoResultCollectCallBack(blockId: String, size: Long, content: ByteBuf) { + if (size != -1) { + logInfo("File: " + blockId + " content is : \" " + content.toString(CharsetUtil.UTF_8) + "\"") + } } def main(args: Array[String]) { @@ -71,14 +85,16 @@ private[spark] object ShuffleCopier extends Logging { val threads = if (args.length > 3) args(3).toInt else 10 val copiers = Executors.newFixedThreadPool(80) - for (i <- Range(0, threads)) { - val runnable = new Runnable() { + val tasks = (for (i <- Range(0, threads)) yield { + Executors.callable(new Runnable() { def run() { - runGetBlock(host, port, file) + val copier = new ShuffleCopier() + copier.getBlock(host, port, file, echoResultCollectCallBack) } - } - copiers.execute(runnable) - } + }) + }).asJava + copiers.invokeAll(tasks) copiers.shutdown + System.exit(0) } } diff --git a/core/src/main/scala/spark/storage/BlockFetcherIterator.scala b/core/src/main/scala/spark/storage/BlockFetcherIterator.scala index 1d69d658f7..fac416a5b3 100644 --- a/core/src/main/scala/spark/storage/BlockFetcherIterator.scala +++ b/core/src/main/scala/spark/storage/BlockFetcherIterator.scala @@ -265,7 +265,7 @@ object BlockFetcherIterator { }).toList } - //keep this to interrupt the threads when necessary + // keep this to interrupt the threads when necessary private def stopCopiers() { for (copier <- copiers) { copier.interrupt() @@ -312,9 +312,10 @@ object BlockFetcherIterator { resultsGotten += 1 val result = results.take() // if all the results has been retrieved, shutdown the copiers - if (resultsGotten == _totalBlocks && copiers != null) { - stopCopiers() - } + // NO need to stop the copiers if we got all the blocks ? + // if (resultsGotten == _totalBlocks && copiers != null) { + // stopCopiers() + // } (result.blockId, if (result.failed) None else Some(result.deserialize())) } } -- cgit v1.2.3 From 038cfc1a9acb32f8c17d883ea64f8cbb324ed82c Mon Sep 17 00:00:00 2001 From: Shivaram Venkataraman Date: Fri, 31 May 2013 23:32:18 -0700 Subject: Make connect timeout configurable --- core/src/main/java/spark/network/netty/FileClient.java | 6 ++++-- core/src/main/scala/spark/network/netty/ShuffleCopier.scala | 3 ++- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/core/src/main/java/spark/network/netty/FileClient.java b/core/src/main/java/spark/network/netty/FileClient.java index 9c9b976ebe..517772202f 100644 --- a/core/src/main/java/spark/network/netty/FileClient.java +++ b/core/src/main/java/spark/network/netty/FileClient.java @@ -17,9 +17,11 @@ class FileClient { private FileClientHandler handler = null; private Channel channel = null; private Bootstrap bootstrap = null; + private int connectTimeout = 60*1000; // 1 min - public FileClient(FileClientHandler handler) { + public FileClient(FileClientHandler handler, int connectTimeout) { this.handler = handler; + this.connectTimeout = connectTimeout; } public void init() { @@ -28,7 +30,7 @@ class FileClient { .channel(OioSocketChannel.class) .option(ChannelOption.SO_KEEPALIVE, true) .option(ChannelOption.TCP_NODELAY, true) - .option(ChannelOption.CONNECT_TIMEOUT_MILLIS, 0) // Disable connect timeout + .option(ChannelOption.CONNECT_TIMEOUT_MILLIS, connectTimeout) // Disable connect timeout .handler(new FileClientChannelInitializer(handler)); } diff --git a/core/src/main/scala/spark/network/netty/ShuffleCopier.scala b/core/src/main/scala/spark/network/netty/ShuffleCopier.scala index 8ec46d42fa..afb2cdbb3a 100644 --- a/core/src/main/scala/spark/network/netty/ShuffleCopier.scala +++ b/core/src/main/scala/spark/network/netty/ShuffleCopier.scala @@ -18,7 +18,8 @@ private[spark] class ShuffleCopier extends Logging { resultCollectCallback: (String, Long, ByteBuf) => Unit) { val handler = new ShuffleCopier.ShuffleClientHandler(resultCollectCallback) - val fc = new FileClient(handler) + val fc = new FileClient(handler, + System.getProperty("spark.shuffle.netty.connect.timeout", "60000").toInt) try { fc.init() fc.connect(host, port) -- cgit v1.2.3 From 3be7bdcefda13d67633f9b9f6d901722fd5649de Mon Sep 17 00:00:00 2001 From: Rohit Rai Date: Sat, 1 Jun 2013 19:32:17 +0530 Subject: Adding example to make Spark RDD from Cassandra --- .../main/scala/spark/examples/CassandraTest.scala | 154 +++++++++++++++++++++ project/SparkBuild.scala | 4 +- 2 files changed, 156 insertions(+), 2 deletions(-) create mode 100644 examples/src/main/scala/spark/examples/CassandraTest.scala diff --git a/examples/src/main/scala/spark/examples/CassandraTest.scala b/examples/src/main/scala/spark/examples/CassandraTest.scala new file mode 100644 index 0000000000..790b24e6f3 --- /dev/null +++ b/examples/src/main/scala/spark/examples/CassandraTest.scala @@ -0,0 +1,154 @@ +package spark.examples + +import org.apache.hadoop.mapreduce.Job +import org.apache.cassandra.hadoop.{ConfigHelper, ColumnFamilyInputFormat} +import org.apache.cassandra.thrift.{IndexExpression, SliceRange, SlicePredicate} +import spark.{RDD, SparkContext} +import SparkContext._ +import java.nio.ByteBuffer +import java.util.SortedMap +import org.apache.cassandra.db.IColumn +import org.apache.cassandra.utils.ByteBufferUtil +import scala.collection.JavaConversions._ + + +/* + * This example demonstrates using Spark with Cassandra with the New Hadoop API and Cassandra support for Hadoop. + * + * To run this example, run this file with the following command params - + * + * + * So if you want to run this on localhost this will be, + * local[3] localhost 9160 + * + * The example makes some assumptions: + * 1. You have already created a keyspace called casDemo and it has a column family named Words + * 2. There are column family has a column named "para" which has test content. + * + * You can create the content by running the following script at the bottom of this file with cassandra-cli. + * + */ +object CassandraTest { + def main(args: Array[String]) { + + //Get a SparkContext + val sc = new SparkContext(args(0), "casDemo") + + //Build the job configuration with ConfigHelper provided by Cassandra + val job = new Job() + job.setInputFormatClass(classOf[ColumnFamilyInputFormat]) + + ConfigHelper.setInputInitialAddress(job.getConfiguration(), args(1)) + + ConfigHelper.setInputRpcPort(job.getConfiguration(), args(2)) + + ConfigHelper.setInputColumnFamily(job.getConfiguration(), "casDemo", "Words") + + val predicate = new SlicePredicate() + val sliceRange = new SliceRange() + sliceRange.setStart(Array.empty[Byte]) + sliceRange.setFinish(Array.empty[Byte]) + predicate.setSlice_range(sliceRange) + ConfigHelper.setInputSlicePredicate(job.getConfiguration(), predicate) + + ConfigHelper.setInputPartitioner(job.getConfiguration(), "Murmur3Partitioner") + + //Make a new Hadoop RDD + val casRdd = sc.newAPIHadoopRDD(job.getConfiguration(), + classOf[ColumnFamilyInputFormat], + classOf[ByteBuffer], + classOf[SortedMap[ByteBuffer, IColumn]]) + + // Let us first get all the paragraphs from the retrieved rows + val paraRdd = casRdd flatMap { + case (key, value) => { + value.filter(v => ByteBufferUtil.string(v._1).compareTo("para") == 0).map(v => ByteBufferUtil.string(v._2.value())) + } + } + + //Lets get the word count in paras + val counts = paraRdd.flatMap(p => p.split(" ")).map(word => (word, 1)).reduceByKey(_ + _) + + counts.collect() foreach { + case(word, count) => println(word + ":" + count) + } + } +} + +/* +create keyspace casDemo; +use casDemo; + +create column family Words with comparator = UTF8Type; +update column family Words with column_metadata = [{column_name: book, validation_class: UTF8Type}, {column_name: para, validation_class: UTF8Type}]; + +assume Words keys as utf8; + +set Words['3musk001']['book'] = 'The Three Musketeers'; +set Words['3musk001']['para'] = 'On the first Monday of the month of April, 1625, the market town of + Meung, in which the author of ROMANCE OF THE ROSE was born, appeared to + be in as perfect a state of revolution as if the Huguenots had just made + a second La Rochelle of it. Many citizens, seeing the women flying + toward the High Street, leaving their children crying at the open doors, + hastened to don the cuirass, and supporting their somewhat uncertain + courage with a musket or a partisan, directed their steps toward the + hostelry of the Jolly Miller, before which was gathered, increasing + every minute, a compact group, vociferous and full of curiosity.'; + +set Words['3musk002']['book'] = 'The Three Musketeers'; +set Words['3musk002']['para'] = 'In those times panics were common, and few days passed without some city + or other registering in its archives an event of this kind. There were + nobles, who made war against each other; there was the king, who made + war against the cardinal; there was Spain, which made war against the + king. Then, in addition to these concealed or public, secret or open + wars, there were robbers, mendicants, Huguenots, wolves, and scoundrels, + who made war upon everybody. The citizens always took up arms readily + against thieves, wolves or scoundrels, often against nobles or + Huguenots, sometimes against the king, but never against cardinal or + Spain. It resulted, then, from this habit that on the said first Monday + of April, 1625, the citizens, on hearing the clamor, and seeing neither + the red-and-yellow standard nor the livery of the Duc de Richelieu, + rushed toward the hostel of the Jolly Miller. When arrived there, the + cause of the hubbub was apparent to all'; + +set Words['3musk003']['book'] = 'The Three Musketeers'; +set Words['3musk003']['para'] = 'You ought, I say, then, to husband the means you have, however large + the sum may be; but you ought also to endeavor to perfect yourself in + the exercises becoming a gentleman. I will write a letter today to the + Director of the Royal Academy, and tomorrow he will admit you without + any expense to yourself. Do not refuse this little service. Our + best-born and richest gentlemen sometimes solicit it without being able + to obtain it. You will learn horsemanship, swordsmanship in all its + branches, and dancing. You will make some desirable acquaintances; and + from time to time you can call upon me, just to tell me how you are + getting on, and to say whether I can be of further service to you.'; + + +set Words['thelostworld001']['book'] = 'The Lost World'; +set Words['thelostworld001']['para'] = 'She sat with that proud, delicate profile of hers outlined against the + red curtain. How beautiful she was! And yet how aloof! We had been + friends, quite good friends; but never could I get beyond the same + comradeship which I might have established with one of my + fellow-reporters upon the Gazette,--perfectly frank, perfectly kindly, + and perfectly unsexual. My instincts are all against a woman being too + frank and at her ease with me. It is no compliment to a man. Where + the real sex feeling begins, timidity and distrust are its companions, + heritage from old wicked days when love and violence went often hand in + hand. The bent head, the averted eye, the faltering voice, the wincing + figure--these, and not the unshrinking gaze and frank reply, are the + true signals of passion. Even in my short life I had learned as much + as that--or had inherited it in that race memory which we call instinct.'; + +set Words['thelostworld002']['book'] = 'The Lost World'; +set Words['thelostworld002']['para'] = 'I always liked McArdle, the crabbed, old, round-backed, red-headed news + editor, and I rather hoped that he liked me. Of course, Beaumont was + the real boss; but he lived in the rarefied atmosphere of some Olympian + height from which he could distinguish nothing smaller than an + international crisis or a split in the Cabinet. Sometimes we saw him + passing in lonely majesty to his inner sanctum, with his eyes staring + vaguely and his mind hovering over the Balkans or the Persian Gulf. He + was above and beyond us. But McArdle was his first lieutenant, and it + was he that we knew. The old man nodded as I entered the room, and he + pushed his spectacles far up on his bald forehead.'; + +*/ diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 0ea23b446f..5152b7b79b 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -201,8 +201,8 @@ object SparkBuild extends Build { def examplesSettings = sharedSettings ++ Seq( name := "spark-examples", - libraryDependencies ++= Seq("com.twitter" % "algebird-core_2.9.2" % "0.1.11") - ) + libraryDependencies ++= Seq("com.twitter" % "algebird-core_2.9.2" % "0.1.11", + "org.apache.cassandra" % "cassandra-all" % "1.2.5" exclude("com.google.guava", "guava") exclude("com.googlecode.concurrentlinkedhashmap", "concurrentlinkedhashmap-lru:1.3") exclude("com.ning","compress-lzf") exclude("io.netty","netty") exclude("jline","jline") exclude("log4j","log4j") exclude("org.apache.cassandra.deps", "avro"))) def bagelSettings = sharedSettings ++ Seq(name := "spark-bagel") -- cgit v1.2.3 From 81c2adc15c9e232846d4ad0adf14d007039409fa Mon Sep 17 00:00:00 2001 From: Rohit Rai Date: Sun, 2 Jun 2013 12:51:15 +0530 Subject: Removing infix call --- examples/src/main/scala/spark/examples/CassandraTest.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/src/main/scala/spark/examples/CassandraTest.scala b/examples/src/main/scala/spark/examples/CassandraTest.scala index 790b24e6f3..49b940d8a7 100644 --- a/examples/src/main/scala/spark/examples/CassandraTest.scala +++ b/examples/src/main/scala/spark/examples/CassandraTest.scala @@ -60,7 +60,7 @@ object CassandraTest { classOf[SortedMap[ByteBuffer, IColumn]]) // Let us first get all the paragraphs from the retrieved rows - val paraRdd = casRdd flatMap { + val paraRdd = casRdd.flatMap { case (key, value) => { value.filter(v => ByteBufferUtil.string(v._1).compareTo("para") == 0).map(v => ByteBufferUtil.string(v._2.value())) } @@ -69,8 +69,8 @@ object CassandraTest { //Lets get the word count in paras val counts = paraRdd.flatMap(p => p.split(" ")).map(word => (word, 1)).reduceByKey(_ + _) - counts.collect() foreach { - case(word, count) => println(word + ":" + count) + counts.collect().foreach { + case (word, count) => println(word + ":" + count) } } } -- cgit v1.2.3 From 6d8423fd1b490d541f0ea379068b8954002d624f Mon Sep 17 00:00:00 2001 From: Rohit Rai Date: Sun, 2 Jun 2013 13:03:45 +0530 Subject: Adding deps to examples/pom.xml Fixing exclusion in examples deps in SparkBuild.scala --- examples/pom.xml | 35 +++++++++++++++++++++++++++++++++++ project/SparkBuild.scala | 2 +- 2 files changed, 36 insertions(+), 1 deletion(-) diff --git a/examples/pom.xml b/examples/pom.xml index c42d2bcdb9..b4c5251d68 100644 --- a/examples/pom.xml +++ b/examples/pom.xml @@ -34,6 +34,41 @@ scalacheck_${scala.version} test + + org.apache.cassandra + cassandra-all + 1.2.5 + + + com.google.guava + guava + + + com.googlecode.concurrentlinkedhashmap + concurrentlinkedhashmap-lru + + + com.ning + compress-lzf + + + io.netty + netty + + + jline + jline + + + log4j + log4j + + + org.apache.cassandra.deps + avro + + + target/scala-${scala.version}/classes diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 5152b7b79b..7f3e223c2e 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -202,7 +202,7 @@ object SparkBuild extends Build { def examplesSettings = sharedSettings ++ Seq( name := "spark-examples", libraryDependencies ++= Seq("com.twitter" % "algebird-core_2.9.2" % "0.1.11", - "org.apache.cassandra" % "cassandra-all" % "1.2.5" exclude("com.google.guava", "guava") exclude("com.googlecode.concurrentlinkedhashmap", "concurrentlinkedhashmap-lru:1.3") exclude("com.ning","compress-lzf") exclude("io.netty","netty") exclude("jline","jline") exclude("log4j","log4j") exclude("org.apache.cassandra.deps", "avro"))) + "org.apache.cassandra" % "cassandra-all" % "1.2.5" exclude("com.google.guava", "guava") exclude("com.googlecode.concurrentlinkedhashmap", "concurrentlinkedhashmap-lru") exclude("com.ning","compress-lzf") exclude("io.netty","netty") exclude("jline","jline") exclude("log4j","log4j") exclude("org.apache.cassandra.deps", "avro"))) def bagelSettings = sharedSettings ++ Seq(name := "spark-bagel") -- cgit v1.2.3 From 4a9913d66a61ac9ef9cab0e08f6151dc2624fd11 Mon Sep 17 00:00:00 2001 From: Gavin Li Date: Sun, 2 Jun 2013 23:21:09 +0000 Subject: add ut for pipe enhancement --- core/src/test/scala/spark/PipedRDDSuite.scala | 31 +++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/core/src/test/scala/spark/PipedRDDSuite.scala b/core/src/test/scala/spark/PipedRDDSuite.scala index a6344edf8f..ee55952a94 100644 --- a/core/src/test/scala/spark/PipedRDDSuite.scala +++ b/core/src/test/scala/spark/PipedRDDSuite.scala @@ -19,6 +19,37 @@ class PipedRDDSuite extends FunSuite with LocalSparkContext { assert(c(3) === "4") } + test("advanced pipe") { + sc = new SparkContext("local", "test") + val nums = sc.makeRDD(Array(1, 2, 3, 4), 2) + + val piped = nums.pipe(Seq("cat"), (i:Int, f: String=> Unit) => f(i + "_"), Array("0")) + + val c = piped.collect() + + assert(c.size === 8) + assert(c(0) === "0") + assert(c(1) === "\u0001") + assert(c(2) === "1_") + assert(c(3) === "2_") + assert(c(4) === "0") + assert(c(5) === "\u0001") + assert(c(6) === "3_") + assert(c(7) === "4_") + + val nums1 = sc.makeRDD(Array("a\t1", "b\t2", "a\t3", "b\t4"), 2) + val d = nums1.groupBy(str=>str.split("\t")(0)).pipe(Seq("cat"), (i:Tuple2[String, Seq[String]], f: String=> Unit) => {for (e <- i._2){ f(e + "_")}}, Array("0")).collect() + assert(d.size === 8) + assert(d(0) === "0") + assert(d(1) === "\u0001") + assert(d(2) === "b\t2_") + assert(d(3) === "b\t4_") + assert(d(4) === "0") + assert(d(5) === "\u0001") + assert(d(6) === "a\t1_") + assert(d(7) === "a\t3_") + } + test("pipe with env variable") { sc = new SparkContext("local", "test") val nums = sc.makeRDD(Array(1, 2, 3, 4), 2) -- cgit v1.2.3 From 606bb1b450064a2b909e4275ce45325dbbef4eca Mon Sep 17 00:00:00 2001 From: Andrew xia Date: Fri, 31 May 2013 15:40:41 +0800 Subject: Fix schedulingAlgorithm bugs for unit test --- .../spark/scheduler/cluster/SchedulingAlgorithm.scala | 17 +++++++++++++---- .../scala/spark/scheduler/ClusterSchedulerSuite.scala | 9 ++++++--- 2 files changed, 19 insertions(+), 7 deletions(-) diff --git a/core/src/main/scala/spark/scheduler/cluster/SchedulingAlgorithm.scala b/core/src/main/scala/spark/scheduler/cluster/SchedulingAlgorithm.scala index a5d6285c99..13120edf63 100644 --- a/core/src/main/scala/spark/scheduler/cluster/SchedulingAlgorithm.scala +++ b/core/src/main/scala/spark/scheduler/cluster/SchedulingAlgorithm.scala @@ -40,15 +40,24 @@ private[spark] class FairSchedulingAlgorithm extends SchedulingAlgorithm { val taskToWeightRatio1 = runningTasks1.toDouble / s1.weight.toDouble val taskToWeightRatio2 = runningTasks2.toDouble / s2.weight.toDouble var res:Boolean = true + var compare:Int = 0 if (s1Needy && !s2Needy) { - res = true + return true } else if (!s1Needy && s2Needy) { - res = false + return false } else if (s1Needy && s2Needy) { - res = minShareRatio1 <= minShareRatio2 + compare = minShareRatio1.compareTo(minShareRatio2) + } else { + compare = taskToWeightRatio1.compareTo(taskToWeightRatio2) + } + + if (compare < 0) { + res = true + } else if (compare > 0) { + res = false } else { - res = taskToWeightRatio1 <= taskToWeightRatio2 + return s1.name < s2.name } return res } diff --git a/core/src/test/scala/spark/scheduler/ClusterSchedulerSuite.scala b/core/src/test/scala/spark/scheduler/ClusterSchedulerSuite.scala index a39418b716..c861597c6b 100644 --- a/core/src/test/scala/spark/scheduler/ClusterSchedulerSuite.scala +++ b/core/src/test/scala/spark/scheduler/ClusterSchedulerSuite.scala @@ -88,7 +88,7 @@ class DummyTask(stageId: Int) extends Task[Int](stageId) } } -class ClusterSchedulerSuite extends FunSuite with LocalSparkContext { +class ClusterSchedulerSuite extends FunSuite with LocalSparkContext with Logging { def createDummyTaskSetManager(priority: Int, stage: Int, numTasks: Int, cs: ClusterScheduler, taskSet: TaskSet): DummyTaskSetManager = { new DummyTaskSetManager(priority, stage, numTasks, cs , taskSet) @@ -96,8 +96,11 @@ class ClusterSchedulerSuite extends FunSuite with LocalSparkContext { def resourceOffer(rootPool: Pool): Int = { val taskSetQueue = rootPool.getSortedTaskSetQueue() - for (taskSet <- taskSetQueue) - { + /* Just for Test*/ + for (manager <- taskSetQueue) { + logInfo("parentName:%s, parent running tasks:%d, name:%s,runningTasks:%d".format(manager.parent.name, manager.parent.runningTasks, manager.name, manager.runningTasks)) + } + for (taskSet <- taskSetQueue) { taskSet.slaveOffer("execId_1", "hostname_1", 1) match { case Some(task) => return taskSet.stageId -- cgit v1.2.3 From 56c64c403383e90a5fd33b6a1f72527377d9bee0 Mon Sep 17 00:00:00 2001 From: Rohit Rai Date: Mon, 3 Jun 2013 12:48:35 +0530 Subject: A better way to read column value if you are sure the column exists in every row. --- examples/src/main/scala/spark/examples/CassandraTest.scala | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/examples/src/main/scala/spark/examples/CassandraTest.scala b/examples/src/main/scala/spark/examples/CassandraTest.scala index 49b940d8a7..6b9fd502e2 100644 --- a/examples/src/main/scala/spark/examples/CassandraTest.scala +++ b/examples/src/main/scala/spark/examples/CassandraTest.scala @@ -10,6 +10,8 @@ import java.util.SortedMap import org.apache.cassandra.db.IColumn import org.apache.cassandra.utils.ByteBufferUtil import scala.collection.JavaConversions._ +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; /* @@ -60,9 +62,9 @@ object CassandraTest { classOf[SortedMap[ByteBuffer, IColumn]]) // Let us first get all the paragraphs from the retrieved rows - val paraRdd = casRdd.flatMap { + val paraRdd = casRdd.map { case (key, value) => { - value.filter(v => ByteBufferUtil.string(v._1).compareTo("para") == 0).map(v => ByteBufferUtil.string(v._2.value())) + ByteBufferUtil.string(value.get(ByteBufferUtil.bytes("para")).value()) } } -- cgit v1.2.3 From b104c7f5c7e2b173fe1b10035efbc00e43df13ec Mon Sep 17 00:00:00 2001 From: Rohit Rai Date: Mon, 3 Jun 2013 15:15:52 +0530 Subject: Example to write the output to cassandra --- .../main/scala/spark/examples/CassandraTest.scala | 48 +++++++++++++++++++--- 1 file changed, 43 insertions(+), 5 deletions(-) diff --git a/examples/src/main/scala/spark/examples/CassandraTest.scala b/examples/src/main/scala/spark/examples/CassandraTest.scala index 6b9fd502e2..2cc62b9fe9 100644 --- a/examples/src/main/scala/spark/examples/CassandraTest.scala +++ b/examples/src/main/scala/spark/examples/CassandraTest.scala @@ -1,17 +1,16 @@ package spark.examples import org.apache.hadoop.mapreduce.Job -import org.apache.cassandra.hadoop.{ConfigHelper, ColumnFamilyInputFormat} -import org.apache.cassandra.thrift.{IndexExpression, SliceRange, SlicePredicate} +import org.apache.cassandra.hadoop.{ColumnFamilyOutputFormat, ConfigHelper, ColumnFamilyInputFormat} +import org.apache.cassandra.thrift._ import spark.{RDD, SparkContext} -import SparkContext._ +import spark.SparkContext._ import java.nio.ByteBuffer import java.util.SortedMap import org.apache.cassandra.db.IColumn import org.apache.cassandra.utils.ByteBufferUtil import scala.collection.JavaConversions._ -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; + /* @@ -44,8 +43,15 @@ object CassandraTest { ConfigHelper.setInputRpcPort(job.getConfiguration(), args(2)) + ConfigHelper.setOutputInitialAddress(job.getConfiguration(), args(1)) + + ConfigHelper.setOutputRpcPort(job.getConfiguration(), args(2)) + ConfigHelper.setInputColumnFamily(job.getConfiguration(), "casDemo", "Words") + ConfigHelper.setOutputColumnFamily(job.getConfiguration(), "casDemo", "WordCount") + + val predicate = new SlicePredicate() val sliceRange = new SliceRange() sliceRange.setStart(Array.empty[Byte]) @@ -55,6 +61,8 @@ object CassandraTest { ConfigHelper.setInputPartitioner(job.getConfiguration(), "Murmur3Partitioner") + ConfigHelper.setOutputPartitioner(job.getConfiguration(), "Murmur3Partitioner") + //Make a new Hadoop RDD val casRdd = sc.newAPIHadoopRDD(job.getConfiguration(), classOf[ColumnFamilyInputFormat], @@ -74,6 +82,33 @@ object CassandraTest { counts.collect().foreach { case (word, count) => println(word + ":" + count) } + + counts.map { + case (word, count) => { + val colWord = new org.apache.cassandra.thrift.Column() + colWord.setName(ByteBufferUtil.bytes("word")) + colWord.setValue(ByteBufferUtil.bytes(word)) + colWord.setTimestamp(System.currentTimeMillis) + + val colCount = new org.apache.cassandra.thrift.Column() + colCount.setName(ByteBufferUtil.bytes("wcount")) + colCount.setValue(ByteBufferUtil.bytes(count.toLong)) + colCount.setTimestamp(System.currentTimeMillis) + + + val outputkey = ByteBufferUtil.bytes(word + "-COUNT-" + System.currentTimeMillis) + + val mutations: java.util.List[Mutation] = new Mutation() :: new Mutation() :: Nil + mutations.get(0).setColumn_or_supercolumn(new ColumnOrSuperColumn()) + mutations.get(0).column_or_supercolumn.setColumn(colWord) + + mutations.get(1).setColumn_or_supercolumn(new ColumnOrSuperColumn()) + mutations.get(1).column_or_supercolumn.setColumn(colCount) + (outputkey, mutations) + } + }.saveAsNewAPIHadoopFile("casDemo", classOf[ByteBuffer], classOf[List[Mutation]], + classOf[ColumnFamilyOutputFormat], job.getConfiguration) + } } @@ -81,6 +116,9 @@ object CassandraTest { create keyspace casDemo; use casDemo; +create column family WordCount with comparator = UTF8Type; +update column family WordCount with column_metadata = [{column_name: word, validation_class: UTF8Type}, {column_name: wcount, validation_class: LongType}]; + create column family Words with comparator = UTF8Type; update column family Words with column_metadata = [{column_name: book, validation_class: UTF8Type}, {column_name: para, validation_class: UTF8Type}]; -- cgit v1.2.3 From a058b0acf3e5ae41e64640feeace3d4e32f47401 Mon Sep 17 00:00:00 2001 From: Shivaram Venkataraman Date: Mon, 3 Jun 2013 12:10:00 -0700 Subject: Delete a file for a block if it already exists. --- core/src/main/scala/spark/storage/DiskStore.scala | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/spark/storage/DiskStore.scala b/core/src/main/scala/spark/storage/DiskStore.scala index c7281200e7..2be5d01e31 100644 --- a/core/src/main/scala/spark/storage/DiskStore.scala +++ b/core/src/main/scala/spark/storage/DiskStore.scala @@ -195,9 +195,15 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String) } private def createFile(blockId: String, allowAppendExisting: Boolean = false): File = { - val file = getFile(blockId) + var file = getFile(blockId) if (!allowAppendExisting && file.exists()) { - throw new Exception("File for block " + blockId + " already exists on disk: " + file) + // NOTE(shivaram): Delete the file if it exists. This might happen if a ShuffleMap task + // was rescheduled on the same machine as the old task ? + logWarning("File for block " + blockId + " already exists on disk: " + file + ". Deleting") + file.delete() + // Reopen the file + file = getFile(blockId) + // throw new Exception("File for block " + blockId + " already exists on disk: " + file) } file } -- cgit v1.2.3 From cd347f547a9a9b7bdd0d3f4734ae5c13be54f75d Mon Sep 17 00:00:00 2001 From: Shivaram Venkataraman Date: Mon, 3 Jun 2013 12:27:51 -0700 Subject: Reuse the file object as it is valid after delete --- core/src/main/scala/spark/storage/DiskStore.scala | 2 -- 1 file changed, 2 deletions(-) diff --git a/core/src/main/scala/spark/storage/DiskStore.scala b/core/src/main/scala/spark/storage/DiskStore.scala index 2be5d01e31..e51d258a21 100644 --- a/core/src/main/scala/spark/storage/DiskStore.scala +++ b/core/src/main/scala/spark/storage/DiskStore.scala @@ -201,8 +201,6 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String) // was rescheduled on the same machine as the old task ? logWarning("File for block " + blockId + " already exists on disk: " + file + ". Deleting") file.delete() - // Reopen the file - file = getFile(blockId) // throw new Exception("File for block " + blockId + " already exists on disk: " + file) } file -- cgit v1.2.3 From 96943a1cc054d7cf80eb8d3dfc7fb19ce48d3c0a Mon Sep 17 00:00:00 2001 From: Shivaram Venkataraman Date: Mon, 3 Jun 2013 12:29:38 -0700 Subject: var to val --- core/src/main/scala/spark/storage/DiskStore.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/spark/storage/DiskStore.scala b/core/src/main/scala/spark/storage/DiskStore.scala index e51d258a21..cd85fa1e9d 100644 --- a/core/src/main/scala/spark/storage/DiskStore.scala +++ b/core/src/main/scala/spark/storage/DiskStore.scala @@ -195,7 +195,7 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String) } private def createFile(blockId: String, allowAppendExisting: Boolean = false): File = { - var file = getFile(blockId) + val file = getFile(blockId) if (!allowAppendExisting && file.exists()) { // NOTE(shivaram): Delete the file if it exists. This might happen if a ShuffleMap task // was rescheduled on the same machine as the old task ? -- cgit v1.2.3 From d1286231e0db15e480bd7d6a600b419db3391b27 Mon Sep 17 00:00:00 2001 From: Konstantin Boudnik Date: Wed, 29 May 2013 20:14:59 -0700 Subject: Sometime Maven build runs out of PermGen space. --- pom.xml | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/pom.xml b/pom.xml index 6ee64d07c2..ce77ba37c6 100644 --- a/pom.xml +++ b/pom.xml @@ -59,6 +59,9 @@ 1.6.1 4.1.2 1.2.17 + + 0m + 512m @@ -392,6 +395,10 @@ -Xms64m -Xmx1024m + -XX:PermSize + ${PermGen} + -XX:MaxPermSize + ${MaxPermGen} -source -- cgit v1.2.3 From 9d359043574f6801ba15ec9d016eba0f00ac2349 Mon Sep 17 00:00:00 2001 From: Christopher Nguyen Date: Tue, 4 Jun 2013 22:12:47 -0700 Subject: In the current code, when both partitions happen to have zero-length, the return mean will be NaN. Consequently, the result of mean after reducing over all partitions will also be NaN, which is not correct if there are partitions with non-zero length. This patch fixes this issue. --- core/src/main/scala/spark/util/StatCounter.scala | 26 +++++++++++++++--------- 1 file changed, 16 insertions(+), 10 deletions(-) diff --git a/core/src/main/scala/spark/util/StatCounter.scala b/core/src/main/scala/spark/util/StatCounter.scala index 5f80180339..2b980340b7 100644 --- a/core/src/main/scala/spark/util/StatCounter.scala +++ b/core/src/main/scala/spark/util/StatCounter.scala @@ -37,17 +37,23 @@ class StatCounter(values: TraversableOnce[Double]) extends Serializable { if (other == this) { merge(other.copy()) // Avoid overwriting fields in a weird order } else { - val delta = other.mu - mu - if (other.n * 10 < n) { - mu = mu + (delta * other.n) / (n + other.n) - } else if (n * 10 < other.n) { - mu = other.mu - (delta * n) / (n + other.n) - } else { - mu = (mu * n + other.mu * other.n) / (n + other.n) + if (n == 0) { + mu = other.mu + m2 = other.m2 + n = other.n + } else if (other.n != 0) { + val delta = other.mu - mu + if (other.n * 10 < n) { + mu = mu + (delta * other.n) / (n + other.n) + } else if (n * 10 < other.n) { + mu = other.mu - (delta * n) / (n + other.n) + } else { + mu = (mu * n + other.mu * other.n) / (n + other.n) + } + m2 += other.m2 + (delta * delta * n * other.n) / (n + other.n) + n += other.n } - m2 += other.m2 + (delta * delta * n * other.n) / (n + other.n) - n += other.n - this + this } } -- cgit v1.2.3 From c851957fe4798d5dfb8deba7bf79a035a0543c74 Mon Sep 17 00:00:00 2001 From: Shivaram Venkataraman Date: Wed, 5 Jun 2013 14:28:38 -0700 Subject: Don't write zero block files with java serializer --- .../scala/spark/storage/BlockFetcherIterator.scala | 5 ++- core/src/main/scala/spark/storage/DiskStore.scala | 46 ++++++++++++++-------- .../scala/spark/storage/ShuffleBlockManager.scala | 2 +- core/src/test/scala/spark/ShuffleSuite.scala | 26 ++++++++++++ 4 files changed, 61 insertions(+), 18 deletions(-) diff --git a/core/src/main/scala/spark/storage/BlockFetcherIterator.scala b/core/src/main/scala/spark/storage/BlockFetcherIterator.scala index fac416a5b3..843069239c 100644 --- a/core/src/main/scala/spark/storage/BlockFetcherIterator.scala +++ b/core/src/main/scala/spark/storage/BlockFetcherIterator.scala @@ -71,6 +71,7 @@ object BlockFetcherIterator { logDebug("Getting " + _totalBlocks + " blocks") protected var startTime = System.currentTimeMillis protected val localBlockIds = new ArrayBuffer[String]() + protected val localNonZeroBlocks = new ArrayBuffer[String]() protected val remoteBlockIds = new HashSet[String]() // A queue to hold our results. @@ -129,6 +130,8 @@ object BlockFetcherIterator { for ((address, blockInfos) <- blocksByAddress) { if (address == blockManagerId) { localBlockIds ++= blockInfos.map(_._1) + localNonZeroBlocks ++= blockInfos.filter(_._2 != 0).map(_._1) + _totalBlocks -= (localBlockIds.size - localNonZeroBlocks.size) } else { remoteBlockIds ++= blockInfos.map(_._1) // Make our requests at least maxBytesInFlight / 5 in length; the reason to keep them @@ -172,7 +175,7 @@ object BlockFetcherIterator { // Get the local blocks while remote blocks are being fetched. Note that it's okay to do // these all at once because they will just memory-map some files, so they won't consume // any memory that might exceed our maxBytesInFlight - for (id <- localBlockIds) { + for (id <- localNonZeroBlocks) { getLocalFromDisk(id, serializer) match { case Some(iter) => { // Pass 0 as size since it's not in flight diff --git a/core/src/main/scala/spark/storage/DiskStore.scala b/core/src/main/scala/spark/storage/DiskStore.scala index cd85fa1e9d..c1cff25552 100644 --- a/core/src/main/scala/spark/storage/DiskStore.scala +++ b/core/src/main/scala/spark/storage/DiskStore.scala @@ -35,21 +35,25 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String) private var bs: OutputStream = null private var objOut: SerializationStream = null private var lastValidPosition = 0L + private var initialized = false override def open(): DiskBlockObjectWriter = { val fos = new FileOutputStream(f, true) channel = fos.getChannel() bs = blockManager.wrapForCompression(blockId, new FastBufferedOutputStream(fos)) objOut = serializer.newInstance().serializeStream(bs) + initialized = true this } override def close() { - objOut.close() - bs.close() - channel = null - bs = null - objOut = null + if (initialized) { + objOut.close() + bs.close() + channel = null + bs = null + objOut = null + } // Invoke the close callback handler. super.close() } @@ -59,23 +63,33 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String) // Flush the partial writes, and set valid length to be the length of the entire file. // Return the number of bytes written for this commit. override def commit(): Long = { - // NOTE: Flush the serializer first and then the compressed/buffered output stream - objOut.flush() - bs.flush() - val prevPos = lastValidPosition - lastValidPosition = channel.position() - lastValidPosition - prevPos + if (initialized) { + // NOTE: Flush the serializer first and then the compressed/buffered output stream + objOut.flush() + bs.flush() + val prevPos = lastValidPosition + lastValidPosition = channel.position() + lastValidPosition - prevPos + } else { + // lastValidPosition is zero if stream is uninitialized + lastValidPosition + } } override def revertPartialWrites() { - // Discard current writes. We do this by flushing the outstanding writes and - // truncate the file to the last valid position. - objOut.flush() - bs.flush() - channel.truncate(lastValidPosition) + if (initialized) { + // Discard current writes. We do this by flushing the outstanding writes and + // truncate the file to the last valid position. + objOut.flush() + bs.flush() + channel.truncate(lastValidPosition) + } } override def write(value: Any) { + if (!initialized) { + open() + } objOut.writeObject(value) } diff --git a/core/src/main/scala/spark/storage/ShuffleBlockManager.scala b/core/src/main/scala/spark/storage/ShuffleBlockManager.scala index 49eabfb0d2..44638e0c2d 100644 --- a/core/src/main/scala/spark/storage/ShuffleBlockManager.scala +++ b/core/src/main/scala/spark/storage/ShuffleBlockManager.scala @@ -24,7 +24,7 @@ class ShuffleBlockManager(blockManager: BlockManager) { val bufferSize = System.getProperty("spark.shuffle.file.buffer.kb", "100").toInt * 1024 val writers = Array.tabulate[BlockObjectWriter](numBuckets) { bucketId => val blockId = ShuffleBlockManager.blockId(shuffleId, bucketId, mapId) - blockManager.getDiskBlockWriter(blockId, serializer, bufferSize).open() + blockManager.getDiskBlockWriter(blockId, serializer, bufferSize) } new ShuffleWriterGroup(mapId, writers) } diff --git a/core/src/test/scala/spark/ShuffleSuite.scala b/core/src/test/scala/spark/ShuffleSuite.scala index b967016cf7..33b02fff80 100644 --- a/core/src/test/scala/spark/ShuffleSuite.scala +++ b/core/src/test/scala/spark/ShuffleSuite.scala @@ -367,6 +367,32 @@ class ShuffleSuite extends FunSuite with ShouldMatchers with LocalSparkContext { assert(nonEmptyBlocks.size <= 4) } + test("zero sized blocks without kryo") { + // Use a local cluster with 2 processes to make sure there are both local and remote blocks + sc = new SparkContext("local-cluster[2,1,512]", "test") + + // 10 partitions from 4 keys + val NUM_BLOCKS = 10 + val a = sc.parallelize(1 to 4, NUM_BLOCKS) + val b = a.map(x => (x, x*2)) + + // NOTE: The default Java serializer doesn't create zero-sized blocks. + // So, use Kryo + val c = new ShuffledRDD(b, new HashPartitioner(10)) + + val shuffleId = c.dependencies.head.asInstanceOf[ShuffleDependency[Int, Int]].shuffleId + assert(c.count === 4) + + val blockSizes = (0 until NUM_BLOCKS).flatMap { id => + val statuses = SparkEnv.get.mapOutputTracker.getServerStatuses(shuffleId, id) + statuses.map(x => x._2) + } + val nonEmptyBlocks = blockSizes.filter(x => x > 0) + + // We should have at most 4 non-zero sized partitions + assert(nonEmptyBlocks.size <= 4) + } + } object ShuffleSuite { -- cgit v1.2.3 From cb2f5046ee99582a5038a78478c23468b14c134e Mon Sep 17 00:00:00 2001 From: Shivaram Venkataraman Date: Wed, 5 Jun 2013 15:09:02 -0700 Subject: Pass in bufferSize to BufferedOutputStream --- core/src/main/scala/spark/storage/DiskStore.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/spark/storage/DiskStore.scala b/core/src/main/scala/spark/storage/DiskStore.scala index c1cff25552..0af6e4a359 100644 --- a/core/src/main/scala/spark/storage/DiskStore.scala +++ b/core/src/main/scala/spark/storage/DiskStore.scala @@ -40,7 +40,7 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String) override def open(): DiskBlockObjectWriter = { val fos = new FileOutputStream(f, true) channel = fos.getChannel() - bs = blockManager.wrapForCompression(blockId, new FastBufferedOutputStream(fos)) + bs = blockManager.wrapForCompression(blockId, new FastBufferedOutputStream(fos, bufferSize)) objOut = serializer.newInstance().serializeStream(bs) initialized = true this -- cgit v1.2.3 From e179ff8a32fc08cc308dc99bac2527d350d0d970 Mon Sep 17 00:00:00 2001 From: Gavin Li Date: Wed, 5 Jun 2013 22:41:05 +0000 Subject: update according to comments --- core/src/main/scala/spark/RDD.scala | 89 +++++++++++++++++++++++---- core/src/main/scala/spark/rdd/PipedRDD.scala | 33 +++++----- core/src/test/scala/spark/PipedRDDSuite.scala | 7 ++- 3 files changed, 99 insertions(+), 30 deletions(-) diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala index 5a41db23c2..a1c9604324 100644 --- a/core/src/main/scala/spark/RDD.scala +++ b/core/src/main/scala/spark/RDD.scala @@ -16,6 +16,7 @@ import org.apache.hadoop.mapred.TextOutputFormat import it.unimi.dsi.fastutil.objects.{Object2LongOpenHashMap => OLMap} +import spark.broadcast.Broadcast import spark.Partitioner._ import spark.partial.BoundedDouble import spark.partial.CountEvaluator @@ -351,31 +352,93 @@ abstract class RDD[T: ClassManifest]( /** * Return an RDD created by piping elements to a forked external process. */ - def pipe(command: String, transform: (T,String => Unit) => Any, arguments: Seq[String]): RDD[String] = - new PipedRDD(this, command, transform, arguments) + def pipe(command: String, env: Map[String, String]): RDD[String] = + new PipedRDD(this, command, env) /** * Return an RDD created by piping elements to a forked external process. - */ - def pipe(command: Seq[String]): RDD[String] = new PipedRDD(this, command) + * How each record in RDD is outputed to the process can be controled by providing a + * function trasnform(T, outputFunction: String => Unit). transform() will be called with + * the currnet record in RDD as the 1st parameter, and the function to output the record to + * the external process (like out.println()) as the 2nd parameter. + * Here's an example on how to pipe the RDD data of groupBy() in a streaming way, + * instead of constructing a huge String to concat all the records: + * def tranform(record:(String, Seq[String]), f:String=>Unit) = for (e <- record._2){f(e)} + * pipeContext can be used to transfer additional context data to the external process + * besides the RDD. pipeContext is a broadcast Seq[String], each line would be piped to + * external process with "^A" as the delimiter in the end of context data. Delimiter can also + * be customized by the last parameter delimiter. + */ + def pipe[U<: Seq[String]]( + command: String, + env: Map[String, String], + transform: (T,String => Unit) => Any, + pipeContext: Broadcast[U], + delimiter: String): RDD[String] = + new PipedRDD(this, command, env, transform, pipeContext, delimiter) /** * Return an RDD created by piping elements to a forked external process. - */ - def pipe(command: Seq[String], transform: (T,String => Unit) => Any, arguments: Seq[String]): RDD[String] = - new PipedRDD(this, command, transform, arguments) + * How each record in RDD is outputed to the process can be controled by providing a + * function trasnform(T, outputFunction: String => Unit). transform() will be called with + * the currnet record in RDD as the 1st parameter, and the function to output the record to + * the external process (like out.println()) as the 2nd parameter. + * Here's an example on how to pipe the RDD data of groupBy() in a streaming way, + * instead of constructing a huge String to concat all the records: + * def tranform(record:(String, Seq[String]), f:String=>Unit) = for (e <- record._2){f(e)} + * pipeContext can be used to transfer additional context data to the external process + * besides the RDD. pipeContext is a broadcast Seq[String], each line would be piped to + * external process with "^A" as the delimiter in the end of context data. Delimiter can also + * be customized by the last parameter delimiter. + */ + def pipe[U<: Seq[String]]( + command: String, + transform: (T,String => Unit) => Any, + pipeContext: Broadcast[U]): RDD[String] = + new PipedRDD(this, command, Map[String, String](), transform, pipeContext, "\u0001") /** * Return an RDD created by piping elements to a forked external process. - */ - def pipe(command: Seq[String], env: Map[String, String]): RDD[String] = - new PipedRDD(this, command, env) + * How each record in RDD is outputed to the process can be controled by providing a + * function trasnform(T, outputFunction: String => Unit). transform() will be called with + * the currnet record in RDD as the 1st parameter, and the function to output the record to + * the external process (like out.println()) as the 2nd parameter. + * Here's an example on how to pipe the RDD data of groupBy() in a streaming way, + * instead of constructing a huge String to concat all the records: + * def tranform(record:(String, Seq[String]), f:String=>Unit) = for (e <- record._2){f(e)} + * pipeContext can be used to transfer additional context data to the external process + * besides the RDD. pipeContext is a broadcast Seq[String], each line would be piped to + * external process with "^A" as the delimiter in the end of context data. Delimiter can also + * be customized by the last parameter delimiter. + */ + def pipe[U<: Seq[String]]( + command: String, + env: Map[String, String], + transform: (T,String => Unit) => Any, + pipeContext: Broadcast[U]): RDD[String] = + new PipedRDD(this, command, env, transform, pipeContext, "\u0001") /** * Return an RDD created by piping elements to a forked external process. - */ - def pipe(command: Seq[String], env: Map[String, String], transform: (T,String => Unit) => Any, arguments: Seq[String]): RDD[String] = - new PipedRDD(this, command, env, transform, arguments) + * How each record in RDD is outputed to the process can be controled by providing a + * function trasnform(T, outputFunction: String => Unit). transform() will be called with + * the currnet record in RDD as the 1st parameter, and the function to output the record to + * the external process (like out.println()) as the 2nd parameter. + * Here's an example on how to pipe the RDD data of groupBy() in a streaming way, + * instead of constructing a huge String to concat all the records: + * def tranform(record:(String, Seq[String]), f:String=>Unit) = for (e <- record._2){f(e)} + * pipeContext can be used to transfer additional context data to the external process + * besides the RDD. pipeContext is a broadcast Seq[String], each line would be piped to + * external process with "^A" as the delimiter in the end of context data. Delimiter can also + * be customized by the last parameter delimiter. + */ + def pipe[U<: Seq[String]]( + command: Seq[String], + env: Map[String, String] = Map(), + transform: (T,String => Unit) => Any = null, + pipeContext: Broadcast[U] = null, + delimiter: String = "\u0001"): RDD[String] = + new PipedRDD(this, command, env, transform, pipeContext, delimiter) /** * Return a new RDD by applying a function to each partition of this RDD. diff --git a/core/src/main/scala/spark/rdd/PipedRDD.scala b/core/src/main/scala/spark/rdd/PipedRDD.scala index 969404c95f..d58aaae709 100644 --- a/core/src/main/scala/spark/rdd/PipedRDD.scala +++ b/core/src/main/scala/spark/rdd/PipedRDD.scala @@ -9,29 +9,33 @@ import scala.collection.mutable.ArrayBuffer import scala.io.Source import spark.{RDD, SparkEnv, Partition, TaskContext} +import spark.broadcast.Broadcast /** * An RDD that pipes the contents of each parent partition through an external command * (printing them one per line) and returns the output as a collection of strings. */ -class PipedRDD[T: ClassManifest]( +class PipedRDD[T: ClassManifest, U <: Seq[String]]( prev: RDD[T], command: Seq[String], envVars: Map[String, String], transform: (T, String => Unit) => Any, - arguments: Seq[String] + pipeContext: Broadcast[U], + delimiter: String ) extends RDD[String](prev) { - def this(prev: RDD[T], command: Seq[String], envVars : Map[String, String]) = this(prev, command, envVars, null, null) - def this(prev: RDD[T], command: Seq[String]) = this(prev, command, Map()) - def this(prev: RDD[T], command: Seq[String], transform: (T,String => Unit) => Any, arguments: Seq[String]) = this(prev, command, Map(), transform, arguments) - // Similar to Runtime.exec(), if we are given a single string, split it into words // using a standard StringTokenizer (i.e. by spaces) - def this(prev: RDD[T], command: String) = this(prev, PipedRDD.tokenize(command)) - def this(prev: RDD[T], command: String, transform: (T,String => Unit) => Any, arguments: Seq[String]) = this(prev, PipedRDD.tokenize(command), Map(), transform, arguments) + def this( + prev: RDD[T], + command: String, + envVars: Map[String, String] = Map(), + transform: (T, String => Unit) => Any = null, + pipeContext: Broadcast[U] = null, + delimiter: String = "\u0001") = + this(prev, PipedRDD.tokenize(command), envVars, transform, pipeContext, delimiter) override def getPartitions: Array[Partition] = firstParent[T].partitions @@ -60,19 +64,18 @@ class PipedRDD[T: ClassManifest]( SparkEnv.set(env) val out = new PrintWriter(proc.getOutputStream) - // input the arguments firstly - if ( arguments != null) { - for (elem <- arguments) { + // input the pipeContext firstly + if ( pipeContext != null) { + for (elem <- pipeContext.value) { out.println(elem) } - // ^A \n as the marker of the end of the arguments - out.println("\u0001") + // delimiter\n as the marker of the end of the pipeContext + out.println(delimiter) } for (elem <- firstParent[T].iterator(split, context)) { if (transform != null) { transform(elem, out.println(_)) - } - else { + } else { out.println(elem) } } diff --git a/core/src/test/scala/spark/PipedRDDSuite.scala b/core/src/test/scala/spark/PipedRDDSuite.scala index ee55952a94..d2852867de 100644 --- a/core/src/test/scala/spark/PipedRDDSuite.scala +++ b/core/src/test/scala/spark/PipedRDDSuite.scala @@ -23,7 +23,8 @@ class PipedRDDSuite extends FunSuite with LocalSparkContext { sc = new SparkContext("local", "test") val nums = sc.makeRDD(Array(1, 2, 3, 4), 2) - val piped = nums.pipe(Seq("cat"), (i:Int, f: String=> Unit) => f(i + "_"), Array("0")) + val piped = nums.pipe(Seq("cat"), Map[String, String](), + (i:Int, f: String=> Unit) => f(i + "_"), sc.broadcast(List("0"))) val c = piped.collect() @@ -38,7 +39,9 @@ class PipedRDDSuite extends FunSuite with LocalSparkContext { assert(c(7) === "4_") val nums1 = sc.makeRDD(Array("a\t1", "b\t2", "a\t3", "b\t4"), 2) - val d = nums1.groupBy(str=>str.split("\t")(0)).pipe(Seq("cat"), (i:Tuple2[String, Seq[String]], f: String=> Unit) => {for (e <- i._2){ f(e + "_")}}, Array("0")).collect() + val d = nums1.groupBy(str=>str.split("\t")(0)). + pipe(Seq("cat"), Map[String, String](), (i:Tuple2[String, Seq[String]], f: String=> Unit) => + {for (e <- i._2){ f(e + "_")}}, sc.broadcast(List("0"))).collect() assert(d.size === 8) assert(d(0) === "0") assert(d(1) === "\u0001") -- cgit v1.2.3 From ac480fd977e0de97bcfe646e39feadbd239c1c29 Mon Sep 17 00:00:00 2001 From: Shivaram Venkataraman Date: Thu, 6 Jun 2013 16:34:27 -0700 Subject: Clean up variables and counters in BlockFetcherIterator --- .../scala/spark/storage/BlockFetcherIterator.scala | 54 +++++++++++++--------- 1 file changed, 31 insertions(+), 23 deletions(-) diff --git a/core/src/main/scala/spark/storage/BlockFetcherIterator.scala b/core/src/main/scala/spark/storage/BlockFetcherIterator.scala index 843069239c..bb78207c9f 100644 --- a/core/src/main/scala/spark/storage/BlockFetcherIterator.scala +++ b/core/src/main/scala/spark/storage/BlockFetcherIterator.scala @@ -67,12 +67,20 @@ object BlockFetcherIterator { throw new IllegalArgumentException("BlocksByAddress is null") } - protected var _totalBlocks = blocksByAddress.map(_._2.size).sum - logDebug("Getting " + _totalBlocks + " blocks") + // Total number blocks fetched (local + remote). Also number of FetchResults expected + protected var _numBlocksToFetch = 0 + protected var startTime = System.currentTimeMillis - protected val localBlockIds = new ArrayBuffer[String]() - protected val localNonZeroBlocks = new ArrayBuffer[String]() - protected val remoteBlockIds = new HashSet[String]() + + // This represents the number of local blocks, also counting zero-sized blocks + private var numLocal = 0 + // BlockIds for local blocks that need to be fetched. Excludes zero-sized blocks + protected val localBlocksToFetch = new ArrayBuffer[String]() + + // This represents the number of remote blocks, also counting zero-sized blocks + private var numRemote = 0 + // BlockIds for remote blocks that need to be fetched. Excludes zero-sized blocks + protected val remoteBlocksToFetch = new HashSet[String]() // A queue to hold our results. protected val results = new LinkedBlockingQueue[FetchResult] @@ -125,15 +133,15 @@ object BlockFetcherIterator { protected def splitLocalRemoteBlocks(): ArrayBuffer[FetchRequest] = { // Split local and remote blocks. Remote blocks are further split into FetchRequests of size // at most maxBytesInFlight in order to limit the amount of data in flight. - val originalTotalBlocks = _totalBlocks val remoteRequests = new ArrayBuffer[FetchRequest] for ((address, blockInfos) <- blocksByAddress) { if (address == blockManagerId) { - localBlockIds ++= blockInfos.map(_._1) - localNonZeroBlocks ++= blockInfos.filter(_._2 != 0).map(_._1) - _totalBlocks -= (localBlockIds.size - localNonZeroBlocks.size) + numLocal = blockInfos.size + // Filter out zero-sized blocks + localBlocksToFetch ++= blockInfos.filter(_._2 != 0).map(_._1) + _numBlocksToFetch += localBlocksToFetch.size } else { - remoteBlockIds ++= blockInfos.map(_._1) + numRemote += blockInfos.size // Make our requests at least maxBytesInFlight / 5 in length; the reason to keep them // smaller than maxBytesInFlight is to allow multiple, parallel fetches from up to 5 // nodes, rather than blocking on reading output from one node. @@ -147,10 +155,10 @@ object BlockFetcherIterator { // Skip empty blocks if (size > 0) { curBlocks += ((blockId, size)) + remoteBlocksToFetch += blockId + _numBlocksToFetch += 1 curRequestSize += size - } else if (size == 0) { - _totalBlocks -= 1 - } else { + } else if (size < 0) { throw new BlockException(blockId, "Negative block size " + size) } if (curRequestSize >= minRequestSize) { @@ -166,8 +174,8 @@ object BlockFetcherIterator { } } } - logInfo("Getting " + _totalBlocks + " non-zero-bytes blocks out of " + - originalTotalBlocks + " blocks") + logInfo("Getting " + _numBlocksToFetch + " non-zero-bytes blocks out of " + + totalBlocks + " blocks") remoteRequests } @@ -175,7 +183,7 @@ object BlockFetcherIterator { // Get the local blocks while remote blocks are being fetched. Note that it's okay to do // these all at once because they will just memory-map some files, so they won't consume // any memory that might exceed our maxBytesInFlight - for (id <- localNonZeroBlocks) { + for (id <- localBlocksToFetch) { getLocalFromDisk(id, serializer) match { case Some(iter) => { // Pass 0 as size since it's not in flight @@ -201,7 +209,7 @@ object BlockFetcherIterator { sendRequest(fetchRequests.dequeue()) } - val numGets = remoteBlockIds.size - fetchRequests.size + val numGets = remoteRequests.size - fetchRequests.size logInfo("Started " + numGets + " remote gets in " + Utils.getUsedTimeMs(startTime)) // Get Local Blocks @@ -213,7 +221,7 @@ object BlockFetcherIterator { //an iterator that will read fetched blocks off the queue as they arrive. @volatile protected var resultsGotten = 0 - override def hasNext: Boolean = resultsGotten < _totalBlocks + override def hasNext: Boolean = resultsGotten < _numBlocksToFetch override def next(): (String, Option[Iterator[Any]]) = { resultsGotten += 1 @@ -230,9 +238,9 @@ object BlockFetcherIterator { } // Implementing BlockFetchTracker trait. - override def totalBlocks: Int = _totalBlocks - override def numLocalBlocks: Int = localBlockIds.size - override def numRemoteBlocks: Int = remoteBlockIds.size + override def totalBlocks: Int = numLocal + numRemote + override def numLocalBlocks: Int = numLocal + override def numRemoteBlocks: Int = numRemote override def remoteFetchTime: Long = _remoteFetchTime override def fetchWaitTime: Long = _fetchWaitTime override def remoteBytesRead: Long = _remoteBytesRead @@ -294,7 +302,7 @@ object BlockFetcherIterator { private var copiers: List[_ <: Thread] = null override def initialize() { - // Split Local Remote Blocks and adjust totalBlocks to include only the non 0-byte blocks + // Split Local Remote Blocks and set numBlocksToFetch val remoteRequests = splitLocalRemoteBlocks() // Add the remote requests into our queue in a random order for (request <- Utils.randomize(remoteRequests)) { @@ -316,7 +324,7 @@ object BlockFetcherIterator { val result = results.take() // if all the results has been retrieved, shutdown the copiers // NO need to stop the copiers if we got all the blocks ? - // if (resultsGotten == _totalBlocks && copiers != null) { + // if (resultsGotten == _numBlocksToFetch && copiers != null) { // stopCopiers() // } (result.blockId, if (result.failed) None else Some(result.deserialize())) -- cgit v1.2.3 From c9ca0a4a588b4c7dc553b155336ae5b95aa9ddd4 Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Fri, 7 Jun 2013 22:40:44 -0700 Subject: Small code style fix to SchedulingAlgorithm.scala --- .../src/main/scala/spark/scheduler/cluster/SchedulingAlgorithm.scala | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/spark/scheduler/cluster/SchedulingAlgorithm.scala b/core/src/main/scala/spark/scheduler/cluster/SchedulingAlgorithm.scala index 13120edf63..e071917c00 100644 --- a/core/src/main/scala/spark/scheduler/cluster/SchedulingAlgorithm.scala +++ b/core/src/main/scala/spark/scheduler/cluster/SchedulingAlgorithm.scala @@ -53,13 +53,12 @@ private[spark] class FairSchedulingAlgorithm extends SchedulingAlgorithm { } if (compare < 0) { - res = true + return true } else if (compare > 0) { - res = false + return false } else { return s1.name < s2.name } - return res } } -- cgit v1.2.3 From b58a29295b2e610cadf1cac44438337ce9b51537 Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Fri, 7 Jun 2013 22:51:28 -0700 Subject: Small formatting and style fixes --- .../spark/scheduler/cluster/SchedulingAlgorithm.scala | 8 ++++---- core/src/main/scala/spark/storage/StorageUtils.scala | 14 +++++++------- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/core/src/main/scala/spark/scheduler/cluster/SchedulingAlgorithm.scala b/core/src/main/scala/spark/scheduler/cluster/SchedulingAlgorithm.scala index e071917c00..f33310a34a 100644 --- a/core/src/main/scala/spark/scheduler/cluster/SchedulingAlgorithm.scala +++ b/core/src/main/scala/spark/scheduler/cluster/SchedulingAlgorithm.scala @@ -13,11 +13,11 @@ private[spark] class FIFOSchedulingAlgorithm extends SchedulingAlgorithm { override def comparator(s1: Schedulable, s2: Schedulable): Boolean = { val priority1 = s1.priority val priority2 = s2.priority - var res = Math.signum(priority1 - priority2) + var res = math.signum(priority1 - priority2) if (res == 0) { val stageId1 = s1.stageId val stageId2 = s2.stageId - res = Math.signum(stageId1 - stageId2) + res = math.signum(stageId1 - stageId2) } if (res < 0) { return true @@ -35,8 +35,8 @@ private[spark] class FairSchedulingAlgorithm extends SchedulingAlgorithm { val runningTasks2 = s2.runningTasks val s1Needy = runningTasks1 < minShare1 val s2Needy = runningTasks2 < minShare2 - val minShareRatio1 = runningTasks1.toDouble / Math.max(minShare1, 1.0).toDouble - val minShareRatio2 = runningTasks2.toDouble / Math.max(minShare2, 1.0).toDouble + val minShareRatio1 = runningTasks1.toDouble / math.max(minShare1, 1.0).toDouble + val minShareRatio2 = runningTasks2.toDouble / math.max(minShare2, 1.0).toDouble val taskToWeightRatio1 = runningTasks1.toDouble / s1.weight.toDouble val taskToWeightRatio2 = runningTasks2.toDouble / s2.weight.toDouble var res:Boolean = true diff --git a/core/src/main/scala/spark/storage/StorageUtils.scala b/core/src/main/scala/spark/storage/StorageUtils.scala index 81e607868d..950c0cdf35 100644 --- a/core/src/main/scala/spark/storage/StorageUtils.scala +++ b/core/src/main/scala/spark/storage/StorageUtils.scala @@ -55,21 +55,21 @@ object StorageUtils { }.mapValues(_.values.toArray) // For each RDD, generate an RDDInfo object - val rddInfos = groupedRddBlocks.map { case(rddKey, rddBlocks) => - + val rddInfos = groupedRddBlocks.map { case (rddKey, rddBlocks) => // Add up memory and disk sizes val memSize = rddBlocks.map(_.memSize).reduce(_ + _) val diskSize = rddBlocks.map(_.diskSize).reduce(_ + _) // Find the id of the RDD, e.g. rdd_1 => 1 val rddId = rddKey.split("_").last.toInt - // Get the friendly name for the rdd, if available. + + // Get the friendly name and storage level for the RDD, if available sc.persistentRdds.get(rddId).map { r => - val rddName = Option(r.name).getOrElse(rddKey) - val rddStorageLevel = r.getStorageLevel - RDDInfo(rddId, rddName, rddStorageLevel, rddBlocks.length, r.partitions.size, memSize, diskSize) + val rddName = Option(r.name).getOrElse(rddKey) + val rddStorageLevel = r.getStorageLevel + RDDInfo(rddId, rddName, rddStorageLevel, rddBlocks.length, r.partitions.size, memSize, diskSize) } - }.flatMap(x => x).toArray + }.flatten.toArray scala.util.Sorting.quickSort(rddInfos) -- cgit v1.2.3 From 1a4d93c025e5d3679257a622f49dfaade4ac18c2 Mon Sep 17 00:00:00 2001 From: Mingfei Date: Sat, 8 Jun 2013 14:23:39 +0800 Subject: modify to pass job annotation by localProperties and use daeamon thread to do joblogger's work --- .../scala/spark/BlockStoreShuffleFetcher.scala | 1 + core/src/main/scala/spark/RDD.scala | 10 +- core/src/main/scala/spark/SparkContext.scala | 8 +- core/src/main/scala/spark/Utils.scala | 10 +- core/src/main/scala/spark/executor/Executor.scala | 1 + .../main/scala/spark/executor/TaskMetrics.scala | 12 +- .../main/scala/spark/scheduler/DAGScheduler.scala | 8 + .../src/main/scala/spark/scheduler/JobLogger.scala | 317 +++++++++++++++++++++ .../main/scala/spark/scheduler/SparkListener.scala | 33 ++- .../scala/spark/scheduler/JobLoggerSuite.scala | 105 +++++++ .../scala/spark/scheduler/SparkListenerSuite.scala | 2 +- 11 files changed, 495 insertions(+), 12 deletions(-) create mode 100644 core/src/main/scala/spark/scheduler/JobLogger.scala create mode 100644 core/src/test/scala/spark/scheduler/JobLoggerSuite.scala diff --git a/core/src/main/scala/spark/BlockStoreShuffleFetcher.scala b/core/src/main/scala/spark/BlockStoreShuffleFetcher.scala index e1fb02157a..3239f4c385 100644 --- a/core/src/main/scala/spark/BlockStoreShuffleFetcher.scala +++ b/core/src/main/scala/spark/BlockStoreShuffleFetcher.scala @@ -58,6 +58,7 @@ private[spark] class BlockStoreShuffleFetcher extends ShuffleFetcher with Loggin CompletionIterator[(K,V), Iterator[(K,V)]](itr, { val shuffleMetrics = new ShuffleReadMetrics + shuffleMetrics.shuffleFinishTime = System.currentTimeMillis shuffleMetrics.remoteFetchTime = blockFetcherItr.remoteFetchTime shuffleMetrics.fetchWaitTime = blockFetcherItr.fetchWaitTime shuffleMetrics.remoteBytesRead = blockFetcherItr.remoteBytesRead diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala index e6c0438d76..8c0b7ca417 100644 --- a/core/src/main/scala/spark/RDD.scala +++ b/core/src/main/scala/spark/RDD.scala @@ -114,6 +114,14 @@ abstract class RDD[T: ClassManifest]( this } + /**User-defined generator of this RDD*/ + var generator = Utils.getCallSiteInfo._4 + + /**reset generator*/ + def setGenerator(_generator: String) = { + generator = _generator + } + /** * Set this RDD's storage level to persist its values across operations after the first time * it is computed. This can only be used to assign a new storage level if the RDD does not @@ -788,7 +796,7 @@ abstract class RDD[T: ClassManifest]( private var storageLevel: StorageLevel = StorageLevel.NONE /** Record user function generating this RDD. */ - private[spark] val origin = Utils.getSparkCallSite + private[spark] val origin = Utils.formatSparkCallSite private[spark] def elementClassManifest: ClassManifest[T] = classManifest[T] diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index bc05d08fd6..b67a2066c8 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -48,7 +48,7 @@ import spark.scheduler.local.LocalScheduler import spark.scheduler.mesos.{CoarseMesosSchedulerBackend, MesosSchedulerBackend} import spark.storage.{BlockManagerUI, StorageStatus, StorageUtils, RDDInfo} import spark.util.{MetadataCleaner, TimeStampedHashMap} - +import spark.scheduler.JobLogger /** * Main entry point for Spark functionality. A SparkContext represents the connection to a Spark @@ -510,7 +510,7 @@ class SparkContext( def addSparkListener(listener: SparkListener) { dagScheduler.sparkListeners += listener } - + addSparkListener(new JobLogger) /** * Return a map from the slave to the max memory available for caching and the remaining * memory available for caching. @@ -630,7 +630,7 @@ class SparkContext( partitions: Seq[Int], allowLocal: Boolean, resultHandler: (Int, U) => Unit) { - val callSite = Utils.getSparkCallSite + val callSite = Utils.formatSparkCallSite logInfo("Starting job: " + callSite) val start = System.nanoTime val result = dagScheduler.runJob(rdd, func, partitions, callSite, allowLocal, resultHandler, localProperties.value) @@ -713,7 +713,7 @@ class SparkContext( func: (TaskContext, Iterator[T]) => U, evaluator: ApproximateEvaluator[U, R], timeout: Long): PartialResult[R] = { - val callSite = Utils.getSparkCallSite + val callSite = Utils.formatSparkCallSite logInfo("Starting job: " + callSite) val start = System.nanoTime val result = dagScheduler.runApproximateJob(rdd, func, evaluator, callSite, timeout, localProperties.value) diff --git a/core/src/main/scala/spark/Utils.scala b/core/src/main/scala/spark/Utils.scala index ec15326014..1630b2b4b0 100644 --- a/core/src/main/scala/spark/Utils.scala +++ b/core/src/main/scala/spark/Utils.scala @@ -528,7 +528,7 @@ private object Utils extends Logging { * (outside the spark package) that called into Spark, as well as which Spark method they called. * This is used, for example, to tell users where in their code each RDD got created. */ - def getSparkCallSite: String = { + def getCallSiteInfo = { val trace = Thread.currentThread.getStackTrace().filter( el => (!el.getMethodName.contains("getStackTrace"))) @@ -540,6 +540,7 @@ private object Utils extends Logging { var firstUserFile = "" var firstUserLine = 0 var finished = false + var firstUserClass = "" for (el <- trace) { if (!finished) { @@ -554,13 +555,18 @@ private object Utils extends Logging { else { firstUserLine = el.getLineNumber firstUserFile = el.getFileName + firstUserClass = el.getClassName finished = true } } } - "%s at %s:%s".format(lastSparkMethod, firstUserFile, firstUserLine) + (lastSparkMethod, firstUserFile, firstUserLine, firstUserClass) } + def formatSparkCallSite = { + val callSiteInfo = getCallSiteInfo + "%s at %s:%s".format(callSiteInfo._1, callSiteInfo._2, callSiteInfo._3) + } /** * Try to find a free port to bind to on the local host. This should ideally never be needed, * except that, unfortunately, some of the networking libraries we currently rely on (e.g. Spray) diff --git a/core/src/main/scala/spark/executor/Executor.scala b/core/src/main/scala/spark/executor/Executor.scala index 890938d48b..8bebfafce4 100644 --- a/core/src/main/scala/spark/executor/Executor.scala +++ b/core/src/main/scala/spark/executor/Executor.scala @@ -104,6 +104,7 @@ private[spark] class Executor(executorId: String, slaveHostname: String, propert val value = task.run(taskId.toInt) val taskFinish = System.currentTimeMillis() task.metrics.foreach{ m => + m.hostname = Utils.localHostName m.executorDeserializeTime = (taskStart - startTime).toInt m.executorRunTime = (taskFinish - taskStart).toInt } diff --git a/core/src/main/scala/spark/executor/TaskMetrics.scala b/core/src/main/scala/spark/executor/TaskMetrics.scala index a7c56c2371..26e8029365 100644 --- a/core/src/main/scala/spark/executor/TaskMetrics.scala +++ b/core/src/main/scala/spark/executor/TaskMetrics.scala @@ -1,6 +1,11 @@ package spark.executor class TaskMetrics extends Serializable { + /** + * host's name the task runs on + */ + var hostname: String = _ + /** * Time taken on the executor to deserialize this task */ @@ -33,10 +38,15 @@ object TaskMetrics { class ShuffleReadMetrics extends Serializable { + /** + * Time when shuffle finishs + */ + var shuffleFinishTime: Long = _ + /** * Total number of blocks fetched in a shuffle (remote or local) */ - var totalBlocksFetched : Int = _ + var totalBlocksFetched: Int = _ /** * Number of remote blocks fetched in a shuffle diff --git a/core/src/main/scala/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/spark/scheduler/DAGScheduler.scala index 7feeb97542..43dd7d6534 100644 --- a/core/src/main/scala/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/spark/scheduler/DAGScheduler.scala @@ -289,6 +289,7 @@ class DAGScheduler( val finalStage = newStage(finalRDD, None, runId) val job = new ActiveJob(runId, finalStage, func, partitions, callSite, listener, properties) clearCacheLocs() + sparkListeners.foreach(_.onJobStart(job, properties)) logInfo("Got job " + job.runId + " (" + callSite + ") with " + partitions.length + " output partitions (allowLocal=" + allowLocal + ")") logInfo("Final stage: " + finalStage + " (" + finalStage.origin + ")") @@ -311,6 +312,7 @@ class DAGScheduler( handleExecutorLost(execId) case completion: CompletionEvent => + sparkListeners.foreach(_.onTaskEnd(completion)) handleTaskCompletion(completion) case TaskSetFailed(taskSet, reason) => @@ -321,6 +323,8 @@ class DAGScheduler( for (job <- activeJobs) { val error = new SparkException("Job cancelled because SparkContext was shut down") job.listener.jobFailed(error) + val JobCancelEvent = new SparkListenerJobCancelled("SPARKCONTEXT_SHUTDOWN") + sparkListeners.foreach(_.onJobEnd(job, JobCancelEvent)) } return true } @@ -468,6 +472,7 @@ class DAGScheduler( } } if (tasks.size > 0) { + sparkListeners.foreach(_.onStageSubmitted(stage, "TASKS_SIZE=" + tasks.size)) logInfo("Submitting " + tasks.size + " missing tasks from " + stage + " (" + stage.rdd + ")") myPending ++= tasks logDebug("New pending tasks: " + myPending) @@ -522,6 +527,7 @@ class DAGScheduler( activeJobs -= job resultStageToJob -= stage markStageAsFinished(stage) + sparkListeners.foreach(_.onJobEnd(job, SparkListenerJobSuccess)) } job.listener.taskSucceeded(rt.outputId, event.result) } @@ -665,6 +671,8 @@ class DAGScheduler( job.listener.jobFailed(new SparkException("Job failed: " + reason)) activeJobs -= job resultStageToJob -= resultStage + val jobFailedEvent = new SparkListenerJobFailed(failedStage) + sparkListeners.foreach(_.onJobEnd(job, jobFailedEvent)) } if (dependentStages.isEmpty) { logInfo("Ignoring failure of " + failedStage + " because all jobs depending on it are done") diff --git a/core/src/main/scala/spark/scheduler/JobLogger.scala b/core/src/main/scala/spark/scheduler/JobLogger.scala new file mode 100644 index 0000000000..f87acfd0b6 --- /dev/null +++ b/core/src/main/scala/spark/scheduler/JobLogger.scala @@ -0,0 +1,317 @@ +package spark.scheduler + +import java.io.PrintWriter +import java.io.File +import java.io.FileNotFoundException +import java.text.SimpleDateFormat +import java.util.{Date, Properties} +import java.util.concurrent.LinkedBlockingQueue +import scala.collection.mutable.{Map, HashMap, ListBuffer} +import scala.io.Source +import spark._ +import spark.executor.TaskMetrics +import spark.scheduler.cluster.TaskInfo + +// used to record runtime information for each job, including RDD graph +// tasks' start/stop shuffle information and information from outside + +sealed trait JobLoggerEvent +case class JobLoggerOnJobStart(job: ActiveJob, properties: Properties) extends JobLoggerEvent +case class JobLoggerOnStageSubmitted(stage: Stage, info: String) extends JobLoggerEvent +case class JobLoggerOnStageCompleted(stageCompleted: StageCompleted) extends JobLoggerEvent +case class JobLoggerOnJobEnd(job: ActiveJob, event: SparkListenerEvents) extends JobLoggerEvent +case class JobLoggerOnTaskEnd(event: CompletionEvent) extends JobLoggerEvent + +class JobLogger(val logDirName: String) extends SparkListener with Logging { + private val logDir = + if (System.getenv("SPARK_LOG_DIR") != null) + System.getenv("SPARK_LOG_DIR") + else + "/tmp/spark" + private val jobIDToPrintWriter = new HashMap[Int, PrintWriter] + private val stageIDToJobID = new HashMap[Int, Int] + private val jobIDToStages = new HashMap[Int, ListBuffer[Stage]] + private val DATE_FORMAT = new SimpleDateFormat("yyyy/MM/dd HH:mm:ss") + private val eventQueue = new LinkedBlockingQueue[JobLoggerEvent] + + createLogDir() + def this() = this(String.valueOf(System.currentTimeMillis())) + + def getLogDir = logDir + def getJobIDtoPrintWriter = jobIDToPrintWriter + def getStageIDToJobID = stageIDToJobID + def getJobIDToStages = jobIDToStages + def getEventQueue = eventQueue + + new Thread("JobLogger") { + setDaemon(true) + override def run() { + while (true) { + val event = eventQueue.take + if (event != null) { + logDebug("Got event of type " + event.getClass.getName) + event match { + case JobLoggerOnJobStart(job, info) => + processJobStartEvent(job, info) + case JobLoggerOnStageSubmitted(stage, info) => + processStageSubmittedEvent(stage, info) + case JobLoggerOnStageCompleted(stageCompleted) => + processStageCompletedEvent(stageCompleted) + case JobLoggerOnJobEnd(job, event) => + processJobEndEvent(job, event) + case JobLoggerOnTaskEnd(event) => + processTaskEndEvent(event) + case _ => + } + } + } + } + }.start() + + //create a folder for log files, the folder's name is the creation time of the jobLogger + protected def createLogDir() { + val dir = new File(logDir + "/" + logDirName + "/") + if (dir.exists()) { + return + } + if (dir.mkdirs() == false) { + logError("create log directory error:" + logDir + "/" + logDirName + "/") + } + } + + // create a log file for one job, the file name is the jobID + protected def createLogWriter(jobID: Int) { + try{ + val fileWriter = new PrintWriter(logDir + "/" + logDirName + "/" + jobID) + jobIDToPrintWriter += (jobID -> fileWriter) + } catch { + case e: FileNotFoundException => e.printStackTrace() + } + } + + // close log file for one job, and clean the stage relationship in stageIDToJobID + protected def closeLogWriter(jobID: Int) = + jobIDToPrintWriter.get(jobID).foreach { fileWriter => + fileWriter.close() + jobIDToStages.get(jobID).foreach(_.foreach{ stage => + stageIDToJobID -= stage.id + }) + jobIDToPrintWriter -= jobID + jobIDToStages -= jobID + } + + // write log information to log file, withTime parameter controls whether to recored + // time stamp for the information + protected def jobLogInfo(jobID: Int, info: String, withTime: Boolean = true) { + var writeInfo = info + if (withTime) { + val date = new Date(System.currentTimeMillis()) + writeInfo = DATE_FORMAT.format(date) + ": " +info + } + jobIDToPrintWriter.get(jobID).foreach(_.println(writeInfo)) + } + + protected def stageLogInfo(stageID: Int, info: String, withTime: Boolean = true) = + stageIDToJobID.get(stageID).foreach(jobID => jobLogInfo(jobID, info, withTime)) + + protected def buildJobDep(jobID: Int, stage: Stage) { + if (stage.priority == jobID) { + jobIDToStages.get(jobID) match { + case Some(stageList) => stageList += stage + case None => val stageList = new ListBuffer[Stage] + stageList += stage + jobIDToStages += (jobID -> stageList) + } + stageIDToJobID += (stage.id -> jobID) + stage.parents.foreach(buildJobDep(jobID, _)) + } + } + + protected def recordStageDep(jobID: Int) { + def getRddsInStage(rdd: RDD[_]): ListBuffer[RDD[_]] = { + var rddList = new ListBuffer[RDD[_]] + rddList += rdd + rdd.dependencies.foreach{ dep => dep match { + case shufDep: ShuffleDependency[_,_] => + case _ => rddList ++= getRddsInStage(dep.rdd) + } + } + rddList + } + jobIDToStages.get(jobID).foreach {_.foreach { stage => + var depRddDesc: String = "" + getRddsInStage(stage.rdd).foreach { rdd => + depRddDesc += rdd.id + "," + } + var depStageDesc: String = "" + stage.parents.foreach { stage => + depStageDesc += "(" + stage.id + "," + stage.shuffleDep.get.shuffleId + ")" + } + jobLogInfo(jobID, "STAGE_ID=" + stage.id + " RDD_DEP=(" + + depRddDesc.substring(0, depRddDesc.length - 1) + ")" + + " STAGE_DEP=" + depStageDesc, false) + } + } + } + + // generate indents and convert to String + protected def indentString(indent: Int) = { + val sb = new StringBuilder() + for (i <- 1 to indent) { + sb.append(" ") + } + sb.toString() + } + + protected def getRddName(rdd: RDD[_]) = { + var rddName = rdd.getClass.getName + if (rdd.name != null) { + rddName = rdd.name + } + rddName + } + + protected def recordRddInStageGraph(jobID: Int, rdd: RDD[_], indent: Int) { + val rddInfo = "RDD_ID=" + rdd.id + "(" + getRddName(rdd) + "," + rdd.generator + ")" + jobLogInfo(jobID, indentString(indent) + rddInfo, false) + rdd.dependencies.foreach{ dep => dep match { + case shufDep: ShuffleDependency[_,_] => + val depInfo = "SHUFFLE_ID=" + shufDep.shuffleId + jobLogInfo(jobID, indentString(indent + 1) + depInfo, false) + case _ => recordRddInStageGraph(jobID, dep.rdd, indent + 1) + } + } + } + + protected def recordStageDepGraph(jobID: Int, stage: Stage, indent: Int = 0) { + var stageInfo: String = "" + if (stage.isShuffleMap) { + stageInfo = "STAGE_ID=" + stage.id + " MAP_STAGE SHUFFLE_ID=" + + stage.shuffleDep.get.shuffleId + }else{ + stageInfo = "STAGE_ID=" + stage.id + " RESULT_STAGE" + } + if (stage.priority == jobID) { + jobLogInfo(jobID, indentString(indent) + stageInfo, false) + recordRddInStageGraph(jobID, stage.rdd, indent) + stage.parents.foreach(recordStageDepGraph(jobID, _, indent + 2)) + } else + jobLogInfo(jobID, indentString(indent) + stageInfo + " JOB_ID=" + stage.priority, false) + } + + // record task metrics into job log files + protected def recordTaskMetrics(stageID: Int, status: String, + taskInfo: TaskInfo, taskMetrics: TaskMetrics) { + val info = " TID=" + taskInfo.taskId + " STAGE_ID=" + stageID + + " START_TIME=" + taskInfo.launchTime + " FINISH_TIME=" + taskInfo.finishTime + + " EXECUTOR_ID=" + taskInfo.executorId + " HOST=" + taskMetrics.hostname + val executorRunTime = " EXECUTOR_RUN_TIME=" + taskMetrics.executorRunTime + val readMetrics = + taskMetrics.shuffleReadMetrics match { + case Some(metrics) => + " SHUFFLE_FINISH_TIME=" + metrics.shuffleFinishTime + + " BLOCK_FETCHED_TOTAL=" + metrics.totalBlocksFetched + + " BLOCK_FETCHED_LOCAL=" + metrics.localBlocksFetched + + " BLOCK_FETCHED_REMOTE=" + metrics.remoteBlocksFetched + + " REMOTE_FETCH_WAIT_TIME=" + metrics.fetchWaitTime + + " REMOTE_FETCH_TIME=" + metrics.remoteFetchTime + + " REMOTE_BYTES_READ=" + metrics.remoteBytesRead + case None => "" + } + val writeMetrics = + taskMetrics.shuffleWriteMetrics match { + case Some(metrics) => + " SHUFFLE_BYTES_WRITTEN=" + metrics.shuffleBytesWritten + case None => "" + } + stageLogInfo(stageID, status + info + executorRunTime + readMetrics + writeMetrics) + } + + override def onStageSubmitted(stage: Stage, info: String = "") { + eventQueue.put(JobLoggerOnStageSubmitted(stage, info)) + } + + protected def processStageSubmittedEvent(stage: Stage, info: String) { + stageLogInfo(stage.id, "STAGE_ID=" + stage.id + " STATUS=SUBMITTED " + info) + } + + override def onStageCompleted(stageCompleted: StageCompleted) { + eventQueue.put(JobLoggerOnStageCompleted(stageCompleted)) + } + + protected def processStageCompletedEvent(stageCompleted: StageCompleted) { + stageLogInfo(stageCompleted.stageInfo.stage.id, "STAGE_ID=" + + stageCompleted.stageInfo.stage.id + " STATUS=COMPLETED") + + } + + override def onTaskEnd(event: CompletionEvent) { + eventQueue.put(JobLoggerOnTaskEnd(event)) + } + + protected def processTaskEndEvent(event: CompletionEvent) { + var taskStatus = "" + event.task match { + case resultTask: ResultTask[_, _] => taskStatus = "TASK_TYPE=RESULT_TASK" + case shuffleMapTask: ShuffleMapTask => taskStatus = "TASK_TYPE=SHUFFLE_MAP_TASK" + } + event.reason match { + case Success => taskStatus += " STATUS=SUCCESS" + recordTaskMetrics(event.task.stageId, taskStatus, event.taskInfo, event.taskMetrics) + case Resubmitted => + taskStatus += " STATUS=RESUBMITTED TID=" + event.taskInfo.taskId + + " STAGE_ID=" + event.task.stageId + stageLogInfo(event.task.stageId, taskStatus) + case FetchFailed(bmAddress, shuffleId, mapId, reduceId) => + taskStatus += " STATUS=FETCHFAILED TID=" + event.taskInfo.taskId + " STAGE_ID=" + + event.task.stageId + " SHUFFLE_ID=" + shuffleId + " MAP_ID=" + + mapId + " REDUCE_ID=" + reduceId + stageLogInfo(event.task.stageId, taskStatus) + case OtherFailure(message) => + taskStatus += " STATUS=FAILURE TID=" + event.taskInfo.taskId + + " STAGE_ID=" + event.task.stageId + " INFO=" + message + stageLogInfo(event.task.stageId, taskStatus) + case _ => + } + } + + override def onJobEnd(job: ActiveJob, event: SparkListenerEvents) { + eventQueue.put(JobLoggerOnJobEnd(job, event)) + } + + protected def processJobEndEvent(job: ActiveJob, event: SparkListenerEvents) { + var info = "JOB_ID=" + job.runId + " STATUS=" + var validEvent = true + event match { + case SparkListenerJobSuccess => info += "SUCCESS" + case SparkListenerJobFailed(failedStage) => + info += "FAILED REASON=STAGE_FAILED FAILED_STAGE_ID=" + failedStage.id + case SparkListenerJobCancelled(reason) => info += "CANCELLED REASON=" + reason + case _ => validEvent = false + } + if (validEvent) { + jobLogInfo(job.runId, info) + closeLogWriter(job.runId) + } + } + + protected def recordJobProperties(jobID: Int, properties: Properties) { + if(properties != null) { + val annotation = properties.getProperty("spark.job.annotation", "") + jobLogInfo(jobID, annotation, false) + } + } + + override def onJobStart(job: ActiveJob, properties: Properties = null) { + eventQueue.put(JobLoggerOnJobStart(job, properties)) + } + + protected def processJobStartEvent(job: ActiveJob, properties: Properties) { + createLogWriter(job.runId) + recordJobProperties(job.runId, properties) + buildJobDep(job.runId, job.finalStage) + recordStageDep(job.runId) + recordStageDepGraph(job.runId, job.finalStage) + jobLogInfo(job.runId, "JOB_ID=" + job.runId + " STATUS=STARTED") + } +} diff --git a/core/src/main/scala/spark/scheduler/SparkListener.scala b/core/src/main/scala/spark/scheduler/SparkListener.scala index a65140b145..9cf7f3ffc0 100644 --- a/core/src/main/scala/spark/scheduler/SparkListener.scala +++ b/core/src/main/scala/spark/scheduler/SparkListener.scala @@ -1,27 +1,54 @@ package spark.scheduler +import java.util.Properties import spark.scheduler.cluster.TaskInfo import spark.util.Distribution -import spark.{Utils, Logging} +import spark.{Utils, Logging, SparkContext, TaskEndReason} import spark.executor.TaskMetrics trait SparkListener { /** * called when a stage is completed, with information on the completed stage */ - def onStageCompleted(stageCompleted: StageCompleted) + def onStageCompleted(stageCompleted: StageCompleted) { } + + /** + * called when a stage is submitted + */ + def onStageSubmitted(stage: Stage, info: String = "") { } + + /** + * called when a task ends + */ + def onTaskEnd(event: CompletionEvent) { } + + /** + * called when a job starts + */ + def onJobStart(job: ActiveJob, properties: Properties = null) { } + + /** + * called when a job ends + */ + def onJobEnd(job: ActiveJob, event: SparkListenerEvents) { } + } sealed trait SparkListenerEvents case class StageCompleted(val stageInfo: StageInfo) extends SparkListenerEvents +case object SparkListenerJobSuccess extends SparkListenerEvents + +case class SparkListenerJobFailed(failedStage: Stage) extends SparkListenerEvents + +case class SparkListenerJobCancelled(reason: String) extends SparkListenerEvents /** * Simple SparkListener that logs a few summary statistics when each stage completes */ class StatsReportListener extends SparkListener with Logging { - def onStageCompleted(stageCompleted: StageCompleted) { + override def onStageCompleted(stageCompleted: StageCompleted) { import spark.scheduler.StatsReportListener._ implicit val sc = stageCompleted this.logInfo("Finished stage: " + stageCompleted.stageInfo) diff --git a/core/src/test/scala/spark/scheduler/JobLoggerSuite.scala b/core/src/test/scala/spark/scheduler/JobLoggerSuite.scala new file mode 100644 index 0000000000..34fd8b995e --- /dev/null +++ b/core/src/test/scala/spark/scheduler/JobLoggerSuite.scala @@ -0,0 +1,105 @@ +package spark.scheduler + +import java.util.Properties +import java.util.concurrent.LinkedBlockingQueue +import org.scalatest.FunSuite +import org.scalatest.matchers.ShouldMatchers +import scala.collection.mutable +import spark._ +import spark.SparkContext._ + + +class JobLoggerSuite extends FunSuite with LocalSparkContext with ShouldMatchers { + + test("inner method") { + sc = new SparkContext("local", "joblogger") + val joblogger = new JobLogger { + def createLogWriterTest(jobID: Int) = createLogWriter(jobID) + def closeLogWriterTest(jobID: Int) = closeLogWriter(jobID) + def getRddNameTest(rdd: RDD[_]) = getRddName(rdd) + def buildJobDepTest(jobID: Int, stage: Stage) = buildJobDep(jobID, stage) + } + type MyRDD = RDD[(Int, Int)] + def makeRdd( + numPartitions: Int, + dependencies: List[Dependency[_]] + ): MyRDD = { + val maxPartition = numPartitions - 1 + return new MyRDD(sc, dependencies) { + override def compute(split: Partition, context: TaskContext): Iterator[(Int, Int)] = + throw new RuntimeException("should not be reached") + override def getPartitions = (0 to maxPartition).map(i => new Partition { + override def index = i + }).toArray + } + } + val jobID = 5 + val parentRdd = makeRdd(4, Nil) + val shuffleDep = new ShuffleDependency(parentRdd, null) + val rootRdd = makeRdd(4, List(shuffleDep)) + val shuffleMapStage = new Stage(1, parentRdd, Some(shuffleDep), Nil, jobID) + val rootStage = new Stage(0, rootRdd, None, List(shuffleMapStage), jobID) + + joblogger.onStageSubmitted(rootStage) + joblogger.getEventQueue.size should be (1) + joblogger.getRddNameTest(parentRdd) should be (parentRdd.getClass.getName) + parentRdd.setName("MyRDD") + joblogger.getRddNameTest(parentRdd) should be ("MyRDD") + joblogger.createLogWriterTest(jobID) + joblogger.getJobIDtoPrintWriter.size should be (1) + joblogger.buildJobDepTest(jobID, rootStage) + joblogger.getJobIDToStages.get(jobID).get.size should be (2) + joblogger.getStageIDToJobID.get(0) should be (Some(jobID)) + joblogger.getStageIDToJobID.get(1) should be (Some(jobID)) + joblogger.closeLogWriterTest(jobID) + joblogger.getStageIDToJobID.size should be (0) + joblogger.getJobIDToStages.size should be (0) + joblogger.getJobIDtoPrintWriter.size should be (0) + } + + test("inner variables") { + sc = new SparkContext("local[4]", "joblogger") + val joblogger = new JobLogger { + override protected def closeLogWriter(jobID: Int) = + getJobIDtoPrintWriter.get(jobID).foreach { fileWriter => + fileWriter.close() + } + } + sc.addSparkListener(joblogger) + val rdd = sc.parallelize(1 to 1e2.toInt, 4).map{ i => (i % 12, 2 * i) } + rdd.reduceByKey(_+_).collect() + + joblogger.getLogDir should be ("/tmp/spark") + joblogger.getJobIDtoPrintWriter.size should be (1) + joblogger.getStageIDToJobID.size should be (2) + joblogger.getStageIDToJobID.get(0) should be (Some(0)) + joblogger.getStageIDToJobID.get(1) should be (Some(0)) + joblogger.getJobIDToStages.size should be (1) + } + + + test("interface functions") { + sc = new SparkContext("local[4]", "joblogger") + val joblogger = new JobLogger { + var onTaskEndCount = 0 + var onJobEndCount = 0 + var onJobStartCount = 0 + var onStageCompletedCount = 0 + var onStageSubmittedCount = 0 + override def onTaskEnd(event: CompletionEvent) = onTaskEndCount += 1 + override def onJobEnd(job: ActiveJob, event: SparkListenerEvents) = onJobEndCount += 1 + override def onJobStart(job: ActiveJob, properties: Properties) = onJobStartCount += 1 + override def onStageCompleted(stageCompleted: StageCompleted) = onStageCompletedCount += 1 + override def onStageSubmitted(stage: Stage, info: String = "") = onStageSubmittedCount += 1 + } + sc.addSparkListener(joblogger) + val rdd = sc.parallelize(1 to 1e2.toInt, 4).map{ i => (i % 12, 2 * i) } + rdd.reduceByKey(_+_).collect() + + joblogger.onJobStartCount should be (1) + joblogger.onJobEndCount should be (1) + joblogger.onTaskEndCount should be (8) + joblogger.onStageSubmittedCount should be (2) + joblogger.onStageCompletedCount should be (2) + } +} diff --git a/core/src/test/scala/spark/scheduler/SparkListenerSuite.scala b/core/src/test/scala/spark/scheduler/SparkListenerSuite.scala index 42a87d8b90..48aa67c543 100644 --- a/core/src/test/scala/spark/scheduler/SparkListenerSuite.scala +++ b/core/src/test/scala/spark/scheduler/SparkListenerSuite.scala @@ -77,7 +77,7 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatc class SaveStageInfo extends SparkListener { val stageInfos = mutable.Buffer[StageInfo]() - def onStageCompleted(stage: StageCompleted) { + override def onStageCompleted(stage: StageCompleted) { stageInfos += stage.stageInfo } } -- cgit v1.2.3 From 4fd86e0e10149ad1803831a308a056c7105cbe67 Mon Sep 17 00:00:00 2001 From: Mingfei Date: Sat, 8 Jun 2013 15:45:47 +0800 Subject: delete test code for joblogger in SparkContext --- core/src/main/scala/spark/SparkContext.scala | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index b67a2066c8..70a9d7698c 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -48,7 +48,6 @@ import spark.scheduler.local.LocalScheduler import spark.scheduler.mesos.{CoarseMesosSchedulerBackend, MesosSchedulerBackend} import spark.storage.{BlockManagerUI, StorageStatus, StorageUtils, RDDInfo} import spark.util.{MetadataCleaner, TimeStampedHashMap} -import spark.scheduler.JobLogger /** * Main entry point for Spark functionality. A SparkContext represents the connection to a Spark @@ -510,7 +509,7 @@ class SparkContext( def addSparkListener(listener: SparkListener) { dagScheduler.sparkListeners += listener } - addSparkListener(new JobLogger) + /** * Return a map from the slave to the max memory available for caching and the remaining * memory available for caching. -- cgit v1.2.3 From ade822011d44bd43e9ac78c1d29ec924a1f6e8e7 Mon Sep 17 00:00:00 2001 From: Mingfei Date: Sat, 8 Jun 2013 16:26:45 +0800 Subject: not check return value of eventQueue.take --- .../src/main/scala/spark/scheduler/JobLogger.scala | 28 ++++++++++------------ 1 file changed, 13 insertions(+), 15 deletions(-) diff --git a/core/src/main/scala/spark/scheduler/JobLogger.scala b/core/src/main/scala/spark/scheduler/JobLogger.scala index f87acfd0b6..46b9fa974b 100644 --- a/core/src/main/scala/spark/scheduler/JobLogger.scala +++ b/core/src/main/scala/spark/scheduler/JobLogger.scala @@ -48,21 +48,19 @@ class JobLogger(val logDirName: String) extends SparkListener with Logging { override def run() { while (true) { val event = eventQueue.take - if (event != null) { - logDebug("Got event of type " + event.getClass.getName) - event match { - case JobLoggerOnJobStart(job, info) => - processJobStartEvent(job, info) - case JobLoggerOnStageSubmitted(stage, info) => - processStageSubmittedEvent(stage, info) - case JobLoggerOnStageCompleted(stageCompleted) => - processStageCompletedEvent(stageCompleted) - case JobLoggerOnJobEnd(job, event) => - processJobEndEvent(job, event) - case JobLoggerOnTaskEnd(event) => - processTaskEndEvent(event) - case _ => - } + logDebug("Got event of type " + event.getClass.getName) + event match { + case JobLoggerOnJobStart(job, info) => + processJobStartEvent(job, info) + case JobLoggerOnStageSubmitted(stage, info) => + processStageSubmittedEvent(stage, info) + case JobLoggerOnStageCompleted(stageCompleted) => + processStageCompletedEvent(stageCompleted) + case JobLoggerOnJobEnd(job, event) => + processJobEndEvent(job, event) + case JobLoggerOnTaskEnd(event) => + processTaskEndEvent(event) + case _ => } } } -- cgit v1.2.3 From d1bbcebae580220076ceaa65f84dcf984ab51a16 Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Sat, 8 Jun 2013 16:58:42 -0700 Subject: Adding compression to Hadoop save functions --- core/src/main/scala/spark/PairRDDFunctions.scala | 39 +++++++++++++++++- core/src/main/scala/spark/RDD.scala | 9 ++++ .../scala/spark/SequenceFileRDDFunctions.scala | 15 ++++--- core/src/test/scala/spark/FileSuite.scala | 48 ++++++++++++++++++++++ 4 files changed, 105 insertions(+), 6 deletions(-) diff --git a/core/src/main/scala/spark/PairRDDFunctions.scala b/core/src/main/scala/spark/PairRDDFunctions.scala index 2b0e697337..9bf1227d65 100644 --- a/core/src/main/scala/spark/PairRDDFunctions.scala +++ b/core/src/main/scala/spark/PairRDDFunctions.scala @@ -10,6 +10,8 @@ import scala.collection.JavaConversions._ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path +import org.apache.hadoop.io.compress.CompressionCodec +import org.apache.hadoop.io.SequenceFile.CompressionType import org.apache.hadoop.mapred.FileOutputCommitter import org.apache.hadoop.mapred.FileOutputFormat import org.apache.hadoop.mapred.HadoopWriter @@ -515,6 +517,16 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest]( saveAsHadoopFile(path, getKeyClass, getValueClass, fm.erasure.asInstanceOf[Class[F]]) } + /** + * Output the RDD to any Hadoop-supported file system, using a Hadoop `OutputFormat` class + * supporting the key and value types K and V in this RDD. Compress the result with the + * supplied codec. + */ + def saveAsHadoopFile[F <: OutputFormat[K, V]]( + path: String, codec: Class[_ <: CompressionCodec]) (implicit fm: ClassManifest[F]) { + saveAsHadoopFile(path, getKeyClass, getValueClass, fm.erasure.asInstanceOf[Class[F]], codec) + } + /** * Output the RDD to any Hadoop-supported file system, using a new Hadoop API `OutputFormat` * (mapreduce.OutputFormat) object supporting the key and value types K and V in this RDD. @@ -574,6 +586,20 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest]( jobCommitter.cleanupJob(jobTaskContext) } + /** + * Output the RDD to any Hadoop-supported file system, using a Hadoop `OutputFormat` class + * supporting the key and value types K and V in this RDD. Compress with the supplied codec. + */ + def saveAsHadoopFile( + path: String, + keyClass: Class[_], + valueClass: Class[_], + outputFormatClass: Class[_ <: OutputFormat[_, _]], + codec: Class[_ <: CompressionCodec]) { + saveAsHadoopFile(path, keyClass, valueClass, outputFormatClass, + new JobConf(self.context.hadoopConfiguration), Some(codec)) + } + /** * Output the RDD to any Hadoop-supported file system, using a Hadoop `OutputFormat` class * supporting the key and value types K and V in this RDD. @@ -583,11 +609,22 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest]( keyClass: Class[_], valueClass: Class[_], outputFormatClass: Class[_ <: OutputFormat[_, _]], - conf: JobConf = new JobConf(self.context.hadoopConfiguration)) { + conf: JobConf = new JobConf(self.context.hadoopConfiguration), + codec: Option[Class[_ <: CompressionCodec]] = None) { conf.setOutputKeyClass(keyClass) conf.setOutputValueClass(valueClass) // conf.setOutputFormat(outputFormatClass) // Doesn't work in Scala 2.9 due to what may be a generics bug conf.set("mapred.output.format.class", outputFormatClass.getName) + codec match { + case Some(c) => { + conf.setCompressMapOutput(true) + conf.set("mapred.output.compress", "true") + conf.setMapOutputCompressorClass(c) + conf.set("mapred.output.compression.codec", c.getCanonicalName) + conf.set("mapred.output.compression.type", CompressionType.BLOCK.toString) + } + case _ => + } conf.setOutputCommitter(classOf[FileOutputCommitter]) FileOutputFormat.setOutputPath(conf, HadoopWriter.createPathFromString(path, conf)) saveAsHadoopDataset(conf) diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala index e6c0438d76..e5995bea22 100644 --- a/core/src/main/scala/spark/RDD.scala +++ b/core/src/main/scala/spark/RDD.scala @@ -7,6 +7,7 @@ import scala.collection.JavaConversions.mapAsScalaMap import scala.collection.mutable.ArrayBuffer import org.apache.hadoop.io.BytesWritable +import org.apache.hadoop.io.compress.CompressionCodec import org.apache.hadoop.io.NullWritable import org.apache.hadoop.io.Text import org.apache.hadoop.mapred.TextOutputFormat @@ -730,6 +731,14 @@ abstract class RDD[T: ClassManifest]( .saveAsHadoopFile[TextOutputFormat[NullWritable, Text]](path) } + /** + * Save this RDD as a compressed text file, using string representations of elements. + */ + def saveAsTextFile(path: String, codec: Class[_ <: CompressionCodec]) { + this.map(x => (NullWritable.get(), new Text(x.toString))) + .saveAsHadoopFile[TextOutputFormat[NullWritable, Text]](path, codec) + } + /** * Save this RDD as a SequenceFile of serialized objects. */ diff --git a/core/src/main/scala/spark/SequenceFileRDDFunctions.scala b/core/src/main/scala/spark/SequenceFileRDDFunctions.scala index 518034e07b..2911f9036e 100644 --- a/core/src/main/scala/spark/SequenceFileRDDFunctions.scala +++ b/core/src/main/scala/spark/SequenceFileRDDFunctions.scala @@ -18,6 +18,7 @@ import org.apache.hadoop.mapred.TextOutputFormat import org.apache.hadoop.mapred.SequenceFileOutputFormat import org.apache.hadoop.mapred.OutputCommitter import org.apache.hadoop.mapred.FileOutputCommitter +import org.apache.hadoop.io.compress.CompressionCodec import org.apache.hadoop.io.Writable import org.apache.hadoop.io.NullWritable import org.apache.hadoop.io.BytesWritable @@ -62,7 +63,7 @@ class SequenceFileRDDFunctions[K <% Writable: ClassManifest, V <% Writable : Cla * byte arrays to BytesWritable, and Strings to Text. The `path` can be on any Hadoop-supported * file system. */ - def saveAsSequenceFile(path: String) { + def saveAsSequenceFile(path: String, codec: Option[Class[_ <: CompressionCodec]] = None) { def anyToWritable[U <% Writable](u: U): Writable = u val keyClass = getWritableClass[K] @@ -72,14 +73,18 @@ class SequenceFileRDDFunctions[K <% Writable: ClassManifest, V <% Writable : Cla logInfo("Saving as sequence file of type (" + keyClass.getSimpleName + "," + valueClass.getSimpleName + ")" ) val format = classOf[SequenceFileOutputFormat[Writable, Writable]] + val jobConf = new JobConf(self.context.hadoopConfiguration) if (!convertKey && !convertValue) { - self.saveAsHadoopFile(path, keyClass, valueClass, format) + self.saveAsHadoopFile(path, keyClass, valueClass, format, jobConf, codec) } else if (!convertKey && convertValue) { - self.map(x => (x._1,anyToWritable(x._2))).saveAsHadoopFile(path, keyClass, valueClass, format) + self.map(x => (x._1,anyToWritable(x._2))).saveAsHadoopFile( + path, keyClass, valueClass, format, jobConf, codec) } else if (convertKey && !convertValue) { - self.map(x => (anyToWritable(x._1),x._2)).saveAsHadoopFile(path, keyClass, valueClass, format) + self.map(x => (anyToWritable(x._1),x._2)).saveAsHadoopFile( + path, keyClass, valueClass, format, jobConf, codec) } else if (convertKey && convertValue) { - self.map(x => (anyToWritable(x._1),anyToWritable(x._2))).saveAsHadoopFile(path, keyClass, valueClass, format) + self.map(x => (anyToWritable(x._1),anyToWritable(x._2))).saveAsHadoopFile( + path, keyClass, valueClass, format, jobConf, codec) } } } diff --git a/core/src/test/scala/spark/FileSuite.scala b/core/src/test/scala/spark/FileSuite.scala index 91b48c7456..a5d2028591 100644 --- a/core/src/test/scala/spark/FileSuite.scala +++ b/core/src/test/scala/spark/FileSuite.scala @@ -7,6 +7,8 @@ import scala.io.Source import com.google.common.io.Files import org.scalatest.FunSuite import org.apache.hadoop.io._ +import org.apache.hadoop.io.compress.{DefaultCodec, CompressionCodec, GzipCodec} + import SparkContext._ @@ -26,6 +28,29 @@ class FileSuite extends FunSuite with LocalSparkContext { assert(sc.textFile(outputDir).collect().toList === List("1", "2", "3", "4")) } + test("text files (compressed)") { + sc = new SparkContext("local", "test") + val tempDir = Files.createTempDir() + val normalDir = new File(tempDir, "output_normal").getAbsolutePath + val compressedOutputDir = new File(tempDir, "output_compressed").getAbsolutePath + val codec = new DefaultCodec() + + val data = sc.parallelize("a" * 10000, 1) + data.saveAsTextFile(normalDir) + data.saveAsTextFile(compressedOutputDir, classOf[DefaultCodec]) + + val normalFile = new File(normalDir, "part-00000") + val normalContent = sc.textFile(normalDir).collect + assert(normalContent === Array.fill(10000)("a")) + + val compressedFile = new File(compressedOutputDir, "part-00000" + codec.getDefaultExtension) + val compressedContent = sc.textFile(compressedOutputDir).collect + assert(compressedContent === Array.fill(10000)("a")) + + assert(compressedFile.length < normalFile.length) + } + + test("SequenceFiles") { sc = new SparkContext("local", "test") val tempDir = Files.createTempDir() @@ -37,6 +62,29 @@ class FileSuite extends FunSuite with LocalSparkContext { assert(output.map(_.toString).collect().toList === List("(1,a)", "(2,aa)", "(3,aaa)")) } + test("SequenceFile (compressed)") { + sc = new SparkContext("local", "test") + val tempDir = Files.createTempDir() + val normalDir = new File(tempDir, "output_normal").getAbsolutePath + val compressedOutputDir = new File(tempDir, "output_compressed").getAbsolutePath + val codec = new DefaultCodec() + + val data = sc.parallelize(Seq.fill(100)("abc"), 1).map(x => (x, x)) + data.saveAsSequenceFile(normalDir) + data.saveAsSequenceFile(compressedOutputDir, Some(classOf[DefaultCodec])) + + val normalFile = new File(normalDir, "part-00000") + val normalContent = sc.sequenceFile[String, String](normalDir).collect + assert(normalContent === Array.fill(100)("abc", "abc")) + + val compressedFile = new File(compressedOutputDir, "part-00000" + codec.getDefaultExtension) + val compressedContent = sc.sequenceFile[String, String](compressedOutputDir).collect + assert(compressedContent === Array.fill(100)("abc", "abc")) + + assert(compressedFile.length < normalFile.length) + } + + test("SequenceFile with writable key") { sc = new SparkContext("local", "test") val tempDir = Files.createTempDir() -- cgit v1.2.3 From 083a3485abdcda5913c2186c4a7930ac07b061c4 Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Sun, 9 Jun 2013 11:49:33 -0700 Subject: Clean extra whitespace --- core/src/test/scala/spark/FileSuite.scala | 2 -- 1 file changed, 2 deletions(-) diff --git a/core/src/test/scala/spark/FileSuite.scala b/core/src/test/scala/spark/FileSuite.scala index a5d2028591..e61ff7793d 100644 --- a/core/src/test/scala/spark/FileSuite.scala +++ b/core/src/test/scala/spark/FileSuite.scala @@ -50,7 +50,6 @@ class FileSuite extends FunSuite with LocalSparkContext { assert(compressedFile.length < normalFile.length) } - test("SequenceFiles") { sc = new SparkContext("local", "test") val tempDir = Files.createTempDir() @@ -84,7 +83,6 @@ class FileSuite extends FunSuite with LocalSparkContext { assert(compressedFile.length < normalFile.length) } - test("SequenceFile with writable key") { sc = new SparkContext("local", "test") val tempDir = Files.createTempDir() -- cgit v1.2.3 From df592192e736edca9e382a7f92e15bead390ef65 Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Sun, 9 Jun 2013 18:09:24 -0700 Subject: Monads FTW --- core/src/main/scala/spark/PairRDDFunctions.scala | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/core/src/main/scala/spark/PairRDDFunctions.scala b/core/src/main/scala/spark/PairRDDFunctions.scala index 9bf1227d65..15593db0d9 100644 --- a/core/src/main/scala/spark/PairRDDFunctions.scala +++ b/core/src/main/scala/spark/PairRDDFunctions.scala @@ -615,15 +615,12 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest]( conf.setOutputValueClass(valueClass) // conf.setOutputFormat(outputFormatClass) // Doesn't work in Scala 2.9 due to what may be a generics bug conf.set("mapred.output.format.class", outputFormatClass.getName) - codec match { - case Some(c) => { - conf.setCompressMapOutput(true) - conf.set("mapred.output.compress", "true") - conf.setMapOutputCompressorClass(c) - conf.set("mapred.output.compression.codec", c.getCanonicalName) - conf.set("mapred.output.compression.type", CompressionType.BLOCK.toString) - } - case _ => + for (c <- codec) { + conf.setCompressMapOutput(true) + conf.set("mapred.output.compress", "true") + conf.setMapOutputCompressorClass(c) + conf.set("mapred.output.compression.codec", c.getCanonicalName) + conf.set("mapred.output.compression.type", CompressionType.BLOCK.toString) } conf.setOutputCommitter(classOf[FileOutputCommitter]) FileOutputFormat.setOutputPath(conf, HadoopWriter.createPathFromString(path, conf)) -- cgit v1.2.3 From ef14dc2e7736732932d4edceb3be8d81ba9f8bc7 Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Sun, 9 Jun 2013 18:09:46 -0700 Subject: Adding Java-API version of compression codec --- .../main/scala/spark/api/java/JavaPairRDD.scala | 11 ++++++ .../main/scala/spark/api/java/JavaRDDLike.scala | 8 ++++ core/src/test/scala/spark/JavaAPISuite.java | 46 ++++++++++++++++++++++ 3 files changed, 65 insertions(+) diff --git a/core/src/main/scala/spark/api/java/JavaPairRDD.scala b/core/src/main/scala/spark/api/java/JavaPairRDD.scala index 30084df4e2..76051597b6 100644 --- a/core/src/main/scala/spark/api/java/JavaPairRDD.scala +++ b/core/src/main/scala/spark/api/java/JavaPairRDD.scala @@ -6,6 +6,7 @@ import java.util.Comparator import scala.Tuple2 import scala.collection.JavaConversions._ +import org.apache.hadoop.io.compress.CompressionCodec import org.apache.hadoop.mapred.JobConf import org.apache.hadoop.mapred.OutputFormat import org.apache.hadoop.mapreduce.{OutputFormat => NewOutputFormat} @@ -459,6 +460,16 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kManifest: ClassManif rdd.saveAsHadoopFile(path, keyClass, valueClass, outputFormatClass) } + /** Output the RDD to any Hadoop-supported file system, compressing with the supplied codec. */ + def saveAsHadoopFile[F <: OutputFormat[_, _]]( + path: String, + keyClass: Class[_], + valueClass: Class[_], + outputFormatClass: Class[F], + codec: Class[_ <: CompressionCodec]) { + rdd.saveAsHadoopFile(path, keyClass, valueClass, outputFormatClass, codec) + } + /** Output the RDD to any Hadoop-supported file system. */ def saveAsNewAPIHadoopFile[F <: NewOutputFormat[_, _]]( path: String, diff --git a/core/src/main/scala/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/spark/api/java/JavaRDDLike.scala index 9b74d1226f..76b14e2e04 100644 --- a/core/src/main/scala/spark/api/java/JavaRDDLike.scala +++ b/core/src/main/scala/spark/api/java/JavaRDDLike.scala @@ -4,6 +4,7 @@ import java.util.{List => JList} import scala.Tuple2 import scala.collection.JavaConversions._ +import org.apache.hadoop.io.compress.CompressionCodec import spark.{SparkContext, Partition, RDD, TaskContext} import spark.api.java.JavaPairRDD._ import spark.api.java.function.{Function2 => JFunction2, Function => JFunction, _} @@ -310,6 +311,13 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { */ def saveAsTextFile(path: String) = rdd.saveAsTextFile(path) + + /** + * Save this RDD as a compressed text file, using string representations of elements. + */ + def saveAsTextFile(path: String, codec: Class[_ <: CompressionCodec]) = + rdd.saveAsTextFile(path, codec) + /** * Save this RDD as a SequenceFile of serialized objects. */ diff --git a/core/src/test/scala/spark/JavaAPISuite.java b/core/src/test/scala/spark/JavaAPISuite.java index 93bb69b41c..6caa85119a 100644 --- a/core/src/test/scala/spark/JavaAPISuite.java +++ b/core/src/test/scala/spark/JavaAPISuite.java @@ -8,6 +8,7 @@ import java.util.*; import scala.Tuple2; import com.google.common.base.Charsets; +import org.apache.hadoop.io.compress.DefaultCodec; import com.google.common.io.Files; import org.apache.hadoop.io.IntWritable; import org.apache.hadoop.io.Text; @@ -473,6 +474,19 @@ public class JavaAPISuite implements Serializable { Assert.assertEquals(expected, readRDD.collect()); } + @Test + public void textFilesCompressed() throws IOException { + File tempDir = Files.createTempDir(); + String outputDir = new File(tempDir, "output").getAbsolutePath(); + JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4)); + rdd.saveAsTextFile(outputDir, DefaultCodec.class); + + // Try reading it in as a text file RDD + List expected = Arrays.asList("1", "2", "3", "4"); + JavaRDD readRDD = sc.textFile(outputDir); + Assert.assertEquals(expected, readRDD.collect()); + } + @Test public void sequenceFile() { File tempDir = Files.createTempDir(); @@ -619,6 +633,38 @@ public class JavaAPISuite implements Serializable { }).collect().toString()); } + @Test + public void hadoopFileCompressed() { + File tempDir = Files.createTempDir(); + String outputDir = new File(tempDir, "output_compressed").getAbsolutePath(); + List> pairs = Arrays.asList( + new Tuple2(1, "a"), + new Tuple2(2, "aa"), + new Tuple2(3, "aaa") + ); + JavaPairRDD rdd = sc.parallelizePairs(pairs); + + rdd.map(new PairFunction, IntWritable, Text>() { + @Override + public Tuple2 call(Tuple2 pair) { + return new Tuple2(new IntWritable(pair._1()), new Text(pair._2())); + } + }).saveAsHadoopFile(outputDir, IntWritable.class, Text.class, SequenceFileOutputFormat.class, + DefaultCodec.class); + + System.out.println(outputDir); + JavaPairRDD output = sc.hadoopFile(outputDir, + SequenceFileInputFormat.class, IntWritable.class, Text.class); + + Assert.assertEquals(pairs.toString(), output.map(new Function, + String>() { + @Override + public String call(Tuple2 x) { + return x.toString(); + } + }).collect().toString()); + } + @Test public void zip() { JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5)); -- cgit v1.2.3 From 190ec617997d621c11ed1aab662a6e3a06815d2f Mon Sep 17 00:00:00 2001 From: Andrew xia Date: Mon, 10 Jun 2013 15:27:02 +0800 Subject: change code style and debug info --- core/src/main/scala/spark/scheduler/local/LocalScheduler.scala | 8 +++----- .../main/scala/spark/scheduler/local/LocalTaskSetManager.scala | 1 - 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala b/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala index 69dacfc2bd..93d4318b29 100644 --- a/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala +++ b/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala @@ -34,8 +34,7 @@ private[spark] class LocalActor(localScheduler: LocalScheduler, var freeCores: I } def launchTask(tasks : Seq[TaskDescription]) { - for (task <- tasks) - { + for (task <- tasks) { freeCores -= 1 localScheduler.threadPool.submit(new Runnable { def run() { @@ -85,8 +84,7 @@ private[spark] class LocalScheduler(threads: Int, val maxFailures: Int, val sc: } schedulableBuilder.buildPools() - localActor = env.actorSystem.actorOf( - Props(new LocalActor(this, threads)), "Test") + localActor = env.actorSystem.actorOf(Props(new LocalActor(this, threads)), "Test") } override def setListener(listener: TaskSchedulerListener) { @@ -109,7 +107,7 @@ private[spark] class LocalScheduler(threads: Int, val maxFailures: Int, val sc: val tasks = new ArrayBuffer[TaskDescription](freeCores) val sortedTaskSetQueue = rootPool.getSortedTaskSetQueue() for (manager <- sortedTaskSetQueue) { - logInfo("parentName:%s,name:%s,runningTasks:%s".format(manager.parent.name, manager.name, manager.runningTasks)) + logDebug("parentName:%s,name:%s,runningTasks:%s".format(manager.parent.name, manager.name, manager.runningTasks)) } var launchTask = false diff --git a/core/src/main/scala/spark/scheduler/local/LocalTaskSetManager.scala b/core/src/main/scala/spark/scheduler/local/LocalTaskSetManager.scala index f2e07d162a..70b69bb26f 100644 --- a/core/src/main/scala/spark/scheduler/local/LocalTaskSetManager.scala +++ b/core/src/main/scala/spark/scheduler/local/LocalTaskSetManager.scala @@ -91,7 +91,6 @@ private[spark] class LocalTaskSetManager(sched: LocalScheduler, val taskSet: Tas if (availableCpus > 0 && numFinished < numTasks) { findTask() match { case Some(index) => - logInfo(taskSet.tasks(index).toString) val taskId = sched.attemptId.getAndIncrement() val task = taskSet.tasks(index) val info = new TaskInfo(taskId, index, System.currentTimeMillis(), "local", "local:1", TaskLocality.NODE_LOCAL) -- cgit v1.2.3 From fd6148c8b20bc051786ff574d3b8f3b5e79b391a Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Mon, 10 Jun 2013 10:27:25 -0700 Subject: Removing print statement --- core/src/test/scala/spark/JavaAPISuite.java | 1 - 1 file changed, 1 deletion(-) diff --git a/core/src/test/scala/spark/JavaAPISuite.java b/core/src/test/scala/spark/JavaAPISuite.java index 6caa85119a..d306124fca 100644 --- a/core/src/test/scala/spark/JavaAPISuite.java +++ b/core/src/test/scala/spark/JavaAPISuite.java @@ -652,7 +652,6 @@ public class JavaAPISuite implements Serializable { }).saveAsHadoopFile(outputDir, IntWritable.class, Text.class, SequenceFileOutputFormat.class, DefaultCodec.class); - System.out.println(outputDir); JavaPairRDD output = sc.hadoopFile(outputDir, SequenceFileInputFormat.class, IntWritable.class, Text.class); -- cgit v1.2.3 From dc4073654b1707f115de30088938f6e53efda0ba Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Tue, 11 Jun 2013 00:08:02 -0400 Subject: Revert "Fix start-slave not passing instance number to spark-daemon." This reverts commit a674d67c0aebb940e3b816e2307206115baec175. --- bin/start-slave.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bin/start-slave.sh b/bin/start-slave.sh index dfcbc6981b..26b5b9d462 100755 --- a/bin/start-slave.sh +++ b/bin/start-slave.sh @@ -12,4 +12,4 @@ if [ "$SPARK_PUBLIC_DNS" = "" ]; then fi fi -"$bin"/spark-daemon.sh start spark.deploy.worker.Worker 1 "$@" +"$bin"/spark-daemon.sh start spark.deploy.worker.Worker "$@" -- cgit v1.2.3 From db5bca08ff00565732946a9c0a0244a9f7021d82 Mon Sep 17 00:00:00 2001 From: ryanlecompte Date: Wed, 12 Jun 2013 10:54:16 -0700 Subject: add a new top K method to RDD using a bounded priority queue --- core/src/main/scala/spark/RDD.scala | 24 +++++++++++ .../scala/spark/util/BoundedPriorityQueue.scala | 48 ++++++++++++++++++++++ core/src/test/scala/spark/RDDSuite.scala | 19 +++++++++ 3 files changed, 91 insertions(+) create mode 100644 core/src/main/scala/spark/util/BoundedPriorityQueue.scala diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala index e6c0438d76..ec5e5e2433 100644 --- a/core/src/main/scala/spark/RDD.scala +++ b/core/src/main/scala/spark/RDD.scala @@ -35,6 +35,7 @@ import spark.rdd.ZippedPartitionsRDD2 import spark.rdd.ZippedPartitionsRDD3 import spark.rdd.ZippedPartitionsRDD4 import spark.storage.StorageLevel +import spark.util.BoundedPriorityQueue import SparkContext._ @@ -722,6 +723,29 @@ abstract class RDD[T: ClassManifest]( case _ => throw new UnsupportedOperationException("empty collection") } + /** + * Returns the top K elements from this RDD as defined by + * the specified implicit Ordering[T]. + * @param num the number of top elements to return + * @param ord the implicit ordering for T + * @return an array of top elements + */ + def top(num: Int)(implicit ord: Ordering[T]): Array[T] = { + val topK = mapPartitions { items => + val queue = new BoundedPriorityQueue[T](num) + queue ++= items + Iterator(queue) + }.reduce { (queue1, queue2) => + queue1 ++= queue2 + queue1 + } + + val builder = Array.newBuilder[T] + builder.sizeHint(topK.size) + builder ++= topK + builder.result() + } + /** * Save this RDD as a text file, using string representations of elements. */ diff --git a/core/src/main/scala/spark/util/BoundedPriorityQueue.scala b/core/src/main/scala/spark/util/BoundedPriorityQueue.scala new file mode 100644 index 0000000000..53ee95a02e --- /dev/null +++ b/core/src/main/scala/spark/util/BoundedPriorityQueue.scala @@ -0,0 +1,48 @@ +package spark.util + +import java.util.{PriorityQueue => JPriorityQueue} +import scala.collection.generic.Growable + +/** + * Bounded priority queue. This class modifies the original PriorityQueue's + * add/offer methods such that only the top K elements are retained. The top + * K elements are defined by an implicit Ordering[A]. + */ +class BoundedPriorityQueue[A](maxSize: Int)(implicit ord: Ordering[A], mf: ClassManifest[A]) + extends JPriorityQueue[A](maxSize, ord) with Growable[A] { + + override def offer(a: A): Boolean = { + if (size < maxSize) super.offer(a) + else maybeReplaceLowest(a) + } + + override def add(a: A): Boolean = offer(a) + + override def ++=(xs: TraversableOnce[A]): this.type = { + xs.foreach(add) + this + } + + override def +=(elem: A): this.type = { + add(elem) + this + } + + override def +=(elem1: A, elem2: A, elems: A*): this.type = { + this += elem1 += elem2 ++= elems + } + + private def maybeReplaceLowest(a: A): Boolean = { + val head = peek() + if (head != null && ord.gt(a, head)) { + poll() + super.offer(a) + } else false + } +} + +object BoundedPriorityQueue { + import scala.collection.JavaConverters._ + implicit def asIterable[A](queue: BoundedPriorityQueue[A]): Iterable[A] = queue.asScala +} + diff --git a/core/src/test/scala/spark/RDDSuite.scala b/core/src/test/scala/spark/RDDSuite.scala index 3f69e99780..67f3332d44 100644 --- a/core/src/test/scala/spark/RDDSuite.scala +++ b/core/src/test/scala/spark/RDDSuite.scala @@ -317,4 +317,23 @@ class RDDSuite extends FunSuite with LocalSparkContext { assert(sample.size === checkSample.size) for (i <- 0 until sample.size) assert(sample(i) === checkSample(i)) } + + test("top with predefined ordering") { + sc = new SparkContext("local", "test") + val nums = Array.range(1, 100000) + val ints = sc.makeRDD(scala.util.Random.shuffle(nums), 2) + val topK = ints.top(5) + assert(topK.size === 5) + assert(topK.sorted === nums.sorted.takeRight(5)) + } + + test("top with custom ordering") { + sc = new SparkContext("local", "test") + val words = Vector("a", "b", "c", "d") + implicit val ord = implicitly[Ordering[String]].reverse + val rdd = sc.makeRDD(words, 2) + val topK = rdd.top(2) + assert(topK.size === 2) + assert(topK.sorted === Array("b", "a")) + } } -- cgit v1.2.3 From 3f96c6f27b08039fb7b8d295f5de2083544e979f Mon Sep 17 00:00:00 2001 From: Mark Hamstra Date: Wed, 12 Jun 2013 17:20:05 -0700 Subject: Fixed jvmArgs in maven build. --- pom.xml | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/pom.xml b/pom.xml index ce77ba37c6..c893ec755e 100644 --- a/pom.xml +++ b/pom.xml @@ -60,7 +60,7 @@ 4.1.2 1.2.17 - 0m + 64m 512m @@ -395,10 +395,8 @@ -Xms64m -Xmx1024m - -XX:PermSize - ${PermGen} - -XX:MaxPermSize - ${MaxPermGen} + -XX:PermSize=${PermGen} + -XX:MaxPermSize=${MaxPermGen} -source -- cgit v1.2.3 From 967a6a699da7da007f51e59d085a357da5ec14da Mon Sep 17 00:00:00 2001 From: Mingfei Date: Thu, 13 Jun 2013 14:36:07 +0800 Subject: modify sparklister function interface according to comments --- .../main/scala/spark/scheduler/DAGScheduler.scala | 15 ++-- .../src/main/scala/spark/scheduler/JobLogger.scala | 89 +++++++++++----------- .../main/scala/spark/scheduler/SparkListener.scala | 38 +++++---- .../scala/spark/scheduler/JobLoggerSuite.scala | 10 +-- 4 files changed, 79 insertions(+), 73 deletions(-) diff --git a/core/src/main/scala/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/spark/scheduler/DAGScheduler.scala index 43dd7d6534..e281e5a8db 100644 --- a/core/src/main/scala/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/spark/scheduler/DAGScheduler.scala @@ -289,7 +289,7 @@ class DAGScheduler( val finalStage = newStage(finalRDD, None, runId) val job = new ActiveJob(runId, finalStage, func, partitions, callSite, listener, properties) clearCacheLocs() - sparkListeners.foreach(_.onJobStart(job, properties)) + sparkListeners.foreach(_.onJobStart(SparkListenerJobStart(job, properties))) logInfo("Got job " + job.runId + " (" + callSite + ") with " + partitions.length + " output partitions (allowLocal=" + allowLocal + ")") logInfo("Final stage: " + finalStage + " (" + finalStage.origin + ")") @@ -312,7 +312,7 @@ class DAGScheduler( handleExecutorLost(execId) case completion: CompletionEvent => - sparkListeners.foreach(_.onTaskEnd(completion)) + sparkListeners.foreach(_.onTaskEnd(SparkListenerTaskEnd(completion))) handleTaskCompletion(completion) case TaskSetFailed(taskSet, reason) => @@ -323,8 +323,8 @@ class DAGScheduler( for (job <- activeJobs) { val error = new SparkException("Job cancelled because SparkContext was shut down") job.listener.jobFailed(error) - val JobCancelEvent = new SparkListenerJobCancelled("SPARKCONTEXT_SHUTDOWN") - sparkListeners.foreach(_.onJobEnd(job, JobCancelEvent)) + sparkListeners.foreach(_.onJobEnd(SparkListenerJobCancelled(job, + "SPARKCONTEXT_SHUTDOWN"))) } return true } @@ -472,7 +472,7 @@ class DAGScheduler( } } if (tasks.size > 0) { - sparkListeners.foreach(_.onStageSubmitted(stage, "TASKS_SIZE=" + tasks.size)) + sparkListeners.foreach(_.onStageSubmitted(SparkListenerStageSubmitted(stage, tasks.size))) logInfo("Submitting " + tasks.size + " missing tasks from " + stage + " (" + stage.rdd + ")") myPending ++= tasks logDebug("New pending tasks: " + myPending) @@ -527,7 +527,7 @@ class DAGScheduler( activeJobs -= job resultStageToJob -= stage markStageAsFinished(stage) - sparkListeners.foreach(_.onJobEnd(job, SparkListenerJobSuccess)) + sparkListeners.foreach(_.onJobEnd(SparkListenerJobSuccess(job))) } job.listener.taskSucceeded(rt.outputId, event.result) } @@ -671,8 +671,7 @@ class DAGScheduler( job.listener.jobFailed(new SparkException("Job failed: " + reason)) activeJobs -= job resultStageToJob -= resultStage - val jobFailedEvent = new SparkListenerJobFailed(failedStage) - sparkListeners.foreach(_.onJobEnd(job, jobFailedEvent)) + sparkListeners.foreach(_.onJobEnd(SparkListenerJobFailed(job, failedStage))) } if (dependentStages.isEmpty) { logInfo("Ignoring failure of " + failedStage + " because all jobs depending on it are done") diff --git a/core/src/main/scala/spark/scheduler/JobLogger.scala b/core/src/main/scala/spark/scheduler/JobLogger.scala index 46b9fa974b..002c5826cb 100644 --- a/core/src/main/scala/spark/scheduler/JobLogger.scala +++ b/core/src/main/scala/spark/scheduler/JobLogger.scala @@ -15,13 +15,6 @@ import spark.scheduler.cluster.TaskInfo // used to record runtime information for each job, including RDD graph // tasks' start/stop shuffle information and information from outside -sealed trait JobLoggerEvent -case class JobLoggerOnJobStart(job: ActiveJob, properties: Properties) extends JobLoggerEvent -case class JobLoggerOnStageSubmitted(stage: Stage, info: String) extends JobLoggerEvent -case class JobLoggerOnStageCompleted(stageCompleted: StageCompleted) extends JobLoggerEvent -case class JobLoggerOnJobEnd(job: ActiveJob, event: SparkListenerEvents) extends JobLoggerEvent -case class JobLoggerOnTaskEnd(event: CompletionEvent) extends JobLoggerEvent - class JobLogger(val logDirName: String) extends SparkListener with Logging { private val logDir = if (System.getenv("SPARK_LOG_DIR") != null) @@ -32,7 +25,7 @@ class JobLogger(val logDirName: String) extends SparkListener with Logging { private val stageIDToJobID = new HashMap[Int, Int] private val jobIDToStages = new HashMap[Int, ListBuffer[Stage]] private val DATE_FORMAT = new SimpleDateFormat("yyyy/MM/dd HH:mm:ss") - private val eventQueue = new LinkedBlockingQueue[JobLoggerEvent] + private val eventQueue = new LinkedBlockingQueue[SparkListenerEvents] createLogDir() def this() = this(String.valueOf(System.currentTimeMillis())) @@ -50,15 +43,19 @@ class JobLogger(val logDirName: String) extends SparkListener with Logging { val event = eventQueue.take logDebug("Got event of type " + event.getClass.getName) event match { - case JobLoggerOnJobStart(job, info) => - processJobStartEvent(job, info) - case JobLoggerOnStageSubmitted(stage, info) => - processStageSubmittedEvent(stage, info) - case JobLoggerOnStageCompleted(stageCompleted) => - processStageCompletedEvent(stageCompleted) - case JobLoggerOnJobEnd(job, event) => - processJobEndEvent(job, event) - case JobLoggerOnTaskEnd(event) => + case SparkListenerJobStart(job, properties) => + processJobStartEvent(job, properties) + case SparkListenerStageSubmitted(stage, taskSize) => + processStageSubmittedEvent(stage, taskSize) + case StageCompleted(stageInfo) => + processStageCompletedEvent(stageInfo) + case SparkListenerJobSuccess(job) => + processJobEndEvent(job) + case SparkListenerJobFailed(job, failedStage) => + processJobEndEvent(job, failedStage) + case SparkListenerJobCancelled(job, reason) => + processJobEndEvent(job, reason) + case SparkListenerTaskEnd(event) => processTaskEndEvent(event) case _ => } @@ -225,26 +222,26 @@ class JobLogger(val logDirName: String) extends SparkListener with Logging { stageLogInfo(stageID, status + info + executorRunTime + readMetrics + writeMetrics) } - override def onStageSubmitted(stage: Stage, info: String = "") { - eventQueue.put(JobLoggerOnStageSubmitted(stage, info)) + override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted) { + eventQueue.put(stageSubmitted) } - protected def processStageSubmittedEvent(stage: Stage, info: String) { - stageLogInfo(stage.id, "STAGE_ID=" + stage.id + " STATUS=SUBMITTED " + info) + protected def processStageSubmittedEvent(stage: Stage, taskSize: Int) { + stageLogInfo(stage.id, "STAGE_ID=" + stage.id + " STATUS=SUBMITTED" + " TASK_SIZE=" + taskSize) } override def onStageCompleted(stageCompleted: StageCompleted) { - eventQueue.put(JobLoggerOnStageCompleted(stageCompleted)) + eventQueue.put(stageCompleted) } - protected def processStageCompletedEvent(stageCompleted: StageCompleted) { - stageLogInfo(stageCompleted.stageInfo.stage.id, "STAGE_ID=" + - stageCompleted.stageInfo.stage.id + " STATUS=COMPLETED") + protected def processStageCompletedEvent(stageInfo: StageInfo) { + stageLogInfo(stageInfo.stage.id, "STAGE_ID=" + + stageInfo.stage.id + " STATUS=COMPLETED") } - override def onTaskEnd(event: CompletionEvent) { - eventQueue.put(JobLoggerOnTaskEnd(event)) + override def onTaskEnd(taskEnd: SparkListenerTaskEnd) { + eventQueue.put(taskEnd) } protected def processTaskEndEvent(event: CompletionEvent) { @@ -273,24 +270,26 @@ class JobLogger(val logDirName: String) extends SparkListener with Logging { } } - override def onJobEnd(job: ActiveJob, event: SparkListenerEvents) { - eventQueue.put(JobLoggerOnJobEnd(job, event)) + override def onJobEnd(jobEnd: SparkListenerEvents) { + eventQueue.put(jobEnd) } - protected def processJobEndEvent(job: ActiveJob, event: SparkListenerEvents) { - var info = "JOB_ID=" + job.runId + " STATUS=" - var validEvent = true - event match { - case SparkListenerJobSuccess => info += "SUCCESS" - case SparkListenerJobFailed(failedStage) => - info += "FAILED REASON=STAGE_FAILED FAILED_STAGE_ID=" + failedStage.id - case SparkListenerJobCancelled(reason) => info += "CANCELLED REASON=" + reason - case _ => validEvent = false - } - if (validEvent) { - jobLogInfo(job.runId, info) - closeLogWriter(job.runId) - } + protected def processJobEndEvent(job: ActiveJob) { + val info = "JOB_ID=" + job.runId + " STATUS=SUCCESS" + jobLogInfo(job.runId, info) + closeLogWriter(job.runId) + } + + protected def processJobEndEvent(job: ActiveJob, failedStage: Stage) { + val info = "JOB_ID=" + job.runId + " STATUS=FAILED REASON=STAGE_FAILED FAILED_STAGE_ID=" + + failedStage.id + jobLogInfo(job.runId, info) + closeLogWriter(job.runId) + } + protected def processJobEndEvent(job: ActiveJob, reason: String) { + var info = "JOB_ID=" + job.runId + " STATUS=CANCELLED REASON=" + reason + jobLogInfo(job.runId, info) + closeLogWriter(job.runId) } protected def recordJobProperties(jobID: Int, properties: Properties) { @@ -300,8 +299,8 @@ class JobLogger(val logDirName: String) extends SparkListener with Logging { } } - override def onJobStart(job: ActiveJob, properties: Properties = null) { - eventQueue.put(JobLoggerOnJobStart(job, properties)) + override def onJobStart(jobStart: SparkListenerJobStart) { + eventQueue.put(jobStart) } protected def processJobStartEvent(job: ActiveJob, properties: Properties) { diff --git a/core/src/main/scala/spark/scheduler/SparkListener.scala b/core/src/main/scala/spark/scheduler/SparkListener.scala index 9cf7f3ffc0..9265261dc1 100644 --- a/core/src/main/scala/spark/scheduler/SparkListener.scala +++ b/core/src/main/scala/spark/scheduler/SparkListener.scala @@ -6,6 +6,24 @@ import spark.util.Distribution import spark.{Utils, Logging, SparkContext, TaskEndReason} import spark.executor.TaskMetrics + +sealed trait SparkListenerEvents + +case class SparkListenerStageSubmitted(stage: Stage, taskSize: Int) extends SparkListenerEvents + +case class StageCompleted(val stageInfo: StageInfo) extends SparkListenerEvents + +case class SparkListenerTaskEnd(event: CompletionEvent) extends SparkListenerEvents + +case class SparkListenerJobStart(job: ActiveJob, properties: Properties = null) + extends SparkListenerEvents + +case class SparkListenerJobSuccess(job: ActiveJob) extends SparkListenerEvents + +case class SparkListenerJobFailed(job: ActiveJob, failedStage: Stage) extends SparkListenerEvents + +case class SparkListenerJobCancelled(job: ActiveJob, reason: String) extends SparkListenerEvents + trait SparkListener { /** * called when a stage is completed, with information on the completed stage @@ -15,35 +33,25 @@ trait SparkListener { /** * called when a stage is submitted */ - def onStageSubmitted(stage: Stage, info: String = "") { } - + def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted) { } + /** * called when a task ends */ - def onTaskEnd(event: CompletionEvent) { } + def onTaskEnd(taskEnd: SparkListenerTaskEnd) { } /** * called when a job starts */ - def onJobStart(job: ActiveJob, properties: Properties = null) { } + def onJobStart(jobStart: SparkListenerJobStart) { } /** * called when a job ends */ - def onJobEnd(job: ActiveJob, event: SparkListenerEvents) { } + def onJobEnd(jobEnd: SparkListenerEvents) { } } -sealed trait SparkListenerEvents - -case class StageCompleted(val stageInfo: StageInfo) extends SparkListenerEvents - -case object SparkListenerJobSuccess extends SparkListenerEvents - -case class SparkListenerJobFailed(failedStage: Stage) extends SparkListenerEvents - -case class SparkListenerJobCancelled(reason: String) extends SparkListenerEvents - /** * Simple SparkListener that logs a few summary statistics when each stage completes */ diff --git a/core/src/test/scala/spark/scheduler/JobLoggerSuite.scala b/core/src/test/scala/spark/scheduler/JobLoggerSuite.scala index 34fd8b995e..a654bf3ffd 100644 --- a/core/src/test/scala/spark/scheduler/JobLoggerSuite.scala +++ b/core/src/test/scala/spark/scheduler/JobLoggerSuite.scala @@ -40,7 +40,7 @@ class JobLoggerSuite extends FunSuite with LocalSparkContext with ShouldMatchers val shuffleMapStage = new Stage(1, parentRdd, Some(shuffleDep), Nil, jobID) val rootStage = new Stage(0, rootRdd, None, List(shuffleMapStage), jobID) - joblogger.onStageSubmitted(rootStage) + joblogger.onStageSubmitted(SparkListenerStageSubmitted(rootStage, 4)) joblogger.getEventQueue.size should be (1) joblogger.getRddNameTest(parentRdd) should be (parentRdd.getClass.getName) parentRdd.setName("MyRDD") @@ -86,11 +86,11 @@ class JobLoggerSuite extends FunSuite with LocalSparkContext with ShouldMatchers var onJobStartCount = 0 var onStageCompletedCount = 0 var onStageSubmittedCount = 0 - override def onTaskEnd(event: CompletionEvent) = onTaskEndCount += 1 - override def onJobEnd(job: ActiveJob, event: SparkListenerEvents) = onJobEndCount += 1 - override def onJobStart(job: ActiveJob, properties: Properties) = onJobStartCount += 1 + override def onTaskEnd(taskEnd: SparkListenerTaskEnd) = onTaskEndCount += 1 + override def onJobEnd(jobEnd: SparkListenerEvents) = onJobEndCount += 1 + override def onJobStart(jobStart: SparkListenerJobStart) = onJobStartCount += 1 override def onStageCompleted(stageCompleted: StageCompleted) = onStageCompletedCount += 1 - override def onStageSubmitted(stage: Stage, info: String = "") = onStageSubmittedCount += 1 + override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted) = onStageSubmittedCount += 1 } sc.addSparkListener(joblogger) val rdd = sc.parallelize(1 to 1e2.toInt, 4).map{ i => (i % 12, 2 * i) } -- cgit v1.2.3 From b5b12823faf62766d880e497c90b44b21f5a433a Mon Sep 17 00:00:00 2001 From: Rohit Rai Date: Thu, 13 Jun 2013 14:05:46 +0530 Subject: Fixing the style as per feedback --- .../main/scala/spark/examples/CassandraTest.scala | 72 +++++++++++----------- 1 file changed, 37 insertions(+), 35 deletions(-) diff --git a/examples/src/main/scala/spark/examples/CassandraTest.scala b/examples/src/main/scala/spark/examples/CassandraTest.scala index 2cc62b9fe9..0fe1833e83 100644 --- a/examples/src/main/scala/spark/examples/CassandraTest.scala +++ b/examples/src/main/scala/spark/examples/CassandraTest.scala @@ -1,9 +1,11 @@ package spark.examples import org.apache.hadoop.mapreduce.Job -import org.apache.cassandra.hadoop.{ColumnFamilyOutputFormat, ConfigHelper, ColumnFamilyInputFormat} +import org.apache.cassandra.hadoop.ColumnFamilyOutputFormat +import org.apache.cassandra.hadoop.ConfigHelper +import org.apache.cassandra.hadoop.ColumnFamilyInputFormat import org.apache.cassandra.thrift._ -import spark.{RDD, SparkContext} +import spark.SparkContext import spark.SparkContext._ import java.nio.ByteBuffer import java.util.SortedMap @@ -12,9 +14,9 @@ import org.apache.cassandra.utils.ByteBufferUtil import scala.collection.JavaConversions._ - /* - * This example demonstrates using Spark with Cassandra with the New Hadoop API and Cassandra support for Hadoop. + * This example demonstrates using Spark with Cassandra with the New Hadoop API and Cassandra + * support for Hadoop. * * To run this example, run this file with the following command params - * @@ -26,32 +28,31 @@ import scala.collection.JavaConversions._ * 1. You have already created a keyspace called casDemo and it has a column family named Words * 2. There are column family has a column named "para" which has test content. * - * You can create the content by running the following script at the bottom of this file with cassandra-cli. + * You can create the content by running the following script at the bottom of this file with + * cassandra-cli. * */ object CassandraTest { + def main(args: Array[String]) { - //Get a SparkContext + // Get a SparkContext val sc = new SparkContext(args(0), "casDemo") - //Build the job configuration with ConfigHelper provided by Cassandra + // Build the job configuration with ConfigHelper provided by Cassandra val job = new Job() job.setInputFormatClass(classOf[ColumnFamilyInputFormat]) - ConfigHelper.setInputInitialAddress(job.getConfiguration(), args(1)) - - ConfigHelper.setInputRpcPort(job.getConfiguration(), args(2)) - - ConfigHelper.setOutputInitialAddress(job.getConfiguration(), args(1)) - - ConfigHelper.setOutputRpcPort(job.getConfiguration(), args(2)) + val host: String = args(1) + val port: String = args(2) + ConfigHelper.setInputInitialAddress(job.getConfiguration(), host) + ConfigHelper.setInputRpcPort(job.getConfiguration(), port) + ConfigHelper.setOutputInitialAddress(job.getConfiguration(), host) + ConfigHelper.setOutputRpcPort(job.getConfiguration(), port) ConfigHelper.setInputColumnFamily(job.getConfiguration(), "casDemo", "Words") - ConfigHelper.setOutputColumnFamily(job.getConfiguration(), "casDemo", "WordCount") - val predicate = new SlicePredicate() val sliceRange = new SliceRange() sliceRange.setStart(Array.empty[Byte]) @@ -60,11 +61,11 @@ object CassandraTest { ConfigHelper.setInputSlicePredicate(job.getConfiguration(), predicate) ConfigHelper.setInputPartitioner(job.getConfiguration(), "Murmur3Partitioner") - ConfigHelper.setOutputPartitioner(job.getConfiguration(), "Murmur3Partitioner") - //Make a new Hadoop RDD - val casRdd = sc.newAPIHadoopRDD(job.getConfiguration(), + // Make a new Hadoop RDD + val casRdd = sc.newAPIHadoopRDD( + job.getConfiguration(), classOf[ColumnFamilyInputFormat], classOf[ByteBuffer], classOf[SortedMap[ByteBuffer, IColumn]]) @@ -76,7 +77,7 @@ object CassandraTest { } } - //Lets get the word count in paras + // Lets get the word count in paras val counts = paraRdd.flatMap(p => p.split(" ")).map(word => (word, 1)).reduceByKey(_ + _) counts.collect().foreach { @@ -95,20 +96,17 @@ object CassandraTest { colCount.setValue(ByteBufferUtil.bytes(count.toLong)) colCount.setTimestamp(System.currentTimeMillis) - val outputkey = ByteBufferUtil.bytes(word + "-COUNT-" + System.currentTimeMillis) val mutations: java.util.List[Mutation] = new Mutation() :: new Mutation() :: Nil mutations.get(0).setColumn_or_supercolumn(new ColumnOrSuperColumn()) mutations.get(0).column_or_supercolumn.setColumn(colWord) - mutations.get(1).setColumn_or_supercolumn(new ColumnOrSuperColumn()) mutations.get(1).column_or_supercolumn.setColumn(colCount) (outputkey, mutations) } }.saveAsNewAPIHadoopFile("casDemo", classOf[ByteBuffer], classOf[List[Mutation]], classOf[ColumnFamilyOutputFormat], job.getConfiguration) - } } @@ -117,16 +115,20 @@ create keyspace casDemo; use casDemo; create column family WordCount with comparator = UTF8Type; -update column family WordCount with column_metadata = [{column_name: word, validation_class: UTF8Type}, {column_name: wcount, validation_class: LongType}]; +update column family WordCount with column_metadata = + [{column_name: word, validation_class: UTF8Type}, + {column_name: wcount, validation_class: LongType}]; create column family Words with comparator = UTF8Type; -update column family Words with column_metadata = [{column_name: book, validation_class: UTF8Type}, {column_name: para, validation_class: UTF8Type}]; +update column family Words with column_metadata = + [{column_name: book, validation_class: UTF8Type}, + {column_name: para, validation_class: UTF8Type}]; assume Words keys as utf8; set Words['3musk001']['book'] = 'The Three Musketeers'; -set Words['3musk001']['para'] = 'On the first Monday of the month of April, 1625, the market town of - Meung, in which the author of ROMANCE OF THE ROSE was born, appeared to +set Words['3musk001']['para'] = 'On the first Monday of the month of April, 1625, the market + town of Meung, in which the author of ROMANCE OF THE ROSE was born, appeared to be in as perfect a state of revolution as if the Huguenots had just made a second La Rochelle of it. Many citizens, seeing the women flying toward the High Street, leaving their children crying at the open doors, @@ -136,8 +138,8 @@ set Words['3musk001']['para'] = 'On the first Monday of the month of April, 1625 every minute, a compact group, vociferous and full of curiosity.'; set Words['3musk002']['book'] = 'The Three Musketeers'; -set Words['3musk002']['para'] = 'In those times panics were common, and few days passed without some city - or other registering in its archives an event of this kind. There were +set Words['3musk002']['para'] = 'In those times panics were common, and few days passed without + some city or other registering in its archives an event of this kind. There were nobles, who made war against each other; there was the king, who made war against the cardinal; there was Spain, which made war against the king. Then, in addition to these concealed or public, secret or open @@ -152,8 +154,8 @@ set Words['3musk002']['para'] = 'In those times panics were common, and few days cause of the hubbub was apparent to all'; set Words['3musk003']['book'] = 'The Three Musketeers'; -set Words['3musk003']['para'] = 'You ought, I say, then, to husband the means you have, however large - the sum may be; but you ought also to endeavor to perfect yourself in +set Words['3musk003']['para'] = 'You ought, I say, then, to husband the means you have, however + large the sum may be; but you ought also to endeavor to perfect yourself in the exercises becoming a gentleman. I will write a letter today to the Director of the Royal Academy, and tomorrow he will admit you without any expense to yourself. Do not refuse this little service. Our @@ -165,8 +167,8 @@ set Words['3musk003']['para'] = 'You ought, I say, then, to husband the means yo set Words['thelostworld001']['book'] = 'The Lost World'; -set Words['thelostworld001']['para'] = 'She sat with that proud, delicate profile of hers outlined against the - red curtain. How beautiful she was! And yet how aloof! We had been +set Words['thelostworld001']['para'] = 'She sat with that proud, delicate profile of hers outlined + against the red curtain. How beautiful she was! And yet how aloof! We had been friends, quite good friends; but never could I get beyond the same comradeship which I might have established with one of my fellow-reporters upon the Gazette,--perfectly frank, perfectly kindly, @@ -180,8 +182,8 @@ set Words['thelostworld001']['para'] = 'She sat with that proud, delicate profil as that--or had inherited it in that race memory which we call instinct.'; set Words['thelostworld002']['book'] = 'The Lost World'; -set Words['thelostworld002']['para'] = 'I always liked McArdle, the crabbed, old, round-backed, red-headed news - editor, and I rather hoped that he liked me. Of course, Beaumont was +set Words['thelostworld002']['para'] = 'I always liked McArdle, the crabbed, old, round-backed, + red-headed news editor, and I rather hoped that he liked me. Of course, Beaumont was the real boss; but he lived in the rarefied atmosphere of some Olympian height from which he could distinguish nothing smaller than an international crisis or a split in the Cabinet. Sometimes we saw him -- cgit v1.2.3 From 1d9f0df0652f455145d2dfed43a9407df6de6c43 Mon Sep 17 00:00:00 2001 From: Shivaram Venkataraman Date: Thu, 13 Jun 2013 14:46:25 -0700 Subject: Fix some comments and style --- core/src/main/java/spark/network/netty/FileClient.java | 2 +- core/src/main/scala/spark/network/netty/ShuffleCopier.scala | 8 ++++---- core/src/main/scala/spark/storage/BlockFetcherIterator.scala | 6 +----- core/src/main/scala/spark/storage/DiskStore.scala | 3 +-- core/src/test/scala/spark/ShuffleSuite.scala | 3 +-- 5 files changed, 8 insertions(+), 14 deletions(-) diff --git a/core/src/main/java/spark/network/netty/FileClient.java b/core/src/main/java/spark/network/netty/FileClient.java index 517772202f..a4bb4bc701 100644 --- a/core/src/main/java/spark/network/netty/FileClient.java +++ b/core/src/main/java/spark/network/netty/FileClient.java @@ -30,7 +30,7 @@ class FileClient { .channel(OioSocketChannel.class) .option(ChannelOption.SO_KEEPALIVE, true) .option(ChannelOption.TCP_NODELAY, true) - .option(ChannelOption.CONNECT_TIMEOUT_MILLIS, connectTimeout) // Disable connect timeout + .option(ChannelOption.CONNECT_TIMEOUT_MILLIS, connectTimeout) .handler(new FileClientChannelInitializer(handler)); } diff --git a/core/src/main/scala/spark/network/netty/ShuffleCopier.scala b/core/src/main/scala/spark/network/netty/ShuffleCopier.scala index afb2cdbb3a..8d5194a737 100644 --- a/core/src/main/scala/spark/network/netty/ShuffleCopier.scala +++ b/core/src/main/scala/spark/network/netty/ShuffleCopier.scala @@ -18,8 +18,9 @@ private[spark] class ShuffleCopier extends Logging { resultCollectCallback: (String, Long, ByteBuf) => Unit) { val handler = new ShuffleCopier.ShuffleClientHandler(resultCollectCallback) - val fc = new FileClient(handler, - System.getProperty("spark.shuffle.netty.connect.timeout", "60000").toInt) + val connectTimeout = System.getProperty("spark.shuffle.netty.connect.timeout", "60000").toInt + val fc = new FileClient(handler, connectTimeout) + try { fc.init() fc.connect(host, port) @@ -29,8 +30,7 @@ private[spark] class ShuffleCopier extends Logging { } catch { // Handle any socket-related exceptions in FileClient case e: Exception => { - logError("Shuffle copy of block " + blockId + " from " + host + ":" + port + - " failed", e) + logError("Shuffle copy of block " + blockId + " from " + host + ":" + port + " failed", e) handler.handleError(blockId) } } diff --git a/core/src/main/scala/spark/storage/BlockFetcherIterator.scala b/core/src/main/scala/spark/storage/BlockFetcherIterator.scala index bb78207c9f..bec876213e 100644 --- a/core/src/main/scala/spark/storage/BlockFetcherIterator.scala +++ b/core/src/main/scala/spark/storage/BlockFetcherIterator.scala @@ -322,11 +322,7 @@ object BlockFetcherIterator { override def next(): (String, Option[Iterator[Any]]) = { resultsGotten += 1 val result = results.take() - // if all the results has been retrieved, shutdown the copiers - // NO need to stop the copiers if we got all the blocks ? - // if (resultsGotten == _numBlocksToFetch && copiers != null) { - // stopCopiers() - // } + // If all the results has been retrieved, copiers will exit automatically (result.blockId, if (result.failed) None else Some(result.deserialize())) } } diff --git a/core/src/main/scala/spark/storage/DiskStore.scala b/core/src/main/scala/spark/storage/DiskStore.scala index 0af6e4a359..15ab840155 100644 --- a/core/src/main/scala/spark/storage/DiskStore.scala +++ b/core/src/main/scala/spark/storage/DiskStore.scala @@ -212,10 +212,9 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String) val file = getFile(blockId) if (!allowAppendExisting && file.exists()) { // NOTE(shivaram): Delete the file if it exists. This might happen if a ShuffleMap task - // was rescheduled on the same machine as the old task ? + // was rescheduled on the same machine as the old task. logWarning("File for block " + blockId + " already exists on disk: " + file + ". Deleting") file.delete() - // throw new Exception("File for block " + blockId + " already exists on disk: " + file) } file } diff --git a/core/src/test/scala/spark/ShuffleSuite.scala b/core/src/test/scala/spark/ShuffleSuite.scala index 33b02fff80..1916885a73 100644 --- a/core/src/test/scala/spark/ShuffleSuite.scala +++ b/core/src/test/scala/spark/ShuffleSuite.scala @@ -376,8 +376,7 @@ class ShuffleSuite extends FunSuite with ShouldMatchers with LocalSparkContext { val a = sc.parallelize(1 to 4, NUM_BLOCKS) val b = a.map(x => (x, x*2)) - // NOTE: The default Java serializer doesn't create zero-sized blocks. - // So, use Kryo + // NOTE: The default Java serializer should create zero-sized blocks val c = new ShuffledRDD(b, new HashPartitioner(10)) val shuffleId = c.dependencies.head.asInstanceOf[ShuffleDependency[Int, Int]].shuffleId -- cgit v1.2.3 From 44b8dbaedeb88f12ea911968c524883805f7ad95 Mon Sep 17 00:00:00 2001 From: ryanlecompte Date: Thu, 13 Jun 2013 16:23:15 -0700 Subject: use Iterator.single(elem) instead of Iterator(elem) for improved performance based on scaladocs --- core/src/main/scala/spark/RDD.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala index ec5e5e2433..bc9c17d507 100644 --- a/core/src/main/scala/spark/RDD.scala +++ b/core/src/main/scala/spark/RDD.scala @@ -734,7 +734,7 @@ abstract class RDD[T: ClassManifest]( val topK = mapPartitions { items => val queue = new BoundedPriorityQueue[T](num) queue ++= items - Iterator(queue) + Iterator.single(queue) }.reduce { (queue1, queue2) => queue1 ++= queue2 queue1 -- cgit v1.2.3 From 93b3f5e535c509a017a433b72249fc49c79d4a0f Mon Sep 17 00:00:00 2001 From: ryanlecompte Date: Thu, 13 Jun 2013 16:26:35 -0700 Subject: drop unneeded ClassManifest implicit --- core/src/main/scala/spark/util/BoundedPriorityQueue.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/spark/util/BoundedPriorityQueue.scala b/core/src/main/scala/spark/util/BoundedPriorityQueue.scala index 53ee95a02e..ef01beaea5 100644 --- a/core/src/main/scala/spark/util/BoundedPriorityQueue.scala +++ b/core/src/main/scala/spark/util/BoundedPriorityQueue.scala @@ -8,7 +8,7 @@ import scala.collection.generic.Growable * add/offer methods such that only the top K elements are retained. The top * K elements are defined by an implicit Ordering[A]. */ -class BoundedPriorityQueue[A](maxSize: Int)(implicit ord: Ordering[A], mf: ClassManifest[A]) +class BoundedPriorityQueue[A](maxSize: Int)(implicit ord: Ordering[A]) extends JPriorityQueue[A](maxSize, ord) with Growable[A] { override def offer(a: A): Boolean = { -- cgit v1.2.3 From 6738178d0daf1bbe7441db7c0c773a29bb2ec388 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 13 Jun 2013 23:59:42 -0700 Subject: SPARK-772: groupByKey should disable map side combine. --- core/src/main/scala/spark/PairRDDFunctions.scala | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/spark/PairRDDFunctions.scala b/core/src/main/scala/spark/PairRDDFunctions.scala index 15593db0d9..fa4bbfc76f 100644 --- a/core/src/main/scala/spark/PairRDDFunctions.scala +++ b/core/src/main/scala/spark/PairRDDFunctions.scala @@ -19,7 +19,7 @@ import org.apache.hadoop.mapred.JobConf import org.apache.hadoop.mapred.OutputFormat import org.apache.hadoop.mapreduce.lib.output.{FileOutputFormat => NewFileOutputFormat} -import org.apache.hadoop.mapreduce.{OutputFormat => NewOutputFormat, RecordWriter => NewRecordWriter, Job => NewAPIHadoopJob, HadoopMapReduceUtil, TaskAttemptID, TaskAttemptContext} +import org.apache.hadoop.mapreduce.{OutputFormat => NewOutputFormat, RecordWriter => NewRecordWriter, Job => NewAPIHadoopJob, HadoopMapReduceUtil} import spark.partial.BoundedDouble import spark.partial.PartialResult @@ -187,11 +187,13 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest]( * partitioning of the resulting key-value pair RDD by passing a Partitioner. */ def groupByKey(partitioner: Partitioner): RDD[(K, Seq[V])] = { + // groupByKey shouldn't use map side combine because map side combine does not + // reduce the amount of data shuffled and requires all map side data be inserted + // into a hash table, leading to more objects in the old gen. def createCombiner(v: V) = ArrayBuffer(v) def mergeValue(buf: ArrayBuffer[V], v: V) = buf += v - def mergeCombiners(b1: ArrayBuffer[V], b2: ArrayBuffer[V]) = b1 ++= b2 val bufs = combineByKey[ArrayBuffer[V]]( - createCombiner _, mergeValue _, mergeCombiners _, partitioner) + createCombiner _, mergeValue _, null, partitioner, mapSideCombine=false) bufs.asInstanceOf[RDD[(K, Seq[V])]] } -- cgit v1.2.3 From 2cc188fd546fa061812f9fd4f72cf936bd01a0e6 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Fri, 14 Jun 2013 00:10:54 -0700 Subject: SPARK-774: cogroup should also disable map side combine by default --- core/src/main/scala/spark/rdd/CoGroupedRDD.scala | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/spark/rdd/CoGroupedRDD.scala index 7599ba1a02..8966f9f86e 100644 --- a/core/src/main/scala/spark/rdd/CoGroupedRDD.scala +++ b/core/src/main/scala/spark/rdd/CoGroupedRDD.scala @@ -6,7 +6,7 @@ import java.util.{HashMap => JHashMap} import scala.collection.JavaConversions import scala.collection.mutable.ArrayBuffer -import spark.{Aggregator, Logging, Partition, Partitioner, RDD, SparkEnv, TaskContext} +import spark.{Aggregator, Partition, Partitioner, RDD, SparkEnv, TaskContext} import spark.{Dependency, OneToOneDependency, ShuffleDependency} @@ -49,12 +49,16 @@ private[spark] class CoGroupAggregator * * @param rdds parent RDDs. * @param part partitioner used to partition the shuffle output. - * @param mapSideCombine flag indicating whether to merge values before shuffle step. + * @param mapSideCombine flag indicating whether to merge values before shuffle step. If the flag + * is on, Spark does an extra pass over the data on the map side to merge + * all values belonging to the same key together. This can reduce the amount + * of data shuffled if and only if the number of distinct keys is very small, + * and the ratio of key size to value size is also very small. */ class CoGroupedRDD[K]( @transient var rdds: Seq[RDD[(K, _)]], part: Partitioner, - val mapSideCombine: Boolean = true, + val mapSideCombine: Boolean = false, val serializerClass: String = null) extends RDD[(K, Seq[Seq[_]])](rdds.head.context, Nil) { -- cgit v1.2.3 From 53add598f2fe09759a0df1e08f87f70503f808c5 Mon Sep 17 00:00:00 2001 From: Andrew xia Date: Sat, 15 Jun 2013 01:34:17 +0800 Subject: Update LocalSchedulerSuite to avoid using sleep for task launch --- .../spark/scheduler/LocalSchedulerSuite.scala | 83 +++++++++++++++------- 1 file changed, 59 insertions(+), 24 deletions(-) diff --git a/core/src/test/scala/spark/scheduler/LocalSchedulerSuite.scala b/core/src/test/scala/spark/scheduler/LocalSchedulerSuite.scala index 37d14ed113..8bd813fd14 100644 --- a/core/src/test/scala/spark/scheduler/LocalSchedulerSuite.scala +++ b/core/src/test/scala/spark/scheduler/LocalSchedulerSuite.scala @@ -9,9 +9,7 @@ import spark.scheduler.cluster._ import scala.collection.mutable.ArrayBuffer import scala.collection.mutable.{ConcurrentMap, HashMap} import java.util.concurrent.Semaphore -import java.util.concurrent.atomic.AtomicBoolean -import java.util.concurrent.atomic.AtomicInteger - +import java.util.concurrent.CountDownLatch import java.util.Properties class Lock() { @@ -35,9 +33,19 @@ class Lock() { object TaskThreadInfo { val threadToLock = HashMap[Int, Lock]() val threadToRunning = HashMap[Int, Boolean]() + val threadToStarted = HashMap[Int, CountDownLatch]() } - +/* + * 1. each thread contains one job. + * 2. each job contains one stage. + * 3. each stage only contains one task. + * 4. each task(launched) must be lanched orderly(using threadToStarted) to make sure + * it will get cpu core resource, and will wait to finished after user manually + * release "Lock" and then cluster will contain another free cpu cores. + * 5. each task(pending) must use "sleep" to make sure it has been added to taskSetManager queue, + * thus it will be scheduled later when cluster has free cpu cores. + */ class LocalSchedulerSuite extends FunSuite with LocalSparkContext { def createThread(threadIndex: Int, poolName: String, sc: SparkContext, sem: Semaphore) { @@ -45,22 +53,23 @@ class LocalSchedulerSuite extends FunSuite with LocalSparkContext { TaskThreadInfo.threadToRunning(threadIndex) = false val nums = sc.parallelize(threadIndex to threadIndex, 1) TaskThreadInfo.threadToLock(threadIndex) = new Lock() + TaskThreadInfo.threadToStarted(threadIndex) = new CountDownLatch(1) new Thread { - if (poolName != null) { - sc.addLocalProperties("spark.scheduler.cluster.fair.pool",poolName) - } - override def run() { - val ans = nums.map(number => { - TaskThreadInfo.threadToRunning(number) = true - TaskThreadInfo.threadToLock(number).jobWait() - number - }).collect() - assert(ans.toList === List(threadIndex)) - sem.release() - TaskThreadInfo.threadToRunning(threadIndex) = false - } + if (poolName != null) { + sc.addLocalProperties("spark.scheduler.cluster.fair.pool",poolName) + } + override def run() { + val ans = nums.map(number => { + TaskThreadInfo.threadToRunning(number) = true + TaskThreadInfo.threadToStarted(number).countDown() + TaskThreadInfo.threadToLock(number).jobWait() + TaskThreadInfo.threadToRunning(number) = false + number + }).collect() + assert(ans.toList === List(threadIndex)) + sem.release() + } }.start() - Thread.sleep(2000) } test("Local FIFO scheduler end-to-end test") { @@ -69,11 +78,24 @@ class LocalSchedulerSuite extends FunSuite with LocalSparkContext { val sem = new Semaphore(0) createThread(1,null,sc,sem) + TaskThreadInfo.threadToStarted(1).await() createThread(2,null,sc,sem) + TaskThreadInfo.threadToStarted(2).await() createThread(3,null,sc,sem) + TaskThreadInfo.threadToStarted(3).await() createThread(4,null,sc,sem) + TaskThreadInfo.threadToStarted(4).await() + // thread 5 and 6 (stage pending)must meet following two points + // 1. stages (taskSetManager) of jobs in thread 5 and 6 should be add to taskSetManager + // queue before executing TaskThreadInfo.threadToLock(1).jobFinished() + // 2. priority of stage in thread 5 should be prior to priority of stage in thread 6 + // So I just use "sleep" 1s here for each thread. + // TODO: any better solution? createThread(5,null,sc,sem) + Thread.sleep(1000) createThread(6,null,sc,sem) + Thread.sleep(1000) + assert(TaskThreadInfo.threadToRunning(1) === true) assert(TaskThreadInfo.threadToRunning(2) === true) assert(TaskThreadInfo.threadToRunning(3) === true) @@ -82,8 +104,8 @@ class LocalSchedulerSuite extends FunSuite with LocalSparkContext { assert(TaskThreadInfo.threadToRunning(6) === false) TaskThreadInfo.threadToLock(1).jobFinished() - Thread.sleep(1000) - + TaskThreadInfo.threadToStarted(5).await() + assert(TaskThreadInfo.threadToRunning(1) === false) assert(TaskThreadInfo.threadToRunning(2) === true) assert(TaskThreadInfo.threadToRunning(3) === true) @@ -92,7 +114,7 @@ class LocalSchedulerSuite extends FunSuite with LocalSparkContext { assert(TaskThreadInfo.threadToRunning(6) === false) TaskThreadInfo.threadToLock(3).jobFinished() - Thread.sleep(1000) + TaskThreadInfo.threadToStarted(6).await() assert(TaskThreadInfo.threadToRunning(1) === false) assert(TaskThreadInfo.threadToRunning(2) === true) @@ -116,23 +138,31 @@ class LocalSchedulerSuite extends FunSuite with LocalSparkContext { System.setProperty("spark.fairscheduler.allocation.file", xmlPath) createThread(10,"1",sc,sem) + TaskThreadInfo.threadToStarted(10).await() createThread(20,"2",sc,sem) + TaskThreadInfo.threadToStarted(20).await() createThread(30,"3",sc,sem) + TaskThreadInfo.threadToStarted(30).await() assert(TaskThreadInfo.threadToRunning(10) === true) assert(TaskThreadInfo.threadToRunning(20) === true) assert(TaskThreadInfo.threadToRunning(30) === true) createThread(11,"1",sc,sem) + TaskThreadInfo.threadToStarted(11).await() createThread(21,"2",sc,sem) + TaskThreadInfo.threadToStarted(21).await() createThread(31,"3",sc,sem) + TaskThreadInfo.threadToStarted(31).await() assert(TaskThreadInfo.threadToRunning(11) === true) assert(TaskThreadInfo.threadToRunning(21) === true) assert(TaskThreadInfo.threadToRunning(31) === true) createThread(12,"1",sc,sem) + TaskThreadInfo.threadToStarted(12).await() createThread(22,"2",sc,sem) + TaskThreadInfo.threadToStarted(22).await() createThread(32,"3",sc,sem) assert(TaskThreadInfo.threadToRunning(12) === true) @@ -140,20 +170,25 @@ class LocalSchedulerSuite extends FunSuite with LocalSparkContext { assert(TaskThreadInfo.threadToRunning(32) === false) TaskThreadInfo.threadToLock(10).jobFinished() - Thread.sleep(1000) + TaskThreadInfo.threadToStarted(32).await() + assert(TaskThreadInfo.threadToRunning(32) === true) + //1. Similar with above scenario, sleep 1s for stage of 23 and 33 to be added to taskSetManager + // queue so that cluster will assign free cpu core to stage 23 after stage 11 finished. + //2. priority of 23 and 33 will be meaningless as using fair scheduler here. createThread(23,"2",sc,sem) createThread(33,"3",sc,sem) + Thread.sleep(1000) TaskThreadInfo.threadToLock(11).jobFinished() - Thread.sleep(1000) + TaskThreadInfo.threadToStarted(23).await() assert(TaskThreadInfo.threadToRunning(23) === true) assert(TaskThreadInfo.threadToRunning(33) === false) TaskThreadInfo.threadToLock(12).jobFinished() - Thread.sleep(1000) + TaskThreadInfo.threadToStarted(33).await() assert(TaskThreadInfo.threadToRunning(33) === true) -- cgit v1.2.3 From e8801d44900153eae6412963d2f3e2f19bfdc4e9 Mon Sep 17 00:00:00 2001 From: ryanlecompte Date: Fri, 14 Jun 2013 23:39:05 -0700 Subject: use delegation for BoundedPriorityQueue, add Java API --- core/src/main/scala/spark/RDD.scala | 9 ++---- core/src/main/scala/spark/api/java/JavaRDD.scala | 1 - .../main/scala/spark/api/java/JavaRDDLike.scala | 27 ++++++++++++++++- .../scala/spark/util/BoundedPriorityQueue.scala | 35 ++++++++++------------ 4 files changed, 44 insertions(+), 28 deletions(-) diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala index bc9c17d507..4a4616c843 100644 --- a/core/src/main/scala/spark/RDD.scala +++ b/core/src/main/scala/spark/RDD.scala @@ -731,19 +731,14 @@ abstract class RDD[T: ClassManifest]( * @return an array of top elements */ def top(num: Int)(implicit ord: Ordering[T]): Array[T] = { - val topK = mapPartitions { items => + mapPartitions { items => val queue = new BoundedPriorityQueue[T](num) queue ++= items Iterator.single(queue) }.reduce { (queue1, queue2) => queue1 ++= queue2 queue1 - } - - val builder = Array.newBuilder[T] - builder.sizeHint(topK.size) - builder ++= topK - builder.result() + }.toArray } /** diff --git a/core/src/main/scala/spark/api/java/JavaRDD.scala b/core/src/main/scala/spark/api/java/JavaRDD.scala index eb81ed64cd..626b499454 100644 --- a/core/src/main/scala/spark/api/java/JavaRDD.scala +++ b/core/src/main/scala/spark/api/java/JavaRDD.scala @@ -86,7 +86,6 @@ JavaRDDLike[T, JavaRDD[T]] { */ def subtract(other: JavaRDD[T], p: Partitioner): JavaRDD[T] = wrapRDD(rdd.subtract(other, p)) - } object JavaRDD { diff --git a/core/src/main/scala/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/spark/api/java/JavaRDDLike.scala index 9b74d1226f..3e9c779d7b 100644 --- a/core/src/main/scala/spark/api/java/JavaRDDLike.scala +++ b/core/src/main/scala/spark/api/java/JavaRDDLike.scala @@ -1,6 +1,6 @@ package spark.api.java -import java.util.{List => JList} +import java.util.{List => JList, Comparator} import scala.Tuple2 import scala.collection.JavaConversions._ @@ -351,4 +351,29 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { def toDebugString(): String = { rdd.toDebugString } + + /** + * Returns the top K elements from this RDD as defined by + * the specified Comparator[T]. + * @param num the number of top elements to return + * @param comp the comparator that defines the order + * @return an array of top elements + */ + def top(num: Int, comp: Comparator[T]): JList[T] = { + import scala.collection.JavaConversions._ + val topElems = rdd.top(num)(Ordering.comparatorToOrdering(comp)) + val arr: java.util.Collection[T] = topElems.toSeq + new java.util.ArrayList(arr) + } + + /** + * Returns the top K elements from this RDD using the + * natural ordering for T. + * @param num the number of top elements to return + * @return an array of top elements + */ + def top(num: Int): JList[T] = { + val comp = com.google.common.collect.Ordering.natural().asInstanceOf[Comparator[T]] + top(num, comp) + } } diff --git a/core/src/main/scala/spark/util/BoundedPriorityQueue.scala b/core/src/main/scala/spark/util/BoundedPriorityQueue.scala index ef01beaea5..4bc5db8bb7 100644 --- a/core/src/main/scala/spark/util/BoundedPriorityQueue.scala +++ b/core/src/main/scala/spark/util/BoundedPriorityQueue.scala @@ -1,30 +1,30 @@ package spark.util +import java.io.Serializable import java.util.{PriorityQueue => JPriorityQueue} import scala.collection.generic.Growable +import scala.collection.JavaConverters._ /** - * Bounded priority queue. This class modifies the original PriorityQueue's - * add/offer methods such that only the top K elements are retained. The top - * K elements are defined by an implicit Ordering[A]. + * Bounded priority queue. This class wraps the original PriorityQueue + * class and modifies it such that only the top K elements are retained. + * The top K elements are defined by an implicit Ordering[A]. */ class BoundedPriorityQueue[A](maxSize: Int)(implicit ord: Ordering[A]) - extends JPriorityQueue[A](maxSize, ord) with Growable[A] { + extends Iterable[A] with Growable[A] with Serializable { - override def offer(a: A): Boolean = { - if (size < maxSize) super.offer(a) - else maybeReplaceLowest(a) - } + private val underlying = new JPriorityQueue[A](maxSize, ord) - override def add(a: A): Boolean = offer(a) + override def iterator: Iterator[A] = underlying.iterator.asScala override def ++=(xs: TraversableOnce[A]): this.type = { - xs.foreach(add) + xs.foreach { this += _ } this } override def +=(elem: A): this.type = { - add(elem) + if (size < maxSize) underlying.offer(elem) + else maybeReplaceLowest(elem) this } @@ -32,17 +32,14 @@ class BoundedPriorityQueue[A](maxSize: Int)(implicit ord: Ordering[A]) this += elem1 += elem2 ++= elems } + override def clear() { underlying.clear() } + private def maybeReplaceLowest(a: A): Boolean = { - val head = peek() + val head = underlying.peek() if (head != null && ord.gt(a, head)) { - poll() - super.offer(a) + underlying.poll() + underlying.offer(a) } else false } } -object BoundedPriorityQueue { - import scala.collection.JavaConverters._ - implicit def asIterable[A](queue: BoundedPriorityQueue[A]): Iterable[A] = queue.asScala -} - -- cgit v1.2.3 From 479442a9b913b08a64da4bd5848111d950105336 Mon Sep 17 00:00:00 2001 From: Christopher Nguyen Date: Sat, 15 Jun 2013 17:35:55 -0700 Subject: Add zeroLengthPartitions() test to make sure, e.g., StatCounter.scala can handle empty partitions without incorrectly returning NaN --- core/src/test/scala/spark/JavaAPISuite.java | 22 ++++++++++++++++++++++ project/plugins.sbt | 2 ++ 2 files changed, 24 insertions(+) diff --git a/core/src/test/scala/spark/JavaAPISuite.java b/core/src/test/scala/spark/JavaAPISuite.java index 93bb69b41c..3190a43e73 100644 --- a/core/src/test/scala/spark/JavaAPISuite.java +++ b/core/src/test/scala/spark/JavaAPISuite.java @@ -314,6 +314,28 @@ public class JavaAPISuite implements Serializable { List take = rdd.take(5); } + @Test + public void zeroLengthPartitions() { + // Create RDD with some consecutive empty partitions (including the "first" one) + JavaDoubleRDD rdd = sc + .parallelizeDoubles(Arrays.asList(-1.0, -1.0, -1.0, -1.0, 2.0, 4.0, -1.0, -1.0), 8) + .filter(new Function() { + @Override + public Boolean call(Double x) { + return x > 0.0; + } + }); + + // Run the partitions, including the consecutive empty ones, through StatCounter + StatCounter stats = rdd.stats(); + Assert.assertEquals(6.0, stats.sum(), 0.01); + Assert.assertEquals(6.0/2, rdd.mean(), 0.01); + Assert.assertEquals(1.0, rdd.variance(), 0.01); + Assert.assertEquals(1.0, rdd.stdev(), 0.01); + + // Add other tests here for classes that should be able to handle empty partitions correctly + } + @Test public void map() { JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5)); diff --git a/project/plugins.sbt b/project/plugins.sbt index d4f2442872..25b812a28d 100644 --- a/project/plugins.sbt +++ b/project/plugins.sbt @@ -16,3 +16,5 @@ addSbtPlugin("io.spray" %% "sbt-twirl" % "0.6.1") //resolvers += Resolver.url("sbt-plugin-releases", new URL("http://scalasbt.artifactoryonline.com/scalasbt/sbt-plugin-releases/"))(Resolver.ivyStylePatterns) //addSbtPlugin("com.jsuereth" % "xsbt-gpg-plugin" % "0.6") + +libraryDependencies += "com.novocode" % "junit-interface" % "0.10-M4" % "test" -- cgit v1.2.3 From 5c886194e458c64fcf24066af351bde47dd8bf12 Mon Sep 17 00:00:00 2001 From: Christopher Nguyen Date: Sun, 16 Jun 2013 01:23:48 -0700 Subject: Move zero-length partition testing from JavaAPISuite.java to PartitioningSuite.scala --- core/src/test/scala/spark/JavaAPISuite.java | 22 ---------------------- core/src/test/scala/spark/PartitioningSuite.scala | 21 +++++++++++++++++++-- 2 files changed, 19 insertions(+), 24 deletions(-) diff --git a/core/src/test/scala/spark/JavaAPISuite.java b/core/src/test/scala/spark/JavaAPISuite.java index 3190a43e73..93bb69b41c 100644 --- a/core/src/test/scala/spark/JavaAPISuite.java +++ b/core/src/test/scala/spark/JavaAPISuite.java @@ -314,28 +314,6 @@ public class JavaAPISuite implements Serializable { List take = rdd.take(5); } - @Test - public void zeroLengthPartitions() { - // Create RDD with some consecutive empty partitions (including the "first" one) - JavaDoubleRDD rdd = sc - .parallelizeDoubles(Arrays.asList(-1.0, -1.0, -1.0, -1.0, 2.0, 4.0, -1.0, -1.0), 8) - .filter(new Function() { - @Override - public Boolean call(Double x) { - return x > 0.0; - } - }); - - // Run the partitions, including the consecutive empty ones, through StatCounter - StatCounter stats = rdd.stats(); - Assert.assertEquals(6.0, stats.sum(), 0.01); - Assert.assertEquals(6.0/2, rdd.mean(), 0.01); - Assert.assertEquals(1.0, rdd.variance(), 0.01); - Assert.assertEquals(1.0, rdd.stdev(), 0.01); - - // Add other tests here for classes that should be able to handle empty partitions correctly - } - @Test public void map() { JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5)); diff --git a/core/src/test/scala/spark/PartitioningSuite.scala b/core/src/test/scala/spark/PartitioningSuite.scala index 60db759c25..e5745c81b3 100644 --- a/core/src/test/scala/spark/PartitioningSuite.scala +++ b/core/src/test/scala/spark/PartitioningSuite.scala @@ -1,10 +1,10 @@ package spark import org.scalatest.FunSuite - import scala.collection.mutable.ArrayBuffer - import SparkContext._ +import spark.util.StatCounter +import scala.math._ class PartitioningSuite extends FunSuite with LocalSparkContext { @@ -120,4 +120,21 @@ class PartitioningSuite extends FunSuite with LocalSparkContext { assert(intercept[SparkException]{ arrPairs.reduceByKeyLocally(_ + _) }.getMessage.contains("array")) assert(intercept[SparkException]{ arrPairs.reduceByKey(_ + _) }.getMessage.contains("array")) } + + test("Zero-length partitions should be correctly handled") { + // Create RDD with some consecutive empty partitions (including the "first" one) + sc = new SparkContext("local", "test") + val rdd: RDD[Double] = sc + .parallelize(Array(-1.0, -1.0, -1.0, -1.0, 2.0, 4.0, -1.0, -1.0), 8) + .filter(_ >= 0.0) + + // Run the partitions, including the consecutive empty ones, through StatCounter + val stats: StatCounter = rdd.stats(); + assert(abs(6.0 - stats.sum) < 0.01); + assert(abs(6.0/2 - rdd.mean) < 0.01); + assert(abs(1.0 - rdd.variance) < 0.01); + assert(abs(1.0 - rdd.stdev) < 0.01); + + // Add other tests here for classes that should be able to handle empty partitions correctly + } } -- cgit v1.2.3 From f91195cc150a3ead122046d14bd35b4fcf28c9cb Mon Sep 17 00:00:00 2001 From: Christopher Nguyen Date: Sun, 16 Jun 2013 01:29:53 -0700 Subject: Import just scala.math.abs rather than scala.math._ --- core/src/test/scala/spark/PartitioningSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/test/scala/spark/PartitioningSuite.scala b/core/src/test/scala/spark/PartitioningSuite.scala index e5745c81b3..16f93e71a3 100644 --- a/core/src/test/scala/spark/PartitioningSuite.scala +++ b/core/src/test/scala/spark/PartitioningSuite.scala @@ -4,7 +4,7 @@ import org.scalatest.FunSuite import scala.collection.mutable.ArrayBuffer import SparkContext._ import spark.util.StatCounter -import scala.math._ +import scala.math.abs class PartitioningSuite extends FunSuite with LocalSparkContext { -- cgit v1.2.3 From fb6d733fa88aa124deecf155af40cc095ecca5b3 Mon Sep 17 00:00:00 2001 From: Gavin Li Date: Sun, 16 Jun 2013 22:32:55 +0000 Subject: update according to comments --- core/src/main/scala/spark/RDD.scala | 71 ++------------------------- core/src/main/scala/spark/rdd/PipedRDD.scala | 29 +++++------ core/src/test/scala/spark/PipedRDDSuite.scala | 13 +++-- 3 files changed, 24 insertions(+), 89 deletions(-) diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala index a1c9604324..152f7be9bb 100644 --- a/core/src/main/scala/spark/RDD.scala +++ b/core/src/main/scala/spark/RDD.scala @@ -355,68 +355,6 @@ abstract class RDD[T: ClassManifest]( def pipe(command: String, env: Map[String, String]): RDD[String] = new PipedRDD(this, command, env) - /** - * Return an RDD created by piping elements to a forked external process. - * How each record in RDD is outputed to the process can be controled by providing a - * function trasnform(T, outputFunction: String => Unit). transform() will be called with - * the currnet record in RDD as the 1st parameter, and the function to output the record to - * the external process (like out.println()) as the 2nd parameter. - * Here's an example on how to pipe the RDD data of groupBy() in a streaming way, - * instead of constructing a huge String to concat all the records: - * def tranform(record:(String, Seq[String]), f:String=>Unit) = for (e <- record._2){f(e)} - * pipeContext can be used to transfer additional context data to the external process - * besides the RDD. pipeContext is a broadcast Seq[String], each line would be piped to - * external process with "^A" as the delimiter in the end of context data. Delimiter can also - * be customized by the last parameter delimiter. - */ - def pipe[U<: Seq[String]]( - command: String, - env: Map[String, String], - transform: (T,String => Unit) => Any, - pipeContext: Broadcast[U], - delimiter: String): RDD[String] = - new PipedRDD(this, command, env, transform, pipeContext, delimiter) - - /** - * Return an RDD created by piping elements to a forked external process. - * How each record in RDD is outputed to the process can be controled by providing a - * function trasnform(T, outputFunction: String => Unit). transform() will be called with - * the currnet record in RDD as the 1st parameter, and the function to output the record to - * the external process (like out.println()) as the 2nd parameter. - * Here's an example on how to pipe the RDD data of groupBy() in a streaming way, - * instead of constructing a huge String to concat all the records: - * def tranform(record:(String, Seq[String]), f:String=>Unit) = for (e <- record._2){f(e)} - * pipeContext can be used to transfer additional context data to the external process - * besides the RDD. pipeContext is a broadcast Seq[String], each line would be piped to - * external process with "^A" as the delimiter in the end of context data. Delimiter can also - * be customized by the last parameter delimiter. - */ - def pipe[U<: Seq[String]]( - command: String, - transform: (T,String => Unit) => Any, - pipeContext: Broadcast[U]): RDD[String] = - new PipedRDD(this, command, Map[String, String](), transform, pipeContext, "\u0001") - - /** - * Return an RDD created by piping elements to a forked external process. - * How each record in RDD is outputed to the process can be controled by providing a - * function trasnform(T, outputFunction: String => Unit). transform() will be called with - * the currnet record in RDD as the 1st parameter, and the function to output the record to - * the external process (like out.println()) as the 2nd parameter. - * Here's an example on how to pipe the RDD data of groupBy() in a streaming way, - * instead of constructing a huge String to concat all the records: - * def tranform(record:(String, Seq[String]), f:String=>Unit) = for (e <- record._2){f(e)} - * pipeContext can be used to transfer additional context data to the external process - * besides the RDD. pipeContext is a broadcast Seq[String], each line would be piped to - * external process with "^A" as the delimiter in the end of context data. Delimiter can also - * be customized by the last parameter delimiter. - */ - def pipe[U<: Seq[String]]( - command: String, - env: Map[String, String], - transform: (T,String => Unit) => Any, - pipeContext: Broadcast[U]): RDD[String] = - new PipedRDD(this, command, env, transform, pipeContext, "\u0001") /** * Return an RDD created by piping elements to a forked external process. @@ -432,13 +370,12 @@ abstract class RDD[T: ClassManifest]( * external process with "^A" as the delimiter in the end of context data. Delimiter can also * be customized by the last parameter delimiter. */ - def pipe[U<: Seq[String]]( + def pipe( command: Seq[String], env: Map[String, String] = Map(), - transform: (T,String => Unit) => Any = null, - pipeContext: Broadcast[U] = null, - delimiter: String = "\u0001"): RDD[String] = - new PipedRDD(this, command, env, transform, pipeContext, delimiter) + printPipeContext: (String => Unit) => Unit = null, + printRDDElement: (T, String => Unit) => Unit = null): RDD[String] = + new PipedRDD(this, command, env, if (printPipeContext ne null) sc.clean(printPipeContext) else null, printRDDElement) /** * Return a new RDD by applying a function to each partition of this RDD. diff --git a/core/src/main/scala/spark/rdd/PipedRDD.scala b/core/src/main/scala/spark/rdd/PipedRDD.scala index d58aaae709..b2c07891ab 100644 --- a/core/src/main/scala/spark/rdd/PipedRDD.scala +++ b/core/src/main/scala/spark/rdd/PipedRDD.scala @@ -16,14 +16,12 @@ import spark.broadcast.Broadcast * An RDD that pipes the contents of each parent partition through an external command * (printing them one per line) and returns the output as a collection of strings. */ -class PipedRDD[T: ClassManifest, U <: Seq[String]]( +class PipedRDD[T: ClassManifest]( prev: RDD[T], command: Seq[String], envVars: Map[String, String], - transform: (T, String => Unit) => Any, - pipeContext: Broadcast[U], - delimiter: String - ) + printPipeContext: (String => Unit) => Unit, + printRDDElement: (T, String => Unit) => Unit) extends RDD[String](prev) { // Similar to Runtime.exec(), if we are given a single string, split it into words @@ -32,10 +30,9 @@ class PipedRDD[T: ClassManifest, U <: Seq[String]]( prev: RDD[T], command: String, envVars: Map[String, String] = Map(), - transform: (T, String => Unit) => Any = null, - pipeContext: Broadcast[U] = null, - delimiter: String = "\u0001") = - this(prev, PipedRDD.tokenize(command), envVars, transform, pipeContext, delimiter) + printPipeContext: (String => Unit) => Unit = null, + printRDDElement: (T, String => Unit) => Unit = null) = + this(prev, PipedRDD.tokenize(command), envVars, printPipeContext, printRDDElement) override def getPartitions: Array[Partition] = firstParent[T].partitions @@ -64,17 +61,13 @@ class PipedRDD[T: ClassManifest, U <: Seq[String]]( SparkEnv.set(env) val out = new PrintWriter(proc.getOutputStream) - // input the pipeContext firstly - if ( pipeContext != null) { - for (elem <- pipeContext.value) { - out.println(elem) - } - // delimiter\n as the marker of the end of the pipeContext - out.println(delimiter) + // input the pipe context firstly + if ( printPipeContext != null) { + printPipeContext(out.println(_)) } for (elem <- firstParent[T].iterator(split, context)) { - if (transform != null) { - transform(elem, out.println(_)) + if (printRDDElement != null) { + printRDDElement(elem, out.println(_)) } else { out.println(elem) } diff --git a/core/src/test/scala/spark/PipedRDDSuite.scala b/core/src/test/scala/spark/PipedRDDSuite.scala index d2852867de..ed075f93ec 100644 --- a/core/src/test/scala/spark/PipedRDDSuite.scala +++ b/core/src/test/scala/spark/PipedRDDSuite.scala @@ -22,9 +22,12 @@ class PipedRDDSuite extends FunSuite with LocalSparkContext { test("advanced pipe") { sc = new SparkContext("local", "test") val nums = sc.makeRDD(Array(1, 2, 3, 4), 2) + val bl = sc.broadcast(List("0")) - val piped = nums.pipe(Seq("cat"), Map[String, String](), - (i:Int, f: String=> Unit) => f(i + "_"), sc.broadcast(List("0"))) + val piped = nums.pipe(Seq("cat"), + Map[String, String](), + (f: String => Unit) => {bl.value.map(f(_));f("\u0001")}, + (i:Int, f: String=> Unit) => f(i + "_")) val c = piped.collect() @@ -40,8 +43,10 @@ class PipedRDDSuite extends FunSuite with LocalSparkContext { val nums1 = sc.makeRDD(Array("a\t1", "b\t2", "a\t3", "b\t4"), 2) val d = nums1.groupBy(str=>str.split("\t")(0)). - pipe(Seq("cat"), Map[String, String](), (i:Tuple2[String, Seq[String]], f: String=> Unit) => - {for (e <- i._2){ f(e + "_")}}, sc.broadcast(List("0"))).collect() + pipe(Seq("cat"), + Map[String, String](), + (f: String => Unit) => {bl.value.map(f(_));f("\u0001")}, + (i:Tuple2[String, Seq[String]], f: String=> Unit) => {for (e <- i._2){ f(e + "_")}}).collect() assert(d.size === 8) assert(d(0) === "0") assert(d(1) === "\u0001") -- cgit v1.2.3 From 4508089fc342802a2f37fea6893cd47abd81fdd7 Mon Sep 17 00:00:00 2001 From: Gavin Li Date: Mon, 17 Jun 2013 05:23:46 +0000 Subject: refine comments and add sc.clean --- core/src/main/scala/spark/RDD.scala | 30 ++++++++++++++++++------------ 1 file changed, 18 insertions(+), 12 deletions(-) diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala index 05ff399a7b..223dcdc19d 100644 --- a/core/src/main/scala/spark/RDD.scala +++ b/core/src/main/scala/spark/RDD.scala @@ -361,24 +361,30 @@ abstract class RDD[T: ClassManifest]( /** * Return an RDD created by piping elements to a forked external process. - * How each record in RDD is outputed to the process can be controled by providing a - * function trasnform(T, outputFunction: String => Unit). transform() will be called with - * the currnet record in RDD as the 1st parameter, and the function to output the record to - * the external process (like out.println()) as the 2nd parameter. - * Here's an example on how to pipe the RDD data of groupBy() in a streaming way, - * instead of constructing a huge String to concat all the records: - * def tranform(record:(String, Seq[String]), f:String=>Unit) = for (e <- record._2){f(e)} - * pipeContext can be used to transfer additional context data to the external process - * besides the RDD. pipeContext is a broadcast Seq[String], each line would be piped to - * external process with "^A" as the delimiter in the end of context data. Delimiter can also - * be customized by the last parameter delimiter. + * The print behavior can be customized by providing two functions. + * + * @param command command to run in forked process. + * @param env environment variables to set. + * @param printPipeContext Before piping elements, this function is called as an oppotunity + * to pipe context data. Print line function (like out.println) will be + * passed as printPipeContext's parameter. + * @param printPipeContext Use this function to customize how to pipe elements. This function + * will be called with each RDD element as the 1st parameter, and the + * print line function (like out.println()) as the 2nd parameter. + * An example of pipe the RDD data of groupBy() in a streaming way, + * instead of constructing a huge String to concat all the elements: + * def printRDDElement(record:(String, Seq[String]), f:String=>Unit) = + * for (e <- record._2){f(e)} + * @return the result RDD */ def pipe( command: Seq[String], env: Map[String, String] = Map(), printPipeContext: (String => Unit) => Unit = null, printRDDElement: (T, String => Unit) => Unit = null): RDD[String] = - new PipedRDD(this, command, env, if (printPipeContext ne null) sc.clean(printPipeContext) else null, printRDDElement) + new PipedRDD(this, command, env, + if (printPipeContext ne null) sc.clean(printPipeContext) else null, + if (printRDDElement ne null) sc.clean(printRDDElement) else null) /** * Return a new RDD by applying a function to each partition of this RDD. -- cgit v1.2.3 From 1450296797e53f1a01166c885050091df9c96e2e Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Mon, 17 Jun 2013 16:58:23 -0400 Subject: SPARK-781: Log the temp directory path when Spark says "Failed to create temp directory". --- core/src/main/scala/spark/Utils.scala | 4 +-- core/src/main/scala/spark/storage/DiskStore.scala | 34 +++++++++++------------ 2 files changed, 19 insertions(+), 19 deletions(-) diff --git a/core/src/main/scala/spark/Utils.scala b/core/src/main/scala/spark/Utils.scala index ec15326014..fd7b8cc8d5 100644 --- a/core/src/main/scala/spark/Utils.scala +++ b/core/src/main/scala/spark/Utils.scala @@ -116,8 +116,8 @@ private object Utils extends Logging { while (dir == null) { attempts += 1 if (attempts > maxAttempts) { - throw new IOException("Failed to create a temp directory after " + maxAttempts + - " attempts!") + throw new IOException("Failed to create a temp directory under (" + root + ") after " + + maxAttempts + " attempts!") } try { dir = new File(root, "spark-" + UUID.randomUUID.toString) diff --git a/core/src/main/scala/spark/storage/DiskStore.scala b/core/src/main/scala/spark/storage/DiskStore.scala index c7281200e7..9914beec99 100644 --- a/core/src/main/scala/spark/storage/DiskStore.scala +++ b/core/src/main/scala/spark/storage/DiskStore.scala @@ -82,15 +82,15 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String) override def size(): Long = lastValidPosition } - val MAX_DIR_CREATION_ATTEMPTS: Int = 10 - val subDirsPerLocalDir = System.getProperty("spark.diskStore.subDirectories", "64").toInt + private val MAX_DIR_CREATION_ATTEMPTS: Int = 10 + private val subDirsPerLocalDir = System.getProperty("spark.diskStore.subDirectories", "64").toInt - var shuffleSender : ShuffleSender = null + private var shuffleSender : ShuffleSender = null // Create one local directory for each path mentioned in spark.local.dir; then, inside this // directory, create multiple subdirectories that we will hash files into, in order to avoid // having really large inodes at the top level. - val localDirs = createLocalDirs() - val subDirs = Array.fill(localDirs.length)(new Array[File](subDirsPerLocalDir)) + private val localDirs: Array[File] = createLocalDirs() + private val subDirs = Array.fill(localDirs.length)(new Array[File](subDirsPerLocalDir)) addShutdownHook() @@ -99,7 +99,6 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String) new DiskBlockObjectWriter(blockId, serializer, bufferSize) } - override def getSize(blockId: String): Long = { getFile(blockId).length() } @@ -232,8 +231,8 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String) private def createLocalDirs(): Array[File] = { logDebug("Creating local directories at root dirs '" + rootDirs + "'") val dateFormat = new SimpleDateFormat("yyyyMMddHHmmss") - rootDirs.split(",").map(rootDir => { - var foundLocalDir: Boolean = false + rootDirs.split(",").map { rootDir => + var foundLocalDir = false var localDir: File = null var localDirId: String = null var tries = 0 @@ -248,7 +247,7 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String) } } catch { case e: Exception => - logWarning("Attempt " + tries + " to create local dir failed", e) + logWarning("Attempt " + tries + " to create local dir " + localDir + " failed", e) } } if (!foundLocalDir) { @@ -258,7 +257,7 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String) } logInfo("Created local directory at " + localDir) localDir - }) + } } private def addShutdownHook() { @@ -266,15 +265,16 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String) Runtime.getRuntime.addShutdownHook(new Thread("delete Spark local dirs") { override def run() { logDebug("Shutdown hook called") - try { - localDirs.foreach { localDir => + localDirs.foreach { localDir => + try { if (!Utils.hasRootAsShutdownDeleteDir(localDir)) Utils.deleteRecursively(localDir) + } catch { + case t: Throwable => + logError("Exception while deleting local spark dir: " + localDir, t) } - if (shuffleSender != null) { - shuffleSender.stop - } - } catch { - case t: Throwable => logError("Exception while deleting local spark dirs", t) + } + if (shuffleSender != null) { + shuffleSender.stop } } }) -- cgit v1.2.3 From be3c406edf06d5ab9da98097c28ce3eebc958b8e Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Mon, 17 Jun 2013 17:07:51 -0400 Subject: Fixed the typo pointed out by Matei. --- core/src/main/scala/spark/Utils.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/spark/Utils.scala b/core/src/main/scala/spark/Utils.scala index fd7b8cc8d5..645c18541e 100644 --- a/core/src/main/scala/spark/Utils.scala +++ b/core/src/main/scala/spark/Utils.scala @@ -116,7 +116,7 @@ private object Utils extends Logging { while (dir == null) { attempts += 1 if (attempts > maxAttempts) { - throw new IOException("Failed to create a temp directory under (" + root + ") after " + + throw new IOException("Failed to create a temp directory (under " + root + ") after " + maxAttempts + " attempts!") } try { -- cgit v1.2.3 From 2ab311f4cee3f918dc28daaebd287b11c9f63429 Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Tue, 18 Jun 2013 00:40:25 +0200 Subject: Removed second version of junit test plugin from plugins.sbt --- project/plugins.sbt | 2 -- 1 file changed, 2 deletions(-) diff --git a/project/plugins.sbt b/project/plugins.sbt index 25b812a28d..d4f2442872 100644 --- a/project/plugins.sbt +++ b/project/plugins.sbt @@ -16,5 +16,3 @@ addSbtPlugin("io.spray" %% "sbt-twirl" % "0.6.1") //resolvers += Resolver.url("sbt-plugin-releases", new URL("http://scalasbt.artifactoryonline.com/scalasbt/sbt-plugin-releases/"))(Resolver.ivyStylePatterns) //addSbtPlugin("com.jsuereth" % "xsbt-gpg-plugin" % "0.6") - -libraryDependencies += "com.novocode" % "junit-interface" % "0.10-M4" % "test" -- cgit v1.2.3 From 1e9269c3eeeaa3a481b95521c703032ed84abd68 Mon Sep 17 00:00:00 2001 From: jerryshao Date: Thu, 13 Jun 2013 10:46:22 +0800 Subject: reduce ZippedPartitionsRDD's getPreferredLocations complexity --- core/src/main/scala/spark/rdd/ZippedPartitionsRDD.scala | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/core/src/main/scala/spark/rdd/ZippedPartitionsRDD.scala b/core/src/main/scala/spark/rdd/ZippedPartitionsRDD.scala index dd9f3c2680..b234428ab2 100644 --- a/core/src/main/scala/spark/rdd/ZippedPartitionsRDD.scala +++ b/core/src/main/scala/spark/rdd/ZippedPartitionsRDD.scala @@ -53,14 +53,10 @@ abstract class ZippedPartitionsBaseRDD[V: ClassManifest]( val exactMatchLocations = exactMatchPreferredLocations.reduce((x, y) => x.intersect(y)) // Remove exact match and then do host local match. - val otherNodePreferredLocations = rddSplitZip.map(x => { - x._1.preferredLocations(x._2).map(hostPort => { - val host = Utils.parseHostPort(hostPort)._1 - - if (exactMatchLocations.contains(host)) null else host - }).filter(_ != null) - }) - val otherNodeLocalLocations = otherNodePreferredLocations.reduce((x, y) => x.intersect(y)) + val exactMatchHosts = exactMatchLocations.map(Utils.parseHostPort(_)._1) + val matchPreferredHosts = exactMatchPreferredLocations.map(locs => locs.map(Utils.parseHostPort(_)._1)) + .reduce((x, y) => x.intersect(y)) + val otherNodeLocalLocations = matchPreferredHosts.filter { s => !exactMatchHosts.contains(s) } otherNodeLocalLocations ++ exactMatchLocations } -- cgit v1.2.3 From 0a2a9bce1e83e891334985c29176c6426b8b1751 Mon Sep 17 00:00:00 2001 From: Gavin Li Date: Tue, 18 Jun 2013 21:30:13 +0000 Subject: fix typo and coding style --- core/src/main/scala/spark/RDD.scala | 14 +++++++------- core/src/main/scala/spark/rdd/PipedRDD.scala | 2 +- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala index 223dcdc19d..709271d4eb 100644 --- a/core/src/main/scala/spark/RDD.scala +++ b/core/src/main/scala/spark/RDD.scala @@ -368,13 +368,13 @@ abstract class RDD[T: ClassManifest]( * @param printPipeContext Before piping elements, this function is called as an oppotunity * to pipe context data. Print line function (like out.println) will be * passed as printPipeContext's parameter. - * @param printPipeContext Use this function to customize how to pipe elements. This function - * will be called with each RDD element as the 1st parameter, and the - * print line function (like out.println()) as the 2nd parameter. - * An example of pipe the RDD data of groupBy() in a streaming way, - * instead of constructing a huge String to concat all the elements: - * def printRDDElement(record:(String, Seq[String]), f:String=>Unit) = - * for (e <- record._2){f(e)} + * @param printRDDElement Use this function to customize how to pipe elements. This function + * will be called with each RDD element as the 1st parameter, and the + * print line function (like out.println()) as the 2nd parameter. + * An example of pipe the RDD data of groupBy() in a streaming way, + * instead of constructing a huge String to concat all the elements: + * def printRDDElement(record:(String, Seq[String]), f:String=>Unit) = + * for (e <- record._2){f(e)} * @return the result RDD */ def pipe( diff --git a/core/src/main/scala/spark/rdd/PipedRDD.scala b/core/src/main/scala/spark/rdd/PipedRDD.scala index b2c07891ab..c0baf43d43 100644 --- a/core/src/main/scala/spark/rdd/PipedRDD.scala +++ b/core/src/main/scala/spark/rdd/PipedRDD.scala @@ -62,7 +62,7 @@ class PipedRDD[T: ClassManifest]( val out = new PrintWriter(proc.getOutputStream) // input the pipe context firstly - if ( printPipeContext != null) { + if (printPipeContext != null) { printPipeContext(out.println(_)) } for (elem <- firstParent[T].iterator(split, context)) { -- cgit v1.2.3 From 7902baddc797f86f5bdbcc966f5cd60545638bf7 Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Wed, 19 Jun 2013 13:34:30 +0200 Subject: Update ASM to version 4.0 --- core/pom.xml | 4 ++-- core/src/main/scala/spark/ClosureCleaner.scala | 11 +++++------ pom.xml | 6 +++--- project/SparkBuild.scala | 2 +- repl/src/main/scala/spark/repl/ExecutorClassLoader.scala | 3 +-- 5 files changed, 12 insertions(+), 14 deletions(-) diff --git a/core/pom.xml b/core/pom.xml index d8687bf991..88f0ed70f3 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -32,8 +32,8 @@ compress-lzf - asm - asm-all + org.ow2.asm + asm com.google.protobuf diff --git a/core/src/main/scala/spark/ClosureCleaner.scala b/core/src/main/scala/spark/ClosureCleaner.scala index 50d6a1c5c9..d5e7132ff9 100644 --- a/core/src/main/scala/spark/ClosureCleaner.scala +++ b/core/src/main/scala/spark/ClosureCleaner.scala @@ -5,8 +5,7 @@ import java.lang.reflect.Field import scala.collection.mutable.Map import scala.collection.mutable.Set -import org.objectweb.asm.{ClassReader, MethodVisitor, Type} -import org.objectweb.asm.commons.EmptyVisitor +import org.objectweb.asm.{ClassReader, ClassVisitor, MethodVisitor, Type} import org.objectweb.asm.Opcodes._ import java.io.{InputStream, IOException, ByteArrayOutputStream, ByteArrayInputStream, BufferedInputStream} @@ -162,10 +161,10 @@ private[spark] object ClosureCleaner extends Logging { } } -private[spark] class FieldAccessFinder(output: Map[Class[_], Set[String]]) extends EmptyVisitor { +private[spark] class FieldAccessFinder(output: Map[Class[_], Set[String]]) extends ClassVisitor(ASM4) { override def visitMethod(access: Int, name: String, desc: String, sig: String, exceptions: Array[String]): MethodVisitor = { - return new EmptyVisitor { + return new MethodVisitor(ASM4) { override def visitFieldInsn(op: Int, owner: String, name: String, desc: String) { if (op == GETFIELD) { for (cl <- output.keys if cl.getName == owner.replace('/', '.')) { @@ -188,7 +187,7 @@ private[spark] class FieldAccessFinder(output: Map[Class[_], Set[String]]) exten } } -private[spark] class InnerClosureFinder(output: Set[Class[_]]) extends EmptyVisitor { +private[spark] class InnerClosureFinder(output: Set[Class[_]]) extends ClassVisitor(ASM4) { var myName: String = null override def visit(version: Int, access: Int, name: String, sig: String, @@ -198,7 +197,7 @@ private[spark] class InnerClosureFinder(output: Set[Class[_]]) extends EmptyVisi override def visitMethod(access: Int, name: String, desc: String, sig: String, exceptions: Array[String]): MethodVisitor = { - return new EmptyVisitor { + return new MethodVisitor(ASM4) { override def visitMethodInsn(op: Int, owner: String, name: String, desc: String) { val argTypes = Type.getArgumentTypes(desc) diff --git a/pom.xml b/pom.xml index c893ec755e..3bcb2a3f34 100644 --- a/pom.xml +++ b/pom.xml @@ -190,9 +190,9 @@ 0.8.4 - asm - asm-all - 3.3.1 + org.ow2.asm + asm + 4.0 com.google.protobuf diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 824af821f9..b1f3f9a2ea 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -148,7 +148,7 @@ object SparkBuild extends Build { "org.slf4j" % "slf4j-log4j12" % slf4jVersion, "commons-daemon" % "commons-daemon" % "1.0.10", "com.ning" % "compress-lzf" % "0.8.4", - "asm" % "asm-all" % "3.3.1", + "org.ow2.asm" % "asm" % "4.0", "com.google.protobuf" % "protobuf-java" % "2.4.1", "de.javakaffee" % "kryo-serializers" % "0.22", "com.typesafe.akka" % "akka-actor" % "2.0.3" excludeAll(excludeNetty), diff --git a/repl/src/main/scala/spark/repl/ExecutorClassLoader.scala b/repl/src/main/scala/spark/repl/ExecutorClassLoader.scala index 13d81ec1cf..0e9aa863b5 100644 --- a/repl/src/main/scala/spark/repl/ExecutorClassLoader.scala +++ b/repl/src/main/scala/spark/repl/ExecutorClassLoader.scala @@ -8,7 +8,6 @@ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileSystem, Path} import org.objectweb.asm._ -import org.objectweb.asm.commons.EmptyVisitor import org.objectweb.asm.Opcodes._ @@ -83,7 +82,7 @@ extends ClassLoader(parent) { } class ConstructorCleaner(className: String, cv: ClassVisitor) -extends ClassAdapter(cv) { +extends ClassVisitor(ASM4, cv) { override def visitMethod(access: Int, name: String, desc: String, sig: String, exceptions: Array[String]): MethodVisitor = { val mv = cv.visitMethod(access, name, desc, sig, exceptions) -- cgit v1.2.3 From ae7a5da6b31f5bf64f713b3d9bff6e441d8615b4 Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Thu, 20 Jun 2013 18:44:46 +0200 Subject: Fix some dependency issues in SBT build (same will be needed for Maven): - Exclude a version of ASM 3.x that comes from HBase - Don't use a special ASF repo for HBase - Update SLF4J version - Add sbt-dependency-graph plugin so we can easily find dependency trees --- project/SparkBuild.scala | 10 +++++----- project/plugins.sbt | 2 ++ 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index b1f3f9a2ea..24c8b734d0 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -125,12 +125,13 @@ object SparkBuild extends Build { publishMavenStyle in MavenCompile := true, publishLocal in MavenCompile <<= publishTask(publishLocalConfiguration in MavenCompile, deliverLocal), publishLocalBoth <<= Seq(publishLocal in MavenCompile, publishLocal).dependOn - ) + ) ++ net.virtualvoid.sbt.graph.Plugin.graphSettings - val slf4jVersion = "1.6.1" + val slf4jVersion = "1.7.2" val excludeJackson = ExclusionRule(organization = "org.codehaus.jackson") val excludeNetty = ExclusionRule(organization = "org.jboss.netty") + val excludeAsm = ExclusionRule(organization = "asm") def coreSettings = sharedSettings ++ Seq( name := "spark-core", @@ -201,11 +202,10 @@ object SparkBuild extends Build { def examplesSettings = sharedSettings ++ Seq( name := "spark-examples", - resolvers ++= Seq("Apache HBase" at "https://repository.apache.org/content/repositories/releases"), libraryDependencies ++= Seq( "com.twitter" % "algebird-core_2.9.2" % "0.1.11", - "org.apache.hbase" % "hbase" % "0.94.6" excludeAll(excludeNetty), + "org.apache.hbase" % "hbase" % "0.94.6" excludeAll(excludeNetty, excludeAsm), "org.apache.cassandra" % "cassandra-all" % "1.2.5" exclude("com.google.guava", "guava") @@ -224,7 +224,7 @@ object SparkBuild extends Build { name := "spark-streaming", libraryDependencies ++= Seq( "org.apache.flume" % "flume-ng-sdk" % "1.2.0" % "compile" excludeAll(excludeNetty), - "com.github.sgroschupf" % "zkclient" % "0.1", + "com.github.sgroschupf" % "zkclient" % "0.1" excludeAll(excludeNetty), "org.twitter4j" % "twitter4j-stream" % "3.0.3" excludeAll(excludeNetty), "com.typesafe.akka" % "akka-zeromq" % "2.0.3" excludeAll(excludeNetty) ) diff --git a/project/plugins.sbt b/project/plugins.sbt index d4f2442872..f806e66481 100644 --- a/project/plugins.sbt +++ b/project/plugins.sbt @@ -16,3 +16,5 @@ addSbtPlugin("io.spray" %% "sbt-twirl" % "0.6.1") //resolvers += Resolver.url("sbt-plugin-releases", new URL("http://scalasbt.artifactoryonline.com/scalasbt/sbt-plugin-releases/"))(Resolver.ivyStylePatterns) //addSbtPlugin("com.jsuereth" % "xsbt-gpg-plugin" % "0.6") + +addSbtPlugin("net.virtual-void" % "sbt-dependency-graph" % "0.7.3") -- cgit v1.2.3 From 52407951541399e60a5292394b3a443a5e7ff281 Mon Sep 17 00:00:00 2001 From: Mingfei Date: Fri, 21 Jun 2013 17:38:23 +0800 Subject: edit according to comments --- core/src/main/scala/spark/RDD.scala | 6 +- core/src/main/scala/spark/Utils.scala | 10 +-- .../main/scala/spark/executor/TaskMetrics.scala | 2 +- .../main/scala/spark/scheduler/DAGScheduler.scala | 14 +++-- .../src/main/scala/spark/scheduler/JobLogger.scala | 72 ++++++++++------------ .../main/scala/spark/scheduler/SparkListener.scala | 25 ++++---- .../scala/spark/scheduler/JobLoggerSuite.scala | 2 +- 7 files changed, 62 insertions(+), 69 deletions(-) diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala index 8c0b7ca417..b17398953b 100644 --- a/core/src/main/scala/spark/RDD.scala +++ b/core/src/main/scala/spark/RDD.scala @@ -114,10 +114,10 @@ abstract class RDD[T: ClassManifest]( this } - /**User-defined generator of this RDD*/ - var generator = Utils.getCallSiteInfo._4 + /** User-defined generator of this RDD*/ + var generator = Utils.getCallSiteInfo.firstUserClass - /**reset generator*/ + /** Reset generator*/ def setGenerator(_generator: String) = { generator = _generator } diff --git a/core/src/main/scala/spark/Utils.scala b/core/src/main/scala/spark/Utils.scala index 1630b2b4b0..1cfaee79b1 100644 --- a/core/src/main/scala/spark/Utils.scala +++ b/core/src/main/scala/spark/Utils.scala @@ -522,13 +522,14 @@ private object Utils extends Logging { execute(command, new File(".")) } - + class CallSiteInfo(val lastSparkMethod: String, val firstUserFile: String, + val firstUserLine: Int, val firstUserClass: String) /** * When called inside a class in the spark package, returns the name of the user code class * (outside the spark package) that called into Spark, as well as which Spark method they called. * This is used, for example, to tell users where in their code each RDD got created. */ - def getCallSiteInfo = { + def getCallSiteInfo: CallSiteInfo = { val trace = Thread.currentThread.getStackTrace().filter( el => (!el.getMethodName.contains("getStackTrace"))) @@ -560,12 +561,13 @@ private object Utils extends Logging { } } } - (lastSparkMethod, firstUserFile, firstUserLine, firstUserClass) + new CallSiteInfo(lastSparkMethod, firstUserFile, firstUserLine, firstUserClass) } def formatSparkCallSite = { val callSiteInfo = getCallSiteInfo - "%s at %s:%s".format(callSiteInfo._1, callSiteInfo._2, callSiteInfo._3) + "%s at %s:%s".format(callSiteInfo.lastSparkMethod, callSiteInfo.firstUserFile, + callSiteInfo.firstUserLine) } /** * Try to find a free port to bind to on the local host. This should ideally never be needed, diff --git a/core/src/main/scala/spark/executor/TaskMetrics.scala b/core/src/main/scala/spark/executor/TaskMetrics.scala index 26e8029365..1dc13754f9 100644 --- a/core/src/main/scala/spark/executor/TaskMetrics.scala +++ b/core/src/main/scala/spark/executor/TaskMetrics.scala @@ -2,7 +2,7 @@ package spark.executor class TaskMetrics extends Serializable { /** - * host's name the task runs on + * Host's name the task runs on */ var hostname: String = _ diff --git a/core/src/main/scala/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/spark/scheduler/DAGScheduler.scala index e281e5a8db..4336f2f36d 100644 --- a/core/src/main/scala/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/spark/scheduler/DAGScheduler.scala @@ -312,7 +312,8 @@ class DAGScheduler( handleExecutorLost(execId) case completion: CompletionEvent => - sparkListeners.foreach(_.onTaskEnd(SparkListenerTaskEnd(completion))) + sparkListeners.foreach(_.onTaskEnd(SparkListenerTaskEnd(completion.task, + completion.reason, completion.taskInfo, completion.taskMetrics))) handleTaskCompletion(completion) case TaskSetFailed(taskSet, reason) => @@ -323,8 +324,8 @@ class DAGScheduler( for (job <- activeJobs) { val error = new SparkException("Job cancelled because SparkContext was shut down") job.listener.jobFailed(error) - sparkListeners.foreach(_.onJobEnd(SparkListenerJobCancelled(job, - "SPARKCONTEXT_SHUTDOWN"))) + sparkListeners.foreach(_.onJobEnd(SparkListenerJobEnd(job, + JobFailed(error)))) } return true } @@ -527,7 +528,7 @@ class DAGScheduler( activeJobs -= job resultStageToJob -= stage markStageAsFinished(stage) - sparkListeners.foreach(_.onJobEnd(SparkListenerJobSuccess(job))) + sparkListeners.foreach(_.onJobEnd(SparkListenerJobEnd(job, JobSucceeded))) } job.listener.taskSucceeded(rt.outputId, event.result) } @@ -668,10 +669,11 @@ class DAGScheduler( val dependentStages = resultStageToJob.keys.filter(x => stageDependsOn(x, failedStage)).toSeq for (resultStage <- dependentStages) { val job = resultStageToJob(resultStage) - job.listener.jobFailed(new SparkException("Job failed: " + reason)) + val error = new SparkException("Job failed: " + reason) + job.listener.jobFailed(error) activeJobs -= job resultStageToJob -= resultStage - sparkListeners.foreach(_.onJobEnd(SparkListenerJobFailed(job, failedStage))) + sparkListeners.foreach(_.onJobEnd(SparkListenerJobEnd(job, JobFailed(error)))) } if (dependentStages.isEmpty) { logInfo("Ignoring failure of " + failedStage + " because all jobs depending on it are done") diff --git a/core/src/main/scala/spark/scheduler/JobLogger.scala b/core/src/main/scala/spark/scheduler/JobLogger.scala index 002c5826cb..760a0252b7 100644 --- a/core/src/main/scala/spark/scheduler/JobLogger.scala +++ b/core/src/main/scala/spark/scheduler/JobLogger.scala @@ -12,7 +12,7 @@ import spark._ import spark.executor.TaskMetrics import spark.scheduler.cluster.TaskInfo -// used to record runtime information for each job, including RDD graph +// Used to record runtime information for each job, including RDD graph // tasks' start/stop shuffle information and information from outside class JobLogger(val logDirName: String) extends SparkListener with Logging { @@ -49,21 +49,17 @@ class JobLogger(val logDirName: String) extends SparkListener with Logging { processStageSubmittedEvent(stage, taskSize) case StageCompleted(stageInfo) => processStageCompletedEvent(stageInfo) - case SparkListenerJobSuccess(job) => - processJobEndEvent(job) - case SparkListenerJobFailed(job, failedStage) => - processJobEndEvent(job, failedStage) - case SparkListenerJobCancelled(job, reason) => - processJobEndEvent(job, reason) - case SparkListenerTaskEnd(event) => - processTaskEndEvent(event) + case SparkListenerJobEnd(job, result) => + processJobEndEvent(job, result) + case SparkListenerTaskEnd(task, reason, taskInfo, taskMetrics) => + processTaskEndEvent(task, reason, taskInfo, taskMetrics) case _ => } } } }.start() - //create a folder for log files, the folder's name is the creation time of the jobLogger + // Create a folder for log files, the folder's name is the creation time of the jobLogger protected def createLogDir() { val dir = new File(logDir + "/" + logDirName + "/") if (dir.exists()) { @@ -244,54 +240,50 @@ class JobLogger(val logDirName: String) extends SparkListener with Logging { eventQueue.put(taskEnd) } - protected def processTaskEndEvent(event: CompletionEvent) { + protected def processTaskEndEvent(task: Task[_], reason: TaskEndReason, + taskInfo: TaskInfo, taskMetrics: TaskMetrics) { var taskStatus = "" - event.task match { + task match { case resultTask: ResultTask[_, _] => taskStatus = "TASK_TYPE=RESULT_TASK" case shuffleMapTask: ShuffleMapTask => taskStatus = "TASK_TYPE=SHUFFLE_MAP_TASK" } - event.reason match { + reason match { case Success => taskStatus += " STATUS=SUCCESS" - recordTaskMetrics(event.task.stageId, taskStatus, event.taskInfo, event.taskMetrics) + recordTaskMetrics(task.stageId, taskStatus, taskInfo, taskMetrics) case Resubmitted => - taskStatus += " STATUS=RESUBMITTED TID=" + event.taskInfo.taskId + - " STAGE_ID=" + event.task.stageId - stageLogInfo(event.task.stageId, taskStatus) + taskStatus += " STATUS=RESUBMITTED TID=" + taskInfo.taskId + + " STAGE_ID=" + task.stageId + stageLogInfo(task.stageId, taskStatus) case FetchFailed(bmAddress, shuffleId, mapId, reduceId) => - taskStatus += " STATUS=FETCHFAILED TID=" + event.taskInfo.taskId + " STAGE_ID=" + - event.task.stageId + " SHUFFLE_ID=" + shuffleId + " MAP_ID=" + + taskStatus += " STATUS=FETCHFAILED TID=" + taskInfo.taskId + " STAGE_ID=" + + task.stageId + " SHUFFLE_ID=" + shuffleId + " MAP_ID=" + mapId + " REDUCE_ID=" + reduceId - stageLogInfo(event.task.stageId, taskStatus) + stageLogInfo(task.stageId, taskStatus) case OtherFailure(message) => - taskStatus += " STATUS=FAILURE TID=" + event.taskInfo.taskId + - " STAGE_ID=" + event.task.stageId + " INFO=" + message - stageLogInfo(event.task.stageId, taskStatus) + taskStatus += " STATUS=FAILURE TID=" + taskInfo.taskId + + " STAGE_ID=" + task.stageId + " INFO=" + message + stageLogInfo(task.stageId, taskStatus) case _ => } } - override def onJobEnd(jobEnd: SparkListenerEvents) { + override def onJobEnd(jobEnd: SparkListenerJobEnd) { eventQueue.put(jobEnd) } - protected def processJobEndEvent(job: ActiveJob) { - val info = "JOB_ID=" + job.runId + " STATUS=SUCCESS" - jobLogInfo(job.runId, info) - closeLogWriter(job.runId) - } - - protected def processJobEndEvent(job: ActiveJob, failedStage: Stage) { - val info = "JOB_ID=" + job.runId + " STATUS=FAILED REASON=STAGE_FAILED FAILED_STAGE_ID=" - + failedStage.id - jobLogInfo(job.runId, info) - closeLogWriter(job.runId) - } - protected def processJobEndEvent(job: ActiveJob, reason: String) { - var info = "JOB_ID=" + job.runId + " STATUS=CANCELLED REASON=" + reason - jobLogInfo(job.runId, info) + protected def processJobEndEvent(job: ActiveJob, reason: JobResult) { + var info = "JOB_ID=" + job.runId + reason match { + case JobSucceeded => info += " STATUS=SUCCESS" + case JobFailed(exception) => + info += " STATUS=FAILED REASON=" + exception.getMessage.split("\\s+").foreach(info += _ + "_") + case _ => + } + jobLogInfo(job.runId, info.substring(0, info.length - 1).toUpperCase) closeLogWriter(job.runId) } - + protected def recordJobProperties(jobID: Int, properties: Properties) { if(properties != null) { val annotation = properties.getProperty("spark.job.annotation", "") diff --git a/core/src/main/scala/spark/scheduler/SparkListener.scala b/core/src/main/scala/spark/scheduler/SparkListener.scala index 9265261dc1..bac984b5c9 100644 --- a/core/src/main/scala/spark/scheduler/SparkListener.scala +++ b/core/src/main/scala/spark/scheduler/SparkListener.scala @@ -3,52 +3,49 @@ package spark.scheduler import java.util.Properties import spark.scheduler.cluster.TaskInfo import spark.util.Distribution -import spark.{Utils, Logging, SparkContext, TaskEndReason} +import spark.{Logging, SparkContext, TaskEndReason, Utils} import spark.executor.TaskMetrics - sealed trait SparkListenerEvents case class SparkListenerStageSubmitted(stage: Stage, taskSize: Int) extends SparkListenerEvents case class StageCompleted(val stageInfo: StageInfo) extends SparkListenerEvents -case class SparkListenerTaskEnd(event: CompletionEvent) extends SparkListenerEvents +case class SparkListenerTaskEnd(task: Task[_], reason: TaskEndReason, taskInfo: TaskInfo, + taskMetrics: TaskMetrics) extends SparkListenerEvents case class SparkListenerJobStart(job: ActiveJob, properties: Properties = null) extends SparkListenerEvents - -case class SparkListenerJobSuccess(job: ActiveJob) extends SparkListenerEvents - -case class SparkListenerJobFailed(job: ActiveJob, failedStage: Stage) extends SparkListenerEvents -case class SparkListenerJobCancelled(job: ActiveJob, reason: String) extends SparkListenerEvents +case class SparkListenerJobEnd(job: ActiveJob, jobResult: JobResult) + extends SparkListenerEvents trait SparkListener { /** - * called when a stage is completed, with information on the completed stage + * Called when a stage is completed, with information on the completed stage */ def onStageCompleted(stageCompleted: StageCompleted) { } /** - * called when a stage is submitted + * Called when a stage is submitted */ def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted) { } /** - * called when a task ends + * Called when a task ends */ def onTaskEnd(taskEnd: SparkListenerTaskEnd) { } /** - * called when a job starts + * Called when a job starts */ def onJobStart(jobStart: SparkListenerJobStart) { } /** - * called when a job ends + * Called when a job ends */ - def onJobEnd(jobEnd: SparkListenerEvents) { } + def onJobEnd(jobEnd: SparkListenerJobEnd) { } } diff --git a/core/src/test/scala/spark/scheduler/JobLoggerSuite.scala b/core/src/test/scala/spark/scheduler/JobLoggerSuite.scala index a654bf3ffd..4000c4d520 100644 --- a/core/src/test/scala/spark/scheduler/JobLoggerSuite.scala +++ b/core/src/test/scala/spark/scheduler/JobLoggerSuite.scala @@ -87,7 +87,7 @@ class JobLoggerSuite extends FunSuite with LocalSparkContext with ShouldMatchers var onStageCompletedCount = 0 var onStageSubmittedCount = 0 override def onTaskEnd(taskEnd: SparkListenerTaskEnd) = onTaskEndCount += 1 - override def onJobEnd(jobEnd: SparkListenerEvents) = onJobEndCount += 1 + override def onJobEnd(jobEnd: SparkListenerJobEnd) = onJobEndCount += 1 override def onJobStart(jobStart: SparkListenerJobStart) = onJobStartCount += 1 override def onStageCompleted(stageCompleted: StageCompleted) = onStageCompletedCount += 1 override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted) = onStageSubmittedCount += 1 -- cgit v1.2.3 From aa7aa587beff22e2db50d2afadd95097856a299a Mon Sep 17 00:00:00 2001 From: Mingfei Date: Fri, 21 Jun 2013 17:48:41 +0800 Subject: some format modification --- core/src/main/scala/spark/scheduler/DAGScheduler.scala | 5 ++--- core/src/main/scala/spark/scheduler/JobLogger.scala | 10 +++++----- 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/core/src/main/scala/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/spark/scheduler/DAGScheduler.scala index 4336f2f36d..e412baa803 100644 --- a/core/src/main/scala/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/spark/scheduler/DAGScheduler.scala @@ -324,8 +324,7 @@ class DAGScheduler( for (job <- activeJobs) { val error = new SparkException("Job cancelled because SparkContext was shut down") job.listener.jobFailed(error) - sparkListeners.foreach(_.onJobEnd(SparkListenerJobEnd(job, - JobFailed(error)))) + sparkListeners.foreach(_.onJobEnd(SparkListenerJobEnd(job, JobFailed(error)))) } return true } @@ -671,9 +670,9 @@ class DAGScheduler( val job = resultStageToJob(resultStage) val error = new SparkException("Job failed: " + reason) job.listener.jobFailed(error) + sparkListeners.foreach(_.onJobEnd(SparkListenerJobEnd(job, JobFailed(error)))) activeJobs -= job resultStageToJob -= resultStage - sparkListeners.foreach(_.onJobEnd(SparkListenerJobEnd(job, JobFailed(error)))) } if (dependentStages.isEmpty) { logInfo("Ignoring failure of " + failedStage + " because all jobs depending on it are done") diff --git a/core/src/main/scala/spark/scheduler/JobLogger.scala b/core/src/main/scala/spark/scheduler/JobLogger.scala index 760a0252b7..178bfaba3d 100644 --- a/core/src/main/scala/spark/scheduler/JobLogger.scala +++ b/core/src/main/scala/spark/scheduler/JobLogger.scala @@ -70,7 +70,7 @@ class JobLogger(val logDirName: String) extends SparkListener with Logging { } } - // create a log file for one job, the file name is the jobID + // Create a log file for one job, the file name is the jobID protected def createLogWriter(jobID: Int) { try{ val fileWriter = new PrintWriter(logDir + "/" + logDirName + "/" + jobID) @@ -80,7 +80,7 @@ class JobLogger(val logDirName: String) extends SparkListener with Logging { } } - // close log file for one job, and clean the stage relationship in stageIDToJobID + // Close log file, and clean the stage relationship in stageIDToJobID protected def closeLogWriter(jobID: Int) = jobIDToPrintWriter.get(jobID).foreach { fileWriter => fileWriter.close() @@ -91,7 +91,7 @@ class JobLogger(val logDirName: String) extends SparkListener with Logging { jobIDToStages -= jobID } - // write log information to log file, withTime parameter controls whether to recored + // Write log information to log file, withTime parameter controls whether to recored // time stamp for the information protected def jobLogInfo(jobID: Int, info: String, withTime: Boolean = true) { var writeInfo = info @@ -145,7 +145,7 @@ class JobLogger(val logDirName: String) extends SparkListener with Logging { } } - // generate indents and convert to String + // Generate indents and convert to String protected def indentString(indent: Int) = { val sb = new StringBuilder() for (i <- 1 to indent) { @@ -190,7 +190,7 @@ class JobLogger(val logDirName: String) extends SparkListener with Logging { jobLogInfo(jobID, indentString(indent) + stageInfo + " JOB_ID=" + stage.priority, false) } - // record task metrics into job log files + // Record task metrics into job log files protected def recordTaskMetrics(stageID: Int, status: String, taskInfo: TaskInfo, taskMetrics: TaskMetrics) { val info = " TID=" + taskInfo.taskId + " STAGE_ID=" + stageID + -- cgit v1.2.3 From 4b9862ac9cf2d00c5245e9a8b0fcb05b82030c98 Mon Sep 17 00:00:00 2001 From: Mingfei Date: Fri, 21 Jun 2013 17:55:32 +0800 Subject: small format modification --- core/src/main/scala/spark/Utils.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/spark/Utils.scala b/core/src/main/scala/spark/Utils.scala index 1cfaee79b1..96d86647f8 100644 --- a/core/src/main/scala/spark/Utils.scala +++ b/core/src/main/scala/spark/Utils.scala @@ -522,8 +522,8 @@ private object Utils extends Logging { execute(command, new File(".")) } - class CallSiteInfo(val lastSparkMethod: String, val firstUserFile: String, - val firstUserLine: Int, val firstUserClass: String) + private[spark] class CallSiteInfo(val lastSparkMethod: String, val firstUserFile: String, + val firstUserLine: Int, val firstUserClass: String) /** * When called inside a class in the spark package, returns the name of the user code class * (outside the spark package) that called into Spark, as well as which Spark method they called. -- cgit v1.2.3 From 2fc794a6c7f1b86e5c0103a9c82af2be7fafb347 Mon Sep 17 00:00:00 2001 From: Mingfei Date: Fri, 21 Jun 2013 18:21:35 +0800 Subject: small modify in DAGScheduler --- core/src/main/scala/spark/scheduler/DAGScheduler.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/spark/scheduler/DAGScheduler.scala index e412baa803..f7d60be5db 100644 --- a/core/src/main/scala/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/spark/scheduler/DAGScheduler.scala @@ -289,7 +289,6 @@ class DAGScheduler( val finalStage = newStage(finalRDD, None, runId) val job = new ActiveJob(runId, finalStage, func, partitions, callSite, listener, properties) clearCacheLocs() - sparkListeners.foreach(_.onJobStart(SparkListenerJobStart(job, properties))) logInfo("Got job " + job.runId + " (" + callSite + ") with " + partitions.length + " output partitions (allowLocal=" + allowLocal + ")") logInfo("Final stage: " + finalStage + " (" + finalStage.origin + ")") @@ -299,6 +298,7 @@ class DAGScheduler( // Compute very short actions like first() or take() with no parent stages locally. runLocally(job) } else { + sparkListeners.foreach(_.onJobStart(SparkListenerJobStart(job, properties))) idToActiveJob(runId) = job activeJobs += job resultStageToJob(finalStage) = job -- cgit v1.2.3 From 40afe0d2a5562738ef2ff37ed1d448ae2d0cc927 Mon Sep 17 00:00:00 2001 From: Jey Kottalam Date: Sun, 10 Mar 2013 13:54:46 -0700 Subject: Add Python timing instrumentation --- core/src/main/scala/spark/api/python/PythonRDD.scala | 12 ++++++++++++ python/pyspark/serializers.py | 4 ++++ python/pyspark/worker.py | 16 +++++++++++++++- 3 files changed, 31 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/spark/api/python/PythonRDD.scala b/core/src/main/scala/spark/api/python/PythonRDD.scala index 807119ca8c..e9978d713f 100644 --- a/core/src/main/scala/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/spark/api/python/PythonRDD.scala @@ -47,6 +47,7 @@ private[spark] class PythonRDD[T: ClassManifest]( currentEnvVars.put(variable, value) } + val startTime = System.currentTimeMillis val proc = pb.start() val env = SparkEnv.get @@ -108,6 +109,17 @@ private[spark] class PythonRDD[T: ClassManifest]( val obj = new Array[Byte](length) stream.readFully(obj) obj + case -3 => + // Timing data from child + val bootTime = stream.readLong() + val initTime = stream.readLong() + val finishTime = stream.readLong() + val boot = bootTime - startTime + val init = initTime - bootTime + val finish = finishTime - initTime + val total = finishTime - startTime + logInfo("Times: total = %s, boot = %s, init = %s, finish = %s".format(total, boot, init, finish)) + read case -2 => // Signals that an exception has been thrown in python val exLength = stream.readInt() diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index 115cf28cc2..5a95144983 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -46,6 +46,10 @@ def read_long(stream): return struct.unpack("!q", length)[0] +def write_long(value, stream): + stream.write(struct.pack("!q", value)) + + def read_int(stream): length = stream.read(4) if length == "": diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 812e7a9da5..4c33ae49dc 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -1,6 +1,8 @@ """ Worker that receives input from Piped RDD. """ +import time +preboot_time = time.time() import os import sys import traceback @@ -12,7 +14,7 @@ from pyspark.broadcast import Broadcast, _broadcastRegistry from pyspark.cloudpickle import CloudPickler from pyspark.files import SparkFiles from pyspark.serializers import write_with_length, read_with_length, write_int, \ - read_long, read_int, dump_pickle, load_pickle, read_from_pickle_file + read_long, write_long, read_int, dump_pickle, load_pickle, read_from_pickle_file # Redirect stdout to stderr so that users must return values from functions. @@ -24,7 +26,16 @@ def load_obj(): return load_pickle(standard_b64decode(sys.stdin.readline().strip())) +def report_times(preboot, boot, init, finish): + write_int(-3, old_stdout) + write_long(1000 * preboot, old_stdout) + write_long(1000 * boot, old_stdout) + write_long(1000 * init, old_stdout) + write_long(1000 * finish, old_stdout) + + def main(): + boot_time = time.time() split_index = read_int(sys.stdin) spark_files_dir = load_pickle(read_with_length(sys.stdin)) SparkFiles._root_directory = spark_files_dir @@ -41,6 +52,7 @@ def main(): dumps = lambda x: x else: dumps = dump_pickle + init_time = time.time() iterator = read_from_pickle_file(sys.stdin) try: for obj in func(split_index, iterator): @@ -49,6 +61,8 @@ def main(): write_int(-2, old_stdout) write_with_length(traceback.format_exc(), old_stdout) sys.exit(-1) + finish_time = time.time() + report_times(preboot_time, boot_time, init_time, finish_time) # Mark the beginning of the accumulators section of the output write_int(-1, old_stdout) for aid, accum in _accumulatorRegistry.items(): -- cgit v1.2.3 From c79a6078c34c207ad9f9910252f5849424828bf1 Mon Sep 17 00:00:00 2001 From: Jey Kottalam Date: Mon, 6 May 2013 16:34:30 -0700 Subject: Prefork Python worker processes --- core/src/main/scala/spark/SparkEnv.scala | 11 +++ .../main/scala/spark/api/python/PythonRDD.scala | 66 +++++-------- .../main/scala/spark/api/python/PythonWorker.scala | 89 +++++++++++++++++ python/pyspark/daemon.py | 109 +++++++++++++++++++++ python/pyspark/worker.py | 61 ++++++------ 5 files changed, 263 insertions(+), 73 deletions(-) create mode 100644 core/src/main/scala/spark/api/python/PythonWorker.scala create mode 100644 python/pyspark/daemon.py diff --git a/core/src/main/scala/spark/SparkEnv.scala b/core/src/main/scala/spark/SparkEnv.scala index be1a04d619..5691e24c32 100644 --- a/core/src/main/scala/spark/SparkEnv.scala +++ b/core/src/main/scala/spark/SparkEnv.scala @@ -1,5 +1,8 @@ package spark +import collection.mutable +import serializer.Serializer + import akka.actor.{Actor, ActorRef, Props, ActorSystemImpl, ActorSystem} import akka.remote.RemoteActorRefProvider @@ -9,6 +12,7 @@ import spark.storage.BlockManagerMaster import spark.network.ConnectionManager import spark.serializer.{Serializer, SerializerManager} import spark.util.AkkaUtils +import spark.api.python.PythonWorker /** @@ -37,6 +41,8 @@ class SparkEnv ( // If executorId is NOT found, return defaultHostPort var executorIdToHostPort: Option[(String, String) => String]) { + private val pythonWorkers = mutable.HashMap[(String, Map[String, String]), PythonWorker]() + def stop() { httpFileServer.stop() mapOutputTracker.stop() @@ -50,6 +56,11 @@ class SparkEnv ( actorSystem.awaitTermination() } + def getPythonWorker(pythonExec: String, envVars: Map[String, String]): PythonWorker = { + synchronized { + pythonWorkers.getOrElseUpdate((pythonExec, envVars), new PythonWorker(pythonExec, envVars)) + } + } def resolveExecutorIdToHostPort(executorId: String, defaultHostPort: String): String = { val env = SparkEnv.get diff --git a/core/src/main/scala/spark/api/python/PythonRDD.scala b/core/src/main/scala/spark/api/python/PythonRDD.scala index e9978d713f..e5acc54c01 100644 --- a/core/src/main/scala/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/spark/api/python/PythonRDD.scala @@ -2,10 +2,9 @@ package spark.api.python import java.io._ import java.net._ -import java.util.{List => JList, ArrayList => JArrayList, Collections} +import java.util.{List => JList, ArrayList => JArrayList, Map => JMap, Collections} import scala.collection.JavaConversions._ -import scala.io.Source import spark.api.java.{JavaSparkContext, JavaPairRDD, JavaRDD} import spark.broadcast.Broadcast @@ -16,7 +15,7 @@ import spark.rdd.PipedRDD private[spark] class PythonRDD[T: ClassManifest]( parent: RDD[T], command: Seq[String], - envVars: java.util.Map[String, String], + envVars: JMap[String, String], preservePartitoning: Boolean, pythonExec: String, broadcastVars: JList[Broadcast[Array[Byte]]], @@ -25,7 +24,7 @@ private[spark] class PythonRDD[T: ClassManifest]( // Similar to Runtime.exec(), if we are given a single string, split it into words // using a standard StringTokenizer (i.e. by spaces) - def this(parent: RDD[T], command: String, envVars: java.util.Map[String, String], + def this(parent: RDD[T], command: String, envVars: JMap[String, String], preservePartitoning: Boolean, pythonExec: String, broadcastVars: JList[Broadcast[Array[Byte]]], accumulator: Accumulator[JList[Array[Byte]]]) = @@ -36,36 +35,18 @@ private[spark] class PythonRDD[T: ClassManifest]( override val partitioner = if (preservePartitoning) parent.partitioner else None - override def compute(split: Partition, context: TaskContext): Iterator[Array[Byte]] = { - val SPARK_HOME = new ProcessBuilder().environment().get("SPARK_HOME") - - val pb = new ProcessBuilder(Seq(pythonExec, SPARK_HOME + "/python/pyspark/worker.py")) - // Add the environmental variables to the process. - val currentEnvVars = pb.environment() - - for ((variable, value) <- envVars) { - currentEnvVars.put(variable, value) - } + override def compute(split: Partition, context: TaskContext): Iterator[Array[Byte]] = { val startTime = System.currentTimeMillis - val proc = pb.start() + val worker = SparkEnv.get.getPythonWorker(pythonExec, envVars.toMap).create val env = SparkEnv.get - // Start a thread to print the process's stderr to ours - new Thread("stderr reader for " + pythonExec) { - override def run() { - for (line <- Source.fromInputStream(proc.getErrorStream).getLines) { - System.err.println(line) - } - } - }.start() - // Start a thread to feed the process input from our parent's iterator new Thread("stdin writer for " + pythonExec) { override def run() { SparkEnv.set(env) - val out = new PrintWriter(proc.getOutputStream) - val dOut = new DataOutputStream(proc.getOutputStream) + val out = new PrintWriter(worker.getOutputStream) + val dOut = new DataOutputStream(worker.getOutputStream) // Partition index dOut.writeInt(split.index) // sparkFilesDir @@ -89,16 +70,21 @@ private[spark] class PythonRDD[T: ClassManifest]( } dOut.flush() out.flush() - proc.getOutputStream.close() + worker.shutdownOutput() } }.start() // Return an iterator that read lines from the process's stdout - val stream = new DataInputStream(proc.getInputStream) + val stream = new DataInputStream(worker.getInputStream) return new Iterator[Array[Byte]] { def next(): Array[Byte] = { val obj = _nextObj - _nextObj = read() + if (hasNext) { + // FIXME: can deadlock if worker is waiting for us to + // respond to current message (currently irrelevant because + // output is shutdown before we read any input) + _nextObj = read() + } obj } @@ -110,7 +96,7 @@ private[spark] class PythonRDD[T: ClassManifest]( stream.readFully(obj) obj case -3 => - // Timing data from child + // Timing data from worker val bootTime = stream.readLong() val initTime = stream.readLong() val finishTime = stream.readLong() @@ -127,23 +113,21 @@ private[spark] class PythonRDD[T: ClassManifest]( stream.readFully(obj) throw new PythonException(new String(obj)) case -1 => - // We've finished the data section of the output, but we can still read some - // accumulator updates; let's do that, breaking when we get EOFException - while (true) { - val len2 = stream.readInt() + // We've finished the data section of the output, but we can still + // read some accumulator updates; let's do that, breaking when we + // get a negative length record. + var len2 = stream.readInt + while (len2 >= 0) { val update = new Array[Byte](len2) stream.readFully(update) accumulator += Collections.singletonList(update) + len2 = stream.readInt } new Array[Byte](0) } } catch { case eof: EOFException => { - val exitStatus = proc.waitFor() - if (exitStatus != 0) { - throw new Exception("Subprocess exited with status " + exitStatus) - } - new Array[Byte](0) + throw new SparkException("Python worker exited unexpectedly (crashed)", eof) } case e => throw e } @@ -171,7 +155,7 @@ private class PairwiseRDD(prev: RDD[Array[Byte]]) extends override def compute(split: Partition, context: TaskContext) = prev.iterator(split, context).grouped(2).map { case Seq(a, b) => (a, b) - case x => throw new Exception("PairwiseRDD: unexpected value: " + x) + case x => throw new SparkException("PairwiseRDD: unexpected value: " + x) } val asJavaPairRDD : JavaPairRDD[Array[Byte], Array[Byte]] = JavaPairRDD.fromRDD(this) } @@ -227,7 +211,7 @@ private[spark] object PythonRDD { dOut.write(s) dOut.writeByte(Pickle.STOP) } else { - throw new Exception("Unexpected RDD type") + throw new SparkException("Unexpected RDD type") } } diff --git a/core/src/main/scala/spark/api/python/PythonWorker.scala b/core/src/main/scala/spark/api/python/PythonWorker.scala new file mode 100644 index 0000000000..8ee3c6884f --- /dev/null +++ b/core/src/main/scala/spark/api/python/PythonWorker.scala @@ -0,0 +1,89 @@ +package spark.api.python + +import java.io.DataInputStream +import java.net.{Socket, SocketException, InetAddress} + +import scala.collection.JavaConversions._ + +import spark._ + +private[spark] class PythonWorker(pythonExec: String, envVars: Map[String, String]) + extends Logging { + var daemon: Process = null + val daemonHost = InetAddress.getByAddress(Array(127, 0, 0, 1)) + var daemonPort: Int = 0 + + def create(): Socket = { + synchronized { + // Start the daemon if it hasn't been started + startDaemon + + // Attempt to connect, restart and retry once if it fails + try { + new Socket(daemonHost, daemonPort) + } catch { + case exc: SocketException => { + logWarning("Python daemon unexpectedly quit, attempting to restart") + stopDaemon + startDaemon + new Socket(daemonHost, daemonPort) + } + case e => throw e + } + } + } + + private def startDaemon() { + synchronized { + // Is it already running? + if (daemon != null) { + return + } + + try { + // Create and start the daemon + val sparkHome = new ProcessBuilder().environment().get("SPARK_HOME") + val pb = new ProcessBuilder(Seq(pythonExec, sparkHome + "/python/pyspark/daemon.py")) + val workerEnv = pb.environment() + workerEnv.putAll(envVars) + daemon = pb.start() + daemonPort = new DataInputStream(daemon.getInputStream).readInt + + // Redirect the stderr to ours + new Thread("stderr reader for " + pythonExec) { + override def run() { + // FIXME HACK: We copy the stream on the level of bytes to + // attempt to dodge encoding problems. + val in = daemon.getErrorStream + var buf = new Array[Byte](1024) + var len = in.read(buf) + while (len != -1) { + System.err.write(buf, 0, len) + len = in.read(buf) + } + } + }.start() + } catch { + case e => { + stopDaemon + throw e + } + } + + // Important: don't close daemon's stdin (daemon.getOutputStream) so it can correctly + // detect our disappearance. + } + } + + private def stopDaemon() { + synchronized { + // Request shutdown of existing daemon by sending SIGTERM + if (daemon != null) { + daemon.destroy + } + + daemon = null + daemonPort = 0 + } + } +} diff --git a/python/pyspark/daemon.py b/python/pyspark/daemon.py new file mode 100644 index 0000000000..642f30b2b9 --- /dev/null +++ b/python/pyspark/daemon.py @@ -0,0 +1,109 @@ +import os +import sys +import multiprocessing +from errno import EINTR, ECHILD +from socket import socket, AF_INET, SOCK_STREAM, SOMAXCONN +from signal import signal, SIGHUP, SIGTERM, SIGCHLD, SIG_DFL, SIG_IGN +from pyspark.worker import main as worker_main +from pyspark.serializers import write_int + +try: + POOLSIZE = multiprocessing.cpu_count() +except NotImplementedError: + POOLSIZE = 4 + +should_exit = False + + +def worker(listen_sock): + # Redirect stdout to stderr + os.dup2(2, 1) + + # Manager sends SIGHUP to request termination of workers in the pool + def handle_sighup(signum, frame): + global should_exit + should_exit = True + signal(SIGHUP, handle_sighup) + + while not should_exit: + # Wait until a client arrives or we have to exit + sock = None + while not should_exit and sock is None: + try: + sock, addr = listen_sock.accept() + except EnvironmentError as err: + if err.errno != EINTR: + raise + + if sock is not None: + # Fork a child to handle the client + if os.fork() == 0: + # Leave the worker pool + signal(SIGHUP, SIG_DFL) + listen_sock.close() + # Handle the client then exit + sockfile = sock.makefile() + worker_main(sockfile, sockfile) + sockfile.close() + sock.close() + os._exit(0) + else: + sock.close() + + assert should_exit + os._exit(0) + + +def manager(): + # Create a new process group to corral our children + os.setpgid(0, 0) + + # Create a listening socket on the AF_INET loopback interface + listen_sock = socket(AF_INET, SOCK_STREAM) + listen_sock.bind(('127.0.0.1', 0)) + listen_sock.listen(max(1024, 2 * POOLSIZE, SOMAXCONN)) + listen_host, listen_port = listen_sock.getsockname() + write_int(listen_port, sys.stdout) + + # Launch initial worker pool + for idx in range(POOLSIZE): + if os.fork() == 0: + worker(listen_sock) + raise RuntimeError("worker() unexpectedly returned") + listen_sock.close() + + def shutdown(): + global should_exit + os.kill(0, SIGHUP) + should_exit = True + + # Gracefully exit on SIGTERM, don't die on SIGHUP + signal(SIGTERM, lambda signum, frame: shutdown()) + signal(SIGHUP, SIG_IGN) + + # Cleanup zombie children + def handle_sigchld(signum, frame): + try: + pid, status = os.waitpid(0, os.WNOHANG) + if (pid, status) != (0, 0) and not should_exit: + raise RuntimeError("pool member crashed: %s, %s" % (pid, status)) + except EnvironmentError as err: + if err.errno not in (ECHILD, EINTR): + raise + signal(SIGCHLD, handle_sigchld) + + # Initialization complete + sys.stdout.close() + while not should_exit: + try: + # Spark tells us to exit by closing stdin + if sys.stdin.read() == '': + shutdown() + except EnvironmentError as err: + if err.errno != EINTR: + shutdown() + raise + + +if __name__ == '__main__': + manager() diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 4c33ae49dc..94d612ea6e 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -1,10 +1,9 @@ """ Worker that receives input from Piped RDD. """ -import time -preboot_time = time.time() import os import sys +import time import traceback from base64 import standard_b64decode # CloudPickler needs to be imported so that depicklers are registered using the @@ -17,57 +16,55 @@ from pyspark.serializers import write_with_length, read_with_length, write_int, read_long, write_long, read_int, dump_pickle, load_pickle, read_from_pickle_file -# Redirect stdout to stderr so that users must return values from functions. -old_stdout = os.fdopen(os.dup(1), 'w') -os.dup2(2, 1) - - -def load_obj(): - return load_pickle(standard_b64decode(sys.stdin.readline().strip())) +def load_obj(infile): + return load_pickle(standard_b64decode(infile.readline().strip())) -def report_times(preboot, boot, init, finish): - write_int(-3, old_stdout) - write_long(1000 * preboot, old_stdout) - write_long(1000 * boot, old_stdout) - write_long(1000 * init, old_stdout) - write_long(1000 * finish, old_stdout) +def report_times(outfile, boot, init, finish): + write_int(-3, outfile) + write_long(1000 * boot, outfile) + write_long(1000 * init, outfile) + write_long(1000 * finish, outfile) -def main(): +def main(infile, outfile): boot_time = time.time() - split_index = read_int(sys.stdin) - spark_files_dir = load_pickle(read_with_length(sys.stdin)) + split_index = read_int(infile) + spark_files_dir = load_pickle(read_with_length(infile)) SparkFiles._root_directory = spark_files_dir SparkFiles._is_running_on_worker = True sys.path.append(spark_files_dir) - num_broadcast_variables = read_int(sys.stdin) + num_broadcast_variables = read_int(infile) for _ in range(num_broadcast_variables): - bid = read_long(sys.stdin) - value = read_with_length(sys.stdin) + bid = read_long(infile) + value = read_with_length(infile) _broadcastRegistry[bid] = Broadcast(bid, load_pickle(value)) - func = load_obj() - bypassSerializer = load_obj() + func = load_obj(infile) + bypassSerializer = load_obj(infile) if bypassSerializer: dumps = lambda x: x else: dumps = dump_pickle init_time = time.time() - iterator = read_from_pickle_file(sys.stdin) + iterator = read_from_pickle_file(infile) try: for obj in func(split_index, iterator): - write_with_length(dumps(obj), old_stdout) + write_with_length(dumps(obj), outfile) except Exception as e: - write_int(-2, old_stdout) - write_with_length(traceback.format_exc(), old_stdout) - sys.exit(-1) + write_int(-2, outfile) + write_with_length(traceback.format_exc(), outfile) + raise finish_time = time.time() - report_times(preboot_time, boot_time, init_time, finish_time) + report_times(outfile, boot_time, init_time, finish_time) # Mark the beginning of the accumulators section of the output - write_int(-1, old_stdout) + write_int(-1, outfile) for aid, accum in _accumulatorRegistry.items(): - write_with_length(dump_pickle((aid, accum._value)), old_stdout) + write_with_length(dump_pickle((aid, accum._value)), outfile) + write_int(-1, outfile) if __name__ == '__main__': - main() + # Redirect stdout to stderr so that users must return values from functions. + old_stdout = os.fdopen(os.dup(1), 'w') + os.dup2(2, 1) + main(sys.stdin, old_stdout) -- cgit v1.2.3 From 62c4781400dd908c2fccdcebf0dc816ff0cb8ed4 Mon Sep 17 00:00:00 2001 From: Jey Kottalam Date: Fri, 10 May 2013 15:48:48 -0700 Subject: Add tests and fixes for Python daemon shutdown --- core/src/main/scala/spark/SparkEnv.scala | 1 + .../main/scala/spark/api/python/PythonWorker.scala | 4 ++ python/pyspark/daemon.py | 46 +++++++++++----------- python/pyspark/tests.py | 43 ++++++++++++++++++++ python/pyspark/worker.py | 2 + 5 files changed, 74 insertions(+), 22 deletions(-) diff --git a/core/src/main/scala/spark/SparkEnv.scala b/core/src/main/scala/spark/SparkEnv.scala index 5691e24c32..5b55d45212 100644 --- a/core/src/main/scala/spark/SparkEnv.scala +++ b/core/src/main/scala/spark/SparkEnv.scala @@ -44,6 +44,7 @@ class SparkEnv ( private val pythonWorkers = mutable.HashMap[(String, Map[String, String]), PythonWorker]() def stop() { + pythonWorkers.foreach { case(key, worker) => worker.stop() } httpFileServer.stop() mapOutputTracker.stop() shuffleFetcher.stop() diff --git a/core/src/main/scala/spark/api/python/PythonWorker.scala b/core/src/main/scala/spark/api/python/PythonWorker.scala index 8ee3c6884f..74c8c6d37a 100644 --- a/core/src/main/scala/spark/api/python/PythonWorker.scala +++ b/core/src/main/scala/spark/api/python/PythonWorker.scala @@ -33,6 +33,10 @@ private[spark] class PythonWorker(pythonExec: String, envVars: Map[String, Strin } } + def stop() { + stopDaemon + } + private def startDaemon() { synchronized { // Is it already running? diff --git a/python/pyspark/daemon.py b/python/pyspark/daemon.py index 642f30b2b9..ab9c19df57 100644 --- a/python/pyspark/daemon.py +++ b/python/pyspark/daemon.py @@ -12,7 +12,7 @@ try: except NotImplementedError: POOLSIZE = 4 -should_exit = False +should_exit = multiprocessing.Event() def worker(listen_sock): @@ -21,14 +21,13 @@ def worker(listen_sock): # Manager sends SIGHUP to request termination of workers in the pool def handle_sighup(signum, frame): - global should_exit - should_exit = True + assert should_exit.is_set() signal(SIGHUP, handle_sighup) - while not should_exit: + while not should_exit.is_set(): # Wait until a client arrives or we have to exit sock = None - while not should_exit and sock is None: + while not should_exit.is_set() and sock is None: try: sock, addr = listen_sock.accept() except EnvironmentError as err: @@ -36,8 +35,8 @@ def worker(listen_sock): raise if sock is not None: - # Fork a child to handle the client - if os.fork() == 0: + # Fork to handle the client + if os.fork() != 0: # Leave the worker pool signal(SIGHUP, SIG_DFL) listen_sock.close() @@ -50,7 +49,7 @@ def worker(listen_sock): else: sock.close() - assert should_exit + assert should_exit.is_set() os._exit(0) @@ -73,9 +72,7 @@ def manager(): listen_sock.close() def shutdown(): - global should_exit - os.kill(0, SIGHUP) - should_exit = True + should_exit.set() # Gracefully exit on SIGTERM, don't die on SIGHUP signal(SIGTERM, lambda signum, frame: shutdown()) @@ -85,8 +82,8 @@ def manager(): def handle_sigchld(signum, frame): try: pid, status = os.waitpid(0, os.WNOHANG) - if (pid, status) != (0, 0) and not should_exit: - raise RuntimeError("pool member crashed: %s, %s" % (pid, status)) + if status != 0 and not should_exit.is_set(): + raise RuntimeError("worker crashed: %s, %s" % (pid, status)) except EnvironmentError as err: if err.errno not in (ECHILD, EINTR): raise @@ -94,15 +91,20 @@ def manager(): # Initialization complete sys.stdout.close() - while not should_exit: - try: - # Spark tells us to exit by closing stdin - if sys.stdin.read() == '': - shutdown() - except EnvironmentError as err: - if err.errno != EINTR: - shutdown() - raise + try: + while not should_exit.is_set(): + try: + # Spark tells us to exit by closing stdin + if os.read(0, 512) == '': + shutdown() + except EnvironmentError as err: + if err.errno != EINTR: + shutdown() + raise + finally: + should_exit.set() + # Send SIGHUP to notify workers of shutdown + os.kill(0, SIGHUP) if __name__ == '__main__': diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 6a1962d267..1e34d47365 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -12,6 +12,7 @@ import unittest from pyspark.context import SparkContext from pyspark.files import SparkFiles from pyspark.java_gateway import SPARK_HOME +from pyspark.serializers import read_int class PySparkTestCase(unittest.TestCase): @@ -117,5 +118,47 @@ class TestIO(PySparkTestCase): self.sc.parallelize([1]).foreach(func) +class TestDaemon(unittest.TestCase): + def connect(self, port): + from socket import socket, AF_INET, SOCK_STREAM + sock = socket(AF_INET, SOCK_STREAM) + sock.connect(('127.0.0.1', port)) + # send a split index of -1 to shutdown the worker + sock.send("\xFF\xFF\xFF\xFF") + sock.close() + return True + + def do_termination_test(self, terminator): + from subprocess import Popen, PIPE + from errno import ECONNREFUSED + + # start daemon + daemon_path = os.path.join(os.path.dirname(__file__), "daemon.py") + daemon = Popen([sys.executable, daemon_path], stdin=PIPE, stdout=PIPE) + + # read the port number + port = read_int(daemon.stdout) + + # daemon should accept connections + self.assertTrue(self.connect(port)) + + # request shutdown + terminator(daemon) + time.sleep(1) + + # daemon should no longer accept connections + with self.assertRaises(EnvironmentError) as trap: + self.connect(port) + self.assertEqual(trap.exception.errno, ECONNREFUSED) + + def test_termination_stdin(self): + """Ensure that daemon and workers terminate when stdin is closed.""" + self.do_termination_test(lambda daemon: daemon.stdin.close()) + + def test_termination_sigterm(self): + """Ensure that daemon and workers terminate on SIGTERM.""" + from signal import SIGTERM + self.do_termination_test(lambda daemon: os.kill(daemon.pid, SIGTERM)) + if __name__ == "__main__": unittest.main() diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 94d612ea6e..f76ee3c236 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -30,6 +30,8 @@ def report_times(outfile, boot, init, finish): def main(infile, outfile): boot_time = time.time() split_index = read_int(infile) + if split_index == -1: # for unit tests + return spark_files_dir = load_pickle(read_with_length(infile)) SparkFiles._root_directory = spark_files_dir SparkFiles._is_running_on_worker = True -- cgit v1.2.3 From edb18ca928c988a713b9228bb74af1737f2b614b Mon Sep 17 00:00:00 2001 From: Jey Kottalam Date: Mon, 13 May 2013 08:53:47 -0700 Subject: Rename PythonWorker to PythonWorkerFactory --- core/src/main/scala/spark/SparkEnv.scala | 8 +- .../main/scala/spark/api/python/PythonRDD.scala | 2 +- .../main/scala/spark/api/python/PythonWorker.scala | 93 --------------------- .../spark/api/python/PythonWorkerFactory.scala | 95 ++++++++++++++++++++++ 4 files changed, 100 insertions(+), 98 deletions(-) delete mode 100644 core/src/main/scala/spark/api/python/PythonWorker.scala create mode 100644 core/src/main/scala/spark/api/python/PythonWorkerFactory.scala diff --git a/core/src/main/scala/spark/SparkEnv.scala b/core/src/main/scala/spark/SparkEnv.scala index 5b55d45212..0a23c45658 100644 --- a/core/src/main/scala/spark/SparkEnv.scala +++ b/core/src/main/scala/spark/SparkEnv.scala @@ -12,7 +12,7 @@ import spark.storage.BlockManagerMaster import spark.network.ConnectionManager import spark.serializer.{Serializer, SerializerManager} import spark.util.AkkaUtils -import spark.api.python.PythonWorker +import spark.api.python.PythonWorkerFactory /** @@ -41,7 +41,7 @@ class SparkEnv ( // If executorId is NOT found, return defaultHostPort var executorIdToHostPort: Option[(String, String) => String]) { - private val pythonWorkers = mutable.HashMap[(String, Map[String, String]), PythonWorker]() + private val pythonWorkers = mutable.HashMap[(String, Map[String, String]), PythonWorkerFactory]() def stop() { pythonWorkers.foreach { case(key, worker) => worker.stop() } @@ -57,9 +57,9 @@ class SparkEnv ( actorSystem.awaitTermination() } - def getPythonWorker(pythonExec: String, envVars: Map[String, String]): PythonWorker = { + def createPythonWorker(pythonExec: String, envVars: Map[String, String]): java.net.Socket = { synchronized { - pythonWorkers.getOrElseUpdate((pythonExec, envVars), new PythonWorker(pythonExec, envVars)) + pythonWorkers.getOrElseUpdate((pythonExec, envVars), new PythonWorkerFactory(pythonExec, envVars)).create } } diff --git a/core/src/main/scala/spark/api/python/PythonRDD.scala b/core/src/main/scala/spark/api/python/PythonRDD.scala index e5acc54c01..3c48071b3f 100644 --- a/core/src/main/scala/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/spark/api/python/PythonRDD.scala @@ -38,8 +38,8 @@ private[spark] class PythonRDD[T: ClassManifest]( override def compute(split: Partition, context: TaskContext): Iterator[Array[Byte]] = { val startTime = System.currentTimeMillis - val worker = SparkEnv.get.getPythonWorker(pythonExec, envVars.toMap).create val env = SparkEnv.get + val worker = env.createPythonWorker(pythonExec, envVars.toMap) // Start a thread to feed the process input from our parent's iterator new Thread("stdin writer for " + pythonExec) { diff --git a/core/src/main/scala/spark/api/python/PythonWorker.scala b/core/src/main/scala/spark/api/python/PythonWorker.scala deleted file mode 100644 index 74c8c6d37a..0000000000 --- a/core/src/main/scala/spark/api/python/PythonWorker.scala +++ /dev/null @@ -1,93 +0,0 @@ -package spark.api.python - -import java.io.DataInputStream -import java.net.{Socket, SocketException, InetAddress} - -import scala.collection.JavaConversions._ - -import spark._ - -private[spark] class PythonWorker(pythonExec: String, envVars: Map[String, String]) - extends Logging { - var daemon: Process = null - val daemonHost = InetAddress.getByAddress(Array(127, 0, 0, 1)) - var daemonPort: Int = 0 - - def create(): Socket = { - synchronized { - // Start the daemon if it hasn't been started - startDaemon - - // Attempt to connect, restart and retry once if it fails - try { - new Socket(daemonHost, daemonPort) - } catch { - case exc: SocketException => { - logWarning("Python daemon unexpectedly quit, attempting to restart") - stopDaemon - startDaemon - new Socket(daemonHost, daemonPort) - } - case e => throw e - } - } - } - - def stop() { - stopDaemon - } - - private def startDaemon() { - synchronized { - // Is it already running? - if (daemon != null) { - return - } - - try { - // Create and start the daemon - val sparkHome = new ProcessBuilder().environment().get("SPARK_HOME") - val pb = new ProcessBuilder(Seq(pythonExec, sparkHome + "/python/pyspark/daemon.py")) - val workerEnv = pb.environment() - workerEnv.putAll(envVars) - daemon = pb.start() - daemonPort = new DataInputStream(daemon.getInputStream).readInt - - // Redirect the stderr to ours - new Thread("stderr reader for " + pythonExec) { - override def run() { - // FIXME HACK: We copy the stream on the level of bytes to - // attempt to dodge encoding problems. - val in = daemon.getErrorStream - var buf = new Array[Byte](1024) - var len = in.read(buf) - while (len != -1) { - System.err.write(buf, 0, len) - len = in.read(buf) - } - } - }.start() - } catch { - case e => { - stopDaemon - throw e - } - } - - // Important: don't close daemon's stdin (daemon.getOutputStream) so it can correctly - // detect our disappearance. - } - } - - private def stopDaemon() { - synchronized { - // Request shutdown of existing daemon by sending SIGTERM - if (daemon != null) { - daemon.destroy - } - - daemon = null - daemonPort = 0 - } - } -} diff --git a/core/src/main/scala/spark/api/python/PythonWorkerFactory.scala b/core/src/main/scala/spark/api/python/PythonWorkerFactory.scala new file mode 100644 index 0000000000..ebbd226b3e --- /dev/null +++ b/core/src/main/scala/spark/api/python/PythonWorkerFactory.scala @@ -0,0 +1,95 @@ +package spark.api.python + +import java.io.{DataInputStream, IOException} +import java.net.{Socket, SocketException, InetAddress} + +import scala.collection.JavaConversions._ + +import spark._ + +private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String, String]) + extends Logging { + var daemon: Process = null + val daemonHost = InetAddress.getByAddress(Array(127, 0, 0, 1)) + var daemonPort: Int = 0 + + def create(): Socket = { + synchronized { + // Start the daemon if it hasn't been started + startDaemon + + // Attempt to connect, restart and retry once if it fails + try { + new Socket(daemonHost, daemonPort) + } catch { + case exc: SocketException => { + logWarning("Python daemon unexpectedly quit, attempting to restart") + stopDaemon + startDaemon + new Socket(daemonHost, daemonPort) + } + case e => throw e + } + } + } + + def stop() { + stopDaemon + } + + private def startDaemon() { + synchronized { + // Is it already running? + if (daemon != null) { + return + } + + try { + // Create and start the daemon + val sparkHome = new ProcessBuilder().environment().get("SPARK_HOME") + val pb = new ProcessBuilder(Seq(pythonExec, sparkHome + "/python/pyspark/daemon.py")) + val workerEnv = pb.environment() + workerEnv.putAll(envVars) + daemon = pb.start() + daemonPort = new DataInputStream(daemon.getInputStream).readInt + + // Redirect the stderr to ours + new Thread("stderr reader for " + pythonExec) { + override def run() { + scala.util.control.Exception.ignoring(classOf[IOException]) { + // FIXME HACK: We copy the stream on the level of bytes to + // attempt to dodge encoding problems. + val in = daemon.getErrorStream + var buf = new Array[Byte](1024) + var len = in.read(buf) + while (len != -1) { + System.err.write(buf, 0, len) + len = in.read(buf) + } + } + } + }.start() + } catch { + case e => { + stopDaemon + throw e + } + } + + // Important: don't close daemon's stdin (daemon.getOutputStream) so it can correctly + // detect our disappearance. + } + } + + private def stopDaemon() { + synchronized { + // Request shutdown of existing daemon by sending SIGTERM + if (daemon != null) { + daemon.destroy + } + + daemon = null + daemonPort = 0 + } + } +} -- cgit v1.2.3 From 7c5ff733ee1d3729b4b26f7c5542ca00c4d64139 Mon Sep 17 00:00:00 2001 From: Jey Kottalam Date: Thu, 23 May 2013 11:50:24 -0700 Subject: PySpark daemon: fix deadlock, improve error handling --- python/pyspark/daemon.py | 67 ++++++++++++++++++++++++++++++++++++------------ 1 file changed, 50 insertions(+), 17 deletions(-) diff --git a/python/pyspark/daemon.py b/python/pyspark/daemon.py index ab9c19df57..2b5e9b3581 100644 --- a/python/pyspark/daemon.py +++ b/python/pyspark/daemon.py @@ -1,6 +1,7 @@ import os import sys import multiprocessing +from ctypes import c_bool from errno import EINTR, ECHILD from socket import socket, AF_INET, SOCK_STREAM, SOMAXCONN from signal import signal, SIGHUP, SIGTERM, SIGCHLD, SIG_DFL, SIG_IGN @@ -12,7 +13,12 @@ try: except NotImplementedError: POOLSIZE = 4 -should_exit = multiprocessing.Event() +exit_flag = multiprocessing.Value(c_bool, False) + + +def should_exit(): + global exit_flag + return exit_flag.value def worker(listen_sock): @@ -20,14 +26,29 @@ def worker(listen_sock): os.dup2(2, 1) # Manager sends SIGHUP to request termination of workers in the pool - def handle_sighup(signum, frame): - assert should_exit.is_set() + def handle_sighup(*args): + assert should_exit() signal(SIGHUP, handle_sighup) - while not should_exit.is_set(): + # Cleanup zombie children + def handle_sigchld(*args): + pid = status = None + try: + while (pid, status) != (0, 0): + pid, status = os.waitpid(0, os.WNOHANG) + except EnvironmentError as err: + if err.errno == EINTR: + # retry + handle_sigchld() + elif err.errno != ECHILD: + raise + signal(SIGCHLD, handle_sigchld) + + # Handle clients + while not should_exit(): # Wait until a client arrives or we have to exit sock = None - while not should_exit.is_set() and sock is None: + while not should_exit() and sock is None: try: sock, addr = listen_sock.accept() except EnvironmentError as err: @@ -35,8 +56,10 @@ def worker(listen_sock): raise if sock is not None: - # Fork to handle the client - if os.fork() != 0: + # Fork a child to handle the client. + # The client is handled in the child so that the manager + # never receives SIGCHLD unless a worker crashes. + if os.fork() == 0: # Leave the worker pool signal(SIGHUP, SIG_DFL) listen_sock.close() @@ -49,8 +72,18 @@ def worker(listen_sock): else: sock.close() - assert should_exit.is_set() - os._exit(0) + +def launch_worker(listen_sock): + if os.fork() == 0: + try: + worker(listen_sock) + except Exception as err: + import traceback + traceback.print_exc() + os._exit(1) + else: + assert should_exit() + os._exit(0) def manager(): @@ -66,23 +99,22 @@ def manager(): # Launch initial worker pool for idx in range(POOLSIZE): - if os.fork() == 0: - worker(listen_sock) - raise RuntimeError("worker() unexpectedly returned") + launch_worker(listen_sock) listen_sock.close() def shutdown(): - should_exit.set() + global exit_flag + exit_flag.value = True # Gracefully exit on SIGTERM, don't die on SIGHUP signal(SIGTERM, lambda signum, frame: shutdown()) signal(SIGHUP, SIG_IGN) # Cleanup zombie children - def handle_sigchld(signum, frame): + def handle_sigchld(*args): try: pid, status = os.waitpid(0, os.WNOHANG) - if status != 0 and not should_exit.is_set(): + if status != 0 and not should_exit(): raise RuntimeError("worker crashed: %s, %s" % (pid, status)) except EnvironmentError as err: if err.errno not in (ECHILD, EINTR): @@ -92,7 +124,7 @@ def manager(): # Initialization complete sys.stdout.close() try: - while not should_exit.is_set(): + while not should_exit(): try: # Spark tells us to exit by closing stdin if os.read(0, 512) == '': @@ -102,7 +134,8 @@ def manager(): shutdown() raise finally: - should_exit.set() + signal(SIGTERM, SIG_DFL) + exit_flag.value = True # Send SIGHUP to notify workers of shutdown os.kill(0, SIGHUP) -- cgit v1.2.3 From 1ba3c173034c37ef99fc312c84943d2ab8885670 Mon Sep 17 00:00:00 2001 From: Jey Kottalam Date: Thu, 20 Jun 2013 12:49:10 -0400 Subject: use parens when calling method with side-effects --- core/src/main/scala/spark/SparkEnv.scala | 2 +- core/src/main/scala/spark/api/python/PythonRDD.scala | 4 ++-- .../main/scala/spark/api/python/PythonWorkerFactory.scala | 14 +++++++------- 3 files changed, 10 insertions(+), 10 deletions(-) diff --git a/core/src/main/scala/spark/SparkEnv.scala b/core/src/main/scala/spark/SparkEnv.scala index 0a23c45658..7ccde2e818 100644 --- a/core/src/main/scala/spark/SparkEnv.scala +++ b/core/src/main/scala/spark/SparkEnv.scala @@ -59,7 +59,7 @@ class SparkEnv ( def createPythonWorker(pythonExec: String, envVars: Map[String, String]): java.net.Socket = { synchronized { - pythonWorkers.getOrElseUpdate((pythonExec, envVars), new PythonWorkerFactory(pythonExec, envVars)).create + pythonWorkers.getOrElseUpdate((pythonExec, envVars), new PythonWorkerFactory(pythonExec, envVars)).create() } } diff --git a/core/src/main/scala/spark/api/python/PythonRDD.scala b/core/src/main/scala/spark/api/python/PythonRDD.scala index 3c48071b3f..63140cf37f 100644 --- a/core/src/main/scala/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/spark/api/python/PythonRDD.scala @@ -116,12 +116,12 @@ private[spark] class PythonRDD[T: ClassManifest]( // We've finished the data section of the output, but we can still // read some accumulator updates; let's do that, breaking when we // get a negative length record. - var len2 = stream.readInt + var len2 = stream.readInt() while (len2 >= 0) { val update = new Array[Byte](len2) stream.readFully(update) accumulator += Collections.singletonList(update) - len2 = stream.readInt + len2 = stream.readInt() } new Array[Byte](0) } diff --git a/core/src/main/scala/spark/api/python/PythonWorkerFactory.scala b/core/src/main/scala/spark/api/python/PythonWorkerFactory.scala index ebbd226b3e..8844411d73 100644 --- a/core/src/main/scala/spark/api/python/PythonWorkerFactory.scala +++ b/core/src/main/scala/spark/api/python/PythonWorkerFactory.scala @@ -16,7 +16,7 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String def create(): Socket = { synchronized { // Start the daemon if it hasn't been started - startDaemon + startDaemon() // Attempt to connect, restart and retry once if it fails try { @@ -24,8 +24,8 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String } catch { case exc: SocketException => { logWarning("Python daemon unexpectedly quit, attempting to restart") - stopDaemon - startDaemon + stopDaemon() + startDaemon() new Socket(daemonHost, daemonPort) } case e => throw e @@ -34,7 +34,7 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String } def stop() { - stopDaemon + stopDaemon() } private def startDaemon() { @@ -51,7 +51,7 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String val workerEnv = pb.environment() workerEnv.putAll(envVars) daemon = pb.start() - daemonPort = new DataInputStream(daemon.getInputStream).readInt + daemonPort = new DataInputStream(daemon.getInputStream).readInt() // Redirect the stderr to ours new Thread("stderr reader for " + pythonExec) { @@ -71,7 +71,7 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String }.start() } catch { case e => { - stopDaemon + stopDaemon() throw e } } @@ -85,7 +85,7 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String synchronized { // Request shutdown of existing daemon by sending SIGTERM if (daemon != null) { - daemon.destroy + daemon.destroy() } daemon = null -- cgit v1.2.3 From c75bed0eebb1f937db02eb98deecd380724f747d Mon Sep 17 00:00:00 2001 From: Jey Kottalam Date: Fri, 21 Jun 2013 12:13:48 -0400 Subject: Fix reporting of PySpark exceptions --- python/pyspark/daemon.py | 22 ++++++++++++++++++---- python/pyspark/worker.py | 2 +- 2 files changed, 19 insertions(+), 5 deletions(-) diff --git a/python/pyspark/daemon.py b/python/pyspark/daemon.py index 2b5e9b3581..78a2da1e18 100644 --- a/python/pyspark/daemon.py +++ b/python/pyspark/daemon.py @@ -21,6 +21,15 @@ def should_exit(): return exit_flag.value +def compute_real_exit_code(exit_code): + # SystemExit's code can be integer or string, but os._exit only accepts integers + import numbers + if isinstance(exit_code, numbers.Integral): + return exit_code + else: + return 1 + + def worker(listen_sock): # Redirect stdout to stderr os.dup2(2, 1) @@ -65,10 +74,15 @@ def worker(listen_sock): listen_sock.close() # Handle the client then exit sockfile = sock.makefile() - worker_main(sockfile, sockfile) - sockfile.close() - sock.close() - os._exit(0) + exit_code = 0 + try: + worker_main(sockfile, sockfile) + except SystemExit as exc: + exit_code = exc.code + finally: + sockfile.close() + sock.close() + os._exit(compute_real_exit_code(exit_code)) else: sock.close() diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index f76ee3c236..379bbfd4c2 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -55,7 +55,7 @@ def main(infile, outfile): except Exception as e: write_int(-2, outfile) write_with_length(traceback.format_exc(), outfile) - raise + sys.exit(-1) finish_time = time.time() report_times(outfile, boot_time, init_time, finish_time) # Mark the beginning of the accumulators section of the output -- cgit v1.2.3 From b350f34703d4f29bbd0e603df852f7aae230b2a2 Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Sat, 22 Jun 2013 07:48:20 -0700 Subject: Increase memory for tests to prevent a crash on JDK 7 --- project/SparkBuild.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 24c8b734d0..faf6e2ae8e 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -54,7 +54,7 @@ object SparkBuild extends Build { // Fork new JVMs for tests and set Java options for those fork := true, - javaOptions += "-Xmx2g", + javaOptions += "-Xmx2500m", // Only allow one test at a time, even across projects, since they run in the same JVM concurrentRestrictions in Global += Tags.limit(Tags.Test, 1), -- cgit v1.2.3 From d92d3f7938dec954ea31de232f50cafd4b644065 Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Sat, 22 Jun 2013 10:24:19 -0700 Subject: Fix resolution of example code with Maven builds --- run | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/run b/run index c0065c53f1..e656e38ccf 100755 --- a/run +++ b/run @@ -132,10 +132,14 @@ if [ -e "$FWDIR/lib_managed" ]; then CLASSPATH="$CLASSPATH:$FWDIR/lib_managed/bundles/*" fi CLASSPATH="$CLASSPATH:$REPL_DIR/lib/*" +# Add the shaded JAR for Maven builds if [ -e $REPL_BIN_DIR/target ]; then for jar in `find "$REPL_BIN_DIR/target" -name 'spark-repl-*-shaded-hadoop*.jar'`; do CLASSPATH="$CLASSPATH:$jar" done + # The shaded JAR doesn't contain examples, so include those separately + EXAMPLES_JAR=`ls "$EXAMPLES_DIR/target/spark-examples"*[0-9T].jar` + CLASSPATH+=":$EXAMPLES_JAR" fi CLASSPATH="$CLASSPATH:$BAGEL_DIR/target/scala-$SCALA_VERSION/classes" for jar in `find $PYSPARK_DIR/lib -name '*jar'`; do @@ -148,9 +152,9 @@ if [ -e "$EXAMPLES_DIR/target/scala-$SCALA_VERSION/spark-examples"*[0-9T].jar ]; # Use the JAR from the SBT build export SPARK_EXAMPLES_JAR=`ls "$EXAMPLES_DIR/target/scala-$SCALA_VERSION/spark-examples"*[0-9T].jar` fi -if [ -e "$EXAMPLES_DIR/target/spark-examples-"*hadoop[12].jar ]; then +if [ -e "$EXAMPLES_DIR/target/spark-examples"*[0-9T].jar ]; then # Use the JAR from the Maven build - export SPARK_EXAMPLES_JAR=`ls "$EXAMPLES_DIR/target/spark-examples-"*hadoop[12].jar` + export SPARK_EXAMPLES_JAR=`ls "$EXAMPLES_DIR/target/spark-examples"*[0-9T].jar` fi # Add hadoop conf dir - else FileSystem.*, etc fail ! -- cgit v1.2.3 From b5df1cd668e45fd0cc22c1666136d05548cae3e9 Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Sat, 22 Jun 2013 17:12:39 -0700 Subject: ADD_JARS environment variable for spark-shell --- docs/scala-programming-guide.md | 10 ++++++++-- repl/src/main/scala/spark/repl/SparkILoop.scala | 9 +++++++-- 2 files changed, 15 insertions(+), 4 deletions(-) diff --git a/docs/scala-programming-guide.md b/docs/scala-programming-guide.md index b0da130fcb..e9cf9ef36f 100644 --- a/docs/scala-programming-guide.md +++ b/docs/scala-programming-guide.md @@ -43,12 +43,18 @@ new SparkContext(master, appName, [sparkHome], [jars]) The `master` parameter is a string specifying a [Spark or Mesos cluster URL](#master-urls) to connect to, or a special "local" string to run in local mode, as described below. `appName` is a name for your application, which will be shown in the cluster web UI. Finally, the last two parameters are needed to deploy your code to a cluster if running in distributed mode, as described later. -In the Spark shell, a special interpreter-aware SparkContext is already created for you, in the variable called `sc`. Making your own SparkContext will not work. You can set which master the context connects to using the `MASTER` environment variable. For example, to run on four cores, use +In the Spark shell, a special interpreter-aware SparkContext is already created for you, in the variable called `sc`. Making your own SparkContext will not work. You can set which master the context connects to using the `MASTER` environment variable, and you can add JARs to the classpath with the `ADD_JARS` variable. For example, to run `spark-shell` on four cores, use {% highlight bash %} $ MASTER=local[4] ./spark-shell {% endhighlight %} +Or, to also add `code.jar` to its classpath, use: + +{% highlight bash %} +$ MASTER=local[4] ADD_JARS=code.jar ./spark-shell +{% endhighlight %} + ### Master URLs The master URL passed to Spark can be in one of the following formats: @@ -78,7 +84,7 @@ If you want to run your job on a cluster, you will need to specify the two optio * `sparkHome`: The path at which Spark is installed on your worker machines (it should be the same on all of them). * `jars`: A list of JAR files on the local machine containing your job's code and any dependencies, which Spark will deploy to all the worker nodes. You'll need to package your job into a set of JARs using your build system. For example, if you're using SBT, the [sbt-assembly](https://github.com/sbt/sbt-assembly) plugin is a good way to make a single JAR with your code and dependencies. -If you run `spark-shell` on a cluster, any classes you define in the shell will automatically be distributed. +If you run `spark-shell` on a cluster, you can add JARs to it by specifying the `ADD_JARS` environment variable before you launch it. This variable should contain a comma-separated list of JARs. For example, `ADD_JARS=a.jar,b.jar ./spark-shell` will launch a shell with `a.jar` and `b.jar` on its classpath. In addition, any new classes you define in the shell will automatically be distributed. # Resilient Distributed Datasets (RDDs) diff --git a/repl/src/main/scala/spark/repl/SparkILoop.scala b/repl/src/main/scala/spark/repl/SparkILoop.scala index 23556dbc8f..86eed090d0 100644 --- a/repl/src/main/scala/spark/repl/SparkILoop.scala +++ b/repl/src/main/scala/spark/repl/SparkILoop.scala @@ -822,7 +822,7 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master: spark.repl.Main.interp.out.println("Spark context available as sc."); spark.repl.Main.interp.out.flush(); """) - command("import spark.SparkContext._"); + command("import spark.SparkContext._") } echo("Type in expressions to have them evaluated.") echo("Type :help for more information.") @@ -838,7 +838,8 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master: if (prop != null) prop else "local" } } - sparkContext = new SparkContext(master, "Spark shell") + val jars = Option(System.getenv("ADD_JARS")).map(_.split(',')).getOrElse(new Array[String](0)) + sparkContext = new SparkContext(master, "Spark shell", System.getenv("SPARK_HOME"), jars) sparkContext } @@ -850,6 +851,10 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master: printWelcome() echo("Initializing interpreter...") + // Add JARS specified in Spark's ADD_JARS variable to classpath + val jars = Option(System.getenv("ADD_JARS")).map(_.split(',')).getOrElse(new Array[String](0)) + jars.foreach(settings.classpath.append(_)) + this.settings = settings createInterpreter() -- cgit v1.2.3 From 0e0f9d3069039f03bbf5eefe3b0637c89fddf0f1 Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Sat, 22 Jun 2013 17:44:04 -0700 Subject: Fix search path for REPL class loader to really find added JARs --- core/src/main/scala/spark/executor/Executor.scala | 38 +++++++++++++---------- repl/src/main/scala/spark/repl/SparkILoop.scala | 4 ++- 2 files changed, 25 insertions(+), 17 deletions(-) diff --git a/core/src/main/scala/spark/executor/Executor.scala b/core/src/main/scala/spark/executor/Executor.scala index 8bebfafce4..2bf55ea9a9 100644 --- a/core/src/main/scala/spark/executor/Executor.scala +++ b/core/src/main/scala/spark/executor/Executor.scala @@ -42,7 +42,8 @@ private[spark] class Executor(executorId: String, slaveHostname: String, propert // Create our ClassLoader and set it on this thread private val urlClassLoader = createClassLoader() - Thread.currentThread.setContextClassLoader(urlClassLoader) + private val replClassLoader = addReplClassLoaderIfNeeded(urlClassLoader) + Thread.currentThread.setContextClassLoader(replClassLoader) // Make any thread terminations due to uncaught exceptions kill the entire // executor process to avoid surprising stalls. @@ -88,7 +89,7 @@ private[spark] class Executor(executorId: String, slaveHostname: String, propert override def run() { val startTime = System.currentTimeMillis() SparkEnv.set(env) - Thread.currentThread.setContextClassLoader(urlClassLoader) + Thread.currentThread.setContextClassLoader(replClassLoader) val ser = SparkEnv.get.closureSerializer.newInstance() logInfo("Running task ID " + taskId) context.statusUpdate(taskId, TaskState.RUNNING, EMPTY_BYTE_BUFFER) @@ -153,26 +154,31 @@ private[spark] class Executor(executorId: String, slaveHostname: String, propert val urls = currentJars.keySet.map { uri => new File(uri.split("/").last).toURI.toURL }.toArray - loader = new URLClassLoader(urls, loader) + new ExecutorURLClassLoader(urls, loader) + } - // If the REPL is in use, add another ClassLoader that will read - // new classes defined by the REPL as the user types code + /** + * If the REPL is in use, add another ClassLoader that will read + * new classes defined by the REPL as the user types code + */ + private def addReplClassLoaderIfNeeded(parent: ClassLoader): ClassLoader = { val classUri = System.getProperty("spark.repl.class.uri") if (classUri != null) { logInfo("Using REPL class URI: " + classUri) - loader = { - try { - val klass = Class.forName("spark.repl.ExecutorClassLoader") - .asInstanceOf[Class[_ <: ClassLoader]] - val constructor = klass.getConstructor(classOf[String], classOf[ClassLoader]) - constructor.newInstance(classUri, loader) - } catch { - case _: ClassNotFoundException => loader - } + try { + val klass = Class.forName("spark.repl.ExecutorClassLoader") + .asInstanceOf[Class[_ <: ClassLoader]] + val constructor = klass.getConstructor(classOf[String], classOf[ClassLoader]) + return constructor.newInstance(classUri, parent) + } catch { + case _: ClassNotFoundException => + logError("Could not find spark.repl.ExecutorClassLoader on classpath!") + System.exit(1) + null } + } else { + return parent } - - return new ExecutorURLClassLoader(Array(), loader) } /** diff --git a/repl/src/main/scala/spark/repl/SparkILoop.scala b/repl/src/main/scala/spark/repl/SparkILoop.scala index 86eed090d0..59f9d05683 100644 --- a/repl/src/main/scala/spark/repl/SparkILoop.scala +++ b/repl/src/main/scala/spark/repl/SparkILoop.scala @@ -838,7 +838,9 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master: if (prop != null) prop else "local" } } - val jars = Option(System.getenv("ADD_JARS")).map(_.split(',')).getOrElse(new Array[String](0)) + val jars = Option(System.getenv("ADD_JARS")).map(_.split(',')) + .getOrElse(new Array[String](0)) + .map(new java.io.File(_).getAbsolutePath) sparkContext = new SparkContext(master, "Spark shell", System.getenv("SPARK_HOME"), jars) sparkContext } -- cgit v1.2.3 From 78ffe164b33c6b11a2e511442605acd2f795a1b5 Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Sun, 23 Jun 2013 10:07:16 -0700 Subject: Clone the zero value for each key in foldByKey The old version reused the object within each task, leading to overwriting of the object when a mutable type is used, which is expected to be common in fold. Conflicts: core/src/test/scala/spark/ShuffleSuite.scala --- core/src/main/scala/spark/PairRDDFunctions.scala | 15 ++++++++++++--- core/src/test/scala/spark/ShuffleSuite.scala | 22 ++++++++++++++++++++++ 2 files changed, 34 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/spark/PairRDDFunctions.scala b/core/src/main/scala/spark/PairRDDFunctions.scala index fa4bbfc76f..7630fe7803 100644 --- a/core/src/main/scala/spark/PairRDDFunctions.scala +++ b/core/src/main/scala/spark/PairRDDFunctions.scala @@ -1,5 +1,6 @@ package spark +import java.nio.ByteBuffer import java.util.{Date, HashMap => JHashMap} import java.text.SimpleDateFormat @@ -64,8 +65,7 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest]( throw new SparkException("Default partitioner cannot partition array keys.") } } - val aggregator = - new Aggregator[K, V, C](createCombiner, mergeValue, mergeCombiners) + val aggregator = new Aggregator[K, V, C](createCombiner, mergeValue, mergeCombiners) if (self.partitioner == Some(partitioner)) { self.mapPartitions(aggregator.combineValuesByKey(_), true) } else if (mapSideCombine) { @@ -97,7 +97,16 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest]( * list concatenation, 0 for addition, or 1 for multiplication.). */ def foldByKey(zeroValue: V, partitioner: Partitioner)(func: (V, V) => V): RDD[(K, V)] = { - combineByKey[V]({v: V => func(zeroValue, v)}, func, func, partitioner) + // Serialize the zero value to a byte array so that we can get a new clone of it on each key + val zeroBuffer = SparkEnv.get.closureSerializer.newInstance().serialize(zeroValue) + val zeroArray = new Array[Byte](zeroBuffer.limit) + zeroBuffer.get(zeroArray) + + // When deserializing, use a lazy val to create just one instance of the serializer per task + lazy val cachedSerializer = SparkEnv.get.closureSerializer.newInstance() + def createZero() = cachedSerializer.deserialize[V](ByteBuffer.wrap(zeroArray)) + + combineByKey[V]((v: V) => func(createZero(), v), func, func, partitioner) } /** diff --git a/core/src/test/scala/spark/ShuffleSuite.scala b/core/src/test/scala/spark/ShuffleSuite.scala index 1916885a73..0c1ec29f96 100644 --- a/core/src/test/scala/spark/ShuffleSuite.scala +++ b/core/src/test/scala/spark/ShuffleSuite.scala @@ -392,6 +392,28 @@ class ShuffleSuite extends FunSuite with ShouldMatchers with LocalSparkContext { assert(nonEmptyBlocks.size <= 4) } + test("foldByKey") { + sc = new SparkContext("local", "test") + val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1))) + val sums = pairs.foldByKey(0)(_+_).collect() + assert(sums.toSet === Set((1, 7), (2, 1))) + } + + test("foldByKey with mutable result type") { + sc = new SparkContext("local", "test") + val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1))) + val bufs = pairs.mapValues(v => ArrayBuffer(v)).cache() + // Fold the values using in-place mutation + val sums = bufs.foldByKey(new ArrayBuffer[Int])(_ ++= _).collect() + assert(sums.toSet === Set((1, ArrayBuffer(1, 2, 3, 1)), (2, ArrayBuffer(1)))) + // Check that the mutable objects in the original RDD were not changed + assert(bufs.collect().toSet === Set( + (1, ArrayBuffer(1)), + (1, ArrayBuffer(2)), + (1, ArrayBuffer(3)), + (1, ArrayBuffer(1)), + (2, ArrayBuffer(1)))) + } } object ShuffleSuite { -- cgit v1.2.3 From 48c7e373c62b2e8cf48157ceb0d92c38c3a40652 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Mon, 24 Jun 2013 23:11:04 -0700 Subject: Minor formatting fixes --- .../src/main/scala/spark/streaming/DStream.scala | 9 +++++-- .../scala/spark/streaming/StreamingContext.scala | 29 +++++++++++++--------- .../streaming/api/java/JavaStreamingContext.scala | 15 +++++++---- 3 files changed, 34 insertions(+), 19 deletions(-) diff --git a/streaming/src/main/scala/spark/streaming/DStream.scala b/streaming/src/main/scala/spark/streaming/DStream.scala index e125310861..9be7926a4a 100644 --- a/streaming/src/main/scala/spark/streaming/DStream.scala +++ b/streaming/src/main/scala/spark/streaming/DStream.scala @@ -441,7 +441,12 @@ abstract class DStream[T: ClassManifest] ( * Return a new DStream in which each RDD has a single element generated by counting each RDD * of this DStream. */ - def count(): DStream[Long] = this.map(_ => (null, 1L)).transform(_.union(context.sparkContext.makeRDD(Seq((null, 0L)), 1))).reduceByKey(_ + _).map(_._2) + def count(): DStream[Long] = { + this.map(_ => (null, 1L)) + .transform(_.union(context.sparkContext.makeRDD(Seq((null, 0L)), 1))) + .reduceByKey(_ + _) + .map(_._2) + } /** * Return a new DStream in which each RDD contains the counts of each distinct value in @@ -457,7 +462,7 @@ abstract class DStream[T: ClassManifest] ( * this DStream will be registered as an output stream and therefore materialized. */ def foreach(foreachFunc: RDD[T] => Unit) { - foreach((r: RDD[T], t: Time) => foreachFunc(r)) + this.foreach((r: RDD[T], t: Time) => foreachFunc(r)) } /** diff --git a/streaming/src/main/scala/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/spark/streaming/StreamingContext.scala index 2c6326943d..03d2907323 100644 --- a/streaming/src/main/scala/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/spark/streaming/StreamingContext.scala @@ -171,10 +171,11 @@ class StreamingContext private ( * should be same. */ def actorStream[T: ClassManifest]( - props: Props, - name: String, - storageLevel: StorageLevel = StorageLevel.MEMORY_ONLY_SER_2, - supervisorStrategy: SupervisorStrategy = ReceiverSupervisorStrategy.defaultStrategy): DStream[T] = { + props: Props, + name: String, + storageLevel: StorageLevel = StorageLevel.MEMORY_ONLY_SER_2, + supervisorStrategy: SupervisorStrategy = ReceiverSupervisorStrategy.defaultStrategy + ): DStream[T] = { networkStream(new ActorReceiver[T](props, name, storageLevel, supervisorStrategy)) } @@ -182,9 +183,10 @@ class StreamingContext private ( * Create an input stream that receives messages pushed by a zeromq publisher. * @param publisherUrl Url of remote zeromq publisher * @param subscribe topic to subscribe to - * @param bytesToObjects A zeroMQ stream publishes sequence of frames for each topic and each frame has sequence - * of byte thus it needs the converter(which might be deserializer of bytes) - * to translate from sequence of sequence of bytes, where sequence refer to a frame + * @param bytesToObjects A zeroMQ stream publishes sequence of frames for each topic + * and each frame has sequence of byte thus it needs the converter + * (which might be deserializer of bytes) to translate from sequence + * of sequence of bytes, where sequence refer to a frame * and sub sequence refer to its payload. * @param storageLevel RDD storage level. Defaults to memory-only. */ @@ -204,7 +206,7 @@ class StreamingContext private ( * @param zkQuorum Zookeper quorum (hostname:port,hostname:port,..). * @param groupId The group id for this consumer. * @param topics Map of (topic_name -> numPartitions) to consume. Each partition is consumed - * in its own thread. + * in its own thread. * @param storageLevel Storage level to use for storing the received objects * (default: StorageLevel.MEMORY_AND_DISK_SER_2) */ @@ -214,15 +216,17 @@ class StreamingContext private ( topics: Map[String, Int], storageLevel: StorageLevel = StorageLevel.MEMORY_ONLY_SER_2 ): DStream[String] = { - val kafkaParams = Map[String, String]("zk.connect" -> zkQuorum, "groupid" -> groupId, "zk.connectiontimeout.ms" -> "10000"); + val kafkaParams = Map[String, String]( + "zk.connect" -> zkQuorum, "groupid" -> groupId, "zk.connectiontimeout.ms" -> "10000") kafkaStream[String, kafka.serializer.StringDecoder](kafkaParams, topics, storageLevel) } /** * Create an input stream that pulls messages from a Kafka Broker. - * @param kafkaParams Map of kafka configuration paramaters. See: http://kafka.apache.org/configuration.html + * @param kafkaParams Map of kafka configuration paramaters. + * See: http://kafka.apache.org/configuration.html * @param topics Map of (topic_name -> numPartitions) to consume. Each partition is consumed - * in its own thread. + * in its own thread. * @param storageLevel Storage level to use for storing the received objects */ def kafkaStream[T: ClassManifest, D <: kafka.serializer.Decoder[_]: Manifest]( @@ -395,7 +399,8 @@ class StreamingContext private ( * it will process either one or all of the RDDs returned by the queue. * @param queue Queue of RDDs * @param oneAtATime Whether only one RDD should be consumed from the queue in every interval - * @param defaultRDD Default RDD is returned by the DStream when the queue is empty. Set as null if no RDD should be returned when empty + * @param defaultRDD Default RDD is returned by the DStream when the queue is empty. + * Set as null if no RDD should be returned when empty * @tparam T Type of objects in the RDD */ def queueStream[T: ClassManifest]( diff --git a/streaming/src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala b/streaming/src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala index b35d9032f1..fd5e06b50f 100644 --- a/streaming/src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala +++ b/streaming/src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala @@ -75,7 +75,8 @@ class JavaStreamingContext(val ssc: StreamingContext) { : JavaDStream[String] = { implicit val cmt: ClassManifest[String] = implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[String]] - ssc.kafkaStream(zkQuorum, groupId, Map(topics.mapValues(_.intValue()).toSeq: _*), StorageLevel.MEMORY_ONLY_SER_2) + ssc.kafkaStream(zkQuorum, groupId, Map(topics.mapValues(_.intValue()).toSeq: _*), + StorageLevel.MEMORY_ONLY_SER_2) } /** @@ -83,8 +84,9 @@ class JavaStreamingContext(val ssc: StreamingContext) { * @param zkQuorum Zookeper quorum (hostname:port,hostname:port,..). * @param groupId The group id for this consumer. * @param topics Map of (topic_name -> numPartitions) to consume. Each partition is consumed + * in its own thread. * @param storageLevel RDD storage level. Defaults to memory-only - * in its own thread. + * */ def kafkaStream( zkQuorum: String, @@ -94,14 +96,16 @@ class JavaStreamingContext(val ssc: StreamingContext) { : JavaDStream[String] = { implicit val cmt: ClassManifest[String] = implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[String]] - ssc.kafkaStream(zkQuorum, groupId, Map(topics.mapValues(_.intValue()).toSeq: _*), storageLevel) + ssc.kafkaStream(zkQuorum, groupId, Map(topics.mapValues(_.intValue()).toSeq: _*), + storageLevel) } /** * Create an input stream that pulls messages form a Kafka Broker. * @param typeClass Type of RDD * @param decoderClass Type of kafka decoder - * @param kafkaParams Map of kafka configuration paramaters. See: http://kafka.apache.org/configuration.html + * @param kafkaParams Map of kafka configuration paramaters. + * See: http://kafka.apache.org/configuration.html * @param topics Map of (topic_name -> numPartitions) to consume. Each partition is consumed * in its own thread. * @param storageLevel RDD storage level. Defaults to memory-only @@ -113,7 +117,8 @@ class JavaStreamingContext(val ssc: StreamingContext) { topics: JMap[String, JInt], storageLevel: StorageLevel) : JavaDStream[T] = { - implicit val cmt: ClassManifest[T] = implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]] + implicit val cmt: ClassManifest[T] = + implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]] implicit val cmd: Manifest[D] = implicitly[Manifest[AnyRef]].asInstanceOf[Manifest[D]] ssc.kafkaStream[T, D]( kafkaParams.toMap, -- cgit v1.2.3