From cfa8e769a86664722f47182fa572179e8beadcb7 Mon Sep 17 00:00:00 2001 From: seanm Date: Mon, 11 Mar 2013 17:16:15 -0600 Subject: KafkaInputDStream improvements. Allows more Kafka configurability --- .../scala/spark/streaming/StreamingContext.scala | 22 +++++++++- .../streaming/dstream/KafkaInputDStream.scala | 48 ++++++++++++++-------- 2 files changed, 51 insertions(+), 19 deletions(-) (limited to 'streaming') diff --git a/streaming/src/main/scala/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/spark/streaming/StreamingContext.scala index 25c67b279b..4e1732adf5 100644 --- a/streaming/src/main/scala/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/spark/streaming/StreamingContext.scala @@ -199,7 +199,7 @@ class StreamingContext private ( } /** - * Create an input stream that pulls messages form a Kafka Broker. + * Create an input stream that pulls messages from a Kafka Broker. * @param zkQuorum Zookeper quorum (hostname:port,hostname:port,..). * @param groupId The group id for this consumer. * @param topics Map of (topic_name -> numPartitions) to consume. Each partition is consumed @@ -216,7 +216,25 @@ class StreamingContext private ( initialOffsets: Map[KafkaPartitionKey, Long] = Map[KafkaPartitionKey, Long](), storageLevel: StorageLevel = StorageLevel.MEMORY_ONLY_SER_2 ): DStream[T] = { - val inputStream = new KafkaInputDStream[T](this, zkQuorum, groupId, topics, initialOffsets, storageLevel) + val kafkaParams = Map[String, String]("zk.connect" -> zkQuorum, "groupid" -> groupId, "zk.connectiontimeout.ms" -> "10000"); + kafkaStream[T](kafkaParams, topics, initialOffsets, storageLevel) + } + + /** + * Create an input stream that pulls messages from a Kafka Broker. + * @param kafkaParams Map of kafka configuration paramaters. See: http://kafka.apache.org/configuration.html + * @param topics Map of (topic_name -> numPartitions) to consume. Each partition is consumed + * in its own thread. + * @param initialOffsets Optional initial offsets for each of the partitions to consume. + * @param storageLevel Storage level to use for storing the received objects + */ + def kafkaStream[T: ClassManifest]( + kafkaParams: Map[String, String], + topics: Map[String, Int], + initialOffsets: Map[KafkaPartitionKey, Long], + storageLevel: StorageLevel + ): DStream[T] = { + val inputStream = new KafkaInputDStream[T](this, kafkaParams, topics, initialOffsets, storageLevel) registerInputStream(inputStream) inputStream } diff --git a/streaming/src/main/scala/spark/streaming/dstream/KafkaInputDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/KafkaInputDStream.scala index dc7139cc27..f769fc1cc3 100644 --- a/streaming/src/main/scala/spark/streaming/dstream/KafkaInputDStream.scala +++ b/streaming/src/main/scala/spark/streaming/dstream/KafkaInputDStream.scala @@ -12,6 +12,8 @@ import kafka.message.{Message, MessageSet, MessageAndMetadata} import kafka.serializer.StringDecoder import kafka.utils.{Utils, ZKGroupTopicDirs} import kafka.utils.ZkUtils._ +import kafka.utils.ZKStringSerializer +import org.I0Itec.zkclient._ import scala.collection.mutable.HashMap import scala.collection.JavaConversions._ @@ -23,8 +25,7 @@ case class KafkaPartitionKey(brokerId: Int, topic: String, groupId: String, part /** * Input stream that pulls messages from a Kafka Broker. * - * @param zkQuorum Zookeper quorum (hostname:port,hostname:port,..). - * @param groupId The group id for this consumer. + * @param kafkaParams Map of kafka configuration paramaters. See: http://kafka.apache.org/configuration.html * @param topics Map of (topic_name -> numPartitions) to consume. Each partition is consumed * in its own thread. * @param initialOffsets Optional initial offsets for each of the partitions to consume. @@ -34,8 +35,7 @@ case class KafkaPartitionKey(brokerId: Int, topic: String, groupId: String, part private[streaming] class KafkaInputDStream[T: ClassManifest]( @transient ssc_ : StreamingContext, - zkQuorum: String, - groupId: String, + kafkaParams: Map[String, String], topics: Map[String, Int], initialOffsets: Map[KafkaPartitionKey, Long], storageLevel: StorageLevel @@ -43,19 +43,16 @@ class KafkaInputDStream[T: ClassManifest]( def getReceiver(): NetworkReceiver[T] = { - new KafkaReceiver(zkQuorum, groupId, topics, initialOffsets, storageLevel) + new KafkaReceiver(kafkaParams, topics, initialOffsets, storageLevel) .asInstanceOf[NetworkReceiver[T]] } } private[streaming] -class KafkaReceiver(zkQuorum: String, groupId: String, +class KafkaReceiver(kafkaParams: Map[String, String], topics: Map[String, Int], initialOffsets: Map[KafkaPartitionKey, Long], storageLevel: StorageLevel) extends NetworkReceiver[Any] { - // Timeout for establishing a connection to Zookeper in ms. - val ZK_TIMEOUT = 10000 - // Handles pushing data into the BlockManager lazy protected val blockGenerator = new BlockGenerator(storageLevel) // Connection to Kafka @@ -72,20 +69,24 @@ class KafkaReceiver(zkQuorum: String, groupId: String, // In case we are using multiple Threads to handle Kafka Messages val executorPool = Executors.newFixedThreadPool(topics.values.reduce(_ + _)) - logInfo("Starting Kafka Consumer Stream with group: " + groupId) + logInfo("Starting Kafka Consumer Stream with group: " + kafkaParams("groupid")) logInfo("Initial offsets: " + initialOffsets.toString) - // Zookeper connection properties + // Kafka connection properties val props = new Properties() - props.put("zk.connect", zkQuorum) - props.put("zk.connectiontimeout.ms", ZK_TIMEOUT.toString) - props.put("groupid", groupId) + kafkaParams.foreach(param => props.put(param._1, param._2)) // Create the connection to the cluster - logInfo("Connecting to Zookeper: " + zkQuorum) + logInfo("Connecting to Zookeper: " + kafkaParams("zk.connect")) val consumerConfig = new ConsumerConfig(props) consumerConnector = Consumer.create(consumerConfig).asInstanceOf[ZookeeperConsumerConnector] - logInfo("Connected to " + zkQuorum) + logInfo("Connected to " + kafkaParams("zk.connect")) + + // When autooffset.reset is 'smallest', it is our responsibility to try and whack the + // consumer group zk node. + if (kafkaParams.get("autooffset.reset").exists(_ == "smallest")) { + tryZookeeperConsumerGroupCleanup(kafkaParams("zk.connect"), kafkaParams("groupid")) + } // If specified, set the topic offset setOffsets(initialOffsets) @@ -97,7 +98,6 @@ class KafkaReceiver(zkQuorum: String, groupId: String, topicMessageStreams.values.foreach { streams => streams.foreach { stream => executorPool.submit(new MessageHandler(stream)) } } - } // Overwrites the offets in Zookeper. @@ -122,4 +122,18 @@ class KafkaReceiver(zkQuorum: String, groupId: String, } } } + + // Handles cleanup of consumer group znode. Lifted with love from Kafka's + // ConsumerConsole.scala tryCleanupZookeeper() + private def tryZookeeperConsumerGroupCleanup(zkUrl: String, groupId: String) { + try { + val dir = "/consumers/" + groupId + logInfo("Cleaning up temporary zookeeper data under " + dir + ".") + val zk = new ZkClient(zkUrl, 30*1000, 30*1000, ZKStringSerializer) + zk.deleteRecursive(dir) + zk.close() + } catch { + case _ => // swallow + } + } } -- cgit v1.2.3 From d06928321194b11e082986cd2bb2737d9bc3b698 Mon Sep 17 00:00:00 2001 From: seanm Date: Thu, 14 Mar 2013 23:25:35 -0600 Subject: fixing memory leak in kafka MessageHandler --- .../src/main/scala/spark/streaming/dstream/KafkaInputDStream.scala | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) (limited to 'streaming') diff --git a/streaming/src/main/scala/spark/streaming/dstream/KafkaInputDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/KafkaInputDStream.scala index f769fc1cc3..d674b6ee87 100644 --- a/streaming/src/main/scala/spark/streaming/dstream/KafkaInputDStream.scala +++ b/streaming/src/main/scala/spark/streaming/dstream/KafkaInputDStream.scala @@ -114,11 +114,8 @@ class KafkaReceiver(kafkaParams: Map[String, String], private class MessageHandler(stream: KafkaStream[String]) extends Runnable { def run() { logInfo("Starting MessageHandler.") - stream.takeWhile { msgAndMetadata => + for (msgAndMetadata <- stream) { blockGenerator += msgAndMetadata.message - // Keep on handling messages - - true } } } -- cgit v1.2.3 From 33fa1e7e4aca4d9e0edf65d2b768b569305fd044 Mon Sep 17 00:00:00 2001 From: seanm Date: Thu, 14 Mar 2013 23:32:52 -0600 Subject: removing dependency on ZookeeperConsumerConnector + purging last relic of kafka reliability that never solidified (ie- setOffsets) --- .../scala/spark/streaming/StreamingContext.scala | 9 ++----- .../streaming/api/java/JavaStreamingContext.scala | 28 ---------------------- .../streaming/dstream/KafkaInputDStream.scala | 28 ++++------------------ 3 files changed, 6 insertions(+), 59 deletions(-) (limited to 'streaming') diff --git a/streaming/src/main/scala/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/spark/streaming/StreamingContext.scala index 4e1732adf5..bb7f216ca7 100644 --- a/streaming/src/main/scala/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/spark/streaming/StreamingContext.scala @@ -204,8 +204,6 @@ class StreamingContext private ( * @param groupId The group id for this consumer. * @param topics Map of (topic_name -> numPartitions) to consume. Each partition is consumed * in its own thread. - * @param initialOffsets Optional initial offsets for each of the partitions to consume. - * By default the value is pulled from zookeper. * @param storageLevel Storage level to use for storing the received objects * (default: StorageLevel.MEMORY_AND_DISK_SER_2) */ @@ -213,11 +211,10 @@ class StreamingContext private ( zkQuorum: String, groupId: String, topics: Map[String, Int], - initialOffsets: Map[KafkaPartitionKey, Long] = Map[KafkaPartitionKey, Long](), storageLevel: StorageLevel = StorageLevel.MEMORY_ONLY_SER_2 ): DStream[T] = { val kafkaParams = Map[String, String]("zk.connect" -> zkQuorum, "groupid" -> groupId, "zk.connectiontimeout.ms" -> "10000"); - kafkaStream[T](kafkaParams, topics, initialOffsets, storageLevel) + kafkaStream[T](kafkaParams, topics, storageLevel) } /** @@ -225,16 +222,14 @@ class StreamingContext private ( * @param kafkaParams Map of kafka configuration paramaters. See: http://kafka.apache.org/configuration.html * @param topics Map of (topic_name -> numPartitions) to consume. Each partition is consumed * in its own thread. - * @param initialOffsets Optional initial offsets for each of the partitions to consume. * @param storageLevel Storage level to use for storing the received objects */ def kafkaStream[T: ClassManifest]( kafkaParams: Map[String, String], topics: Map[String, Int], - initialOffsets: Map[KafkaPartitionKey, Long], storageLevel: StorageLevel ): DStream[T] = { - val inputStream = new KafkaInputDStream[T](this, kafkaParams, topics, initialOffsets, storageLevel) + val inputStream = new KafkaInputDStream[T](this, kafkaParams, topics, storageLevel) registerInputStream(inputStream) inputStream } diff --git a/streaming/src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala b/streaming/src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala index f3b40b5b88..2373f4824a 100644 --- a/streaming/src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala +++ b/streaming/src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala @@ -84,39 +84,12 @@ class JavaStreamingContext(val ssc: StreamingContext) { * @param groupId The group id for this consumer. * @param topics Map of (topic_name -> numPartitions) to consume. Each partition is consumed * in its own thread. - * @param initialOffsets Optional initial offsets for each of the partitions to consume. - * By default the value is pulled from zookeper. - */ - def kafkaStream[T]( - zkQuorum: String, - groupId: String, - topics: JMap[String, JInt], - initialOffsets: JMap[KafkaPartitionKey, JLong]) - : JavaDStream[T] = { - implicit val cmt: ClassManifest[T] = - implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]] - ssc.kafkaStream[T]( - zkQuorum, - groupId, - Map(topics.mapValues(_.intValue()).toSeq: _*), - Map(initialOffsets.mapValues(_.longValue()).toSeq: _*)) - } - - /** - * Create an input stream that pulls messages form a Kafka Broker. - * @param zkQuorum Zookeper quorum (hostname:port,hostname:port,..). - * @param groupId The group id for this consumer. - * @param topics Map of (topic_name -> numPartitions) to consume. Each partition is consumed - * in its own thread. - * @param initialOffsets Optional initial offsets for each of the partitions to consume. - * By default the value is pulled from zookeper. * @param storageLevel RDD storage level. Defaults to memory-only */ def kafkaStream[T]( zkQuorum: String, groupId: String, topics: JMap[String, JInt], - initialOffsets: JMap[KafkaPartitionKey, JLong], storageLevel: StorageLevel) : JavaDStream[T] = { implicit val cmt: ClassManifest[T] = @@ -125,7 +98,6 @@ class JavaStreamingContext(val ssc: StreamingContext) { zkQuorum, groupId, Map(topics.mapValues(_.intValue()).toSeq: _*), - Map(initialOffsets.mapValues(_.longValue()).toSeq: _*), storageLevel) } diff --git a/streaming/src/main/scala/spark/streaming/dstream/KafkaInputDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/KafkaInputDStream.scala index d674b6ee87..c6da1a7f70 100644 --- a/streaming/src/main/scala/spark/streaming/dstream/KafkaInputDStream.scala +++ b/streaming/src/main/scala/spark/streaming/dstream/KafkaInputDStream.scala @@ -19,17 +19,12 @@ import scala.collection.mutable.HashMap import scala.collection.JavaConversions._ -// Key for a specific Kafka Partition: (broker, topic, group, part) -case class KafkaPartitionKey(brokerId: Int, topic: String, groupId: String, partId: Int) - /** * Input stream that pulls messages from a Kafka Broker. * * @param kafkaParams Map of kafka configuration paramaters. See: http://kafka.apache.org/configuration.html * @param topics Map of (topic_name -> numPartitions) to consume. Each partition is consumed * in its own thread. - * @param initialOffsets Optional initial offsets for each of the partitions to consume. - * By default the value is pulled from zookeper. * @param storageLevel RDD storage level. */ private[streaming] @@ -37,26 +32,25 @@ class KafkaInputDStream[T: ClassManifest]( @transient ssc_ : StreamingContext, kafkaParams: Map[String, String], topics: Map[String, Int], - initialOffsets: Map[KafkaPartitionKey, Long], storageLevel: StorageLevel ) extends NetworkInputDStream[T](ssc_ ) with Logging { def getReceiver(): NetworkReceiver[T] = { - new KafkaReceiver(kafkaParams, topics, initialOffsets, storageLevel) + new KafkaReceiver(kafkaParams, topics, storageLevel) .asInstanceOf[NetworkReceiver[T]] } } private[streaming] class KafkaReceiver(kafkaParams: Map[String, String], - topics: Map[String, Int], initialOffsets: Map[KafkaPartitionKey, Long], + topics: Map[String, Int], storageLevel: StorageLevel) extends NetworkReceiver[Any] { // Handles pushing data into the BlockManager lazy protected val blockGenerator = new BlockGenerator(storageLevel) // Connection to Kafka - var consumerConnector : ZookeeperConsumerConnector = null + var consumerConnector : ConsumerConnector = null def onStop() { blockGenerator.stop() @@ -70,7 +64,6 @@ class KafkaReceiver(kafkaParams: Map[String, String], val executorPool = Executors.newFixedThreadPool(topics.values.reduce(_ + _)) logInfo("Starting Kafka Consumer Stream with group: " + kafkaParams("groupid")) - logInfo("Initial offsets: " + initialOffsets.toString) // Kafka connection properties val props = new Properties() @@ -79,7 +72,7 @@ class KafkaReceiver(kafkaParams: Map[String, String], // Create the connection to the cluster logInfo("Connecting to Zookeper: " + kafkaParams("zk.connect")) val consumerConfig = new ConsumerConfig(props) - consumerConnector = Consumer.create(consumerConfig).asInstanceOf[ZookeeperConsumerConnector] + consumerConnector = Consumer.create(consumerConfig) logInfo("Connected to " + kafkaParams("zk.connect")) // When autooffset.reset is 'smallest', it is our responsibility to try and whack the @@ -88,9 +81,6 @@ class KafkaReceiver(kafkaParams: Map[String, String], tryZookeeperConsumerGroupCleanup(kafkaParams("zk.connect"), kafkaParams("groupid")) } - // If specified, set the topic offset - setOffsets(initialOffsets) - // Create Threads for each Topic/Message Stream we are listening val topicMessageStreams = consumerConnector.createMessageStreams(topics, new StringDecoder()) @@ -100,16 +90,6 @@ class KafkaReceiver(kafkaParams: Map[String, String], } } - // Overwrites the offets in Zookeper. - private def setOffsets(offsets: Map[KafkaPartitionKey, Long]) { - offsets.foreach { case(key, offset) => - val topicDirs = new ZKGroupTopicDirs(key.groupId, key.topic) - val partitionName = key.brokerId + "-" + key.partId - updatePersistentPath(consumerConnector.zkClient, - topicDirs.consumerOffsetDir + "/" + partitionName, offset.toString) - } - } - // Handles Kafka Messages private class MessageHandler(stream: KafkaStream[String]) extends Runnable { def run() { -- cgit v1.2.3 From d61978d0abad30a148680c8a63df33e40e469525 Mon Sep 17 00:00:00 2001 From: seanm Date: Fri, 15 Mar 2013 23:36:52 -0600 Subject: keeping JavaStreamingContext in sync with StreamingContext + adding comments for better clarity --- .../main/scala/spark/streaming/api/java/JavaStreamingContext.scala | 7 +++---- .../src/main/scala/spark/streaming/dstream/KafkaInputDStream.scala | 6 ++++-- 2 files changed, 7 insertions(+), 6 deletions(-) (limited to 'streaming') diff --git a/streaming/src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala b/streaming/src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala index 2373f4824a..7a8864614c 100644 --- a/streaming/src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala +++ b/streaming/src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala @@ -80,6 +80,7 @@ class JavaStreamingContext(val ssc: StreamingContext) { /** * Create an input stream that pulls messages form a Kafka Broker. + * @param kafkaParams Map of kafka configuration paramaters. See: http://kafka.apache.org/configuration.html * @param zkQuorum Zookeper quorum (hostname:port,hostname:port,..). * @param groupId The group id for this consumer. * @param topics Map of (topic_name -> numPartitions) to consume. Each partition is consumed @@ -87,16 +88,14 @@ class JavaStreamingContext(val ssc: StreamingContext) { * @param storageLevel RDD storage level. Defaults to memory-only */ def kafkaStream[T]( - zkQuorum: String, - groupId: String, + kafkaParams: JMap[String, String], topics: JMap[String, JInt], storageLevel: StorageLevel) : JavaDStream[T] = { implicit val cmt: ClassManifest[T] = implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]] ssc.kafkaStream[T]( - zkQuorum, - groupId, + kafkaParams.toMap, Map(topics.mapValues(_.intValue()).toSeq: _*), storageLevel) } diff --git a/streaming/src/main/scala/spark/streaming/dstream/KafkaInputDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/KafkaInputDStream.scala index c6da1a7f70..85693808d1 100644 --- a/streaming/src/main/scala/spark/streaming/dstream/KafkaInputDStream.scala +++ b/streaming/src/main/scala/spark/streaming/dstream/KafkaInputDStream.scala @@ -100,8 +100,10 @@ class KafkaReceiver(kafkaParams: Map[String, String], } } - // Handles cleanup of consumer group znode. Lifted with love from Kafka's - // ConsumerConsole.scala tryCleanupZookeeper() + // Delete consumer group from zookeeper. This effectivly resets the group so we can consume from the beginning again. + // The kafka high level consumer doesn't expose setting offsets currently, this is a trick copied from Kafkas' + // ConsoleConsumer. See code related to 'autooffset.reset' when it is set to 'smallest': + // https://github.com/apache/kafka/blob/0.7.2/core/src/main/scala/kafka/consumer/ConsoleConsumer.scala private def tryZookeeperConsumerGroupCleanup(zkUrl: String, groupId: String) { try { val dir = "/consumers/" + groupId -- cgit v1.2.3 From 329ef34c2e04d28c2ad150cf6674d6e86d7511ce Mon Sep 17 00:00:00 2001 From: seanm Date: Tue, 26 Mar 2013 23:56:15 -0600 Subject: fixing autooffset.reset behavior when set to 'largest' --- .../main/scala/spark/streaming/dstream/KafkaInputDStream.scala | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) (limited to 'streaming') diff --git a/streaming/src/main/scala/spark/streaming/dstream/KafkaInputDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/KafkaInputDStream.scala index 85693808d1..17a5be3420 100644 --- a/streaming/src/main/scala/spark/streaming/dstream/KafkaInputDStream.scala +++ b/streaming/src/main/scala/spark/streaming/dstream/KafkaInputDStream.scala @@ -75,9 +75,9 @@ class KafkaReceiver(kafkaParams: Map[String, String], consumerConnector = Consumer.create(consumerConfig) logInfo("Connected to " + kafkaParams("zk.connect")) - // When autooffset.reset is 'smallest', it is our responsibility to try and whack the + // When autooffset.reset is defined, it is our responsibility to try and whack the // consumer group zk node. - if (kafkaParams.get("autooffset.reset").exists(_ == "smallest")) { + if (kafkaParams.contains("autooffset.reset")) { tryZookeeperConsumerGroupCleanup(kafkaParams("zk.connect"), kafkaParams("groupid")) } @@ -100,9 +100,11 @@ class KafkaReceiver(kafkaParams: Map[String, String], } } - // Delete consumer group from zookeeper. This effectivly resets the group so we can consume from the beginning again. + // It is our responsibility to delete the consumer group when specifying autooffset.reset. This is because + // Kafka 0.7.2 only honors this param when the group is not in zookeeper. + // // The kafka high level consumer doesn't expose setting offsets currently, this is a trick copied from Kafkas' - // ConsoleConsumer. See code related to 'autooffset.reset' when it is set to 'smallest': + // ConsoleConsumer. See code related to 'autooffset.reset' when it is set to 'smallest'/'largest': // https://github.com/apache/kafka/blob/0.7.2/core/src/main/scala/kafka/consumer/ConsoleConsumer.scala private def tryZookeeperConsumerGroupCleanup(zkUrl: String, groupId: String) { try { -- cgit v1.2.3 From df47b40b764e25cbd10ce49d7152e1d33f51a263 Mon Sep 17 00:00:00 2001 From: shane-huang Date: Wed, 20 Feb 2013 11:51:13 +0800 Subject: Shuffle Performance fix: Use netty embeded OIO file server instead of ConnectionManager Shuffle Performance Optimization: do not send 0-byte block requests to reduce network messages change reference from io.Source to scala.io.Source to avoid looking into io.netty package Signed-off-by: shane-huang --- .../main/java/spark/network/netty/FileClient.java | 89 +++++++ .../netty/FileClientChannelInitializer.java | 29 +++ .../spark/network/netty/FileClientHandler.java | 38 +++ .../main/java/spark/network/netty/FileServer.java | 59 +++++ .../netty/FileServerChannelInitializer.java | 33 +++ .../spark/network/netty/FileServerHandler.java | 68 ++++++ .../java/spark/network/netty/PathResolver.java | 12 + .../scala/spark/network/netty/FileHeader.scala | 57 +++++ .../scala/spark/network/netty/ShuffleCopier.scala | 88 +++++++ .../scala/spark/network/netty/ShuffleSender.scala | 50 ++++ .../main/scala/spark/storage/BlockManager.scala | 272 +++++++++++++++++---- core/src/main/scala/spark/storage/DiskStore.scala | 51 +++- project/SparkBuild.scala | 3 +- .../scala/spark/streaming/util/RawTextSender.scala | 2 +- 14 files changed, 795 insertions(+), 56 deletions(-) create mode 100644 core/src/main/java/spark/network/netty/FileClient.java create mode 100644 core/src/main/java/spark/network/netty/FileClientChannelInitializer.java create mode 100644 core/src/main/java/spark/network/netty/FileClientHandler.java create mode 100644 core/src/main/java/spark/network/netty/FileServer.java create mode 100644 core/src/main/java/spark/network/netty/FileServerChannelInitializer.java create mode 100644 core/src/main/java/spark/network/netty/FileServerHandler.java create mode 100755 core/src/main/java/spark/network/netty/PathResolver.java create mode 100644 core/src/main/scala/spark/network/netty/FileHeader.scala create mode 100644 core/src/main/scala/spark/network/netty/ShuffleCopier.scala create mode 100644 core/src/main/scala/spark/network/netty/ShuffleSender.scala (limited to 'streaming') diff --git a/core/src/main/java/spark/network/netty/FileClient.java b/core/src/main/java/spark/network/netty/FileClient.java new file mode 100644 index 0000000000..d0c5081dd2 --- /dev/null +++ b/core/src/main/java/spark/network/netty/FileClient.java @@ -0,0 +1,89 @@ +package spark.network.netty; + +import io.netty.bootstrap.Bootstrap; +import io.netty.channel.AbstractChannel; +import io.netty.channel.Channel; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelFutureListener; +import io.netty.channel.ChannelOption; +import io.netty.channel.nio.NioEventLoopGroup; +import io.netty.channel.socket.nio.NioSocketChannel; +import io.netty.channel.oio.OioEventLoopGroup; +import io.netty.channel.socket.oio.OioSocketChannel; + +import java.util.Arrays; + +public class FileClient { + + private FileClientHandler handler = null; + private Channel channel = null; + private Bootstrap bootstrap = null; + + public FileClient(FileClientHandler handler){ + this.handler = handler; + } + + public void init(){ + bootstrap = new Bootstrap(); + bootstrap.group(new OioEventLoopGroup()) + .channel(OioSocketChannel.class) + .option(ChannelOption.SO_KEEPALIVE, true) + .option(ChannelOption.TCP_NODELAY, true) + .handler(new FileClientChannelInitializer(handler)); + } + + public static final class ChannelCloseListener implements ChannelFutureListener { + private FileClient fc = null; + public ChannelCloseListener(FileClient fc){ + this.fc = fc; + } + @Override + public void operationComplete(ChannelFuture future) { + if (fc.bootstrap!=null){ + fc.bootstrap.shutdown(); + fc.bootstrap = null; + } + } + } + + public void connect(String host, int port){ + try { + + // Start the connection attempt. + channel = bootstrap.connect(host, port).sync().channel(); + // ChannelFuture cf = channel.closeFuture(); + //cf.addListener(new ChannelCloseListener(this)); + } catch (InterruptedException e) { + close(); + } + } + + public void waitForClose(){ + try { + channel.closeFuture().sync(); + } catch (InterruptedException e){ + e.printStackTrace(); + } + } + + public void sendRequest(String file){ + //assert(file == null); + //assert(channel == null); + channel.write(file+"\r\n"); + } + + public void close(){ + if(channel != null) { + channel.close(); + channel = null; + } + if ( bootstrap!=null) { + bootstrap.shutdown(); + bootstrap = null; + } + } + + +} + + diff --git a/core/src/main/java/spark/network/netty/FileClientChannelInitializer.java b/core/src/main/java/spark/network/netty/FileClientChannelInitializer.java new file mode 100644 index 0000000000..50e5704619 --- /dev/null +++ b/core/src/main/java/spark/network/netty/FileClientChannelInitializer.java @@ -0,0 +1,29 @@ +package spark.network.netty; + +import io.netty.buffer.BufType; +import io.netty.channel.ChannelInitializer; +import io.netty.channel.socket.SocketChannel; +import io.netty.handler.codec.LengthFieldBasedFrameDecoder; +import io.netty.handler.codec.string.StringEncoder; +import io.netty.util.CharsetUtil; + +import io.netty.handler.logging.LoggingHandler; +import io.netty.handler.logging.LogLevel; + +public class FileClientChannelInitializer extends + ChannelInitializer { + + private FileClientHandler fhandler; + + public FileClientChannelInitializer(FileClientHandler handler) { + fhandler = handler; + } + + @Override + public void initChannel(SocketChannel channel) { + // file no more than 2G + channel.pipeline() + .addLast("encoder", new StringEncoder(BufType.BYTE)) + .addLast("handler", fhandler); + } +} diff --git a/core/src/main/java/spark/network/netty/FileClientHandler.java b/core/src/main/java/spark/network/netty/FileClientHandler.java new file mode 100644 index 0000000000..911c8b32b5 --- /dev/null +++ b/core/src/main/java/spark/network/netty/FileClientHandler.java @@ -0,0 +1,38 @@ +package spark.network.netty; + +import io.netty.buffer.ByteBuf; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundByteHandlerAdapter; +import io.netty.util.CharsetUtil; + +import java.util.concurrent.atomic.AtomicInteger; +import java.util.logging.Logger; + +public abstract class FileClientHandler extends ChannelInboundByteHandlerAdapter { + + private FileHeader currentHeader = null; + + public abstract void handle(ChannelHandlerContext ctx, ByteBuf in, FileHeader header); + + @Override + public ByteBuf newInboundBuffer(ChannelHandlerContext ctx) { + // Use direct buffer if possible. + return ctx.alloc().ioBuffer(); + } + + @Override + public void inboundBufferUpdated(ChannelHandlerContext ctx, ByteBuf in) { + // get header + if (currentHeader == null && in.readableBytes() >= FileHeader.HEADER_SIZE()) { + currentHeader = FileHeader.create(in.readBytes(FileHeader.HEADER_SIZE())); + } + // get file + if(in.readableBytes() >= currentHeader.fileLen()){ + handle(ctx,in,currentHeader); + currentHeader = null; + ctx.close(); + } + } + +} + diff --git a/core/src/main/java/spark/network/netty/FileServer.java b/core/src/main/java/spark/network/netty/FileServer.java new file mode 100644 index 0000000000..729e45f0a1 --- /dev/null +++ b/core/src/main/java/spark/network/netty/FileServer.java @@ -0,0 +1,59 @@ +package spark.network.netty; + +import java.io.File; +import io.netty.bootstrap.ServerBootstrap; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelOption; +import io.netty.channel.Channel; +import io.netty.channel.nio.NioEventLoopGroup; +import io.netty.channel.socket.nio.NioServerSocketChannel; +import io.netty.channel.oio.OioEventLoopGroup; +import io.netty.channel.socket.oio.OioServerSocketChannel; +import io.netty.handler.logging.LogLevel; +import io.netty.handler.logging.LoggingHandler; + +/** + * Server that accept the path of a file an echo back its content. + */ +public class FileServer { + + private ServerBootstrap bootstrap = null; + private Channel channel = null; + private PathResolver pResolver; + + public FileServer(PathResolver pResolver){ + this.pResolver = pResolver; + } + + public void run(int port) { + // Configure the server. + bootstrap = new ServerBootstrap(); + try { + bootstrap.group(new OioEventLoopGroup(), new OioEventLoopGroup()) + .channel(OioServerSocketChannel.class) + .option(ChannelOption.SO_BACKLOG, 100) + .option(ChannelOption.SO_RCVBUF, 1500) + .childHandler(new FileServerChannelInitializer(pResolver)); + // Start the server. + channel = bootstrap.bind(port).sync().channel(); + channel.closeFuture().sync(); + } catch (InterruptedException e) { + // TODO Auto-generated catch block + e.printStackTrace(); + } finally{ + bootstrap.shutdown(); + } + } + + public void stop(){ + if (channel!=null){ + channel.close(); + } + if (bootstrap != null){ + bootstrap.shutdown(); + bootstrap = null; + } + } +} + + diff --git a/core/src/main/java/spark/network/netty/FileServerChannelInitializer.java b/core/src/main/java/spark/network/netty/FileServerChannelInitializer.java new file mode 100644 index 0000000000..9d0618ff1c --- /dev/null +++ b/core/src/main/java/spark/network/netty/FileServerChannelInitializer.java @@ -0,0 +1,33 @@ +package spark.network.netty; + +import java.io.File; +import io.netty.buffer.BufType; +import io.netty.channel.ChannelInitializer; +import io.netty.channel.socket.SocketChannel; +import io.netty.handler.codec.string.StringDecoder; +import io.netty.handler.codec.string.StringEncoder; +import io.netty.handler.codec.DelimiterBasedFrameDecoder; +import io.netty.handler.codec.Delimiters; +import io.netty.util.CharsetUtil; +import io.netty.handler.logging.LoggingHandler; +import io.netty.handler.logging.LogLevel; + +public class FileServerChannelInitializer extends + ChannelInitializer { + + PathResolver pResolver; + + public FileServerChannelInitializer(PathResolver pResolver) { + this.pResolver = pResolver; + } + + @Override + public void initChannel(SocketChannel channel) { + channel.pipeline() + .addLast("framer", new DelimiterBasedFrameDecoder( + 8192, Delimiters.lineDelimiter())) + .addLast("strDecoder", new StringDecoder()) + .addLast("handler", new FileServerHandler(pResolver)); + + } +} diff --git a/core/src/main/java/spark/network/netty/FileServerHandler.java b/core/src/main/java/spark/network/netty/FileServerHandler.java new file mode 100644 index 0000000000..e1083e87a2 --- /dev/null +++ b/core/src/main/java/spark/network/netty/FileServerHandler.java @@ -0,0 +1,68 @@ +package spark.network.netty; + +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundMessageHandlerAdapter; +import io.netty.channel.DefaultFileRegion; +import io.netty.handler.stream.ChunkedFile; +import java.io.File; +import java.io.FileInputStream; + +public class FileServerHandler extends + ChannelInboundMessageHandlerAdapter { + + PathResolver pResolver; + + public FileServerHandler(PathResolver pResolver){ + this.pResolver = pResolver; + } + + @Override + public void messageReceived(ChannelHandlerContext ctx, String blockId) { + String path = pResolver.getAbsolutePath(blockId); + // if getFilePath returns null, close the channel + if (path == null) { + //ctx.close(); + return; + } + File file = new File(path); + if (file.exists()) { + if (!file.isFile()) { + //logger.info("Not a file : " + file.getAbsolutePath()); + ctx.write(new FileHeader(0, blockId).buffer()); + ctx.flush(); + return; + } + long length = file.length(); + if (length > Integer.MAX_VALUE || length <= 0 ) { + //logger.info("too large file : " + file.getAbsolutePath() + " of size "+ length); + ctx.write(new FileHeader(0, blockId).buffer()); + ctx.flush(); + return; + } + int len = new Long(length).intValue(); + //logger.info("Sending block "+blockId+" filelen = "+len); + //logger.info("header = "+ (new FileHeader(len, blockId)).buffer()); + ctx.write((new FileHeader(len, blockId)).buffer()); + try { + ctx.sendFile(new DefaultFileRegion(new FileInputStream(file) + .getChannel(), 0, file.length())); + } catch (Exception e) { + // TODO Auto-generated catch block + //logger.warning("Exception when sending file : " + //+ file.getAbsolutePath()); + e.printStackTrace(); + } + } else { + //logger.warning("File not found: " + file.getAbsolutePath()); + ctx.write(new FileHeader(0, blockId).buffer()); + } + ctx.flush(); + } + + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { + cause.printStackTrace(); + ctx.close(); + } +} diff --git a/core/src/main/java/spark/network/netty/PathResolver.java b/core/src/main/java/spark/network/netty/PathResolver.java new file mode 100755 index 0000000000..5d5eda006e --- /dev/null +++ b/core/src/main/java/spark/network/netty/PathResolver.java @@ -0,0 +1,12 @@ +package spark.network.netty; + +public interface PathResolver { + /** + * Get the absolute path of the file + * + * @param fileId + * @return the absolute path of file + */ + public String getAbsolutePath(String fileId); + +} diff --git a/core/src/main/scala/spark/network/netty/FileHeader.scala b/core/src/main/scala/spark/network/netty/FileHeader.scala new file mode 100644 index 0000000000..aed4254234 --- /dev/null +++ b/core/src/main/scala/spark/network/netty/FileHeader.scala @@ -0,0 +1,57 @@ +package spark.network.netty + +import io.netty.buffer._ + +import spark.Logging + +private[spark] class FileHeader ( + val fileLen: Int, + val blockId: String) extends Logging { + + lazy val buffer = { + val buf = Unpooled.buffer() + buf.capacity(FileHeader.HEADER_SIZE) + buf.writeInt(fileLen) + buf.writeInt(blockId.length) + blockId.foreach((x: Char) => buf.writeByte(x)) + //padding the rest of header + if (FileHeader.HEADER_SIZE - buf.readableBytes > 0 ) { + buf.writeZero(FileHeader.HEADER_SIZE - buf.readableBytes) + } else { + throw new Exception("too long header " + buf.readableBytes) + logInfo("too long header") + } + buf + } + +} + +private[spark] object FileHeader { + + val HEADER_SIZE = 40 + + def getFileLenOffset = 0 + def getFileLenSize = Integer.SIZE/8 + + def create(buf: ByteBuf): FileHeader = { + val length = buf.readInt + val idLength = buf.readInt + val idBuilder = new StringBuilder(idLength) + for (i <- 1 to idLength) { + idBuilder += buf.readByte().asInstanceOf[Char] + } + val blockId = idBuilder.toString() + new FileHeader(length, blockId) + } + + + def main (args:Array[String]){ + + val header = new FileHeader(25,"block_0"); + val buf = header.buffer; + val newheader = FileHeader.create(buf); + System.out.println("id="+newheader.blockId+",size="+newheader.fileLen) + + } +} + diff --git a/core/src/main/scala/spark/network/netty/ShuffleCopier.scala b/core/src/main/scala/spark/network/netty/ShuffleCopier.scala new file mode 100644 index 0000000000..d8d35bfeec --- /dev/null +++ b/core/src/main/scala/spark/network/netty/ShuffleCopier.scala @@ -0,0 +1,88 @@ +package spark.network.netty + +import io.netty.buffer.ByteBuf +import io.netty.channel.ChannelHandlerContext +import io.netty.channel.ChannelInboundByteHandlerAdapter +import io.netty.util.CharsetUtil + +import java.util.concurrent.atomic.AtomicInteger +import java.util.logging.Logger +import spark.Logging +import spark.network.ConnectionManagerId +import java.util.concurrent.Executors + +private[spark] class ShuffleCopier extends Logging { + + def getBlock(cmId: ConnectionManagerId, + blockId: String, + resultCollectCallback: (String, Long, ByteBuf) => Unit) = { + + val handler = new ShuffleClientHandler(resultCollectCallback) + val fc = new FileClient(handler) + fc.init() + fc.connect(cmId.host, cmId.port) + fc.sendRequest(blockId) + fc.waitForClose() + fc.close() + } + + def getBlocks(cmId: ConnectionManagerId, + blocks: Seq[(String, Long)], + resultCollectCallback: (String, Long, ByteBuf) => Unit) = { + + blocks.map { + case(blockId,size) => { + getBlock(cmId,blockId,resultCollectCallback) + } + } + } +} + +private[spark] class ShuffleClientHandler(val resultCollectCallBack: (String, Long, ByteBuf) => Unit ) extends FileClientHandler with Logging { + + def handle(ctx: ChannelHandlerContext, in: ByteBuf, header: FileHeader) { + logDebug("Received Block: " + header.blockId + " (" + header.fileLen + "B)"); + resultCollectCallBack(header.blockId, header.fileLen.toLong, in.readBytes(header.fileLen)) + } +} + +private[spark] object ShuffleCopier extends Logging { + + def echoResultCollectCallBack(blockId: String, size: Long, content: ByteBuf) = { + logInfo("File: " + blockId + " content is : \" " + + content.toString(CharsetUtil.UTF_8) + "\"") + } + + def runGetBlock(host:String, port:Int, file:String){ + val handler = new ShuffleClientHandler(echoResultCollectCallBack) + val fc = new FileClient(handler) + fc.init(); + fc.connect(host, port) + fc.sendRequest(file) + fc.waitForClose(); + fc.close() + } + + def main(args: Array[String]) { + if (args.length < 3) { + System.err.println("Usage: ShuffleCopier ") + System.exit(1) + } + val host = args(0) + val port = args(1).toInt + val file = args(2) + val threads = if (args.length>3) args(3).toInt else 10 + + val copiers = Executors.newFixedThreadPool(80) + for (i <- Range(0,threads)){ + val runnable = new Runnable() { + def run() { + runGetBlock(host,port,file) + } + } + copiers.execute(runnable) + } + copiers.shutdown + } + +} diff --git a/core/src/main/scala/spark/network/netty/ShuffleSender.scala b/core/src/main/scala/spark/network/netty/ShuffleSender.scala new file mode 100644 index 0000000000..c1986812e9 --- /dev/null +++ b/core/src/main/scala/spark/network/netty/ShuffleSender.scala @@ -0,0 +1,50 @@ +package spark.network.netty + +import spark.Logging +import java.io.File + + +private[spark] class ShuffleSender(val port: Int, val pResolver:PathResolver) extends Logging { + val server = new FileServer(pResolver) + + Runtime.getRuntime().addShutdownHook( + new Thread() { + override def run() { + server.stop() + } + } + ) + + def start() { + server.run(port) + } +} + +private[spark] object ShuffleSender { + def main(args: Array[String]) { + if (args.length < 3) { + System.err.println("Usage: ShuffleSender ") + System.exit(1) + } + val port = args(0).toInt + val subDirsPerLocalDir = args(1).toInt + val localDirs = args.drop(2) map {new File(_)} + val pResovler = new PathResolver { + def getAbsolutePath(blockId:String):String = { + if (!blockId.startsWith("shuffle_")) { + throw new Exception("Block " + blockId + " is not a shuffle block") + } + // Figure out which local directory it hashes to, and which subdirectory in that + val hash = math.abs(blockId.hashCode) + val dirId = hash % localDirs.length + val subDirId = (hash / localDirs.length) % subDirsPerLocalDir + val subDir = new File(localDirs(dirId), "%02x".format(subDirId)) + val file = new File(subDir, blockId) + return file.getAbsolutePath + } + } + val sender = new ShuffleSender(port, pResovler) + + sender.start() + } +} diff --git a/core/src/main/scala/spark/storage/BlockManager.scala b/core/src/main/scala/spark/storage/BlockManager.scala index 210061e972..b8b68d4283 100644 --- a/core/src/main/scala/spark/storage/BlockManager.scala +++ b/core/src/main/scala/spark/storage/BlockManager.scala @@ -23,6 +23,8 @@ import spark.util.{ByteBufferInputStream, IdGenerator, MetadataCleaner, TimeStam import sun.nio.ch.DirectBuffer +import spark.network.netty.ShuffleCopier +import io.netty.buffer.ByteBuf private[spark] case class BlockException(blockId: String, message: String, ex: Exception = null) @@ -467,6 +469,21 @@ class BlockManager( getLocal(blockId).orElse(getRemote(blockId)) } + /** + * A request to fetch one or more blocks, complete with their sizes + */ + class FetchRequest(val address: BlockManagerId, val blocks: Seq[(String, Long)]) { + val size = blocks.map(_._2).sum + } + + /** + * A result of a fetch. Includes the block ID, size in bytes, and a function to deserialize + * the block (since we want all deserializaton to happen in the calling thread); can also + * represent a fetch failure if size == -1. + */ + class FetchResult(val blockId: String, val size: Long, val deserialize: () => Iterator[Any]) { + def failed: Boolean = size == -1 + } /** * Get multiple blocks from local and remote block manager using their BlockManagerIds. Returns * an Iterator of (block ID, value) pairs so that clients may handle blocks in a pipelined @@ -475,7 +492,12 @@ class BlockManager( */ def getMultiple(blocksByAddress: Seq[(BlockManagerId, Seq[(String, Long)])]) : BlockFetcherIterator = { - return new BlockFetcherIterator(this, blocksByAddress) + + if(System.getProperty("spark.shuffle.use.netty", "false").toBoolean){ + return new NettyBlockFetcherIterator(this, blocksByAddress) + } else { + return new BlockFetcherIterator(this, blocksByAddress) + } } def put(blockId: String, values: Iterator[Any], level: StorageLevel, tellMaster: Boolean) @@ -908,7 +930,7 @@ class BlockFetcherIterator( if (blocksByAddress == null) { throw new IllegalArgumentException("BlocksByAddress is null") } - val totalBlocks = blocksByAddress.map(_._2.size).sum + var totalBlocks = blocksByAddress.map(_._2.size).sum logDebug("Getting " + totalBlocks + " blocks") var startTime = System.currentTimeMillis val localBlockIds = new ArrayBuffer[String]() @@ -974,68 +996,83 @@ class BlockFetcherIterator( } } - // Split local and remote blocks. Remote blocks are further split into FetchRequests of size - // at most maxBytesInFlight in order to limit the amount of data in flight. - val remoteRequests = new ArrayBuffer[FetchRequest] - for ((address, blockInfos) <- blocksByAddress) { - if (address == blockManagerId) { - localBlockIds ++= blockInfos.map(_._1) - } else { - remoteBlockIds ++= blockInfos.map(_._1) - // Make our requests at least maxBytesInFlight / 5 in length; the reason to keep them - // smaller than maxBytesInFlight is to allow multiple, parallel fetches from up to 5 - // nodes, rather than blocking on reading output from one node. - val minRequestSize = math.max(maxBytesInFlight / 5, 1L) - logInfo("maxBytesInFlight: " + maxBytesInFlight + ", minRequest: " + minRequestSize) - val iterator = blockInfos.iterator - var curRequestSize = 0L - var curBlocks = new ArrayBuffer[(String, Long)] - while (iterator.hasNext) { - val (blockId, size) = iterator.next() - curBlocks += ((blockId, size)) - curRequestSize += size - if (curRequestSize >= minRequestSize) { - // Add this FetchRequest + def splitLocalRemoteBlocks():ArrayBuffer[FetchRequest] = { + // Split local and remote blocks. Remote blocks are further split into FetchRequests of size + // at most maxBytesInFlight in order to limit the amount of data in flight. + val remoteRequests = new ArrayBuffer[FetchRequest] + for ((address, blockInfos) <- blocksByAddress) { + if (address == blockManagerId) { + localBlockIds ++= blockInfos.map(_._1) + } else { + remoteBlockIds ++= blockInfos.map(_._1) + // Make our requests at least maxBytesInFlight / 5 in length; the reason to keep them + // smaller than maxBytesInFlight is to allow multiple, parallel fetches from up to 5 + // nodes, rather than blocking on reading output from one node. + val minRequestSize = math.max(maxBytesInFlight / 5, 1L) + logInfo("maxBytesInFlight: " + maxBytesInFlight + ", minRequest: " + minRequestSize) + val iterator = blockInfos.iterator + var curRequestSize = 0L + var curBlocks = new ArrayBuffer[(String, Long)] + while (iterator.hasNext) { + val (blockId, size) = iterator.next() + curBlocks += ((blockId, size)) + curRequestSize += size + if (curRequestSize >= minRequestSize) { + // Add this FetchRequest + remoteRequests += new FetchRequest(address, curBlocks) + curRequestSize = 0 + curBlocks = new ArrayBuffer[(String, Long)] + } + } + // Add in the final request + if (!curBlocks.isEmpty) { remoteRequests += new FetchRequest(address, curBlocks) - curRequestSize = 0 - curBlocks = new ArrayBuffer[(String, Long)] } } - // Add in the final request - if (!curBlocks.isEmpty) { - remoteRequests += new FetchRequest(address, curBlocks) - } } + remoteRequests } - // Add the remote requests into our queue in a random order - fetchRequests ++= Utils.randomize(remoteRequests) - // Send out initial requests for blocks, up to our maxBytesInFlight - while (!fetchRequests.isEmpty && - (bytesInFlight == 0 || bytesInFlight + fetchRequests.front.size <= maxBytesInFlight)) { - sendRequest(fetchRequests.dequeue()) + def getLocalBlocks(){ + // Get the local blocks while remote blocks are being fetched. Note that it's okay to do + // these all at once because they will just memory-map some files, so they won't consume + // any memory that might exceed our maxBytesInFlight + for (id <- localBlockIds) { + getLocal(id) match { + case Some(iter) => { + results.put(new FetchResult(id, 0, () => iter)) // Pass 0 as size since it's not in flight + logDebug("Got local block " + id) + } + case None => { + throw new BlockException(id, "Could not get block " + id + " from local machine") + } + } + } } - val numGets = remoteBlockIds.size - fetchRequests.size - logInfo("Started " + numGets + " remote gets in " + Utils.getUsedTimeMs(startTime)) - - // Get the local blocks while remote blocks are being fetched. Note that it's okay to do - // these all at once because they will just memory-map some files, so they won't consume - // any memory that might exceed our maxBytesInFlight - startTime = System.currentTimeMillis - for (id <- localBlockIds) { - getLocal(id) match { - case Some(iter) => { - results.put(new FetchResult(id, 0, () => iter)) // Pass 0 as size since it's not in flight - logDebug("Got local block " + id) - } - case None => { - throw new BlockException(id, "Could not get block " + id + " from local machine") - } + def initialize(){ + // Split local and remote blocks. + val remoteRequests = splitLocalRemoteBlocks() + // Add the remote requests into our queue in a random order + fetchRequests ++= Utils.randomize(remoteRequests) + + // Send out initial requests for blocks, up to our maxBytesInFlight + while (!fetchRequests.isEmpty && + (bytesInFlight == 0 || bytesInFlight + fetchRequests.front.size <= maxBytesInFlight)) { + sendRequest(fetchRequests.dequeue()) } + + val numGets = remoteBlockIds.size - fetchRequests.size + logInfo("Started " + numGets + " remote gets in " + Utils.getUsedTimeMs(startTime)) + + // Get Local Blocks + startTime = System.currentTimeMillis + getLocalBlocks() + logDebug("Got local blocks in " + Utils.getUsedTimeMs(startTime) + " ms") + } - logDebug("Got local blocks in " + Utils.getUsedTimeMs(startTime) + " ms") + initialize() //an iterator that will read fetched blocks off the queue as they arrive. var resultsGotten = 0 @@ -1066,3 +1103,132 @@ class BlockFetcherIterator( def remoteBytesRead = _remoteBytesRead } + +class NettyBlockFetcherIterator( + blockManager: BlockManager, + blocksByAddress: Seq[(BlockManagerId, Seq[(String, Long)])] +) extends BlockFetcherIterator(blockManager,blocksByAddress) { + + import blockManager._ + + val fetchRequestsSync = new LinkedBlockingQueue[FetchRequest] + + def putResult(blockId:String, blockSize:Long, blockData:ByteBuffer, + results : LinkedBlockingQueue[FetchResult]){ + results.put(new FetchResult( + blockId, blockSize, () => dataDeserialize(blockId, blockData) )) + } + + def startCopiers (numCopiers: Int): List [ _ <: Thread]= { + (for ( i <- Range(0,numCopiers) ) yield { + val copier = new Thread { + override def run(){ + try { + while(!isInterrupted && !fetchRequestsSync.isEmpty) { + sendRequest(fetchRequestsSync.take()) + } + } catch { + case x: InterruptedException => logInfo("Copier Interrupted") + case _ => throw new SparkException("Exception Throw in Shuffle Copier") + } + } + } + copier.start + copier + }).toList + } + + //keep this to interrupt the threads when necessary + def stopCopiers(copiers : List[_ <: Thread]) { + for (copier <- copiers) { + copier.interrupt() + } + } + + override def sendRequest(req: FetchRequest) { + logDebug("Sending request for %d blocks (%s) from %s".format( + req.blocks.size, Utils.memoryBytesToString(req.size), req.address.ip)) + val cmId = new ConnectionManagerId(req.address.ip, System.getProperty("spark.shuffle.sender.port", "6653").toInt) + val cpier = new ShuffleCopier + cpier.getBlocks(cmId,req.blocks,(blockId:String,blockSize:Long,blockData:ByteBuf) => putResult(blockId,blockSize,blockData.nioBuffer,results)) + logDebug("Sent request for remote blocks " + req.blocks + " from " + req.address.ip ) + } + + override def splitLocalRemoteBlocks() : ArrayBuffer[FetchRequest] = { + // Split local and remote blocks. Remote blocks are further split into FetchRequests of size + // at most maxBytesInFlight in order to limit the amount of data in flight. + val originalTotalBlocks = totalBlocks; + val remoteRequests = new ArrayBuffer[FetchRequest] + for ((address, blockInfos) <- blocksByAddress) { + if (address == blockManagerId) { + localBlockIds ++= blockInfos.map(_._1) + } else { + remoteBlockIds ++= blockInfos.map(_._1) + // Make our requests at least maxBytesInFlight / 5 in length; the reason to keep them + // smaller than maxBytesInFlight is to allow multiple, parallel fetches from up to 5 + // nodes, rather than blocking on reading output from one node. + val minRequestSize = math.max(maxBytesInFlight / 5, 1L) + logInfo("maxBytesInFlight: " + maxBytesInFlight + ", minRequest: " + minRequestSize) + val iterator = blockInfos.iterator + var curRequestSize = 0L + var curBlocks = new ArrayBuffer[(String, Long)] + while (iterator.hasNext) { + val (blockId, size) = iterator.next() + if (size > 0) { + curBlocks += ((blockId, size)) + curRequestSize += size + } else if (size == 0){ + //here we changes the totalBlocks + totalBlocks -= 1 + } else { + throw new SparkException("Negative block size "+blockId) + } + if (curRequestSize >= minRequestSize) { + // Add this FetchRequest + remoteRequests += new FetchRequest(address, curBlocks) + curRequestSize = 0 + curBlocks = new ArrayBuffer[(String, Long)] + } + } + // Add in the final request + if (!curBlocks.isEmpty) { + remoteRequests += new FetchRequest(address, curBlocks) + } + } + } + logInfo("Getting " + totalBlocks + " non 0-byte blocks out of " + originalTotalBlocks + " blocks") + remoteRequests + } + + var copiers : List[_ <: Thread] = null + + override def initialize(){ + // Split Local Remote Blocks and adjust totalBlocks to include only the non 0-byte blocks + val remoteRequests = splitLocalRemoteBlocks() + // Add the remote requests into our queue in a random order + for (request <- Utils.randomize(remoteRequests)) { + fetchRequestsSync.put(request) + } + + copiers = startCopiers(System.getProperty("spark.shuffle.copier.threads", "6").toInt) + logInfo("Started " + fetchRequestsSync.size + " remote gets in " + Utils.getUsedTimeMs(startTime)) + + // Get Local Blocks + startTime = System.currentTimeMillis + getLocalBlocks() + logDebug("Got local blocks in " + Utils.getUsedTimeMs(startTime) + " ms") + } + + override def next(): (String, Option[Iterator[Any]]) = { + resultsGotten += 1 + val result = results.take() + // if all the results has been retrieved + // shutdown the copiers + if (resultsGotten == totalBlocks) { + if( copiers != null ) + stopCopiers(copiers) + } + (result.blockId, if (result.failed) None else Some(result.deserialize())) + } + } + diff --git a/core/src/main/scala/spark/storage/DiskStore.scala b/core/src/main/scala/spark/storage/DiskStore.scala index ddbf8821ad..d702bb23e0 100644 --- a/core/src/main/scala/spark/storage/DiskStore.scala +++ b/core/src/main/scala/spark/storage/DiskStore.scala @@ -13,24 +13,35 @@ import scala.collection.mutable.ArrayBuffer import spark.executor.ExecutorExitCode import spark.Utils +import spark.Logging +import spark.network.netty.ShuffleSender +import spark.network.netty.PathResolver /** * Stores BlockManager blocks on disk. */ private class DiskStore(blockManager: BlockManager, rootDirs: String) - extends BlockStore(blockManager) { + extends BlockStore(blockManager) with Logging { val MAX_DIR_CREATION_ATTEMPTS: Int = 10 val subDirsPerLocalDir = System.getProperty("spark.diskStore.subDirectories", "64").toInt + var shuffleSender : Thread = null + val thisInstance = this // Create one local directory for each path mentioned in spark.local.dir; then, inside this // directory, create multiple subdirectories that we will hash files into, in order to avoid // having really large inodes at the top level. val localDirs = createLocalDirs() val subDirs = Array.fill(localDirs.length)(new Array[File](subDirsPerLocalDir)) + val useNetty = System.getProperty("spark.shuffle.use.netty", "false").toBoolean + addShutdownHook() + if(useNetty){ + startShuffleBlockSender() + } + override def getSize(blockId: String): Long = { getFile(blockId).length() } @@ -180,10 +191,48 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String) logDebug("Shutdown hook called") try { localDirs.foreach(localDir => Utils.deleteRecursively(localDir)) + if (useNetty && shuffleSender != null) + shuffleSender.stop } catch { case t: Throwable => logError("Exception while deleting local spark dirs", t) } } }) } + + private def startShuffleBlockSender (){ + try { + val port = System.getProperty("spark.shuffle.sender.port", "6653").toInt + + val pResolver = new PathResolver { + def getAbsolutePath(blockId:String):String = { + if (!blockId.startsWith("shuffle_")) { + return null + } + thisInstance.getFile(blockId).getAbsolutePath() + } + } + shuffleSender = new Thread { + override def run() = { + val sender = new ShuffleSender(port,pResolver) + logInfo("created ShuffleSender binding to port : "+ port) + sender.start + } + } + shuffleSender.setDaemon(true) + shuffleSender.start + + } catch { + case interrupted: InterruptedException => + logInfo("Runner thread for ShuffleBlockSender interrupted") + + case e: Exception => { + logError("Error running ShuffleBlockSender ", e) + if (shuffleSender != null) { + shuffleSender.stop + shuffleSender = null + } + } + } + } } diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 5f378b2398..e3645653ee 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -141,7 +141,8 @@ object SparkBuild extends Build { "cc.spray" % "spray-can" % "1.0-M2.1", "cc.spray" % "spray-server" % "1.0-M2.1", "cc.spray" %% "spray-json" % "1.1.1", - "org.apache.mesos" % "mesos" % "0.9.0-incubating" + "org.apache.mesos" % "mesos" % "0.9.0-incubating", + "io.netty" % "netty-all" % "4.0.0.Beta2" ) ++ (if (HADOOP_MAJOR_VERSION == "2") Some("org.apache.hadoop" % "hadoop-client" % HADOOP_VERSION) else None).toSeq, unmanagedSourceDirectories in Compile <+= baseDirectory{ _ / ("src/hadoop" + HADOOP_MAJOR_VERSION + "/scala") } ) ++ assemblySettings ++ extraAssemblySettings ++ Twirl.settings diff --git a/streaming/src/main/scala/spark/streaming/util/RawTextSender.scala b/streaming/src/main/scala/spark/streaming/util/RawTextSender.scala index d8b987ec86..bd0b0e74c1 100644 --- a/streaming/src/main/scala/spark/streaming/util/RawTextSender.scala +++ b/streaming/src/main/scala/spark/streaming/util/RawTextSender.scala @@ -5,7 +5,7 @@ import spark.util.{RateLimitedOutputStream, IntParam} import java.net.ServerSocket import spark.{Logging, KryoSerializer} import it.unimi.dsi.fastutil.io.FastByteArrayOutputStream -import io.Source +import scala.io.Source import java.io.IOException /** -- cgit v1.2.3 From 6798a09df84fb97e196c84d55cf3e21ad676871f Mon Sep 17 00:00:00 2001 From: Mridul Muralidharan Date: Sun, 7 Apr 2013 17:47:38 +0530 Subject: Add support for building against hadoop2-yarn : adding new maven profile for it --- bagel/pom.xml | 37 +++++++++++ core/pom.xml | 62 +++++++++++++++++++ .../apache/hadoop/mapred/HadoopMapRedUtil.scala | 3 + .../hadoop/mapreduce/HadoopMapReduceUtil.scala | 3 + .../apache/hadoop/mapred/HadoopMapRedUtil.scala | 13 ++++ .../hadoop/mapreduce/HadoopMapReduceUtil.scala | 13 ++++ .../apache/hadoop/mapred/HadoopMapRedUtil.scala | 3 + .../hadoop/mapreduce/HadoopMapReduceUtil.scala | 3 + core/src/main/scala/spark/PairRDDFunctions.scala | 5 +- core/src/main/scala/spark/rdd/NewHadoopRDD.scala | 2 +- examples/pom.xml | 43 +++++++++++++ pom.xml | 54 ++++++++++++++++ project/SparkBuild.scala | 34 +++++++++-- repl-bin/pom.xml | 50 +++++++++++++++ repl/pom.xml | 71 ++++++++++++++++++++++ streaming/pom.xml | 37 +++++++++++ 16 files changed, 424 insertions(+), 9 deletions(-) create mode 100644 core/src/hadoop2-yarn/scala/org/apache/hadoop/mapred/HadoopMapRedUtil.scala create mode 100644 core/src/hadoop2-yarn/scala/org/apache/hadoop/mapreduce/HadoopMapReduceUtil.scala (limited to 'streaming') diff --git a/bagel/pom.xml b/bagel/pom.xml index 510cff4669..89282161ea 100644 --- a/bagel/pom.xml +++ b/bagel/pom.xml @@ -102,5 +102,42 @@ + + hadoop2-yarn + + + org.spark-project + spark-core + ${project.version} + hadoop2-yarn + + + org.apache.hadoop + hadoop-client + provided + + + org.apache.hadoop + hadoop-yarn-api + provided + + + org.apache.hadoop + hadoop-yarn-common + provided + + + + + + org.apache.maven.plugins + maven-jar-plugin + + hadoop2-yarn + + + + + diff --git a/core/pom.xml b/core/pom.xml index fe9c803728..9baa447662 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -279,5 +279,67 @@ + + hadoop2-yarn + + + org.apache.hadoop + hadoop-client + provided + + + org.apache.hadoop + hadoop-yarn-api + provided + + + org.apache.hadoop + hadoop-yarn-common + provided + + + + + + org.codehaus.mojo + build-helper-maven-plugin + + + add-source + generate-sources + + add-source + + + + src/main/scala + src/hadoop2-yarn/scala + + + + + add-scala-test-sources + generate-test-sources + + add-test-source + + + + src/test/scala + + + + + + + org.apache.maven.plugins + maven-jar-plugin + + hadoop2-yarn + + + + + diff --git a/core/src/hadoop1/scala/org/apache/hadoop/mapred/HadoopMapRedUtil.scala b/core/src/hadoop1/scala/org/apache/hadoop/mapred/HadoopMapRedUtil.scala index ca9f7219de..f286f2cf9c 100644 --- a/core/src/hadoop1/scala/org/apache/hadoop/mapred/HadoopMapRedUtil.scala +++ b/core/src/hadoop1/scala/org/apache/hadoop/mapred/HadoopMapRedUtil.scala @@ -4,4 +4,7 @@ trait HadoopMapRedUtil { def newJobContext(conf: JobConf, jobId: JobID): JobContext = new JobContext(conf, jobId) def newTaskAttemptContext(conf: JobConf, attemptId: TaskAttemptID): TaskAttemptContext = new TaskAttemptContext(conf, attemptId) + + def newTaskAttemptID(jtIdentifier: String, jobId: Int, isMap: Boolean, taskId: Int, attemptId: Int) = new TaskAttemptID(jtIdentifier, + jobId, isMap, taskId, attemptId) } diff --git a/core/src/hadoop1/scala/org/apache/hadoop/mapreduce/HadoopMapReduceUtil.scala b/core/src/hadoop1/scala/org/apache/hadoop/mapreduce/HadoopMapReduceUtil.scala index de7b0f81e3..264d421d14 100644 --- a/core/src/hadoop1/scala/org/apache/hadoop/mapreduce/HadoopMapReduceUtil.scala +++ b/core/src/hadoop1/scala/org/apache/hadoop/mapreduce/HadoopMapReduceUtil.scala @@ -6,4 +6,7 @@ trait HadoopMapReduceUtil { def newJobContext(conf: Configuration, jobId: JobID): JobContext = new JobContext(conf, jobId) def newTaskAttemptContext(conf: Configuration, attemptId: TaskAttemptID): TaskAttemptContext = new TaskAttemptContext(conf, attemptId) + + def newTaskAttemptID(jtIdentifier: String, jobId: Int, isMap: Boolean, taskId: Int, attemptId: Int) = new TaskAttemptID(jtIdentifier, + jobId, isMap, taskId, attemptId) } diff --git a/core/src/hadoop2-yarn/scala/org/apache/hadoop/mapred/HadoopMapRedUtil.scala b/core/src/hadoop2-yarn/scala/org/apache/hadoop/mapred/HadoopMapRedUtil.scala new file mode 100644 index 0000000000..875c0a220b --- /dev/null +++ b/core/src/hadoop2-yarn/scala/org/apache/hadoop/mapred/HadoopMapRedUtil.scala @@ -0,0 +1,13 @@ + +package org.apache.hadoop.mapred + +import org.apache.hadoop.mapreduce.TaskType + +trait HadoopMapRedUtil { + def newJobContext(conf: JobConf, jobId: JobID): JobContext = new JobContextImpl(conf, jobId) + + def newTaskAttemptContext(conf: JobConf, attemptId: TaskAttemptID): TaskAttemptContext = new TaskAttemptContextImpl(conf, attemptId) + + def newTaskAttemptID(jtIdentifier: String, jobId: Int, isMap: Boolean, taskId: Int, attemptId: Int) = + new TaskAttemptID(jtIdentifier, jobId, if (isMap) TaskType.MAP else TaskType.REDUCE, taskId, attemptId) +} diff --git a/core/src/hadoop2-yarn/scala/org/apache/hadoop/mapreduce/HadoopMapReduceUtil.scala b/core/src/hadoop2-yarn/scala/org/apache/hadoop/mapreduce/HadoopMapReduceUtil.scala new file mode 100644 index 0000000000..8bc6fb6dea --- /dev/null +++ b/core/src/hadoop2-yarn/scala/org/apache/hadoop/mapreduce/HadoopMapReduceUtil.scala @@ -0,0 +1,13 @@ +package org.apache.hadoop.mapreduce + +import org.apache.hadoop.conf.Configuration +import task.{TaskAttemptContextImpl, JobContextImpl} + +trait HadoopMapReduceUtil { + def newJobContext(conf: Configuration, jobId: JobID): JobContext = new JobContextImpl(conf, jobId) + + def newTaskAttemptContext(conf: Configuration, attemptId: TaskAttemptID): TaskAttemptContext = new TaskAttemptContextImpl(conf, attemptId) + + def newTaskAttemptID(jtIdentifier: String, jobId: Int, isMap: Boolean, taskId: Int, attemptId: Int) = + new TaskAttemptID(jtIdentifier, jobId, if (isMap) TaskType.MAP else TaskType.REDUCE, taskId, attemptId) +} diff --git a/core/src/hadoop2/scala/org/apache/hadoop/mapred/HadoopMapRedUtil.scala b/core/src/hadoop2/scala/org/apache/hadoop/mapred/HadoopMapRedUtil.scala index 35300cea58..a0652d7fc7 100644 --- a/core/src/hadoop2/scala/org/apache/hadoop/mapred/HadoopMapRedUtil.scala +++ b/core/src/hadoop2/scala/org/apache/hadoop/mapred/HadoopMapRedUtil.scala @@ -4,4 +4,7 @@ trait HadoopMapRedUtil { def newJobContext(conf: JobConf, jobId: JobID): JobContext = new JobContextImpl(conf, jobId) def newTaskAttemptContext(conf: JobConf, attemptId: TaskAttemptID): TaskAttemptContext = new TaskAttemptContextImpl(conf, attemptId) + + def newTaskAttemptID(jtIdentifier: String, jobId: Int, isMap: Boolean, taskId: Int, attemptId: Int) = new TaskAttemptID(jtIdentifier, + jobId, isMap, taskId, attemptId) } diff --git a/core/src/hadoop2/scala/org/apache/hadoop/mapreduce/HadoopMapReduceUtil.scala b/core/src/hadoop2/scala/org/apache/hadoop/mapreduce/HadoopMapReduceUtil.scala index 7afdbff320..7fdbe322fd 100644 --- a/core/src/hadoop2/scala/org/apache/hadoop/mapreduce/HadoopMapReduceUtil.scala +++ b/core/src/hadoop2/scala/org/apache/hadoop/mapreduce/HadoopMapReduceUtil.scala @@ -7,4 +7,7 @@ trait HadoopMapReduceUtil { def newJobContext(conf: Configuration, jobId: JobID): JobContext = new JobContextImpl(conf, jobId) def newTaskAttemptContext(conf: Configuration, attemptId: TaskAttemptID): TaskAttemptContext = new TaskAttemptContextImpl(conf, attemptId) + + def newTaskAttemptID(jtIdentifier: String, jobId: Int, isMap: Boolean, taskId: Int, attemptId: Int) = new TaskAttemptID(jtIdentifier, + jobId, isMap, taskId, attemptId) } diff --git a/core/src/main/scala/spark/PairRDDFunctions.scala b/core/src/main/scala/spark/PairRDDFunctions.scala index 07efba9e8d..39469fa3c8 100644 --- a/core/src/main/scala/spark/PairRDDFunctions.scala +++ b/core/src/main/scala/spark/PairRDDFunctions.scala @@ -545,8 +545,7 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest]( // around by taking a mod. We expect that no task will be attempted 2 billion times. val attemptNumber = (context.attemptId % Int.MaxValue).toInt /* "reduce task" */ - val attemptId = new TaskAttemptID(jobtrackerID, - stageId, false, context.splitId, attemptNumber) + val attemptId = newTaskAttemptID(jobtrackerID, stageId, false, context.splitId, attemptNumber) val hadoopContext = newTaskAttemptContext(wrappedConf.value, attemptId) val format = outputFormatClass.newInstance val committer = format.getOutputCommitter(hadoopContext) @@ -565,7 +564,7 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest]( * however we're only going to use this local OutputCommitter for * setupJob/commitJob, so we just use a dummy "map" task. */ - val jobAttemptId = new TaskAttemptID(jobtrackerID, stageId, true, 0, 0) + val jobAttemptId = newTaskAttemptID(jobtrackerID, stageId, true, 0, 0) val jobTaskContext = newTaskAttemptContext(wrappedConf.value, jobAttemptId) val jobCommitter = jobFormat.getOutputCommitter(jobTaskContext) jobCommitter.setupJob(jobTaskContext) diff --git a/core/src/main/scala/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/spark/rdd/NewHadoopRDD.scala index bdd974590a..901d01ef30 100644 --- a/core/src/main/scala/spark/rdd/NewHadoopRDD.scala +++ b/core/src/main/scala/spark/rdd/NewHadoopRDD.scala @@ -57,7 +57,7 @@ class NewHadoopRDD[K, V]( override def compute(theSplit: Partition, context: TaskContext) = new Iterator[(K, V)] { val split = theSplit.asInstanceOf[NewHadoopPartition] val conf = confBroadcast.value.value - val attemptId = new TaskAttemptID(jobtrackerId, id, true, split.index, 0) + val attemptId = newTaskAttemptID(jobtrackerId, id, true, split.index, 0) val hadoopAttemptContext = newTaskAttemptContext(conf, attemptId) val format = inputFormatClass.newInstance if (format.isInstanceOf[Configurable]) { diff --git a/examples/pom.xml b/examples/pom.xml index 39cc47c709..9594257ad4 100644 --- a/examples/pom.xml +++ b/examples/pom.xml @@ -118,5 +118,48 @@ + + hadoop2-yarn + + + org.spark-project + spark-core + ${project.version} + hadoop2-yarn + + + org.spark-project + spark-streaming + ${project.version} + hadoop2-yarn + + + org.apache.hadoop + hadoop-client + provided + + + org.apache.hadoop + hadoop-yarn-api + provided + + + org.apache.hadoop + hadoop-yarn-common + provided + + + + + + org.apache.maven.plugins + maven-jar-plugin + + hadoop2-yarn + + + + + diff --git a/pom.xml b/pom.xml index 08d1fc12e0..b3134a957d 100644 --- a/pom.xml +++ b/pom.xml @@ -558,5 +558,59 @@ + + + hadoop2-yarn + + 2 + 2.0.3-alpha + + + + + maven-root + Maven root repository + http://repo1.maven.org/maven2/ + + true + + + false + + + + + + + + + org.apache.hadoop + hadoop-client + ${yarn.version} + + + org.apache.hadoop + hadoop-yarn-api + ${yarn.version} + + + org.apache.hadoop + hadoop-yarn-common + ${yarn.version} + + + + org.apache.avro + avro + 1.7.1.cloudera.2 + + + org.apache.avro + avro-ipc + 1.7.1.cloudera.2 + + + + diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 5f378b2398..f041930b4e 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -1,3 +1,4 @@ + import sbt._ import sbt.Classpaths.publishTask import Keys._ @@ -10,12 +11,18 @@ import twirl.sbt.TwirlPlugin._ object SparkBuild extends Build { // Hadoop version to build against. For example, "0.20.2", "0.20.205.0", or // "1.0.4" for Apache releases, or "0.20.2-cdh3u5" for Cloudera Hadoop. - val HADOOP_VERSION = "1.0.4" - val HADOOP_MAJOR_VERSION = "1" + //val HADOOP_VERSION = "1.0.4" + //val HADOOP_MAJOR_VERSION = "1" // For Hadoop 2 versions such as "2.0.0-mr1-cdh4.1.1", set the HADOOP_MAJOR_VERSION to "2" //val HADOOP_VERSION = "2.0.0-mr1-cdh4.1.1" //val HADOOP_MAJOR_VERSION = "2" + //val HADOOP_YARN = false + + // For Hadoop 2 YARN support + val HADOOP_VERSION = "2.0.3-alpha" + val HADOOP_MAJOR_VERSION = "2" + val HADOOP_YARN = true lazy val root = Project("root", file("."), settings = rootSettings) aggregate(core, repl, examples, bagel, streaming) @@ -129,7 +136,6 @@ object SparkBuild extends Build { "org.slf4j" % "slf4j-api" % slf4jVersion, "org.slf4j" % "slf4j-log4j12" % slf4jVersion, "com.ning" % "compress-lzf" % "0.8.4", - "org.apache.hadoop" % "hadoop-core" % HADOOP_VERSION, "asm" % "asm-all" % "3.3.1", "com.google.protobuf" % "protobuf-java" % "2.4.1", "de.javakaffee" % "kryo-serializers" % "0.22", @@ -142,8 +148,26 @@ object SparkBuild extends Build { "cc.spray" % "spray-server" % "1.0-M2.1", "cc.spray" %% "spray-json" % "1.1.1", "org.apache.mesos" % "mesos" % "0.9.0-incubating" - ) ++ (if (HADOOP_MAJOR_VERSION == "2") Some("org.apache.hadoop" % "hadoop-client" % HADOOP_VERSION) else None).toSeq, - unmanagedSourceDirectories in Compile <+= baseDirectory{ _ / ("src/hadoop" + HADOOP_MAJOR_VERSION + "/scala") } + ) ++ ( + if (HADOOP_MAJOR_VERSION == "2") { + if (HADOOP_YARN) { + Seq( + "org.apache.hadoop" % "hadoop-client" % HADOOP_VERSION, + "org.apache.hadoop" % "hadoop-yarn-api" % HADOOP_VERSION, + "org.apache.hadoop" % "hadoop-yarn-common" % HADOOP_VERSION + ) + } else { + Seq( + "org.apache.hadoop" % "hadoop-core" % HADOOP_VERSION, + "org.apache.hadoop" % "hadoop-client" % HADOOP_VERSION + ) + } + } else { + Seq("org.apache.hadoop" % "hadoop-core" % HADOOP_VERSION) + }), + unmanagedSourceDirectories in Compile <+= baseDirectory{ _ / + ( if (HADOOP_YARN && HADOOP_MAJOR_VERSION == "2") "src/hadoop2-yarn/scala" else "src/hadoop" + HADOOP_MAJOR_VERSION + "/scala" ) + } ) ++ assemblySettings ++ extraAssemblySettings ++ Twirl.settings def rootSettings = sharedSettings ++ Seq( diff --git a/repl-bin/pom.xml b/repl-bin/pom.xml index dd720e2291..f9d84fd3c4 100644 --- a/repl-bin/pom.xml +++ b/repl-bin/pom.xml @@ -153,6 +153,56 @@ + + hadoop2-yarn + + hadoop2-yarn + + + + org.spark-project + spark-core + ${project.version} + hadoop2-yarn + + + org.spark-project + spark-bagel + ${project.version} + hadoop2-yarn + runtime + + + org.spark-project + spark-examples + ${project.version} + hadoop2-yarn + runtime + + + org.spark-project + spark-repl + ${project.version} + hadoop2-yarn + runtime + + + org.apache.hadoop + hadoop-client + provided + + + org.apache.hadoop + hadoop-yarn-api + provided + + + org.apache.hadoop + hadoop-yarn-common + provided + + + deb diff --git a/repl/pom.xml b/repl/pom.xml index a3e4606edc..1f885673f4 100644 --- a/repl/pom.xml +++ b/repl/pom.xml @@ -187,5 +187,76 @@ + + hadoop2-yarn + + hadoop2-yarn + + + + org.spark-project + spark-core + ${project.version} + hadoop2-yarn + + + org.spark-project + spark-bagel + ${project.version} + hadoop2-yarn + runtime + + + org.spark-project + spark-examples + ${project.version} + hadoop2-yarn + runtime + + + org.spark-project + spark-streaming + ${project.version} + hadoop2-yarn + runtime + + + org.apache.hadoop + hadoop-client + provided + + + org.apache.hadoop + hadoop-yarn-api + provided + + + org.apache.hadoop + hadoop-yarn-common + provided + + + org.apache.avro + avro + provided + + + org.apache.avro + avro-ipc + provided + + + + + + org.apache.maven.plugins + maven-jar-plugin + + hadoop2-yarn + + + + + diff --git a/streaming/pom.xml b/streaming/pom.xml index ec077e8089..fc2e211a42 100644 --- a/streaming/pom.xml +++ b/streaming/pom.xml @@ -149,5 +149,42 @@ + + hadoop2-yarn + + + org.spark-project + spark-core + ${project.version} + hadoop2-yarn + + + org.apache.hadoop + hadoop-client + provided + + + org.apache.hadoop + hadoop-yarn-api + provided + + + org.apache.hadoop + hadoop-yarn-common + provided + + + + + + org.apache.maven.plugins + maven-jar-plugin + + hadoop2-yarn + + + + + -- cgit v1.2.3 From b42d68c8ce9f63513969297b65f4b5a2b06e6078 Mon Sep 17 00:00:00 2001 From: seanm Date: Mon, 15 Apr 2013 12:54:55 -0600 Subject: fixing Spark Streaming count() so that 0 will be emitted when there is nothing to count --- streaming/src/main/scala/spark/streaming/DStream.scala | 5 ++++- streaming/src/test/scala/spark/streaming/BasicOperationsSuite.scala | 4 ++-- 2 files changed, 6 insertions(+), 3 deletions(-) (limited to 'streaming') diff --git a/streaming/src/main/scala/spark/streaming/DStream.scala b/streaming/src/main/scala/spark/streaming/DStream.scala index e1be5ef51c..e3a9247924 100644 --- a/streaming/src/main/scala/spark/streaming/DStream.scala +++ b/streaming/src/main/scala/spark/streaming/DStream.scala @@ -441,7 +441,10 @@ abstract class DStream[T: ClassManifest] ( * Return a new DStream in which each RDD has a single element generated by counting each RDD * of this DStream. */ - def count(): DStream[Long] = this.map(_ => 1L).reduce(_ + _) + def count(): DStream[Long] = { + val zero = new ConstantInputDStream(context, context.sparkContext.makeRDD(Seq((null, 0L)), 1)) + this.map(_ => (null, 1L)).union(zero).reduceByKey(_ + _).map(_._2) + } /** * Return a new DStream in which each RDD contains the counts of each distinct value in diff --git a/streaming/src/test/scala/spark/streaming/BasicOperationsSuite.scala b/streaming/src/test/scala/spark/streaming/BasicOperationsSuite.scala index 8fce91853c..168e1b7a55 100644 --- a/streaming/src/test/scala/spark/streaming/BasicOperationsSuite.scala +++ b/streaming/src/test/scala/spark/streaming/BasicOperationsSuite.scala @@ -90,9 +90,9 @@ class BasicOperationsSuite extends TestSuiteBase { test("count") { testOperation( - Seq(1 to 1, 1 to 2, 1 to 3, 1 to 4), + Seq(Seq(), 1 to 1, 1 to 2, 1 to 3, 1 to 4), (s: DStream[Int]) => s.count(), - Seq(Seq(1L), Seq(2L), Seq(3L), Seq(4L)) + Seq(Seq(0L), Seq(1L), Seq(2L), Seq(3L), Seq(4L)) ) } -- cgit v1.2.3 From ab0f834dbb509d323577572691293b74368a9d86 Mon Sep 17 00:00:00 2001 From: seanm Date: Tue, 16 Apr 2013 11:57:05 -0600 Subject: adding spark.streaming.blockInterval property --- docs/configuration.md | 7 +++++++ .../main/scala/spark/streaming/dstream/NetworkInputDStream.scala | 2 +- 2 files changed, 8 insertions(+), 1 deletion(-) (limited to 'streaming') diff --git a/docs/configuration.md b/docs/configuration.md index 04eb6daaa5..55f1962b18 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -253,6 +253,13 @@ Apart from these, the following properties are also available, and may be useful applications). Note that any RDD that persists in memory for more than this duration will be cleared as well. + + spark.streaming.blockInterval + 200 + + Duration (milliseconds) of how long to batch new objects coming from network receivers. + + diff --git a/streaming/src/main/scala/spark/streaming/dstream/NetworkInputDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/NetworkInputDStream.scala index 7385474963..26805e9621 100644 --- a/streaming/src/main/scala/spark/streaming/dstream/NetworkInputDStream.scala +++ b/streaming/src/main/scala/spark/streaming/dstream/NetworkInputDStream.scala @@ -198,7 +198,7 @@ abstract class NetworkReceiver[T: ClassManifest]() extends Serializable with Log case class Block(id: String, iterator: Iterator[T], metadata: Any = null) val clock = new SystemClock() - val blockInterval = 200L + val blockInterval = System.getProperty("spark.streaming.blockInterval", "200").toLong val blockIntervalTimer = new RecurringTimer(clock, blockInterval, updateCurrentBuffer) val blockStorageLevel = storageLevel val blocksForPushing = new ArrayBlockingQueue[Block](1000) -- cgit v1.2.3 From 7e56e99573b4cf161293e648aeb159375c9c0fcb Mon Sep 17 00:00:00 2001 From: seanm Date: Sun, 24 Mar 2013 13:40:19 -0600 Subject: Surfacing decoders on KafkaInputDStream --- .../spark/streaming/examples/KafkaWordCount.scala | 2 +- .../scala/spark/streaming/StreamingContext.scala | 11 ++++---- .../streaming/api/java/JavaStreamingContext.scala | 33 ++++++++++++++++------ .../streaming/dstream/KafkaInputDStream.scala | 17 ++++++----- .../test/java/spark/streaming/JavaAPISuite.java | 6 ++-- 5 files changed, 44 insertions(+), 25 deletions(-) (limited to 'streaming') diff --git a/examples/src/main/scala/spark/streaming/examples/KafkaWordCount.scala b/examples/src/main/scala/spark/streaming/examples/KafkaWordCount.scala index 9b135a5c54..e0c3555f21 100644 --- a/examples/src/main/scala/spark/streaming/examples/KafkaWordCount.scala +++ b/examples/src/main/scala/spark/streaming/examples/KafkaWordCount.scala @@ -37,7 +37,7 @@ object KafkaWordCount { ssc.checkpoint("checkpoint") val topicpMap = topics.split(",").map((_,numThreads.toInt)).toMap - val lines = ssc.kafkaStream[String](zkQuorum, group, topicpMap) + val lines = ssc.kafkaStream(zkQuorum, group, topicpMap) val words = lines.flatMap(_.split(" ")) val wordCounts = words.map(x => (x, 1l)).reduceByKeyAndWindow(add _, subtract _, Minutes(10), Seconds(2), 2) wordCounts.print() diff --git a/streaming/src/main/scala/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/spark/streaming/StreamingContext.scala index bb7f216ca7..2c6326943d 100644 --- a/streaming/src/main/scala/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/spark/streaming/StreamingContext.scala @@ -28,6 +28,7 @@ import org.apache.hadoop.fs.Path import java.util.UUID import twitter4j.Status + /** * A StreamingContext is the main entry point for Spark Streaming functionality. Besides the basic * information (such as, cluster URL and job name) to internally create a SparkContext, it provides @@ -207,14 +208,14 @@ class StreamingContext private ( * @param storageLevel Storage level to use for storing the received objects * (default: StorageLevel.MEMORY_AND_DISK_SER_2) */ - def kafkaStream[T: ClassManifest]( + def kafkaStream( zkQuorum: String, groupId: String, topics: Map[String, Int], storageLevel: StorageLevel = StorageLevel.MEMORY_ONLY_SER_2 - ): DStream[T] = { + ): DStream[String] = { val kafkaParams = Map[String, String]("zk.connect" -> zkQuorum, "groupid" -> groupId, "zk.connectiontimeout.ms" -> "10000"); - kafkaStream[T](kafkaParams, topics, storageLevel) + kafkaStream[String, kafka.serializer.StringDecoder](kafkaParams, topics, storageLevel) } /** @@ -224,12 +225,12 @@ class StreamingContext private ( * in its own thread. * @param storageLevel Storage level to use for storing the received objects */ - def kafkaStream[T: ClassManifest]( + def kafkaStream[T: ClassManifest, D <: kafka.serializer.Decoder[_]: Manifest]( kafkaParams: Map[String, String], topics: Map[String, Int], storageLevel: StorageLevel ): DStream[T] = { - val inputStream = new KafkaInputDStream[T](this, kafkaParams, topics, storageLevel) + val inputStream = new KafkaInputDStream[T, D](this, kafkaParams, topics, storageLevel) registerInputStream(inputStream) inputStream } diff --git a/streaming/src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala b/streaming/src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala index 7a8864614c..13427873ff 100644 --- a/streaming/src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala +++ b/streaming/src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala @@ -68,33 +68,50 @@ class JavaStreamingContext(val ssc: StreamingContext) { * @param topics Map of (topic_name -> numPartitions) to consume. Each partition is consumed * in its own thread. */ - def kafkaStream[T]( + def kafkaStream( zkQuorum: String, groupId: String, topics: JMap[String, JInt]) - : JavaDStream[T] = { - implicit val cmt: ClassManifest[T] = - implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]] - ssc.kafkaStream[T](zkQuorum, groupId, Map(topics.mapValues(_.intValue()).toSeq: _*)) + : JavaDStream[String] = { + implicit val cmt: ClassManifest[String] = + implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[String]] + ssc.kafkaStream(zkQuorum, groupId, Map(topics.mapValues(_.intValue()).toSeq: _*), StorageLevel.MEMORY_ONLY_SER_2) } /** * Create an input stream that pulls messages form a Kafka Broker. - * @param kafkaParams Map of kafka configuration paramaters. See: http://kafka.apache.org/configuration.html * @param zkQuorum Zookeper quorum (hostname:port,hostname:port,..). * @param groupId The group id for this consumer. * @param topics Map of (topic_name -> numPartitions) to consume. Each partition is consumed + * @param storageLevel RDD storage level. Defaults to memory-only + * in its own thread. + */ + def kafkaStream( + zkQuorum: String, + groupId: String, + topics: JMap[String, JInt], + storageLevel: StorageLevel) + : JavaDStream[String] = { + implicit val cmt: ClassManifest[String] = + implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[String]] + ssc.kafkaStream(zkQuorum, groupId, Map(topics.mapValues(_.intValue()).toSeq: _*), storageLevel) + } + + /** + * Create an input stream that pulls messages form a Kafka Broker. + * @param kafkaParams Map of kafka configuration paramaters. See: http://kafka.apache.org/configuration.html + * @param topics Map of (topic_name -> numPartitions) to consume. Each partition is consumed * in its own thread. * @param storageLevel RDD storage level. Defaults to memory-only */ - def kafkaStream[T]( + def kafkaStream[T, D <: kafka.serializer.Decoder[_]: Manifest]( kafkaParams: JMap[String, String], topics: JMap[String, JInt], storageLevel: StorageLevel) : JavaDStream[T] = { implicit val cmt: ClassManifest[T] = implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]] - ssc.kafkaStream[T]( + ssc.kafkaStream[T, D]( kafkaParams.toMap, Map(topics.mapValues(_.intValue()).toSeq: _*), storageLevel) diff --git a/streaming/src/main/scala/spark/streaming/dstream/KafkaInputDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/KafkaInputDStream.scala index 17a5be3420..7bd53fb6dd 100644 --- a/streaming/src/main/scala/spark/streaming/dstream/KafkaInputDStream.scala +++ b/streaming/src/main/scala/spark/streaming/dstream/KafkaInputDStream.scala @@ -9,7 +9,7 @@ import java.util.concurrent.Executors import kafka.consumer._ import kafka.message.{Message, MessageSet, MessageAndMetadata} -import kafka.serializer.StringDecoder +import kafka.serializer.Decoder import kafka.utils.{Utils, ZKGroupTopicDirs} import kafka.utils.ZkUtils._ import kafka.utils.ZKStringSerializer @@ -28,7 +28,7 @@ import scala.collection.JavaConversions._ * @param storageLevel RDD storage level. */ private[streaming] -class KafkaInputDStream[T: ClassManifest]( +class KafkaInputDStream[T: ClassManifest, D <: Decoder[_]: Manifest]( @transient ssc_ : StreamingContext, kafkaParams: Map[String, String], topics: Map[String, Int], @@ -37,15 +37,17 @@ class KafkaInputDStream[T: ClassManifest]( def getReceiver(): NetworkReceiver[T] = { - new KafkaReceiver(kafkaParams, topics, storageLevel) + new KafkaReceiver[T, D](kafkaParams, topics, storageLevel) .asInstanceOf[NetworkReceiver[T]] } } private[streaming] -class KafkaReceiver(kafkaParams: Map[String, String], +class KafkaReceiver[T: ClassManifest, D <: Decoder[_]: Manifest]( + kafkaParams: Map[String, String], topics: Map[String, Int], - storageLevel: StorageLevel) extends NetworkReceiver[Any] { + storageLevel: StorageLevel + ) extends NetworkReceiver[Any] { // Handles pushing data into the BlockManager lazy protected val blockGenerator = new BlockGenerator(storageLevel) @@ -82,7 +84,8 @@ class KafkaReceiver(kafkaParams: Map[String, String], } // Create Threads for each Topic/Message Stream we are listening - val topicMessageStreams = consumerConnector.createMessageStreams(topics, new StringDecoder()) + val decoder = manifest[D].erasure.newInstance.asInstanceOf[Decoder[T]] + val topicMessageStreams = consumerConnector.createMessageStreams(topics, decoder) // Start the messages handler for each partition topicMessageStreams.values.foreach { streams => @@ -91,7 +94,7 @@ class KafkaReceiver(kafkaParams: Map[String, String], } // Handles Kafka Messages - private class MessageHandler(stream: KafkaStream[String]) extends Runnable { + private class MessageHandler[T: ClassManifest](stream: KafkaStream[T]) extends Runnable { def run() { logInfo("Starting MessageHandler.") for (msgAndMetadata <- stream) { diff --git a/streaming/src/test/java/spark/streaming/JavaAPISuite.java b/streaming/src/test/java/spark/streaming/JavaAPISuite.java index 3bed500f73..61e4c0a207 100644 --- a/streaming/src/test/java/spark/streaming/JavaAPISuite.java +++ b/streaming/src/test/java/spark/streaming/JavaAPISuite.java @@ -23,7 +23,6 @@ import spark.streaming.api.java.JavaPairDStream; import spark.streaming.api.java.JavaStreamingContext; import spark.streaming.JavaTestUtils; import spark.streaming.JavaCheckpointTestUtils; -import spark.streaming.dstream.KafkaPartitionKey; import spark.streaming.InputStreamsSuite; import java.io.*; @@ -1203,10 +1202,9 @@ public class JavaAPISuite implements Serializable { @Test public void testKafkaStream() { HashMap topics = Maps.newHashMap(); - HashMap offsets = Maps.newHashMap(); JavaDStream test1 = ssc.kafkaStream("localhost:12345", "group", topics); - JavaDStream test2 = ssc.kafkaStream("localhost:12345", "group", topics, offsets); - JavaDStream test3 = ssc.kafkaStream("localhost:12345", "group", topics, offsets, + JavaDStream test2 = ssc.kafkaStream("localhost:12345", "group", topics); + JavaDStream test3 = ssc.kafkaStream("localhost:12345", "group", topics, StorageLevel.MEMORY_AND_DISK()); } -- cgit v1.2.3 From afee9024430ef79cc0840a5e5788b60c8c53f9d2 Mon Sep 17 00:00:00 2001 From: Mridul Muralidharan Date: Sun, 28 Apr 2013 22:26:45 +0530 Subject: Attempt to fix streaming test failures after yarn branch merge --- bagel/src/test/scala/bagel/BagelSuite.scala | 1 + core/src/test/scala/spark/LocalSparkContext.scala | 3 ++- repl/src/test/scala/spark/repl/ReplSuite.scala | 1 + .../main/scala/spark/streaming/Checkpoint.scala | 30 +++++++++++++++++----- .../spark/streaming/util/MasterFailureTest.scala | 8 +++++- .../spark/streaming/BasicOperationsSuite.scala | 1 + .../scala/spark/streaming/CheckpointSuite.scala | 4 ++- .../test/scala/spark/streaming/FailureSuite.scala | 2 ++ .../scala/spark/streaming/InputStreamsSuite.scala | 1 + .../spark/streaming/WindowOperationsSuite.scala | 1 + 10 files changed, 42 insertions(+), 10 deletions(-) (limited to 'streaming') diff --git a/bagel/src/test/scala/bagel/BagelSuite.scala b/bagel/src/test/scala/bagel/BagelSuite.scala index 25db395c22..a09c978068 100644 --- a/bagel/src/test/scala/bagel/BagelSuite.scala +++ b/bagel/src/test/scala/bagel/BagelSuite.scala @@ -23,6 +23,7 @@ class BagelSuite extends FunSuite with Assertions with BeforeAndAfter with Timeo } // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown System.clearProperty("spark.driver.port") + System.clearProperty("spark.hostPort") } test("halting by voting") { diff --git a/core/src/test/scala/spark/LocalSparkContext.scala b/core/src/test/scala/spark/LocalSparkContext.scala index ff00dd05dd..76d5258b02 100644 --- a/core/src/test/scala/spark/LocalSparkContext.scala +++ b/core/src/test/scala/spark/LocalSparkContext.scala @@ -27,6 +27,7 @@ object LocalSparkContext { sc.stop() // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown System.clearProperty("spark.driver.port") + System.clearProperty("spark.hostPort") } /** Runs `f` by passing in `sc` and ensures that `sc` is stopped. */ @@ -38,4 +39,4 @@ object LocalSparkContext { } } -} \ No newline at end of file +} diff --git a/repl/src/test/scala/spark/repl/ReplSuite.scala b/repl/src/test/scala/spark/repl/ReplSuite.scala index 43559b96d3..1c64f9b98d 100644 --- a/repl/src/test/scala/spark/repl/ReplSuite.scala +++ b/repl/src/test/scala/spark/repl/ReplSuite.scala @@ -32,6 +32,7 @@ class ReplSuite extends FunSuite { interp.sparkContext.stop() // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown System.clearProperty("spark.driver.port") + System.clearProperty("spark.hostPort") return out.toString } diff --git a/streaming/src/main/scala/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/spark/streaming/Checkpoint.scala index e303e33e5e..7bd104b8d5 100644 --- a/streaming/src/main/scala/spark/streaming/Checkpoint.scala +++ b/streaming/src/main/scala/spark/streaming/Checkpoint.scala @@ -38,28 +38,43 @@ class Checkpoint(@transient ssc: StreamingContext, val checkpointTime: Time) private[streaming] class CheckpointWriter(checkpointDir: String) extends Logging { val file = new Path(checkpointDir, "graph") + // The file to which we actually write - and then "move" to file. + private val writeFile = new Path(file.getParent, file.getName + ".next") + private val bakFile = new Path(file.getParent, file.getName + ".bk") + + @volatile private var stopped = false + val conf = new Configuration() var fs = file.getFileSystem(conf) val maxAttempts = 3 val executor = Executors.newFixedThreadPool(1) + // Removed code which validates whether there is only one CheckpointWriter per path 'file' since + // I did not notice any errors - reintroduce it ? + class CheckpointWriteHandler(checkpointTime: Time, bytes: Array[Byte]) extends Runnable { def run() { var attempts = 0 val startTime = System.currentTimeMillis() while (attempts < maxAttempts) { + if (stopped) { + logInfo("Already stopped, ignore checkpoint attempt for " + file) + return + } attempts += 1 try { logDebug("Saving checkpoint for time " + checkpointTime + " to file '" + file + "'") - if (fs.exists(file)) { - val bkFile = new Path(file.getParent, file.getName + ".bk") - FileUtil.copy(fs, file, fs, bkFile, true, true, conf) - logDebug("Moved existing checkpoint file to " + bkFile) - } - val fos = fs.create(file) + // This is inherently thread unsafe .. so alleviating it by writing to '.new' and then doing moves : which should be pretty fast. + val fos = fs.create(writeFile) fos.write(bytes) fos.close() - fos.close() + if (fs.exists(file) && fs.rename(file, bakFile)) { + logDebug("Moved existing checkpoint file to " + bakFile) + } + // paranoia + fs.delete(file, false) + fs.rename(writeFile, file) + val finishTime = System.currentTimeMillis(); logInfo("Checkpoint for time " + checkpointTime + " saved to file '" + file + "', took " + bytes.length + " bytes and " + (finishTime - startTime) + " milliseconds") @@ -84,6 +99,7 @@ class CheckpointWriter(checkpointDir: String) extends Logging { } def stop() { + stopped = true executor.shutdown() } } diff --git a/streaming/src/main/scala/spark/streaming/util/MasterFailureTest.scala b/streaming/src/main/scala/spark/streaming/util/MasterFailureTest.scala index f673e5be15..e7a3f92bc0 100644 --- a/streaming/src/main/scala/spark/streaming/util/MasterFailureTest.scala +++ b/streaming/src/main/scala/spark/streaming/util/MasterFailureTest.scala @@ -74,6 +74,7 @@ object MasterFailureTest extends Logging { val operation = (st: DStream[String]) => { val updateFunc = (values: Seq[Long], state: Option[Long]) => { + logInfo("UpdateFunc .. state = " + state.getOrElse(0L) + ", values = " + values) Some(values.foldLeft(0L)(_ + _) + state.getOrElse(0L)) } st.flatMap(_.split(" ")) @@ -159,6 +160,7 @@ object MasterFailureTest extends Logging { // Setup the streaming computation with the given operation System.clearProperty("spark.driver.port") + System.clearProperty("spark.hostPort") var ssc = new StreamingContext("local[4]", "MasterFailureTest", batchDuration, null, Nil, Map()) ssc.checkpoint(checkpointDir.toString) val inputStream = ssc.textFileStream(testDir.toString) @@ -205,6 +207,7 @@ object MasterFailureTest extends Logging { // (iii) Its not timed out yet System.clearProperty("spark.streaming.clock") System.clearProperty("spark.driver.port") + System.clearProperty("spark.hostPort") ssc.start() val startTime = System.currentTimeMillis() while (!killed && !isLastOutputGenerated && !isTimedOut) { @@ -357,13 +360,16 @@ class FileGeneratingThread(input: Seq[String], testDir: Path, interval: Long) // Write the data to a local file and then move it to the target test directory val localFile = new File(localTestDir, (i+1).toString) val hadoopFile = new Path(testDir, (i+1).toString) + val tempHadoopFile = new Path(testDir, ".tmp_" + (i+1).toString) FileUtils.writeStringToFile(localFile, input(i).toString + "\n") var tries = 0 var done = false while (!done && tries < maxTries) { tries += 1 try { - fs.copyFromLocalFile(new Path(localFile.toString), hadoopFile) + // fs.copyFromLocalFile(new Path(localFile.toString), hadoopFile) + fs.copyFromLocalFile(new Path(localFile.toString), tempHadoopFile) + fs.rename(tempHadoopFile, hadoopFile) done = true } catch { case ioe: IOException => { diff --git a/streaming/src/test/scala/spark/streaming/BasicOperationsSuite.scala b/streaming/src/test/scala/spark/streaming/BasicOperationsSuite.scala index cf2ed8b1d4..e7352deb81 100644 --- a/streaming/src/test/scala/spark/streaming/BasicOperationsSuite.scala +++ b/streaming/src/test/scala/spark/streaming/BasicOperationsSuite.scala @@ -15,6 +15,7 @@ class BasicOperationsSuite extends TestSuiteBase { after { // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown System.clearProperty("spark.driver.port") + System.clearProperty("spark.hostPort") } test("map") { diff --git a/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala b/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala index cac86deeaf..607dea77ec 100644 --- a/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala +++ b/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala @@ -31,6 +31,7 @@ class CheckpointSuite extends TestSuiteBase with BeforeAndAfter { // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown System.clearProperty("spark.driver.port") + System.clearProperty("spark.hostPort") } var ssc: StreamingContext = null @@ -325,6 +326,7 @@ class CheckpointSuite extends TestSuiteBase with BeforeAndAfter { ) ssc = new StreamingContext(checkpointDir) System.clearProperty("spark.driver.port") + System.clearProperty("spark.hostPort") ssc.start() val outputNew = advanceTimeWithRealDelay[V](ssc, nextNumBatches) // the first element will be re-processed data of the last batch before restart @@ -350,4 +352,4 @@ class CheckpointSuite extends TestSuiteBase with BeforeAndAfter { val outputStream = ssc.graph.getOutputStreams.head.asInstanceOf[TestOutputStream[V]] outputStream.output } -} \ No newline at end of file +} diff --git a/streaming/src/test/scala/spark/streaming/FailureSuite.scala b/streaming/src/test/scala/spark/streaming/FailureSuite.scala index a5fa7ab92d..4529e774e9 100644 --- a/streaming/src/test/scala/spark/streaming/FailureSuite.scala +++ b/streaming/src/test/scala/spark/streaming/FailureSuite.scala @@ -22,10 +22,12 @@ class FailureSuite extends FunSuite with BeforeAndAfter with Logging { val batchDuration = Milliseconds(1000) before { + logInfo("BEFORE ...") FileUtils.deleteDirectory(new File(directory)) } after { + logInfo("AFTER ...") FileUtils.deleteDirectory(new File(directory)) } diff --git a/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala b/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala index 67dca2ac31..0acb6db6f2 100644 --- a/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala +++ b/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala @@ -41,6 +41,7 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { after { // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown System.clearProperty("spark.driver.port") + System.clearProperty("spark.hostPort") } diff --git a/streaming/src/test/scala/spark/streaming/WindowOperationsSuite.scala b/streaming/src/test/scala/spark/streaming/WindowOperationsSuite.scala index 1b66f3bda2..80d827706f 100644 --- a/streaming/src/test/scala/spark/streaming/WindowOperationsSuite.scala +++ b/streaming/src/test/scala/spark/streaming/WindowOperationsSuite.scala @@ -16,6 +16,7 @@ class WindowOperationsSuite extends TestSuiteBase { after { // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown System.clearProperty("spark.driver.port") + System.clearProperty("spark.hostPort") } val largerSlideInput = Seq( -- cgit v1.2.3 From 7fa6978a1e8822cf377fbb1e8a8d23adc4ebe12e Mon Sep 17 00:00:00 2001 From: Mridul Muralidharan Date: Sun, 28 Apr 2013 23:08:10 +0530 Subject: Allow CheckpointWriter pending tasks to finish --- streaming/src/main/scala/spark/streaming/Checkpoint.scala | 13 +++++++------ streaming/src/main/scala/spark/streaming/DStreamGraph.scala | 2 +- 2 files changed, 8 insertions(+), 7 deletions(-) (limited to 'streaming') diff --git a/streaming/src/main/scala/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/spark/streaming/Checkpoint.scala index 7bd104b8d5..4bbad908d0 100644 --- a/streaming/src/main/scala/spark/streaming/Checkpoint.scala +++ b/streaming/src/main/scala/spark/streaming/Checkpoint.scala @@ -42,7 +42,7 @@ class CheckpointWriter(checkpointDir: String) extends Logging { private val writeFile = new Path(file.getParent, file.getName + ".next") private val bakFile = new Path(file.getParent, file.getName + ".bk") - @volatile private var stopped = false + private var stopped = false val conf = new Configuration() var fs = file.getFileSystem(conf) @@ -57,10 +57,6 @@ class CheckpointWriter(checkpointDir: String) extends Logging { var attempts = 0 val startTime = System.currentTimeMillis() while (attempts < maxAttempts) { - if (stopped) { - logInfo("Already stopped, ignore checkpoint attempt for " + file) - return - } attempts += 1 try { logDebug("Saving checkpoint for time " + checkpointTime + " to file '" + file + "'") @@ -99,8 +95,13 @@ class CheckpointWriter(checkpointDir: String) extends Logging { } def stop() { - stopped = true + synchronized { + if (stopped) return ; + stopped = true + } executor.shutdown() + val terminated = executor.awaitTermination(10, java.util.concurrent.TimeUnit.SECONDS) + logInfo("CheckpointWriter executor terminated ? " + terminated) } } diff --git a/streaming/src/main/scala/spark/streaming/DStreamGraph.scala b/streaming/src/main/scala/spark/streaming/DStreamGraph.scala index adb7f3a24d..3b331956f5 100644 --- a/streaming/src/main/scala/spark/streaming/DStreamGraph.scala +++ b/streaming/src/main/scala/spark/streaming/DStreamGraph.scala @@ -54,8 +54,8 @@ final private[streaming] class DStreamGraph extends Serializable with Logging { throw new Exception("Batch duration already set as " + batchDuration + ". cannot set it again.") } + batchDuration = duration } - batchDuration = duration } def remember(duration: Duration) { -- cgit v1.2.3 From 3a89a76b874298853cf47510ab33e863abf117d7 Mon Sep 17 00:00:00 2001 From: Mridul Muralidharan Date: Mon, 29 Apr 2013 00:04:12 +0530 Subject: Make log message more descriptive to aid in debugging --- streaming/src/main/scala/spark/streaming/Checkpoint.scala | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) (limited to 'streaming') diff --git a/streaming/src/main/scala/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/spark/streaming/Checkpoint.scala index 4bbad908d0..66e67cbfa1 100644 --- a/streaming/src/main/scala/spark/streaming/Checkpoint.scala +++ b/streaming/src/main/scala/spark/streaming/Checkpoint.scala @@ -100,8 +100,10 @@ class CheckpointWriter(checkpointDir: String) extends Logging { stopped = true } executor.shutdown() + val startTime = System.currentTimeMillis() val terminated = executor.awaitTermination(10, java.util.concurrent.TimeUnit.SECONDS) - logInfo("CheckpointWriter executor terminated ? " + terminated) + val endTime = System.currentTimeMillis() + logInfo("CheckpointWriter executor terminated ? " + terminated + ", waited for " + (endTime - startTime) + " ms.") } } -- cgit v1.2.3 From 430c531464a5372237c97394f8f4b6ec344394c0 Mon Sep 17 00:00:00 2001 From: Mridul Muralidharan Date: Mon, 29 Apr 2013 00:24:30 +0530 Subject: Remove debug statements --- streaming/src/main/scala/spark/streaming/util/MasterFailureTest.scala | 1 - streaming/src/test/scala/spark/streaming/FailureSuite.scala | 2 -- 2 files changed, 3 deletions(-) (limited to 'streaming') diff --git a/streaming/src/main/scala/spark/streaming/util/MasterFailureTest.scala b/streaming/src/main/scala/spark/streaming/util/MasterFailureTest.scala index e7a3f92bc0..426a9b6f71 100644 --- a/streaming/src/main/scala/spark/streaming/util/MasterFailureTest.scala +++ b/streaming/src/main/scala/spark/streaming/util/MasterFailureTest.scala @@ -74,7 +74,6 @@ object MasterFailureTest extends Logging { val operation = (st: DStream[String]) => { val updateFunc = (values: Seq[Long], state: Option[Long]) => { - logInfo("UpdateFunc .. state = " + state.getOrElse(0L) + ", values = " + values) Some(values.foldLeft(0L)(_ + _) + state.getOrElse(0L)) } st.flatMap(_.split(" ")) diff --git a/streaming/src/test/scala/spark/streaming/FailureSuite.scala b/streaming/src/test/scala/spark/streaming/FailureSuite.scala index 4529e774e9..a5fa7ab92d 100644 --- a/streaming/src/test/scala/spark/streaming/FailureSuite.scala +++ b/streaming/src/test/scala/spark/streaming/FailureSuite.scala @@ -22,12 +22,10 @@ class FailureSuite extends FunSuite with BeforeAndAfter with Logging { val batchDuration = Milliseconds(1000) before { - logInfo("BEFORE ...") FileUtils.deleteDirectory(new File(directory)) } after { - logInfo("AFTER ...") FileUtils.deleteDirectory(new File(directory)) } -- cgit v1.2.3 From d761e7359deb7ca864d33b8f2e4380b57448630b Mon Sep 17 00:00:00 2001 From: seanm Date: Fri, 10 May 2013 12:05:10 -0600 Subject: adding kafkaStream API tests --- streaming/src/test/java/spark/streaming/JavaAPISuite.java | 4 ++-- .../src/test/scala/spark/streaming/InputStreamsSuite.scala | 11 +++++++++++ 2 files changed, 13 insertions(+), 2 deletions(-) (limited to 'streaming') diff --git a/streaming/src/test/java/spark/streaming/JavaAPISuite.java b/streaming/src/test/java/spark/streaming/JavaAPISuite.java index 61e4c0a207..350d0888a3 100644 --- a/streaming/src/test/java/spark/streaming/JavaAPISuite.java +++ b/streaming/src/test/java/spark/streaming/JavaAPISuite.java @@ -4,6 +4,7 @@ import com.google.common.base.Optional; import com.google.common.collect.Lists; import com.google.common.collect.Maps; import com.google.common.io.Files; +import kafka.serializer.StringDecoder; import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat; import org.junit.After; import org.junit.Assert; @@ -1203,8 +1204,7 @@ public class JavaAPISuite implements Serializable { public void testKafkaStream() { HashMap topics = Maps.newHashMap(); JavaDStream test1 = ssc.kafkaStream("localhost:12345", "group", topics); - JavaDStream test2 = ssc.kafkaStream("localhost:12345", "group", topics); - JavaDStream test3 = ssc.kafkaStream("localhost:12345", "group", topics, + JavaDStream test2 = ssc.kafkaStream("localhost:12345", "group", topics, StorageLevel.MEMORY_AND_DISK()); } diff --git a/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala b/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala index 1024d3ac97..595c766a21 100644 --- a/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala +++ b/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala @@ -240,6 +240,17 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { assert(output(i) === expectedOutput(i)) } } + + test("kafka input stream") { + val ssc = new StreamingContext(master, framework, batchDuration) + val topics = Map("my-topic" -> 1) + val test1 = ssc.kafkaStream("localhost:12345", "group", topics) + val test2 = ssc.kafkaStream("localhost:12345", "group", topics, StorageLevel.MEMORY_AND_DISK) + + // Test specifying decoder + val kafkaParams = Map("zk.connect"->"localhost:12345","groupid"->"consumer-group") + val test3 = ssc.kafkaStream[String, kafka.serializer.StringDecoder](kafkaParams, topics, StorageLevel.MEMORY_AND_DISK) + } } -- cgit v1.2.3 From b95c1bdbbaeea86152e24b394a03bbbad95989d5 Mon Sep 17 00:00:00 2001 From: seanm Date: Fri, 10 May 2013 12:47:24 -0600 Subject: count() now uses a transform instead of ConstantInputDStream --- streaming/src/main/scala/spark/streaming/DStream.scala | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) (limited to 'streaming') diff --git a/streaming/src/main/scala/spark/streaming/DStream.scala b/streaming/src/main/scala/spark/streaming/DStream.scala index e3a9247924..e125310861 100644 --- a/streaming/src/main/scala/spark/streaming/DStream.scala +++ b/streaming/src/main/scala/spark/streaming/DStream.scala @@ -441,10 +441,7 @@ abstract class DStream[T: ClassManifest] ( * Return a new DStream in which each RDD has a single element generated by counting each RDD * of this DStream. */ - def count(): DStream[Long] = { - val zero = new ConstantInputDStream(context, context.sparkContext.makeRDD(Seq((null, 0L)), 1)) - this.map(_ => (null, 1L)).union(zero).reduceByKey(_ + _).map(_._2) - } + def count(): DStream[Long] = this.map(_ => (null, 1L)).transform(_.union(context.sparkContext.makeRDD(Seq((null, 0L)), 1))).reduceByKey(_ + _).map(_._2) /** * Return a new DStream in which each RDD contains the counts of each distinct value in -- cgit v1.2.3 From 3632980b1b61dbb9ab9a3ab3d92fb415cb7173b9 Mon Sep 17 00:00:00 2001 From: seanm Date: Fri, 10 May 2013 15:54:26 -0600 Subject: fixing indentation --- .../src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'streaming') diff --git a/streaming/src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala b/streaming/src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala index 13427873ff..4ad2bdf8a8 100644 --- a/streaming/src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala +++ b/streaming/src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala @@ -105,7 +105,7 @@ class JavaStreamingContext(val ssc: StreamingContext) { * @param storageLevel RDD storage level. Defaults to memory-only */ def kafkaStream[T, D <: kafka.serializer.Decoder[_]: Manifest]( - kafkaParams: JMap[String, String], + kafkaParams: JMap[String, String], topics: JMap[String, JInt], storageLevel: StorageLevel) : JavaDStream[T] = { -- cgit v1.2.3 From f25282def5826fab6caabff28c82c57a7f3fdcb8 Mon Sep 17 00:00:00 2001 From: seanm Date: Fri, 10 May 2013 17:34:28 -0600 Subject: fixing kafkaStream Java API and adding test --- .../scala/spark/streaming/api/java/JavaStreamingContext.scala | 10 +++++++--- streaming/src/test/java/spark/streaming/JavaAPISuite.java | 6 ++++++ 2 files changed, 13 insertions(+), 3 deletions(-) (limited to 'streaming') diff --git a/streaming/src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala b/streaming/src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala index 4ad2bdf8a8..b35d9032f1 100644 --- a/streaming/src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala +++ b/streaming/src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala @@ -99,18 +99,22 @@ class JavaStreamingContext(val ssc: StreamingContext) { /** * Create an input stream that pulls messages form a Kafka Broker. + * @param typeClass Type of RDD + * @param decoderClass Type of kafka decoder * @param kafkaParams Map of kafka configuration paramaters. See: http://kafka.apache.org/configuration.html * @param topics Map of (topic_name -> numPartitions) to consume. Each partition is consumed * in its own thread. * @param storageLevel RDD storage level. Defaults to memory-only */ - def kafkaStream[T, D <: kafka.serializer.Decoder[_]: Manifest]( + def kafkaStream[T, D <: kafka.serializer.Decoder[_]]( + typeClass: Class[T], + decoderClass: Class[D], kafkaParams: JMap[String, String], topics: JMap[String, JInt], storageLevel: StorageLevel) : JavaDStream[T] = { - implicit val cmt: ClassManifest[T] = - implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]] + implicit val cmt: ClassManifest[T] = implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]] + implicit val cmd: Manifest[D] = implicitly[Manifest[AnyRef]].asInstanceOf[Manifest[D]] ssc.kafkaStream[T, D]( kafkaParams.toMap, Map(topics.mapValues(_.intValue()).toSeq: _*), diff --git a/streaming/src/test/java/spark/streaming/JavaAPISuite.java b/streaming/src/test/java/spark/streaming/JavaAPISuite.java index 350d0888a3..e5fdbe1b7a 100644 --- a/streaming/src/test/java/spark/streaming/JavaAPISuite.java +++ b/streaming/src/test/java/spark/streaming/JavaAPISuite.java @@ -1206,6 +1206,12 @@ public class JavaAPISuite implements Serializable { JavaDStream test1 = ssc.kafkaStream("localhost:12345", "group", topics); JavaDStream test2 = ssc.kafkaStream("localhost:12345", "group", topics, StorageLevel.MEMORY_AND_DISK()); + + HashMap kafkaParams = Maps.newHashMap(); + kafkaParams.put("zk.connect","localhost:12345"); + kafkaParams.put("groupid","consumer-group"); + JavaDStream test3 = ssc.kafkaStream(String.class, StringDecoder.class, kafkaParams, topics, + StorageLevel.MEMORY_AND_DISK()); } @Test -- cgit v1.2.3 From e7982c798efccd523165d0e347c7912ba14fcdd7 Mon Sep 17 00:00:00 2001 From: Jey Kottalam Date: Sat, 18 May 2013 16:11:29 -0700 Subject: Exclude old versions of Netty from Maven-based build --- pom.xml | 6 ++++++ streaming/pom.xml | 6 ++++++ 2 files changed, 12 insertions(+) (limited to 'streaming') diff --git a/pom.xml b/pom.xml index eda18fdd12..6ee64d07c2 100644 --- a/pom.xml +++ b/pom.xml @@ -565,6 +565,12 @@ org.apache.avro avro-ipc 1.7.1.cloudera.2 + + + org.jboss.netty + netty + + diff --git a/streaming/pom.xml b/streaming/pom.xml index 08ff3e2ae1..4dc9a19d51 100644 --- a/streaming/pom.xml +++ b/streaming/pom.xml @@ -41,6 +41,12 @@ org.apache.flume flume-ng-sdk 1.2.0 + + + org.jboss.netty + netty + + com.github.sgroschupf -- cgit v1.2.3 From 93a1643405d7c1a1fffe8210130341f34d64ea72 Mon Sep 17 00:00:00 2001 From: James Phillpotts Date: Fri, 21 Jun 2013 14:21:52 +0100 Subject: Allow other twitter authorizations than username/password --- .../src/main/scala/spark/streaming/StreamingContext.scala | 15 ++++++++++++++- .../spark/streaming/dstream/TwitterInputDStream.scala | 14 ++++++-------- 2 files changed, 20 insertions(+), 9 deletions(-) (limited to 'streaming') diff --git a/streaming/src/main/scala/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/spark/streaming/StreamingContext.scala index b8b60aab43..f97e47ada0 100644 --- a/streaming/src/main/scala/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/spark/streaming/StreamingContext.scala @@ -27,6 +27,7 @@ import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat} import org.apache.hadoop.mapreduce.lib.input.TextInputFormat import org.apache.hadoop.fs.Path import twitter4j.Status +import twitter4j.auth.{Authorization, BasicAuthorization} /** * A StreamingContext is the main entry point for Spark Streaming functionality. Besides the basic @@ -372,8 +373,20 @@ class StreamingContext private ( password: String, filters: Seq[String] = Nil, storageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK_SER_2 + ): DStream[Status] = twitterStream(new BasicAuthorization(username, password), filters, storageLevel) + + /** + * Create a input stream that returns tweets received from Twitter. + * @param twitterAuth Twitter4J authentication + * @param filters Set of filter strings to get only those tweets that match them + * @param storageLevel Storage level to use for storing the received objects + */ + def twitterStream( + twitterAuth: Authorization, + filters: Seq[String] = Nil, + storageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK_SER_2 ): DStream[Status] = { - val inputStream = new TwitterInputDStream(this, username, password, filters, storageLevel) + val inputStream = new TwitterInputDStream(this, twitterAuth, filters, storageLevel) registerInputStream(inputStream) inputStream } diff --git a/streaming/src/main/scala/spark/streaming/dstream/TwitterInputDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/TwitterInputDStream.scala index c697498862..0b01091a52 100644 --- a/streaming/src/main/scala/spark/streaming/dstream/TwitterInputDStream.scala +++ b/streaming/src/main/scala/spark/streaming/dstream/TwitterInputDStream.scala @@ -6,6 +6,7 @@ import storage.StorageLevel import twitter4j._ import twitter4j.auth.BasicAuthorization +import twitter4j.auth.Authorization /* A stream of Twitter statuses, potentially filtered by one or more keywords. * @@ -16,21 +17,19 @@ import twitter4j.auth.BasicAuthorization private[streaming] class TwitterInputDStream( @transient ssc_ : StreamingContext, - username: String, - password: String, + twitterAuth: Authorization, filters: Seq[String], storageLevel: StorageLevel ) extends NetworkInputDStream[Status](ssc_) { - + override def getReceiver(): NetworkReceiver[Status] = { - new TwitterReceiver(username, password, filters, storageLevel) + new TwitterReceiver(twitterAuth, filters, storageLevel) } } private[streaming] class TwitterReceiver( - username: String, - password: String, + twitterAuth: Authorization, filters: Seq[String], storageLevel: StorageLevel ) extends NetworkReceiver[Status] { @@ -40,8 +39,7 @@ class TwitterReceiver( protected override def onStart() { blockGenerator.start() - twitterStream = new TwitterStreamFactory() - .getInstance(new BasicAuthorization(username, password)) + twitterStream = new TwitterStreamFactory().getInstance(twitterAuth) twitterStream.addListener(new StatusListener { def onStatus(status: Status) = { blockGenerator += status -- cgit v1.2.3 From 8955787a596216a35ad4ec52b57331aa40444bef Mon Sep 17 00:00:00 2001 From: James Phillpotts Date: Mon, 24 Jun 2013 09:15:17 +0100 Subject: Twitter API v1 is retired - username/password auth no longer possible --- .../main/scala/spark/streaming/StreamingContext.scala | 16 +--------------- 1 file changed, 1 insertion(+), 15 deletions(-) (limited to 'streaming') diff --git a/streaming/src/main/scala/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/spark/streaming/StreamingContext.scala index f97e47ada0..05be6bd58a 100644 --- a/streaming/src/main/scala/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/spark/streaming/StreamingContext.scala @@ -27,7 +27,7 @@ import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat} import org.apache.hadoop.mapreduce.lib.input.TextInputFormat import org.apache.hadoop.fs.Path import twitter4j.Status -import twitter4j.auth.{Authorization, BasicAuthorization} +import twitter4j.auth.Authorization /** * A StreamingContext is the main entry point for Spark Streaming functionality. Besides the basic @@ -361,20 +361,6 @@ class StreamingContext private ( fileStream[LongWritable, Text, TextInputFormat](directory).map(_._2.toString) } - /** - * Create a input stream that returns tweets received from Twitter. - * @param username Twitter username - * @param password Twitter password - * @param filters Set of filter strings to get only those tweets that match them - * @param storageLevel Storage level to use for storing the received objects - */ - def twitterStream( - username: String, - password: String, - filters: Seq[String] = Nil, - storageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK_SER_2 - ): DStream[Status] = twitterStream(new BasicAuthorization(username, password), filters, storageLevel) - /** * Create a input stream that returns tweets received from Twitter. * @param twitterAuth Twitter4J authentication -- cgit v1.2.3 From 48c7e373c62b2e8cf48157ceb0d92c38c3a40652 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Mon, 24 Jun 2013 23:11:04 -0700 Subject: Minor formatting fixes --- .../src/main/scala/spark/streaming/DStream.scala | 9 +++++-- .../scala/spark/streaming/StreamingContext.scala | 29 +++++++++++++--------- .../streaming/api/java/JavaStreamingContext.scala | 15 +++++++---- 3 files changed, 34 insertions(+), 19 deletions(-) (limited to 'streaming') diff --git a/streaming/src/main/scala/spark/streaming/DStream.scala b/streaming/src/main/scala/spark/streaming/DStream.scala index e125310861..9be7926a4a 100644 --- a/streaming/src/main/scala/spark/streaming/DStream.scala +++ b/streaming/src/main/scala/spark/streaming/DStream.scala @@ -441,7 +441,12 @@ abstract class DStream[T: ClassManifest] ( * Return a new DStream in which each RDD has a single element generated by counting each RDD * of this DStream. */ - def count(): DStream[Long] = this.map(_ => (null, 1L)).transform(_.union(context.sparkContext.makeRDD(Seq((null, 0L)), 1))).reduceByKey(_ + _).map(_._2) + def count(): DStream[Long] = { + this.map(_ => (null, 1L)) + .transform(_.union(context.sparkContext.makeRDD(Seq((null, 0L)), 1))) + .reduceByKey(_ + _) + .map(_._2) + } /** * Return a new DStream in which each RDD contains the counts of each distinct value in @@ -457,7 +462,7 @@ abstract class DStream[T: ClassManifest] ( * this DStream will be registered as an output stream and therefore materialized. */ def foreach(foreachFunc: RDD[T] => Unit) { - foreach((r: RDD[T], t: Time) => foreachFunc(r)) + this.foreach((r: RDD[T], t: Time) => foreachFunc(r)) } /** diff --git a/streaming/src/main/scala/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/spark/streaming/StreamingContext.scala index 2c6326943d..03d2907323 100644 --- a/streaming/src/main/scala/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/spark/streaming/StreamingContext.scala @@ -171,10 +171,11 @@ class StreamingContext private ( * should be same. */ def actorStream[T: ClassManifest]( - props: Props, - name: String, - storageLevel: StorageLevel = StorageLevel.MEMORY_ONLY_SER_2, - supervisorStrategy: SupervisorStrategy = ReceiverSupervisorStrategy.defaultStrategy): DStream[T] = { + props: Props, + name: String, + storageLevel: StorageLevel = StorageLevel.MEMORY_ONLY_SER_2, + supervisorStrategy: SupervisorStrategy = ReceiverSupervisorStrategy.defaultStrategy + ): DStream[T] = { networkStream(new ActorReceiver[T](props, name, storageLevel, supervisorStrategy)) } @@ -182,9 +183,10 @@ class StreamingContext private ( * Create an input stream that receives messages pushed by a zeromq publisher. * @param publisherUrl Url of remote zeromq publisher * @param subscribe topic to subscribe to - * @param bytesToObjects A zeroMQ stream publishes sequence of frames for each topic and each frame has sequence - * of byte thus it needs the converter(which might be deserializer of bytes) - * to translate from sequence of sequence of bytes, where sequence refer to a frame + * @param bytesToObjects A zeroMQ stream publishes sequence of frames for each topic + * and each frame has sequence of byte thus it needs the converter + * (which might be deserializer of bytes) to translate from sequence + * of sequence of bytes, where sequence refer to a frame * and sub sequence refer to its payload. * @param storageLevel RDD storage level. Defaults to memory-only. */ @@ -204,7 +206,7 @@ class StreamingContext private ( * @param zkQuorum Zookeper quorum (hostname:port,hostname:port,..). * @param groupId The group id for this consumer. * @param topics Map of (topic_name -> numPartitions) to consume. Each partition is consumed - * in its own thread. + * in its own thread. * @param storageLevel Storage level to use for storing the received objects * (default: StorageLevel.MEMORY_AND_DISK_SER_2) */ @@ -214,15 +216,17 @@ class StreamingContext private ( topics: Map[String, Int], storageLevel: StorageLevel = StorageLevel.MEMORY_ONLY_SER_2 ): DStream[String] = { - val kafkaParams = Map[String, String]("zk.connect" -> zkQuorum, "groupid" -> groupId, "zk.connectiontimeout.ms" -> "10000"); + val kafkaParams = Map[String, String]( + "zk.connect" -> zkQuorum, "groupid" -> groupId, "zk.connectiontimeout.ms" -> "10000") kafkaStream[String, kafka.serializer.StringDecoder](kafkaParams, topics, storageLevel) } /** * Create an input stream that pulls messages from a Kafka Broker. - * @param kafkaParams Map of kafka configuration paramaters. See: http://kafka.apache.org/configuration.html + * @param kafkaParams Map of kafka configuration paramaters. + * See: http://kafka.apache.org/configuration.html * @param topics Map of (topic_name -> numPartitions) to consume. Each partition is consumed - * in its own thread. + * in its own thread. * @param storageLevel Storage level to use for storing the received objects */ def kafkaStream[T: ClassManifest, D <: kafka.serializer.Decoder[_]: Manifest]( @@ -395,7 +399,8 @@ class StreamingContext private ( * it will process either one or all of the RDDs returned by the queue. * @param queue Queue of RDDs * @param oneAtATime Whether only one RDD should be consumed from the queue in every interval - * @param defaultRDD Default RDD is returned by the DStream when the queue is empty. Set as null if no RDD should be returned when empty + * @param defaultRDD Default RDD is returned by the DStream when the queue is empty. + * Set as null if no RDD should be returned when empty * @tparam T Type of objects in the RDD */ def queueStream[T: ClassManifest]( diff --git a/streaming/src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala b/streaming/src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala index b35d9032f1..fd5e06b50f 100644 --- a/streaming/src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala +++ b/streaming/src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala @@ -75,7 +75,8 @@ class JavaStreamingContext(val ssc: StreamingContext) { : JavaDStream[String] = { implicit val cmt: ClassManifest[String] = implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[String]] - ssc.kafkaStream(zkQuorum, groupId, Map(topics.mapValues(_.intValue()).toSeq: _*), StorageLevel.MEMORY_ONLY_SER_2) + ssc.kafkaStream(zkQuorum, groupId, Map(topics.mapValues(_.intValue()).toSeq: _*), + StorageLevel.MEMORY_ONLY_SER_2) } /** @@ -83,8 +84,9 @@ class JavaStreamingContext(val ssc: StreamingContext) { * @param zkQuorum Zookeper quorum (hostname:port,hostname:port,..). * @param groupId The group id for this consumer. * @param topics Map of (topic_name -> numPartitions) to consume. Each partition is consumed + * in its own thread. * @param storageLevel RDD storage level. Defaults to memory-only - * in its own thread. + * */ def kafkaStream( zkQuorum: String, @@ -94,14 +96,16 @@ class JavaStreamingContext(val ssc: StreamingContext) { : JavaDStream[String] = { implicit val cmt: ClassManifest[String] = implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[String]] - ssc.kafkaStream(zkQuorum, groupId, Map(topics.mapValues(_.intValue()).toSeq: _*), storageLevel) + ssc.kafkaStream(zkQuorum, groupId, Map(topics.mapValues(_.intValue()).toSeq: _*), + storageLevel) } /** * Create an input stream that pulls messages form a Kafka Broker. * @param typeClass Type of RDD * @param decoderClass Type of kafka decoder - * @param kafkaParams Map of kafka configuration paramaters. See: http://kafka.apache.org/configuration.html + * @param kafkaParams Map of kafka configuration paramaters. + * See: http://kafka.apache.org/configuration.html * @param topics Map of (topic_name -> numPartitions) to consume. Each partition is consumed * in its own thread. * @param storageLevel RDD storage level. Defaults to memory-only @@ -113,7 +117,8 @@ class JavaStreamingContext(val ssc: StreamingContext) { topics: JMap[String, JInt], storageLevel: StorageLevel) : JavaDStream[T] = { - implicit val cmt: ClassManifest[T] = implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]] + implicit val cmt: ClassManifest[T] = + implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]] implicit val cmd: Manifest[D] = implicitly[Manifest[AnyRef]].asInstanceOf[Manifest[D]] ssc.kafkaStream[T, D]( kafkaParams.toMap, -- cgit v1.2.3 From 366572edcab87701fd795ca0142ac9829b312d36 Mon Sep 17 00:00:00 2001 From: James Phillpotts Date: Tue, 25 Jun 2013 22:59:34 +0100 Subject: Include a default OAuth implementation, and update examples and JavaStreamingContext --- .../streaming/examples/TwitterAlgebirdCMS.scala | 2 +- .../streaming/examples/TwitterAlgebirdHLL.scala | 2 +- .../streaming/examples/TwitterPopularTags.scala | 2 +- .../scala/spark/streaming/StreamingContext.scala | 2 +- .../streaming/api/java/JavaStreamingContext.scala | 69 +++++++++++++++------- .../streaming/dstream/TwitterInputDStream.scala | 32 +++++++++- 6 files changed, 81 insertions(+), 28 deletions(-) (limited to 'streaming') diff --git a/examples/src/main/scala/spark/streaming/examples/TwitterAlgebirdCMS.scala b/examples/src/main/scala/spark/streaming/examples/TwitterAlgebirdCMS.scala index a9642100e3..548190309e 100644 --- a/examples/src/main/scala/spark/streaming/examples/TwitterAlgebirdCMS.scala +++ b/examples/src/main/scala/spark/streaming/examples/TwitterAlgebirdCMS.scala @@ -45,7 +45,7 @@ object TwitterAlgebirdCMS { val ssc = new StreamingContext(master, "TwitterAlgebirdCMS", Seconds(10), System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR"))) - val stream = ssc.twitterStream(username, password, filters, StorageLevel.MEMORY_ONLY_SER) + val stream = ssc.twitterStream(None, filters, StorageLevel.MEMORY_ONLY_SER) val users = stream.map(status => status.getUser.getId) diff --git a/examples/src/main/scala/spark/streaming/examples/TwitterAlgebirdHLL.scala b/examples/src/main/scala/spark/streaming/examples/TwitterAlgebirdHLL.scala index f3288bfb85..5a86c6318d 100644 --- a/examples/src/main/scala/spark/streaming/examples/TwitterAlgebirdHLL.scala +++ b/examples/src/main/scala/spark/streaming/examples/TwitterAlgebirdHLL.scala @@ -34,7 +34,7 @@ object TwitterAlgebirdHLL { val ssc = new StreamingContext(master, "TwitterAlgebirdHLL", Seconds(5), System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR"))) - val stream = ssc.twitterStream(username, password, filters, StorageLevel.MEMORY_ONLY_SER) + val stream = ssc.twitterStream(None, filters, StorageLevel.MEMORY_ONLY_SER) val users = stream.map(status => status.getUser.getId) diff --git a/examples/src/main/scala/spark/streaming/examples/TwitterPopularTags.scala b/examples/src/main/scala/spark/streaming/examples/TwitterPopularTags.scala index 9d4494c6f2..076c3878c8 100644 --- a/examples/src/main/scala/spark/streaming/examples/TwitterPopularTags.scala +++ b/examples/src/main/scala/spark/streaming/examples/TwitterPopularTags.scala @@ -23,7 +23,7 @@ object TwitterPopularTags { val ssc = new StreamingContext(master, "TwitterPopularTags", Seconds(2), System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR"))) - val stream = ssc.twitterStream(username, password, filters) + val stream = ssc.twitterStream(None, filters) val hashTags = stream.flatMap(status => status.getText.split(" ").filter(_.startsWith("#"))) diff --git a/streaming/src/main/scala/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/spark/streaming/StreamingContext.scala index 05be6bd58a..0f36504c0d 100644 --- a/streaming/src/main/scala/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/spark/streaming/StreamingContext.scala @@ -368,7 +368,7 @@ class StreamingContext private ( * @param storageLevel Storage level to use for storing the received objects */ def twitterStream( - twitterAuth: Authorization, + twitterAuth: Option[Authorization] = None, filters: Seq[String] = Nil, storageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK_SER_2 ): DStream[Status] = { diff --git a/streaming/src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala b/streaming/src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala index 3d149a742c..85390ef57e 100644 --- a/streaming/src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala +++ b/streaming/src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala @@ -4,23 +4,18 @@ import spark.streaming._ import receivers.{ActorReceiver, ReceiverSupervisorStrategy} import spark.streaming.dstream._ import spark.storage.StorageLevel - import spark.api.java.function.{Function => JFunction, Function2 => JFunction2} import spark.api.java.{JavaSparkContext, JavaRDD} - import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat} - import twitter4j.Status - import akka.actor.Props import akka.actor.SupervisorStrategy import akka.zeromq.Subscribe - import scala.collection.JavaConversions._ - import java.lang.{Long => JLong, Integer => JInt} import java.io.InputStream import java.util.{Map => JMap} +import twitter4j.auth.Authorization /** * A StreamingContext is the main entry point for Spark Streaming functionality. Besides the basic @@ -315,46 +310,78 @@ class JavaStreamingContext(val ssc: StreamingContext) { /** * Create a input stream that returns tweets received from Twitter. - * @param username Twitter username - * @param password Twitter password + * @param twitterAuth Twitter4J Authorization * @param filters Set of filter strings to get only those tweets that match them * @param storageLevel Storage level to use for storing the received objects */ def twitterStream( - username: String, - password: String, + twitterAuth: Authorization, filters: Array[String], storageLevel: StorageLevel ): JavaDStream[Status] = { - ssc.twitterStream(username, password, filters, storageLevel) + ssc.twitterStream(Some(twitterAuth), filters, storageLevel) + } + + /** + * Create a input stream that returns tweets received from Twitter using + * java.util.Preferences to store OAuth token. OAuth key and secret should + * be provided using system properties twitter4j.oauth.consumerKey and + * twitter4j.oauth.consumerSecret + * @param filters Set of filter strings to get only those tweets that match them + * @param storageLevel Storage level to use for storing the received objects + */ + def twitterStream( + filters: Array[String], + storageLevel: StorageLevel + ): JavaDStream[Status] = { + ssc.twitterStream(None, filters, storageLevel) } /** * Create a input stream that returns tweets received from Twitter. - * @param username Twitter username - * @param password Twitter password + * @param twitterAuth Twitter4J Authorization * @param filters Set of filter strings to get only those tweets that match them */ def twitterStream( - username: String, - password: String, + twitterAuth: Authorization, filters: Array[String] ): JavaDStream[Status] = { - ssc.twitterStream(username, password, filters) + ssc.twitterStream(Some(twitterAuth), filters) + } + + /** + * Create a input stream that returns tweets received from Twitter using + * java.util.Preferences to store OAuth token. OAuth key and secret should + * be provided using system properties twitter4j.oauth.consumerKey and + * twitter4j.oauth.consumerSecret + * @param filters Set of filter strings to get only those tweets that match them + */ + def twitterStream( + filters: Array[String] + ): JavaDStream[Status] = { + ssc.twitterStream(None, filters) } /** * Create a input stream that returns tweets received from Twitter. - * @param username Twitter username - * @param password Twitter password + * @param twitterAuth Twitter4J Authorization */ def twitterStream( - username: String, - password: String + twitterAuth: Authorization ): JavaDStream[Status] = { - ssc.twitterStream(username, password) + ssc.twitterStream(Some(twitterAuth)) } + /** + * Create a input stream that returns tweets received from Twitter using + * java.util.Preferences to store OAuth token. OAuth key and secret should + * be provided using system properties twitter4j.oauth.consumerKey and + * twitter4j.oauth.consumerSecret + */ + def twitterStream(): JavaDStream[Status] = { + ssc.twitterStream() + } + /** * Create an input stream with any arbitrary user implemented actor receiver. * @param props Props object defining creation of the actor diff --git a/streaming/src/main/scala/spark/streaming/dstream/TwitterInputDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/TwitterInputDStream.scala index 0b01091a52..e0c654d385 100644 --- a/streaming/src/main/scala/spark/streaming/dstream/TwitterInputDStream.scala +++ b/streaming/src/main/scala/spark/streaming/dstream/TwitterInputDStream.scala @@ -3,27 +3,53 @@ package spark.streaming.dstream import spark._ import spark.streaming._ import storage.StorageLevel - import twitter4j._ import twitter4j.auth.BasicAuthorization import twitter4j.auth.Authorization +import java.util.prefs.Preferences +import twitter4j.conf.PropertyConfiguration +import twitter4j.auth.OAuthAuthorization +import twitter4j.auth.AccessToken /* A stream of Twitter statuses, potentially filtered by one or more keywords. * * @constructor create a new Twitter stream using the supplied username and password to authenticate. * An optional set of string filters can be used to restrict the set of tweets. The Twitter API is * such that this may return a sampled subset of all tweets during each interval. +* +* Includes a simple implementation of OAuth using consumer key and secret provided using system +* properties twitter4j.oauth.consumerKey and twitter4j.oauth.consumerSecret */ private[streaming] class TwitterInputDStream( @transient ssc_ : StreamingContext, - twitterAuth: Authorization, + twitterAuth: Option[Authorization], filters: Seq[String], storageLevel: StorageLevel ) extends NetworkInputDStream[Status](ssc_) { + lazy val createOAuthAuthorization: Authorization = { + val userRoot = Preferences.userRoot(); + val token = Option(userRoot.get(PropertyConfiguration.OAUTH_ACCESS_TOKEN, null)) + val tokenSecret = Option(userRoot.get(PropertyConfiguration.OAUTH_ACCESS_TOKEN_SECRET, null)) + val oAuth = new OAuthAuthorization(new PropertyConfiguration(System.getProperties())) + if (token.isEmpty || tokenSecret.isEmpty) { + val requestToken = oAuth.getOAuthRequestToken() + println("Authorize application using URL: "+requestToken.getAuthorizationURL()) + println("Enter PIN: ") + val pin = Console.readLine + val accessToken = if (pin.length() > 0) oAuth.getOAuthAccessToken(requestToken, pin) else oAuth.getOAuthAccessToken() + userRoot.put(PropertyConfiguration.OAUTH_ACCESS_TOKEN, accessToken.getToken()) + userRoot.put(PropertyConfiguration.OAUTH_ACCESS_TOKEN_SECRET, accessToken.getTokenSecret()) + userRoot.flush() + } else { + oAuth.setOAuthAccessToken(new AccessToken(token.get, tokenSecret.get)); + } + oAuth + } + override def getReceiver(): NetworkReceiver[Status] = { - new TwitterReceiver(twitterAuth, filters, storageLevel) + new TwitterReceiver(if (twitterAuth.isEmpty) createOAuthAuthorization else twitterAuth.get, filters, storageLevel) } } -- cgit v1.2.3 From 4358acfe07e991090fbe009aafe3f5110fbf0c40 Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Sat, 29 Jun 2013 15:25:06 -0700 Subject: Initialize Twitter4J OAuth from system properties instead of prompting --- .../scala/spark/streaming/StreamingContext.scala | 4 ++- .../streaming/api/java/JavaStreamingContext.scala | 23 +++++++--------- .../streaming/dstream/TwitterInputDStream.scala | 32 ++++++---------------- .../test/java/spark/streaming/JavaAPISuite.java | 2 +- 4 files changed, 23 insertions(+), 38 deletions(-) (limited to 'streaming') diff --git a/streaming/src/main/scala/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/spark/streaming/StreamingContext.scala index e61438fe3a..36b841af8f 100644 --- a/streaming/src/main/scala/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/spark/streaming/StreamingContext.scala @@ -381,7 +381,9 @@ class StreamingContext private ( /** * Create a input stream that returns tweets received from Twitter. - * @param twitterAuth Twitter4J authentication + * @param twitterAuth Twitter4J authentication, or None to use Twitter4J's default OAuth + * authorization; this uses the system properties twitter4j.oauth.consumerKey, + * .consumerSecret, .accessToken and .accessTokenSecret. * @param filters Set of filter strings to get only those tweets that match them * @param storageLevel Storage level to use for storing the received objects */ diff --git a/streaming/src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala b/streaming/src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala index c4a223b419..ed7b789d98 100644 --- a/streaming/src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala +++ b/streaming/src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala @@ -307,7 +307,7 @@ class JavaStreamingContext(val ssc: StreamingContext) { /** * Create a input stream that returns tweets received from Twitter. - * @param twitterAuth Twitter4J Authorization + * @param twitterAuth Twitter4J Authorization object * @param filters Set of filter strings to get only those tweets that match them * @param storageLevel Storage level to use for storing the received objects */ @@ -320,10 +320,9 @@ class JavaStreamingContext(val ssc: StreamingContext) { } /** - * Create a input stream that returns tweets received from Twitter using - * java.util.Preferences to store OAuth token. OAuth key and secret should - * be provided using system properties twitter4j.oauth.consumerKey and - * twitter4j.oauth.consumerSecret + * Create a input stream that returns tweets received from Twitter using Twitter4J's default + * OAuth authentication; this requires the system properties twitter4j.oauth.consumerKey, + * .consumerSecret, .accessToken and .accessTokenSecret to be set. * @param filters Set of filter strings to get only those tweets that match them * @param storageLevel Storage level to use for storing the received objects */ @@ -347,10 +346,9 @@ class JavaStreamingContext(val ssc: StreamingContext) { } /** - * Create a input stream that returns tweets received from Twitter using - * java.util.Preferences to store OAuth token. OAuth key and secret should - * be provided using system properties twitter4j.oauth.consumerKey and - * twitter4j.oauth.consumerSecret + * Create a input stream that returns tweets received from Twitter using Twitter4J's default + * OAuth authentication; this requires the system properties twitter4j.oauth.consumerKey, + * .consumerSecret, .accessToken and .accessTokenSecret to be set. * @param filters Set of filter strings to get only those tweets that match them */ def twitterStream( @@ -370,10 +368,9 @@ class JavaStreamingContext(val ssc: StreamingContext) { } /** - * Create a input stream that returns tweets received from Twitter using - * java.util.Preferences to store OAuth token. OAuth key and secret should - * be provided using system properties twitter4j.oauth.consumerKey and - * twitter4j.oauth.consumerSecret + * Create a input stream that returns tweets received from Twitter using Twitter4J's default + * OAuth authentication; this requires the system properties twitter4j.oauth.consumerKey, + * .consumerSecret, .accessToken and .accessTokenSecret to be set. */ def twitterStream(): JavaDStream[Status] = { ssc.twitterStream() diff --git a/streaming/src/main/scala/spark/streaming/dstream/TwitterInputDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/TwitterInputDStream.scala index e0c654d385..ff7a58be45 100644 --- a/streaming/src/main/scala/spark/streaming/dstream/TwitterInputDStream.scala +++ b/streaming/src/main/scala/spark/streaming/dstream/TwitterInputDStream.scala @@ -4,21 +4,21 @@ import spark._ import spark.streaming._ import storage.StorageLevel import twitter4j._ -import twitter4j.auth.BasicAuthorization import twitter4j.auth.Authorization import java.util.prefs.Preferences +import twitter4j.conf.ConfigurationBuilder import twitter4j.conf.PropertyConfiguration import twitter4j.auth.OAuthAuthorization import twitter4j.auth.AccessToken /* A stream of Twitter statuses, potentially filtered by one or more keywords. * -* @constructor create a new Twitter stream using the supplied username and password to authenticate. +* @constructor create a new Twitter stream using the supplied Twitter4J authentication credentials. * An optional set of string filters can be used to restrict the set of tweets. The Twitter API is * such that this may return a sampled subset of all tweets during each interval. * -* Includes a simple implementation of OAuth using consumer key and secret provided using system -* properties twitter4j.oauth.consumerKey and twitter4j.oauth.consumerSecret +* If no Authorization object is provided, initializes OAuth authorization using the system +* properties twitter4j.oauth.consumerKey, .consumerSecret, .accessToken and .accessTokenSecret. */ private[streaming] class TwitterInputDStream( @@ -28,28 +28,14 @@ class TwitterInputDStream( storageLevel: StorageLevel ) extends NetworkInputDStream[Status](ssc_) { - lazy val createOAuthAuthorization: Authorization = { - val userRoot = Preferences.userRoot(); - val token = Option(userRoot.get(PropertyConfiguration.OAUTH_ACCESS_TOKEN, null)) - val tokenSecret = Option(userRoot.get(PropertyConfiguration.OAUTH_ACCESS_TOKEN_SECRET, null)) - val oAuth = new OAuthAuthorization(new PropertyConfiguration(System.getProperties())) - if (token.isEmpty || tokenSecret.isEmpty) { - val requestToken = oAuth.getOAuthRequestToken() - println("Authorize application using URL: "+requestToken.getAuthorizationURL()) - println("Enter PIN: ") - val pin = Console.readLine - val accessToken = if (pin.length() > 0) oAuth.getOAuthAccessToken(requestToken, pin) else oAuth.getOAuthAccessToken() - userRoot.put(PropertyConfiguration.OAUTH_ACCESS_TOKEN, accessToken.getToken()) - userRoot.put(PropertyConfiguration.OAUTH_ACCESS_TOKEN_SECRET, accessToken.getTokenSecret()) - userRoot.flush() - } else { - oAuth.setOAuthAccessToken(new AccessToken(token.get, tokenSecret.get)); - } - oAuth + private def createOAuthAuthorization(): Authorization = { + new OAuthAuthorization(new ConfigurationBuilder().build()) } + + private val authorization = twitterAuth.getOrElse(createOAuthAuthorization()) override def getReceiver(): NetworkReceiver[Status] = { - new TwitterReceiver(if (twitterAuth.isEmpty) createOAuthAuthorization else twitterAuth.get, filters, storageLevel) + new TwitterReceiver(authorization, filters, storageLevel) } } diff --git a/streaming/src/test/java/spark/streaming/JavaAPISuite.java b/streaming/src/test/java/spark/streaming/JavaAPISuite.java index e5fdbe1b7a..4cf10582a9 100644 --- a/streaming/src/test/java/spark/streaming/JavaAPISuite.java +++ b/streaming/src/test/java/spark/streaming/JavaAPISuite.java @@ -1267,7 +1267,7 @@ public class JavaAPISuite implements Serializable { @Test public void testTwitterStream() { String[] filters = new String[] { "good", "bad", "ugly" }; - JavaDStream test = ssc.twitterStream("username", "password", filters, StorageLevel.MEMORY_ONLY()); + JavaDStream test = ssc.twitterStream(filters, StorageLevel.MEMORY_ONLY()); } @Test -- cgit v1.2.3