From 46eecd110a4017ea0c86cbb1010d0ccd6a5eb2ef Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Sat, 31 Aug 2013 19:27:07 -0700 Subject: Initial work to rename package to org.apache.spark --- .../org/apache/spark/streaming/Checkpoint.scala | 190 +++ .../scala/org/apache/spark/streaming/DStream.scala | 702 +++++++++++ .../spark/streaming/DStreamCheckpointData.scala | 110 ++ .../org/apache/spark/streaming/DStreamGraph.scala | 167 +++ .../org/apache/spark/streaming/Duration.scala | 83 ++ .../org/apache/spark/streaming/Interval.scala | 59 + .../scala/org/apache/spark/streaming/Job.scala | 41 + .../org/apache/spark/streaming/JobManager.scala | 88 ++ .../spark/streaming/NetworkInputTracker.scala | 173 +++ .../spark/streaming/PairDStreamFunctions.scala | 534 ++++++++ .../org/apache/spark/streaming/Scheduler.scala | 131 ++ .../apache/spark/streaming/StreamingContext.scala | 563 +++++++++ .../scala/org/apache/spark/streaming/Time.scala | 72 ++ .../spark/streaming/api/java/JavaDStream.scala | 102 ++ .../spark/streaming/api/java/JavaDStreamLike.scala | 316 +++++ .../spark/streaming/api/java/JavaPairDStream.scala | 613 +++++++++ .../streaming/api/java/JavaStreamingContext.scala | 614 +++++++++ .../spark/streaming/dstream/CoGroupedDStream.scala | 57 + .../streaming/dstream/ConstantInputDStream.scala | 36 + .../spark/streaming/dstream/FileInputDStream.scala | 199 +++ .../spark/streaming/dstream/FilteredDStream.scala | 38 + .../streaming/dstream/FlatMapValuedDStream.scala | 37 + .../streaming/dstream/FlatMappedDStream.scala | 37 + .../streaming/dstream/FlumeInputDStream.scala | 154 +++ .../spark/streaming/dstream/ForEachDStream.scala | 45 + .../spark/streaming/dstream/GlommedDStream.scala | 34 + .../spark/streaming/dstream/InputDStream.scala | 70 ++ .../streaming/dstream/KafkaInputDStream.scala | 141 +++ .../streaming/dstream/MapPartitionedDStream.scala | 38 + .../spark/streaming/dstream/MapValuedDStream.scala | 38 + .../spark/streaming/dstream/MappedDStream.scala | 37 + .../streaming/dstream/NetworkInputDStream.scala | 272 ++++ .../streaming/dstream/PluggableInputDStream.scala | 30 + .../streaming/dstream/QueueInputDStream.scala | 59 + .../spark/streaming/dstream/RawInputDStream.scala | 108 ++ .../streaming/dstream/ReducedWindowedDStream.scala | 174 +++ .../spark/streaming/dstream/ShuffledDStream.scala | 44 + .../streaming/dstream/SocketInputDStream.scala | 94 ++ .../spark/streaming/dstream/StateDStream.scala | 109 ++ .../streaming/dstream/TransformedDStream.scala | 36 + .../streaming/dstream/TwitterInputDStream.scala | 99 ++ .../spark/streaming/dstream/UnionDStream.scala | 57 + .../spark/streaming/dstream/WindowedDStream.scala | 57 + .../spark/streaming/receivers/ActorReceiver.scala | 175 +++ .../spark/streaming/receivers/ZeroMQReceiver.scala | 50 + .../org/apache/spark/streaming/util/Clock.scala | 101 ++ .../spark/streaming/util/MasterFailureTest.scala | 414 +++++++ .../spark/streaming/util/RawTextHelper.scala | 115 ++ .../spark/streaming/util/RawTextSender.scala | 77 ++ .../spark/streaming/util/RecurringTimer.scala | 94 ++ .../main/scala/spark/streaming/Checkpoint.scala | 190 --- .../src/main/scala/spark/streaming/DStream.scala | 700 ----------- .../spark/streaming/DStreamCheckpointData.scala | 110 -- .../main/scala/spark/streaming/DStreamGraph.scala | 167 --- .../src/main/scala/spark/streaming/Duration.scala | 83 -- .../src/main/scala/spark/streaming/Interval.scala | 59 - streaming/src/main/scala/spark/streaming/Job.scala | 41 - .../main/scala/spark/streaming/JobManager.scala | 88 -- .../spark/streaming/NetworkInputTracker.scala | 173 --- .../spark/streaming/PairDStreamFunctions.scala | 534 -------- .../src/main/scala/spark/streaming/Scheduler.scala | 130 -- .../scala/spark/streaming/StreamingContext.scala | 563 --------- .../src/main/scala/spark/streaming/Time.scala | 72 -- .../spark/streaming/api/java/JavaDStream.scala | 102 -- .../spark/streaming/api/java/JavaDStreamLike.scala | 316 ----- .../spark/streaming/api/java/JavaPairDStream.scala | 613 --------- .../streaming/api/java/JavaStreamingContext.scala | 613 --------- .../spark/streaming/dstream/CoGroupedDStream.scala | 57 - .../streaming/dstream/ConstantInputDStream.scala | 36 - .../spark/streaming/dstream/FileInputDStream.scala | 199 --- .../spark/streaming/dstream/FilteredDStream.scala | 38 - .../streaming/dstream/FlatMapValuedDStream.scala | 37 - .../streaming/dstream/FlatMappedDStream.scala | 37 - .../streaming/dstream/FlumeInputDStream.scala | 154 --- .../spark/streaming/dstream/ForEachDStream.scala | 45 - .../spark/streaming/dstream/GlommedDStream.scala | 34 - .../spark/streaming/dstream/InputDStream.scala | 70 -- .../streaming/dstream/KafkaInputDStream.scala | 141 --- .../streaming/dstream/MapPartitionedDStream.scala | 38 - .../spark/streaming/dstream/MapValuedDStream.scala | 38 - .../spark/streaming/dstream/MappedDStream.scala | 37 - .../streaming/dstream/NetworkInputDStream.scala | 272 ---- .../streaming/dstream/PluggableInputDStream.scala | 30 - .../streaming/dstream/QueueInputDStream.scala | 59 - .../spark/streaming/dstream/RawInputDStream.scala | 108 -- .../streaming/dstream/ReducedWindowedDStream.scala | 174 --- .../spark/streaming/dstream/ShuffledDStream.scala | 44 - .../streaming/dstream/SocketInputDStream.scala | 94 -- .../spark/streaming/dstream/StateDStream.scala | 109 -- .../streaming/dstream/TransformedDStream.scala | 36 - .../streaming/dstream/TwitterInputDStream.scala | 99 -- .../spark/streaming/dstream/UnionDStream.scala | 57 - .../spark/streaming/dstream/WindowedDStream.scala | 57 - .../spark/streaming/receivers/ActorReceiver.scala | 175 --- .../spark/streaming/receivers/ZeroMQReceiver.scala | 50 - .../main/scala/spark/streaming/util/Clock.scala | 101 -- .../spark/streaming/util/MasterFailureTest.scala | 414 ------- .../scala/spark/streaming/util/RawTextHelper.scala | 115 -- .../scala/spark/streaming/util/RawTextSender.scala | 77 -- .../spark/streaming/util/RecurringTimer.scala | 94 -- .../org/apache/spark/streaming/JavaAPISuite.java | 1304 ++++++++++++++++++++ .../org/apache/spark/streaming/JavaTestUtils.scala | 85 ++ .../test/java/spark/streaming/JavaAPISuite.java | 1304 -------------------- .../test/java/spark/streaming/JavaTestUtils.scala | 84 -- .../spark/streaming/BasicOperationsSuite.scala | 322 +++++ .../apache/spark/streaming/CheckpointSuite.scala | 372 ++++++ .../org/apache/spark/streaming/FailureSuite.scala | 57 + .../apache/spark/streaming/InputStreamsSuite.scala | 349 ++++++ .../org/apache/spark/streaming/TestSuiteBase.scala | 314 +++++ .../spark/streaming/WindowOperationsSuite.scala | 340 +++++ .../spark/streaming/BasicOperationsSuite.scala | 322 ----- .../scala/spark/streaming/CheckpointSuite.scala | 372 ------ .../test/scala/spark/streaming/FailureSuite.scala | 57 - .../scala/spark/streaming/InputStreamsSuite.scala | 349 ------ .../test/scala/spark/streaming/TestSuiteBase.scala | 314 ----- .../spark/streaming/WindowOperationsSuite.scala | 340 ----- 116 files changed, 10827 insertions(+), 10822 deletions(-) create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/DStream.scala create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/DStreamCheckpointData.scala create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/Duration.scala create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/Interval.scala create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/Job.scala create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/JobManager.scala create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/NetworkInputTracker.scala create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/PairDStreamFunctions.scala create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/Scheduler.scala create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/Time.scala create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStream.scala create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/dstream/CoGroupedDStream.scala create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/dstream/ConstantInputDStream.scala create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/dstream/FilteredDStream.scala create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/dstream/FlatMapValuedDStream.scala create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/dstream/FlatMappedDStream.scala create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/dstream/FlumeInputDStream.scala create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/dstream/ForEachDStream.scala create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/dstream/GlommedDStream.scala create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/dstream/KafkaInputDStream.scala create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/dstream/MapPartitionedDStream.scala create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/dstream/MapValuedDStream.scala create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/dstream/MappedDStream.scala create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/dstream/NetworkInputDStream.scala create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/dstream/PluggableInputDStream.scala create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/dstream/QueueInputDStream.scala create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/dstream/RawInputDStream.scala create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/dstream/ReducedWindowedDStream.scala create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/dstream/ShuffledDStream.scala create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/dstream/SocketInputDStream.scala create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/dstream/StateDStream.scala create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/dstream/TransformedDStream.scala create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/dstream/TwitterInputDStream.scala create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/dstream/UnionDStream.scala create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/dstream/WindowedDStream.scala create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/receivers/ActorReceiver.scala create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/receivers/ZeroMQReceiver.scala create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/util/Clock.scala create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/util/MasterFailureTest.scala create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/util/RawTextHelper.scala create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/util/RawTextSender.scala create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/util/RecurringTimer.scala delete mode 100644 streaming/src/main/scala/spark/streaming/Checkpoint.scala delete mode 100644 streaming/src/main/scala/spark/streaming/DStream.scala delete mode 100644 streaming/src/main/scala/spark/streaming/DStreamCheckpointData.scala delete mode 100644 streaming/src/main/scala/spark/streaming/DStreamGraph.scala delete mode 100644 streaming/src/main/scala/spark/streaming/Duration.scala delete mode 100644 streaming/src/main/scala/spark/streaming/Interval.scala delete mode 100644 streaming/src/main/scala/spark/streaming/Job.scala delete mode 100644 streaming/src/main/scala/spark/streaming/JobManager.scala delete mode 100644 streaming/src/main/scala/spark/streaming/NetworkInputTracker.scala delete mode 100644 streaming/src/main/scala/spark/streaming/PairDStreamFunctions.scala delete mode 100644 streaming/src/main/scala/spark/streaming/Scheduler.scala delete mode 100644 streaming/src/main/scala/spark/streaming/StreamingContext.scala delete mode 100644 streaming/src/main/scala/spark/streaming/Time.scala delete mode 100644 streaming/src/main/scala/spark/streaming/api/java/JavaDStream.scala delete mode 100644 streaming/src/main/scala/spark/streaming/api/java/JavaDStreamLike.scala delete mode 100644 streaming/src/main/scala/spark/streaming/api/java/JavaPairDStream.scala delete mode 100644 streaming/src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala delete mode 100644 streaming/src/main/scala/spark/streaming/dstream/CoGroupedDStream.scala delete mode 100644 streaming/src/main/scala/spark/streaming/dstream/ConstantInputDStream.scala delete mode 100644 streaming/src/main/scala/spark/streaming/dstream/FileInputDStream.scala delete mode 100644 streaming/src/main/scala/spark/streaming/dstream/FilteredDStream.scala delete mode 100644 streaming/src/main/scala/spark/streaming/dstream/FlatMapValuedDStream.scala delete mode 100644 streaming/src/main/scala/spark/streaming/dstream/FlatMappedDStream.scala delete mode 100644 streaming/src/main/scala/spark/streaming/dstream/FlumeInputDStream.scala delete mode 100644 streaming/src/main/scala/spark/streaming/dstream/ForEachDStream.scala delete mode 100644 streaming/src/main/scala/spark/streaming/dstream/GlommedDStream.scala delete mode 100644 streaming/src/main/scala/spark/streaming/dstream/InputDStream.scala delete mode 100644 streaming/src/main/scala/spark/streaming/dstream/KafkaInputDStream.scala delete mode 100644 streaming/src/main/scala/spark/streaming/dstream/MapPartitionedDStream.scala delete mode 100644 streaming/src/main/scala/spark/streaming/dstream/MapValuedDStream.scala delete mode 100644 streaming/src/main/scala/spark/streaming/dstream/MappedDStream.scala delete mode 100644 streaming/src/main/scala/spark/streaming/dstream/NetworkInputDStream.scala delete mode 100644 streaming/src/main/scala/spark/streaming/dstream/PluggableInputDStream.scala delete mode 100644 streaming/src/main/scala/spark/streaming/dstream/QueueInputDStream.scala delete mode 100644 streaming/src/main/scala/spark/streaming/dstream/RawInputDStream.scala delete mode 100644 streaming/src/main/scala/spark/streaming/dstream/ReducedWindowedDStream.scala delete mode 100644 streaming/src/main/scala/spark/streaming/dstream/ShuffledDStream.scala delete mode 100644 streaming/src/main/scala/spark/streaming/dstream/SocketInputDStream.scala delete mode 100644 streaming/src/main/scala/spark/streaming/dstream/StateDStream.scala delete mode 100644 streaming/src/main/scala/spark/streaming/dstream/TransformedDStream.scala delete mode 100644 streaming/src/main/scala/spark/streaming/dstream/TwitterInputDStream.scala delete mode 100644 streaming/src/main/scala/spark/streaming/dstream/UnionDStream.scala delete mode 100644 streaming/src/main/scala/spark/streaming/dstream/WindowedDStream.scala delete mode 100644 streaming/src/main/scala/spark/streaming/receivers/ActorReceiver.scala delete mode 100644 streaming/src/main/scala/spark/streaming/receivers/ZeroMQReceiver.scala delete mode 100644 streaming/src/main/scala/spark/streaming/util/Clock.scala delete mode 100644 streaming/src/main/scala/spark/streaming/util/MasterFailureTest.scala delete mode 100644 streaming/src/main/scala/spark/streaming/util/RawTextHelper.scala delete mode 100644 streaming/src/main/scala/spark/streaming/util/RawTextSender.scala delete mode 100644 streaming/src/main/scala/spark/streaming/util/RecurringTimer.scala create mode 100644 streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java create mode 100644 streaming/src/test/java/org/apache/spark/streaming/JavaTestUtils.scala delete mode 100644 streaming/src/test/java/spark/streaming/JavaAPISuite.java delete mode 100644 streaming/src/test/java/spark/streaming/JavaTestUtils.scala create mode 100644 streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala create mode 100644 streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala create mode 100644 streaming/src/test/scala/org/apache/spark/streaming/FailureSuite.scala create mode 100644 streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala create mode 100644 streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala create mode 100644 streaming/src/test/scala/org/apache/spark/streaming/WindowOperationsSuite.scala delete mode 100644 streaming/src/test/scala/spark/streaming/BasicOperationsSuite.scala delete mode 100644 streaming/src/test/scala/spark/streaming/CheckpointSuite.scala delete mode 100644 streaming/src/test/scala/spark/streaming/FailureSuite.scala delete mode 100644 streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala delete mode 100644 streaming/src/test/scala/spark/streaming/TestSuiteBase.scala delete mode 100644 streaming/src/test/scala/spark/streaming/WindowOperationsSuite.scala (limited to 'streaming/src') diff --git a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala new file mode 100644 index 0000000000..2d8f072624 --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala @@ -0,0 +1,190 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming + +import java.io._ +import java.util.concurrent.Executors +import java.util.concurrent.RejectedExecutionException + +import org.apache.hadoop.fs.Path +import org.apache.hadoop.conf.Configuration + +import org.apache.spark.Logging +import org.apache.spark.io.CompressionCodec + + +private[streaming] +class Checkpoint(@transient ssc: StreamingContext, val checkpointTime: Time) + extends Logging with Serializable { + val master = ssc.sc.master + val framework = ssc.sc.appName + val sparkHome = ssc.sc.sparkHome + val jars = ssc.sc.jars + val environment = ssc.sc.environment + val graph = ssc.graph + val checkpointDir = ssc.checkpointDir + val checkpointDuration = ssc.checkpointDuration + val pendingTimes = ssc.scheduler.jobManager.getPendingTimes() + + def validate() { + assert(master != null, "Checkpoint.master is null") + assert(framework != null, "Checkpoint.framework is null") + assert(graph != null, "Checkpoint.graph is null") + assert(checkpointTime != null, "Checkpoint.checkpointTime is null") + logInfo("Checkpoint for time " + checkpointTime + " validated") + } +} + + +/** + * Convenience class to speed up the writing of graph checkpoint to file + */ +private[streaming] +class CheckpointWriter(checkpointDir: String) extends Logging { + val file = new Path(checkpointDir, "graph") + // The file to which we actually write - and then "move" to file. + private val writeFile = new Path(file.getParent, file.getName + ".next") + private val bakFile = new Path(file.getParent, file.getName + ".bk") + + private var stopped = false + + val conf = new Configuration() + var fs = file.getFileSystem(conf) + val maxAttempts = 3 + val executor = Executors.newFixedThreadPool(1) + + private val compressionCodec = CompressionCodec.createCodec() + + // Removed code which validates whether there is only one CheckpointWriter per path 'file' since + // I did not notice any errors - reintroduce it ? + + class CheckpointWriteHandler(checkpointTime: Time, bytes: Array[Byte]) extends Runnable { + def run() { + var attempts = 0 + val startTime = System.currentTimeMillis() + while (attempts < maxAttempts) { + attempts += 1 + try { + logDebug("Saving checkpoint for time " + checkpointTime + " to file '" + file + "'") + // This is inherently thread unsafe .. so alleviating it by writing to '.new' and then doing moves : which should be pretty fast. + val fos = fs.create(writeFile) + fos.write(bytes) + fos.close() + if (fs.exists(file) && fs.rename(file, bakFile)) { + logDebug("Moved existing checkpoint file to " + bakFile) + } + // paranoia + fs.delete(file, false) + fs.rename(writeFile, file) + + val finishTime = System.currentTimeMillis(); + logInfo("Checkpoint for time " + checkpointTime + " saved to file '" + file + + "', took " + bytes.length + " bytes and " + (finishTime - startTime) + " milliseconds") + return + } catch { + case ioe: IOException => + logWarning("Error writing checkpoint to file in " + attempts + " attempts", ioe) + } + } + logError("Could not write checkpoint for time " + checkpointTime + " to file '" + file + "'") + } + } + + def write(checkpoint: Checkpoint) { + val bos = new ByteArrayOutputStream() + val zos = compressionCodec.compressedOutputStream(bos) + val oos = new ObjectOutputStream(zos) + oos.writeObject(checkpoint) + oos.close() + bos.close() + try { + executor.execute(new CheckpointWriteHandler(checkpoint.checkpointTime, bos.toByteArray)) + } catch { + case rej: RejectedExecutionException => + logError("Could not submit checkpoint task to the thread pool executor", rej) + } + } + + def stop() { + synchronized { + if (stopped) return ; + stopped = true + } + executor.shutdown() + val startTime = System.currentTimeMillis() + val terminated = executor.awaitTermination(10, java.util.concurrent.TimeUnit.SECONDS) + val endTime = System.currentTimeMillis() + logInfo("CheckpointWriter executor terminated ? " + terminated + ", waited for " + (endTime - startTime) + " ms.") + } +} + + +private[streaming] +object CheckpointReader extends Logging { + + def read(path: String): Checkpoint = { + val fs = new Path(path).getFileSystem(new Configuration()) + val attempts = Seq(new Path(path, "graph"), new Path(path, "graph.bk"), new Path(path), new Path(path + ".bk")) + + val compressionCodec = CompressionCodec.createCodec() + + attempts.foreach(file => { + if (fs.exists(file)) { + logInfo("Attempting to load checkpoint from file '" + file + "'") + try { + val fis = fs.open(file) + // ObjectInputStream uses the last defined user-defined class loader in the stack + // to find classes, which maybe the wrong class loader. Hence, a inherited version + // of ObjectInputStream is used to explicitly use the current thread's default class + // loader to find and load classes. This is a well know Java issue and has popped up + // in other places (e.g., http://jira.codehaus.org/browse/GROOVY-1627) + val zis = compressionCodec.compressedInputStream(fis) + val ois = new ObjectInputStreamWithLoader(zis, Thread.currentThread().getContextClassLoader) + val cp = ois.readObject.asInstanceOf[Checkpoint] + ois.close() + fs.close() + cp.validate() + logInfo("Checkpoint successfully loaded from file '" + file + "'") + logInfo("Checkpoint was generated at time " + cp.checkpointTime) + return cp + } catch { + case e: Exception => + logError("Error loading checkpoint from file '" + file + "'", e) + } + } else { + logWarning("Could not read checkpoint from file '" + file + "' as it does not exist") + } + + }) + throw new Exception("Could not read checkpoint from path '" + path + "'") + } +} + +private[streaming] +class ObjectInputStreamWithLoader(inputStream_ : InputStream, loader: ClassLoader) + extends ObjectInputStream(inputStream_) { + + override def resolveClass(desc: ObjectStreamClass): Class[_] = { + try { + return loader.loadClass(desc.getName()) + } catch { + case e: Exception => + } + return super.resolveClass(desc) + } +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/DStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/DStream.scala new file mode 100644 index 0000000000..362247cc38 --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/DStream.scala @@ -0,0 +1,702 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming + +import org.apache.spark.streaming.dstream._ +import StreamingContext._ +import org.apache.spark.util.MetadataCleaner + +//import Time._ + +import org.apache.spark.{RDD, Logging} +import org.apache.spark.storage.StorageLevel + +import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.HashMap + +import java.io.{ObjectInputStream, IOException, ObjectOutputStream} + +import org.apache.hadoop.fs.{FileSystem, Path} +import org.apache.hadoop.conf.Configuration + +/** + * A Discretized Stream (DStream), the basic abstraction in Spark Streaming, is a continuous + * sequence of RDDs (of the same type) representing a continuous stream of data (see [[org.apache.spark.RDD]] + * for more details on RDDs). DStreams can either be created from live data (such as, data from + * HDFS, Kafka or Flume) or it can be generated by transformation existing DStreams using operations + * such as `map`, `window` and `reduceByKeyAndWindow`. While a Spark Streaming program is running, each + * DStream periodically generates a RDD, either from live data or by transforming the RDD generated + * by a parent DStream. + * + * This class contains the basic operations available on all DStreams, such as `map`, `filter` and + * `window`. In addition, [[org.apache.spark.streaming.PairDStreamFunctions]] contains operations available + * only on DStreams of key-value pairs, such as `groupByKeyAndWindow` and `join`. These operations + * are automatically available on any DStream of the right type (e.g., DStream[(Int, Int)] through + * implicit conversions when `spark.streaming.StreamingContext._` is imported. + * + * DStreams internally is characterized by a few basic properties: + * - A list of other DStreams that the DStream depends on + * - A time interval at which the DStream generates an RDD + * - A function that is used to generate an RDD after each time interval + */ + +abstract class DStream[T: ClassManifest] ( + @transient protected[streaming] var ssc: StreamingContext + ) extends Serializable with Logging { + + initLogging() + + // ======================================================================= + // Methods that should be implemented by subclasses of DStream + // ======================================================================= + + /** Time interval after which the DStream generates a RDD */ + def slideDuration: Duration + + /** List of parent DStreams on which this DStream depends on */ + def dependencies: List[DStream[_]] + + /** Method that generates a RDD for the given time */ + def compute (validTime: Time): Option[RDD[T]] + + // ======================================================================= + // Methods and fields available on all DStreams + // ======================================================================= + + // RDDs generated, marked as protected[streaming] so that testsuites can access it + @transient + protected[streaming] var generatedRDDs = new HashMap[Time, RDD[T]] () + + // Time zero for the DStream + protected[streaming] var zeroTime: Time = null + + // Duration for which the DStream will remember each RDD created + protected[streaming] var rememberDuration: Duration = null + + // Storage level of the RDDs in the stream + protected[streaming] var storageLevel: StorageLevel = StorageLevel.NONE + + // Checkpoint details + protected[streaming] val mustCheckpoint = false + protected[streaming] var checkpointDuration: Duration = null + protected[streaming] val checkpointData = new DStreamCheckpointData(this) + + // Reference to whole DStream graph + protected[streaming] var graph: DStreamGraph = null + + protected[streaming] def isInitialized = (zeroTime != null) + + // Duration for which the DStream requires its parent DStream to remember each RDD created + protected[streaming] def parentRememberDuration = rememberDuration + + /** Return the StreamingContext associated with this DStream */ + def context = ssc + + /** Persist the RDDs of this DStream with the given storage level */ + def persist(level: StorageLevel): DStream[T] = { + if (this.isInitialized) { + throw new UnsupportedOperationException( + "Cannot change storage level of an DStream after streaming context has started") + } + this.storageLevel = level + this + } + + /** Persist RDDs of this DStream with the default storage level (MEMORY_ONLY_SER) */ + def persist(): DStream[T] = persist(StorageLevel.MEMORY_ONLY_SER) + + /** Persist RDDs of this DStream with the default storage level (MEMORY_ONLY_SER) */ + def cache(): DStream[T] = persist() + + /** + * Enable periodic checkpointing of RDDs of this DStream + * @param interval Time interval after which generated RDD will be checkpointed + */ + def checkpoint(interval: Duration): DStream[T] = { + if (isInitialized) { + throw new UnsupportedOperationException( + "Cannot change checkpoint interval of an DStream after streaming context has started") + } + persist() + checkpointDuration = interval + this + } + + /** + * Initialize the DStream by setting the "zero" time, based on which + * the validity of future times is calculated. This method also recursively initializes + * its parent DStreams. + */ + protected[streaming] def initialize(time: Time) { + if (zeroTime != null && zeroTime != time) { + throw new Exception("ZeroTime is already initialized to " + zeroTime + + ", cannot initialize it again to " + time) + } + zeroTime = time + + // Set the checkpoint interval to be slideDuration or 10 seconds, which ever is larger + if (mustCheckpoint && checkpointDuration == null) { + checkpointDuration = slideDuration * math.ceil(Seconds(10) / slideDuration).toInt + logInfo("Checkpoint interval automatically set to " + checkpointDuration) + } + + // Set the minimum value of the rememberDuration if not already set + var minRememberDuration = slideDuration + if (checkpointDuration != null && minRememberDuration <= checkpointDuration) { + minRememberDuration = checkpointDuration * 2 // times 2 just to be sure that the latest checkpoint is not forgetten + } + if (rememberDuration == null || rememberDuration < minRememberDuration) { + rememberDuration = minRememberDuration + } + + // Initialize the dependencies + dependencies.foreach(_.initialize(zeroTime)) + } + + protected[streaming] def validate() { + assert(rememberDuration != null, "Remember duration is set to null") + + assert( + !mustCheckpoint || checkpointDuration != null, + "The checkpoint interval for " + this.getClass.getSimpleName + " has not been set." + + " Please use DStream.checkpoint() to set the interval." + ) + + assert( + checkpointDuration == null || context.sparkContext.checkpointDir.isDefined, + "The checkpoint directory has not been set. Please use StreamingContext.checkpoint()" + + " or SparkContext.checkpoint() to set the checkpoint directory." + ) + + assert( + checkpointDuration == null || checkpointDuration >= slideDuration, + "The checkpoint interval for " + this.getClass.getSimpleName + " has been set to " + + checkpointDuration + " which is lower than its slide time (" + slideDuration + "). " + + "Please set it to at least " + slideDuration + "." + ) + + assert( + checkpointDuration == null || checkpointDuration.isMultipleOf(slideDuration), + "The checkpoint interval for " + this.getClass.getSimpleName + " has been set to " + + checkpointDuration + " which not a multiple of its slide time (" + slideDuration + "). " + + "Please set it to a multiple " + slideDuration + "." + ) + + assert( + checkpointDuration == null || storageLevel != StorageLevel.NONE, + "" + this.getClass.getSimpleName + " has been marked for checkpointing but the storage " + + "level has not been set to enable persisting. Please use DStream.persist() to set the " + + "storage level to use memory for better checkpointing performance." + ) + + assert( + checkpointDuration == null || rememberDuration > checkpointDuration, + "The remember duration for " + this.getClass.getSimpleName + " has been set to " + + rememberDuration + " which is not more than the checkpoint interval (" + + checkpointDuration + "). Please set it to higher than " + checkpointDuration + "." + ) + + val metadataCleanerDelay = MetadataCleaner.getDelaySeconds + logInfo("metadataCleanupDelay = " + metadataCleanerDelay) + assert( + metadataCleanerDelay < 0 || rememberDuration.milliseconds < metadataCleanerDelay * 1000, + "It seems you are doing some DStream window operation or setting a checkpoint interval " + + "which requires " + this.getClass.getSimpleName + " to remember generated RDDs for more " + + "than " + rememberDuration.milliseconds / 1000 + " seconds. But Spark's metadata cleanup" + + "delay is set to " + metadataCleanerDelay + " seconds, which is not sufficient. Please " + + "set the Java property 'spark.cleaner.delay' to more than " + + math.ceil(rememberDuration.milliseconds / 1000.0).toInt + " seconds." + ) + + dependencies.foreach(_.validate()) + + logInfo("Slide time = " + slideDuration) + logInfo("Storage level = " + storageLevel) + logInfo("Checkpoint interval = " + checkpointDuration) + logInfo("Remember duration = " + rememberDuration) + logInfo("Initialized and validated " + this) + } + + protected[streaming] def setContext(s: StreamingContext) { + if (ssc != null && ssc != s) { + throw new Exception("Context is already set in " + this + ", cannot set it again") + } + ssc = s + logInfo("Set context for " + this) + dependencies.foreach(_.setContext(ssc)) + } + + protected[streaming] def setGraph(g: DStreamGraph) { + if (graph != null && graph != g) { + throw new Exception("Graph is already set in " + this + ", cannot set it again") + } + graph = g + dependencies.foreach(_.setGraph(graph)) + } + + protected[streaming] def remember(duration: Duration) { + if (duration != null && duration > rememberDuration) { + rememberDuration = duration + logInfo("Duration for remembering RDDs set to " + rememberDuration + " for " + this) + } + dependencies.foreach(_.remember(parentRememberDuration)) + } + + /** Checks whether the 'time' is valid wrt slideDuration for generating RDD */ + protected def isTimeValid(time: Time): Boolean = { + if (!isInitialized) { + throw new Exception (this + " has not been initialized") + } else if (time <= zeroTime || ! (time - zeroTime).isMultipleOf(slideDuration)) { + logInfo("Time " + time + " is invalid as zeroTime is " + zeroTime + " and slideDuration is " + slideDuration + " and difference is " + (time - zeroTime)) + false + } else { + logInfo("Time " + time + " is valid") + true + } + } + + /** + * Retrieve a precomputed RDD of this DStream, or computes the RDD. This is an internal + * method that should not be called directly. + */ + protected[streaming] def getOrCompute(time: Time): Option[RDD[T]] = { + // If this DStream was not initialized (i.e., zeroTime not set), then do it + // If RDD was already generated, then retrieve it from HashMap + generatedRDDs.get(time) match { + + // If an RDD was already generated and is being reused, then + // probably all RDDs in this DStream will be reused and hence should be cached + case Some(oldRDD) => Some(oldRDD) + + // if RDD was not generated, and if the time is valid + // (based on sliding time of this DStream), then generate the RDD + case None => { + if (isTimeValid(time)) { + compute(time) match { + case Some(newRDD) => + if (storageLevel != StorageLevel.NONE) { + newRDD.persist(storageLevel) + logInfo("Persisting RDD " + newRDD.id + " for time " + time + " to " + storageLevel + " at time " + time) + } + if (checkpointDuration != null && (time - zeroTime).isMultipleOf(checkpointDuration)) { + newRDD.checkpoint() + logInfo("Marking RDD " + newRDD.id + " for time " + time + " for checkpointing at time " + time) + } + generatedRDDs.put(time, newRDD) + Some(newRDD) + case None => + None + } + } else { + None + } + } + } + } + + /** + * Generate a SparkStreaming job for the given time. This is an internal method that + * should not be called directly. This default implementation creates a job + * that materializes the corresponding RDD. Subclasses of DStream may override this + * to generate their own jobs. + */ + protected[streaming] def generateJob(time: Time): Option[Job] = { + getOrCompute(time) match { + case Some(rdd) => { + val jobFunc = () => { + val emptyFunc = { (iterator: Iterator[T]) => {} } + context.sparkContext.runJob(rdd, emptyFunc) + } + Some(new Job(time, jobFunc)) + } + case None => None + } + } + + /** + * Clear metadata that are older than `rememberDuration` of this DStream. + * This is an internal method that should not be called directly. This default + * implementation clears the old generated RDDs. Subclasses of DStream may override + * this to clear their own metadata along with the generated RDDs. + */ + protected[streaming] def clearOldMetadata(time: Time) { + var numForgotten = 0 + val oldRDDs = generatedRDDs.filter(_._1 <= (time - rememberDuration)) + generatedRDDs --= oldRDDs.keys + logInfo("Cleared " + oldRDDs.size + " RDDs that were older than " + + (time - rememberDuration) + ": " + oldRDDs.keys.mkString(", ")) + dependencies.foreach(_.clearOldMetadata(time)) + } + + /* Adds metadata to the Stream while it is running. + * This methd should be overwritten by sublcasses of InputDStream. + */ + protected[streaming] def addMetadata(metadata: Any) { + if (metadata != null) { + logInfo("Dropping Metadata: " + metadata.toString) + } + } + + /** + * Refresh the list of checkpointed RDDs that will be saved along with checkpoint of + * this stream. This is an internal method that should not be called directly. This is + * a default implementation that saves only the file names of the checkpointed RDDs to + * checkpointData. Subclasses of DStream (especially those of InputDStream) may override + * this method to save custom checkpoint data. + */ + protected[streaming] def updateCheckpointData(currentTime: Time) { + logInfo("Updating checkpoint data for time " + currentTime) + checkpointData.update() + dependencies.foreach(_.updateCheckpointData(currentTime)) + checkpointData.cleanup() + logDebug("Updated checkpoint data for time " + currentTime + ": " + checkpointData) + } + + /** + * Restore the RDDs in generatedRDDs from the checkpointData. This is an internal method + * that should not be called directly. This is a default implementation that recreates RDDs + * from the checkpoint file names stored in checkpointData. Subclasses of DStream that + * override the updateCheckpointData() method would also need to override this method. + */ + protected[streaming] def restoreCheckpointData() { + // Create RDDs from the checkpoint data + logInfo("Restoring checkpoint data") + checkpointData.restore() + dependencies.foreach(_.restoreCheckpointData()) + logInfo("Restored checkpoint data") + } + + @throws(classOf[IOException]) + private def writeObject(oos: ObjectOutputStream) { + logDebug(this.getClass().getSimpleName + ".writeObject used") + if (graph != null) { + graph.synchronized { + if (graph.checkpointInProgress) { + oos.defaultWriteObject() + } else { + val msg = "Object of " + this.getClass.getName + " is being serialized " + + " possibly as a part of closure of an RDD operation. This is because " + + " the DStream object is being referred to from within the closure. " + + " Please rewrite the RDD operation inside this DStream to avoid this. " + + " This has been enforced to avoid bloating of Spark tasks " + + " with unnecessary objects." + throw new java.io.NotSerializableException(msg) + } + } + } else { + throw new java.io.NotSerializableException("Graph is unexpectedly null when DStream is being serialized.") + } + } + + @throws(classOf[IOException]) + private def readObject(ois: ObjectInputStream) { + logDebug(this.getClass().getSimpleName + ".readObject used") + ois.defaultReadObject() + generatedRDDs = new HashMap[Time, RDD[T]] () + } + + // ======================================================================= + // DStream operations + // ======================================================================= + + /** Return a new DStream by applying a function to all elements of this DStream. */ + def map[U: ClassManifest](mapFunc: T => U): DStream[U] = { + new MappedDStream(this, context.sparkContext.clean(mapFunc)) + } + + /** + * Return a new DStream by applying a function to all elements of this DStream, + * and then flattening the results + */ + def flatMap[U: ClassManifest](flatMapFunc: T => Traversable[U]): DStream[U] = { + new FlatMappedDStream(this, context.sparkContext.clean(flatMapFunc)) + } + + /** Return a new DStream containing only the elements that satisfy a predicate. */ + def filter(filterFunc: T => Boolean): DStream[T] = new FilteredDStream(this, filterFunc) + + /** + * Return a new DStream in which each RDD is generated by applying glom() to each RDD of + * this DStream. Applying glom() to an RDD coalesces all elements within each partition into + * an array. + */ + def glom(): DStream[Array[T]] = new GlommedDStream(this) + + /** + * Return a new DStream in which each RDD is generated by applying mapPartitions() to each RDDs + * of this DStream. Applying mapPartitions() to an RDD applies a function to each partition + * of the RDD. + */ + def mapPartitions[U: ClassManifest]( + mapPartFunc: Iterator[T] => Iterator[U], + preservePartitioning: Boolean = false + ): DStream[U] = { + new MapPartitionedDStream(this, context.sparkContext.clean(mapPartFunc), preservePartitioning) + } + + /** + * Return a new DStream in which each RDD has a single element generated by reducing each RDD + * of this DStream. + */ + def reduce(reduceFunc: (T, T) => T): DStream[T] = + this.map(x => (null, x)).reduceByKey(reduceFunc, 1).map(_._2) + + /** + * Return a new DStream in which each RDD has a single element generated by counting each RDD + * of this DStream. + */ + def count(): DStream[Long] = { + this.map(_ => (null, 1L)) + .transform(_.union(context.sparkContext.makeRDD(Seq((null, 0L)), 1))) + .reduceByKey(_ + _) + .map(_._2) + } + + /** + * Return a new DStream in which each RDD contains the counts of each distinct value in + * each RDD of this DStream. Hash partitioning is used to generate + * the RDDs with `numPartitions` partitions (Spark's default number of partitions if + * `numPartitions` not specified). + */ + def countByValue(numPartitions: Int = ssc.sc.defaultParallelism): DStream[(T, Long)] = + this.map(x => (x, 1L)).reduceByKey((x: Long, y: Long) => x + y, numPartitions) + + /** + * Apply a function to each RDD in this DStream. This is an output operator, so + * this DStream will be registered as an output stream and therefore materialized. + */ + def foreach(foreachFunc: RDD[T] => Unit) { + this.foreach((r: RDD[T], t: Time) => foreachFunc(r)) + } + + /** + * Apply a function to each RDD in this DStream. This is an output operator, so + * this DStream will be registered as an output stream and therefore materialized. + */ + def foreach(foreachFunc: (RDD[T], Time) => Unit) { + val newStream = new ForEachDStream(this, context.sparkContext.clean(foreachFunc)) + ssc.registerOutputStream(newStream) + newStream + } + + /** + * Return a new DStream in which each RDD is generated by applying a function + * on each RDD of this DStream. + */ + def transform[U: ClassManifest](transformFunc: RDD[T] => RDD[U]): DStream[U] = { + transform((r: RDD[T], t: Time) => transformFunc(r)) + } + + /** + * Return a new DStream in which each RDD is generated by applying a function + * on each RDD of this DStream. + */ + def transform[U: ClassManifest](transformFunc: (RDD[T], Time) => RDD[U]): DStream[U] = { + new TransformedDStream(this, context.sparkContext.clean(transformFunc)) + } + + /** + * Print the first ten elements of each RDD generated in this DStream. This is an output + * operator, so this DStream will be registered as an output stream and there materialized. + */ + def print() { + def foreachFunc = (rdd: RDD[T], time: Time) => { + val first11 = rdd.take(11) + println ("-------------------------------------------") + println ("Time: " + time) + println ("-------------------------------------------") + first11.take(10).foreach(println) + if (first11.size > 10) println("...") + println() + } + val newStream = new ForEachDStream(this, context.sparkContext.clean(foreachFunc)) + ssc.registerOutputStream(newStream) + } + + /** + * Return a new DStream in which each RDD contains all the elements in seen in a + * sliding window of time over this DStream. The new DStream generates RDDs with + * the same interval as this DStream. + * @param windowDuration width of the window; must be a multiple of this DStream's interval. + */ + def window(windowDuration: Duration): DStream[T] = window(windowDuration, this.slideDuration) + + /** + * Return a new DStream in which each RDD contains all the elements in seen in a + * sliding window of time over this DStream. + * @param windowDuration width of the window; must be a multiple of this DStream's + * batching interval + * @param slideDuration sliding interval of the window (i.e., the interval after which + * the new DStream will generate RDDs); must be a multiple of this + * DStream's batching interval + */ + def window(windowDuration: Duration, slideDuration: Duration): DStream[T] = { + new WindowedDStream(this, windowDuration, slideDuration) + } + + /** + * Return a new DStream in which each RDD has a single element generated by reducing all + * elements in a sliding window over this DStream. + * @param reduceFunc associative reduce function + * @param windowDuration width of the window; must be a multiple of this DStream's + * batching interval + * @param slideDuration sliding interval of the window (i.e., the interval after which + * the new DStream will generate RDDs); must be a multiple of this + * DStream's batching interval + */ + def reduceByWindow( + reduceFunc: (T, T) => T, + windowDuration: Duration, + slideDuration: Duration + ): DStream[T] = { + this.reduce(reduceFunc).window(windowDuration, slideDuration).reduce(reduceFunc) + } + + /** + * Return a new DStream in which each RDD has a single element generated by reducing all + * elements in a sliding window over this DStream. However, the reduction is done incrementally + * using the old window's reduced value : + * 1. reduce the new values that entered the window (e.g., adding new counts) + * 2. "inverse reduce" the old values that left the window (e.g., subtracting old counts) + * This is more efficient than reduceByWindow without "inverse reduce" function. + * However, it is applicable to only "invertible reduce functions". + * @param reduceFunc associative reduce function + * @param invReduceFunc inverse reduce function + * @param windowDuration width of the window; must be a multiple of this DStream's + * batching interval + * @param slideDuration sliding interval of the window (i.e., the interval after which + * the new DStream will generate RDDs); must be a multiple of this + * DStream's batching interval + */ + def reduceByWindow( + reduceFunc: (T, T) => T, + invReduceFunc: (T, T) => T, + windowDuration: Duration, + slideDuration: Duration + ): DStream[T] = { + this.map(x => (1, x)) + .reduceByKeyAndWindow(reduceFunc, invReduceFunc, windowDuration, slideDuration, 1) + .map(_._2) + } + + /** + * Return a new DStream in which each RDD has a single element generated by counting the number + * of elements in a sliding window over this DStream. Hash partitioning is used to generate the RDDs with + * Spark's default number of partitions. + * @param windowDuration width of the window; must be a multiple of this DStream's + * batching interval + * @param slideDuration sliding interval of the window (i.e., the interval after which + * the new DStream will generate RDDs); must be a multiple of this + * DStream's batching interval + */ + def countByWindow(windowDuration: Duration, slideDuration: Duration): DStream[Long] = { + this.map(_ => 1L).reduceByWindow(_ + _, _ - _, windowDuration, slideDuration) + } + + /** + * Return a new DStream in which each RDD contains the count of distinct elements in + * RDDs in a sliding window over this DStream. Hash partitioning is used to generate + * the RDDs with `numPartitions` partitions (Spark's default number of partitions if + * `numPartitions` not specified). + * @param windowDuration width of the window; must be a multiple of this DStream's + * batching interval + * @param slideDuration sliding interval of the window (i.e., the interval after which + * the new DStream will generate RDDs); must be a multiple of this + * DStream's batching interval + * @param numPartitions number of partitions of each RDD in the new DStream. + */ + def countByValueAndWindow( + windowDuration: Duration, + slideDuration: Duration, + numPartitions: Int = ssc.sc.defaultParallelism + ): DStream[(T, Long)] = { + + this.map(x => (x, 1L)).reduceByKeyAndWindow( + (x: Long, y: Long) => x + y, + (x: Long, y: Long) => x - y, + windowDuration, + slideDuration, + numPartitions, + (x: (T, Long)) => x._2 != 0L + ) + } + + /** + * Return a new DStream by unifying data of another DStream with this DStream. + * @param that Another DStream having the same slideDuration as this DStream. + */ + def union(that: DStream[T]): DStream[T] = new UnionDStream[T](Array(this, that)) + + /** + * Return all the RDDs defined by the Interval object (both end times included) + */ + protected[streaming] def slice(interval: Interval): Seq[RDD[T]] = { + slice(interval.beginTime, interval.endTime) + } + + /** + * Return all the RDDs between 'fromTime' to 'toTime' (both included) + */ + def slice(fromTime: Time, toTime: Time): Seq[RDD[T]] = { + if (!(fromTime - zeroTime).isMultipleOf(slideDuration)) { + logWarning("fromTime (" + fromTime + ") is not a multiple of slideDuration (" + slideDuration + ")") + } + if (!(toTime - zeroTime).isMultipleOf(slideDuration)) { + logWarning("toTime (" + fromTime + ") is not a multiple of slideDuration (" + slideDuration + ")") + } + val alignedToTime = toTime.floor(slideDuration) + val alignedFromTime = fromTime.floor(slideDuration) + + logInfo("Slicing from " + fromTime + " to " + toTime + + " (aligned to " + alignedFromTime + " and " + alignedToTime + ")") + + alignedFromTime.to(alignedToTime, slideDuration).flatMap(time => { + if (time >= zeroTime) getOrCompute(time) else None + }) + } + + /** + * Save each RDD in this DStream as a Sequence file of serialized objects. + * The file name at each batch interval is generated based on `prefix` and + * `suffix`: "prefix-TIME_IN_MS.suffix". + */ + def saveAsObjectFiles(prefix: String, suffix: String = "") { + val saveFunc = (rdd: RDD[T], time: Time) => { + val file = rddToFileName(prefix, suffix, time) + rdd.saveAsObjectFile(file) + } + this.foreach(saveFunc) + } + + /** + * Save each RDD in this DStream as at text file, using string representation + * of elements. The file name at each batch interval is generated based on + * `prefix` and `suffix`: "prefix-TIME_IN_MS.suffix". + */ + def saveAsTextFiles(prefix: String, suffix: String = "") { + val saveFunc = (rdd: RDD[T], time: Time) => { + val file = rddToFileName(prefix, suffix, time) + rdd.saveAsTextFile(file) + } + this.foreach(saveFunc) + } + + def register() { + ssc.registerOutputStream(this) + } +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/DStreamCheckpointData.scala b/streaming/src/main/scala/org/apache/spark/streaming/DStreamCheckpointData.scala new file mode 100644 index 0000000000..58a0da2870 --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/DStreamCheckpointData.scala @@ -0,0 +1,110 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming + +import org.apache.hadoop.fs.Path +import org.apache.hadoop.fs.FileSystem +import org.apache.hadoop.conf.Configuration +import collection.mutable.HashMap +import org.apache.spark.Logging + + + +private[streaming] +class DStreamCheckpointData[T: ClassManifest] (dstream: DStream[T]) + extends Serializable with Logging { + protected val data = new HashMap[Time, AnyRef]() + + @transient private var fileSystem : FileSystem = null + @transient private var lastCheckpointFiles: HashMap[Time, String] = null + + protected[streaming] def checkpointFiles = data.asInstanceOf[HashMap[Time, String]] + + /** + * Updates the checkpoint data of the DStream. This gets called every time + * the graph checkpoint is initiated. Default implementation records the + * checkpoint files to which the generate RDDs of the DStream has been saved. + */ + def update() { + + // Get the checkpointed RDDs from the generated RDDs + val newCheckpointFiles = dstream.generatedRDDs.filter(_._2.getCheckpointFile.isDefined) + .map(x => (x._1, x._2.getCheckpointFile.get)) + + // Make a copy of the existing checkpoint data (checkpointed RDDs) + lastCheckpointFiles = checkpointFiles.clone() + + // If the new checkpoint data has checkpoints then replace existing with the new one + if (newCheckpointFiles.size > 0) { + checkpointFiles.clear() + checkpointFiles ++= newCheckpointFiles + } + + // TODO: remove this, this is just for debugging + newCheckpointFiles.foreach { + case (time, data) => { logInfo("Added checkpointed RDD for time " + time + " to stream checkpoint") } + } + } + + /** + * Cleanup old checkpoint data. This gets called every time the graph + * checkpoint is initiated, but after `update` is called. Default + * implementation, cleans up old checkpoint files. + */ + def cleanup() { + // If there is at least on checkpoint file in the current checkpoint files, + // then delete the old checkpoint files. + if (checkpointFiles.size > 0 && lastCheckpointFiles != null) { + (lastCheckpointFiles -- checkpointFiles.keySet).foreach { + case (time, file) => { + try { + val path = new Path(file) + if (fileSystem == null) { + fileSystem = path.getFileSystem(new Configuration()) + } + fileSystem.delete(path, true) + logInfo("Deleted checkpoint file '" + file + "' for time " + time) + } catch { + case e: Exception => + logWarning("Error deleting old checkpoint file '" + file + "' for time " + time, e) + } + } + } + } + } + + /** + * Restore the checkpoint data. This gets called once when the DStream graph + * (along with its DStreams) are being restored from a graph checkpoint file. + * Default implementation restores the RDDs from their checkpoint files. + */ + def restore() { + // Create RDDs from the checkpoint data + checkpointFiles.foreach { + case(time, file) => { + logInfo("Restoring checkpointed RDD for time " + time + " from file '" + file + "'") + dstream.generatedRDDs += ((time, dstream.context.sparkContext.checkpointFile[T](file))) + } + } + } + + override def toString() = { + "[\n" + checkpointFiles.size + " checkpoint files \n" + checkpointFiles.mkString("\n") + "\n]" + } +} + diff --git a/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala b/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala new file mode 100644 index 0000000000..b9a58fded6 --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala @@ -0,0 +1,167 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming + +import dstream.InputDStream +import java.io.{ObjectInputStream, IOException, ObjectOutputStream} +import collection.mutable.ArrayBuffer +import org.apache.spark.Logging + +final private[streaming] class DStreamGraph extends Serializable with Logging { + initLogging() + + private val inputStreams = new ArrayBuffer[InputDStream[_]]() + private val outputStreams = new ArrayBuffer[DStream[_]]() + + var rememberDuration: Duration = null + var checkpointInProgress = false + + var zeroTime: Time = null + var startTime: Time = null + var batchDuration: Duration = null + + def start(time: Time) { + this.synchronized { + if (zeroTime != null) { + throw new Exception("DStream graph computation already started") + } + zeroTime = time + startTime = time + outputStreams.foreach(_.initialize(zeroTime)) + outputStreams.foreach(_.remember(rememberDuration)) + outputStreams.foreach(_.validate) + inputStreams.par.foreach(_.start()) + } + } + + def restart(time: Time) { + this.synchronized { startTime = time } + } + + def stop() { + this.synchronized { + inputStreams.par.foreach(_.stop()) + } + } + + def setContext(ssc: StreamingContext) { + this.synchronized { + outputStreams.foreach(_.setContext(ssc)) + } + } + + def setBatchDuration(duration: Duration) { + this.synchronized { + if (batchDuration != null) { + throw new Exception("Batch duration already set as " + batchDuration + + ". cannot set it again.") + } + batchDuration = duration + } + } + + def remember(duration: Duration) { + this.synchronized { + if (rememberDuration != null) { + throw new Exception("Batch duration already set as " + batchDuration + + ". cannot set it again.") + } + rememberDuration = duration + } + } + + def addInputStream(inputStream: InputDStream[_]) { + this.synchronized { + inputStream.setGraph(this) + inputStreams += inputStream + } + } + + def addOutputStream(outputStream: DStream[_]) { + this.synchronized { + outputStream.setGraph(this) + outputStreams += outputStream + } + } + + def getInputStreams() = this.synchronized { inputStreams.toArray } + + def getOutputStreams() = this.synchronized { outputStreams.toArray } + + def generateJobs(time: Time): Seq[Job] = { + this.synchronized { + logInfo("Generating jobs for time " + time) + val jobs = outputStreams.flatMap(outputStream => outputStream.generateJob(time)) + logInfo("Generated " + jobs.length + " jobs for time " + time) + jobs + } + } + + def clearOldMetadata(time: Time) { + this.synchronized { + logInfo("Clearing old metadata for time " + time) + outputStreams.foreach(_.clearOldMetadata(time)) + logInfo("Cleared old metadata for time " + time) + } + } + + def updateCheckpointData(time: Time) { + this.synchronized { + logInfo("Updating checkpoint data for time " + time) + outputStreams.foreach(_.updateCheckpointData(time)) + logInfo("Updated checkpoint data for time " + time) + } + } + + def restoreCheckpointData() { + this.synchronized { + logInfo("Restoring checkpoint data") + outputStreams.foreach(_.restoreCheckpointData()) + logInfo("Restored checkpoint data") + } + } + + def validate() { + this.synchronized { + assert(batchDuration != null, "Batch duration has not been set") + //assert(batchDuration >= Milliseconds(100), "Batch duration of " + batchDuration + " is very low") + assert(getOutputStreams().size > 0, "No output streams registered, so nothing to execute") + } + } + + @throws(classOf[IOException]) + private def writeObject(oos: ObjectOutputStream) { + this.synchronized { + logDebug("DStreamGraph.writeObject used") + checkpointInProgress = true + oos.defaultWriteObject() + checkpointInProgress = false + } + } + + @throws(classOf[IOException]) + private def readObject(ois: ObjectInputStream) { + this.synchronized { + logDebug("DStreamGraph.readObject used") + checkpointInProgress = true + ois.defaultReadObject() + checkpointInProgress = false + } + } +} + diff --git a/streaming/src/main/scala/org/apache/spark/streaming/Duration.scala b/streaming/src/main/scala/org/apache/spark/streaming/Duration.scala new file mode 100644 index 0000000000..290ad37812 --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/Duration.scala @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming + +import org.apache.spark.Utils + +case class Duration (private val millis: Long) { + + def < (that: Duration): Boolean = (this.millis < that.millis) + + def <= (that: Duration): Boolean = (this.millis <= that.millis) + + def > (that: Duration): Boolean = (this.millis > that.millis) + + def >= (that: Duration): Boolean = (this.millis >= that.millis) + + def + (that: Duration): Duration = new Duration(millis + that.millis) + + def - (that: Duration): Duration = new Duration(millis - that.millis) + + def * (times: Int): Duration = new Duration(millis * times) + + def / (that: Duration): Double = millis.toDouble / that.millis.toDouble + + def isMultipleOf(that: Duration): Boolean = + (this.millis % that.millis == 0) + + def min(that: Duration): Duration = if (this < that) this else that + + def max(that: Duration): Duration = if (this > that) this else that + + def isZero: Boolean = (this.millis == 0) + + override def toString: String = (millis.toString + " ms") + + def toFormattedString: String = millis.toString + + def milliseconds: Long = millis + + def prettyPrint = Utils.msDurationToString(millis) + +} + +/** + * Helper object that creates instance of [[org.apache.spark.streaming.Duration]] representing + * a given number of milliseconds. + */ +object Milliseconds { + def apply(milliseconds: Long) = new Duration(milliseconds) +} + +/** + * Helper object that creates instance of [[org.apache.spark.streaming.Duration]] representing + * a given number of seconds. + */ +object Seconds { + def apply(seconds: Long) = new Duration(seconds * 1000) +} + +/** + * Helper object that creates instance of [[org.apache.spark.streaming.Duration]] representing + * a given number of minutes. + */ +object Minutes { + def apply(minutes: Long) = new Duration(minutes * 60000) +} + + diff --git a/streaming/src/main/scala/org/apache/spark/streaming/Interval.scala b/streaming/src/main/scala/org/apache/spark/streaming/Interval.scala new file mode 100644 index 0000000000..04c994c136 --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/Interval.scala @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming + +private[streaming] +class Interval(val beginTime: Time, val endTime: Time) { + def this(beginMs: Long, endMs: Long) = this(new Time(beginMs), new Time(endMs)) + + def duration(): Duration = endTime - beginTime + + def + (time: Duration): Interval = { + new Interval(beginTime + time, endTime + time) + } + + def - (time: Duration): Interval = { + new Interval(beginTime - time, endTime - time) + } + + def < (that: Interval): Boolean = { + if (this.duration != that.duration) { + throw new Exception("Comparing two intervals with different durations [" + this + ", " + that + "]") + } + this.endTime < that.endTime + } + + def <= (that: Interval) = (this < that || this == that) + + def > (that: Interval) = !(this <= that) + + def >= (that: Interval) = !(this < that) + + override def toString = "[" + beginTime + ", " + endTime + "]" +} + +private[streaming] +object Interval { + def currentInterval(duration: Duration): Interval = { + val time = new Time(System.currentTimeMillis) + val intervalBegin = time.floor(duration) + new Interval(intervalBegin, intervalBegin + duration) + } +} + + diff --git a/streaming/src/main/scala/org/apache/spark/streaming/Job.scala b/streaming/src/main/scala/org/apache/spark/streaming/Job.scala new file mode 100644 index 0000000000..2128b7c7a6 --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/Job.scala @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming + +import java.util.concurrent.atomic.AtomicLong + +private[streaming] +class Job(val time: Time, func: () => _) { + val id = Job.getNewId() + def run(): Long = { + val startTime = System.currentTimeMillis + func() + val stopTime = System.currentTimeMillis + (stopTime - startTime) + } + + override def toString = "streaming job " + id + " @ " + time +} + +private[streaming] +object Job { + val id = new AtomicLong(0) + + def getNewId() = id.getAndIncrement() +} + diff --git a/streaming/src/main/scala/org/apache/spark/streaming/JobManager.scala b/streaming/src/main/scala/org/apache/spark/streaming/JobManager.scala new file mode 100644 index 0000000000..5233129506 --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/JobManager.scala @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming + +import org.apache.spark.Logging +import org.apache.spark.SparkEnv +import java.util.concurrent.Executors +import collection.mutable.HashMap +import collection.mutable.ArrayBuffer + + +private[streaming] +class JobManager(ssc: StreamingContext, numThreads: Int = 1) extends Logging { + + class JobHandler(ssc: StreamingContext, job: Job) extends Runnable { + def run() { + SparkEnv.set(ssc.env) + try { + val timeTaken = job.run() + logInfo("Total delay: %.5f s for job %s of time %s (execution: %.5f s)".format( + (System.currentTimeMillis() - job.time.milliseconds) / 1000.0, job.id, job.time.milliseconds, timeTaken / 1000.0)) + } catch { + case e: Exception => + logError("Running " + job + " failed", e) + } + clearJob(job) + } + } + + initLogging() + + val jobExecutor = Executors.newFixedThreadPool(numThreads) + val jobs = new HashMap[Time, ArrayBuffer[Job]] + + def runJob(job: Job) { + jobs.synchronized { + jobs.getOrElseUpdate(job.time, new ArrayBuffer[Job]) += job + } + jobExecutor.execute(new JobHandler(ssc, job)) + logInfo("Added " + job + " to queue") + } + + def stop() { + jobExecutor.shutdown() + } + + private def clearJob(job: Job) { + var timeCleared = false + val time = job.time + jobs.synchronized { + val jobsOfTime = jobs.get(time) + if (jobsOfTime.isDefined) { + jobsOfTime.get -= job + if (jobsOfTime.get.isEmpty) { + jobs -= time + timeCleared = true + } + } else { + throw new Exception("Job finished for time " + job.time + + " but time does not exist in jobs") + } + } + if (timeCleared) { + ssc.scheduler.clearOldMetadata(time) + } + } + + def getPendingTimes(): Array[Time] = { + jobs.synchronized { + jobs.keySet.toArray + } + } +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/NetworkInputTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/NetworkInputTracker.scala new file mode 100644 index 0000000000..aae79a4e6f --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/NetworkInputTracker.scala @@ -0,0 +1,173 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming + +import org.apache.spark.streaming.dstream.{NetworkInputDStream, NetworkReceiver} +import org.apache.spark.streaming.dstream.{StopReceiver, ReportBlock, ReportError} +import org.apache.spark.Logging +import org.apache.spark.SparkEnv +import org.apache.spark.SparkContext._ + +import scala.collection.mutable.HashMap +import scala.collection.mutable.Queue + +import akka.actor._ +import akka.pattern.ask +import akka.util.duration._ +import akka.dispatch._ + +private[streaming] sealed trait NetworkInputTrackerMessage +private[streaming] case class RegisterReceiver(streamId: Int, receiverActor: ActorRef) extends NetworkInputTrackerMessage +private[streaming] case class AddBlocks(streamId: Int, blockIds: Seq[String], metadata: Any) extends NetworkInputTrackerMessage +private[streaming] case class DeregisterReceiver(streamId: Int, msg: String) extends NetworkInputTrackerMessage + +/** + * This class manages the execution of the receivers of NetworkInputDStreams. + */ +private[streaming] +class NetworkInputTracker( + @transient ssc: StreamingContext, + @transient networkInputStreams: Array[NetworkInputDStream[_]]) + extends Logging { + + val networkInputStreamMap = Map(networkInputStreams.map(x => (x.id, x)): _*) + val receiverExecutor = new ReceiverExecutor() + val receiverInfo = new HashMap[Int, ActorRef] + val receivedBlockIds = new HashMap[Int, Queue[String]] + val timeout = 5000.milliseconds + + var currentTime: Time = null + + /** Start the actor and receiver execution thread. */ + def start() { + ssc.env.actorSystem.actorOf(Props(new NetworkInputTrackerActor), "NetworkInputTracker") + receiverExecutor.start() + } + + /** Stop the receiver execution thread. */ + def stop() { + // TODO: stop the actor as well + receiverExecutor.interrupt() + receiverExecutor.stopReceivers() + } + + /** Return all the blocks received from a receiver. */ + def getBlockIds(receiverId: Int, time: Time): Array[String] = synchronized { + val queue = receivedBlockIds.synchronized { + receivedBlockIds.getOrElse(receiverId, new Queue[String]()) + } + val result = queue.synchronized { + queue.dequeueAll(x => true) + } + logInfo("Stream " + receiverId + " received " + result.size + " blocks") + result.toArray + } + + /** Actor to receive messages from the receivers. */ + private class NetworkInputTrackerActor extends Actor { + def receive = { + case RegisterReceiver(streamId, receiverActor) => { + if (!networkInputStreamMap.contains(streamId)) { + throw new Exception("Register received for unexpected id " + streamId) + } + receiverInfo += ((streamId, receiverActor)) + logInfo("Registered receiver for network stream " + streamId + " from " + sender.path.address) + sender ! true + } + case AddBlocks(streamId, blockIds, metadata) => { + val tmp = receivedBlockIds.synchronized { + if (!receivedBlockIds.contains(streamId)) { + receivedBlockIds += ((streamId, new Queue[String])) + } + receivedBlockIds(streamId) + } + tmp.synchronized { + tmp ++= blockIds + } + networkInputStreamMap(streamId).addMetadata(metadata) + } + case DeregisterReceiver(streamId, msg) => { + receiverInfo -= streamId + logError("De-registered receiver for network stream " + streamId + + " with message " + msg) + //TODO: Do something about the corresponding NetworkInputDStream + } + } + } + + /** This thread class runs all the receivers on the cluster. */ + class ReceiverExecutor extends Thread { + val env = ssc.env + + override def run() { + try { + SparkEnv.set(env) + startReceivers() + } catch { + case ie: InterruptedException => logInfo("ReceiverExecutor interrupted") + } finally { + stopReceivers() + } + } + + /** + * Get the receivers from the NetworkInputDStreams, distributes them to the + * worker nodes as a parallel collection, and runs them. + */ + def startReceivers() { + val receivers = networkInputStreams.map(nis => { + val rcvr = nis.getReceiver() + rcvr.setStreamId(nis.id) + rcvr + }) + + // Right now, we only honor preferences if all receivers have them + val hasLocationPreferences = receivers.map(_.getLocationPreference().isDefined).reduce(_ && _) + + // Create the parallel collection of receivers to distributed them on the worker nodes + val tempRDD = + if (hasLocationPreferences) { + val receiversWithPreferences = receivers.map(r => (r, Seq(r.getLocationPreference().toString))) + ssc.sc.makeRDD[NetworkReceiver[_]](receiversWithPreferences) + } + else { + ssc.sc.makeRDD(receivers, receivers.size) + } + + // Function to start the receiver on the worker node + val startReceiver = (iterator: Iterator[NetworkReceiver[_]]) => { + if (!iterator.hasNext) { + throw new Exception("Could not start receiver as details not found.") + } + iterator.next().start() + } + // Run the dummy Spark job to ensure that all slaves have registered. + // This avoids all the receivers to be scheduled on the same node. + ssc.sparkContext.makeRDD(1 to 50, 50).map(x => (x, 1)).reduceByKey(_ + _, 20).collect() + + // Distribute the receivers and start them + ssc.sparkContext.runJob(tempRDD, startReceiver) + } + + /** Stops the receivers. */ + def stopReceivers() { + // Signal the receivers to stop + receiverInfo.values.foreach(_ ! StopReceiver) + } + } +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/PairDStreamFunctions.scala b/streaming/src/main/scala/org/apache/spark/streaming/PairDStreamFunctions.scala new file mode 100644 index 0000000000..d8a7381e87 --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/PairDStreamFunctions.scala @@ -0,0 +1,534 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming + +import org.apache.spark.streaming.StreamingContext._ +import org.apache.spark.streaming.dstream.{ReducedWindowedDStream, StateDStream} +import org.apache.spark.streaming.dstream.{CoGroupedDStream, ShuffledDStream} +import org.apache.spark.streaming.dstream.{MapValuedDStream, FlatMapValuedDStream} + +import org.apache.spark.{Manifests, RDD, Partitioner, HashPartitioner} +import org.apache.spark.SparkContext._ +import org.apache.spark.storage.StorageLevel + +import scala.collection.mutable.ArrayBuffer + +import org.apache.hadoop.mapred.{JobConf, OutputFormat} +import org.apache.hadoop.mapreduce.{OutputFormat => NewOutputFormat} +import org.apache.hadoop.mapred.OutputFormat +import org.apache.hadoop.security.UserGroupInformation +import org.apache.hadoop.conf.Configuration + +class PairDStreamFunctions[K: ClassManifest, V: ClassManifest](self: DStream[(K,V)]) +extends Serializable { + + private[streaming] def ssc = self.ssc + + private[streaming] def defaultPartitioner(numPartitions: Int = self.ssc.sc.defaultParallelism) = { + new HashPartitioner(numPartitions) + } + + /** + * Return a new DStream by applying `groupByKey` to each RDD. Hash partitioning is used to + * generate the RDDs with Spark's default number of partitions. + */ + def groupByKey(): DStream[(K, Seq[V])] = { + groupByKey(defaultPartitioner()) + } + + /** + * Return a new DStream by applying `groupByKey` to each RDD. Hash partitioning is used to + * generate the RDDs with `numPartitions` partitions. + */ + def groupByKey(numPartitions: Int): DStream[(K, Seq[V])] = { + groupByKey(defaultPartitioner(numPartitions)) + } + + /** + * Return a new DStream by applying `groupByKey` on each RDD. The supplied [[org.apache.spark.Partitioner]] + * is used to control the partitioning of each RDD. + */ + def groupByKey(partitioner: Partitioner): DStream[(K, Seq[V])] = { + val createCombiner = (v: V) => ArrayBuffer[V](v) + val mergeValue = (c: ArrayBuffer[V], v: V) => (c += v) + val mergeCombiner = (c1: ArrayBuffer[V], c2: ArrayBuffer[V]) => (c1 ++ c2) + combineByKey(createCombiner, mergeValue, mergeCombiner, partitioner) + .asInstanceOf[DStream[(K, Seq[V])]] + } + + /** + * Return a new DStream by applying `reduceByKey` to each RDD. The values for each key are + * merged using the associative reduce function. Hash partitioning is used to generate the RDDs + * with Spark's default number of partitions. + */ + def reduceByKey(reduceFunc: (V, V) => V): DStream[(K, V)] = { + reduceByKey(reduceFunc, defaultPartitioner()) + } + + /** + * Return a new DStream by applying `reduceByKey` to each RDD. The values for each key are + * merged using the supplied reduce function. Hash partitioning is used to generate the RDDs + * with `numPartitions` partitions. + */ + def reduceByKey(reduceFunc: (V, V) => V, numPartitions: Int): DStream[(K, V)] = { + reduceByKey(reduceFunc, defaultPartitioner(numPartitions)) + } + + /** + * Return a new DStream by applying `reduceByKey` to each RDD. The values for each key are + * merged using the supplied reduce function. [[org.apache.spark.Partitioner]] is used to control the + * partitioning of each RDD. + */ + def reduceByKey(reduceFunc: (V, V) => V, partitioner: Partitioner): DStream[(K, V)] = { + val cleanedReduceFunc = ssc.sc.clean(reduceFunc) + combineByKey((v: V) => v, cleanedReduceFunc, cleanedReduceFunc, partitioner) + } + + /** + * Combine elements of each key in DStream's RDDs using custom functions. This is similar to the + * combineByKey for RDDs. Please refer to combineByKey in [[org.apache.spark.PairRDDFunctions]] for more + * information. + */ + def combineByKey[C: ClassManifest]( + createCombiner: V => C, + mergeValue: (C, V) => C, + mergeCombiner: (C, C) => C, + partitioner: Partitioner) : DStream[(K, C)] = { + new ShuffledDStream[K, V, C](self, createCombiner, mergeValue, mergeCombiner, partitioner) + } + + /** + * Return a new DStream by applying `groupByKey` over a sliding window. This is similar to + * `DStream.groupByKey()` but applies it over a sliding window. The new DStream generates RDDs + * with the same interval as this DStream. Hash partitioning is used to generate the RDDs with + * Spark's default number of partitions. + * @param windowDuration width of the window; must be a multiple of this DStream's + * batching interval + */ + def groupByKeyAndWindow(windowDuration: Duration): DStream[(K, Seq[V])] = { + groupByKeyAndWindow(windowDuration, self.slideDuration, defaultPartitioner()) + } + + /** + * Return a new DStream by applying `groupByKey` over a sliding window. Similar to + * `DStream.groupByKey()`, but applies it over a sliding window. Hash partitioning is used to + * generate the RDDs with Spark's default number of partitions. + * @param windowDuration width of the window; must be a multiple of this DStream's + * batching interval + * @param slideDuration sliding interval of the window (i.e., the interval after which + * the new DStream will generate RDDs); must be a multiple of this + * DStream's batching interval + */ + def groupByKeyAndWindow(windowDuration: Duration, slideDuration: Duration): DStream[(K, Seq[V])] = { + groupByKeyAndWindow(windowDuration, slideDuration, defaultPartitioner()) + } + + /** + * Return a new DStream by applying `groupByKey` over a sliding window on `this` DStream. + * Similar to `DStream.groupByKey()`, but applies it over a sliding window. + * Hash partitioning is used to generate the RDDs with `numPartitions` partitions. + * @param windowDuration width of the window; must be a multiple of this DStream's + * batching interval + * @param slideDuration sliding interval of the window (i.e., the interval after which + * the new DStream will generate RDDs); must be a multiple of this + * DStream's batching interval + * @param numPartitions number of partitions of each RDD in the new DStream; if not specified + * then Spark's default number of partitions will be used + */ + def groupByKeyAndWindow( + windowDuration: Duration, + slideDuration: Duration, + numPartitions: Int + ): DStream[(K, Seq[V])] = { + groupByKeyAndWindow(windowDuration, slideDuration, defaultPartitioner(numPartitions)) + } + + /** + * Create a new DStream by applying `groupByKey` over a sliding window on `this` DStream. + * Similar to `DStream.groupByKey()`, but applies it over a sliding window. + * @param windowDuration width of the window; must be a multiple of this DStream's + * batching interval + * @param slideDuration sliding interval of the window (i.e., the interval after which + * the new DStream will generate RDDs); must be a multiple of this + * DStream's batching interval + * @param partitioner partitioner for controlling the partitioning of each RDD in the new DStream. + */ + def groupByKeyAndWindow( + windowDuration: Duration, + slideDuration: Duration, + partitioner: Partitioner + ): DStream[(K, Seq[V])] = { + self.window(windowDuration, slideDuration).groupByKey(partitioner) + } + + /** + * Return a new DStream by applying `reduceByKey` over a sliding window on `this` DStream. + * Similar to `DStream.reduceByKey()`, but applies it over a sliding window. The new DStream + * generates RDDs with the same interval as this DStream. Hash partitioning is used to generate + * the RDDs with Spark's default number of partitions. + * @param reduceFunc associative reduce function + * @param windowDuration width of the window; must be a multiple of this DStream's + * batching interval + */ + def reduceByKeyAndWindow( + reduceFunc: (V, V) => V, + windowDuration: Duration + ): DStream[(K, V)] = { + reduceByKeyAndWindow(reduceFunc, windowDuration, self.slideDuration, defaultPartitioner()) + } + + /** + * Return a new DStream by applying `reduceByKey` over a sliding window. This is similar to + * `DStream.reduceByKey()` but applies it over a sliding window. Hash partitioning is used to + * generate the RDDs with Spark's default number of partitions. + * @param reduceFunc associative reduce function + * @param windowDuration width of the window; must be a multiple of this DStream's + * batching interval + * @param slideDuration sliding interval of the window (i.e., the interval after which + * the new DStream will generate RDDs); must be a multiple of this + * DStream's batching interval + */ + def reduceByKeyAndWindow( + reduceFunc: (V, V) => V, + windowDuration: Duration, + slideDuration: Duration + ): DStream[(K, V)] = { + reduceByKeyAndWindow(reduceFunc, windowDuration, slideDuration, defaultPartitioner()) + } + + /** + * Return a new DStream by applying `reduceByKey` over a sliding window. This is similar to + * `DStream.reduceByKey()` but applies it over a sliding window. Hash partitioning is used to + * generate the RDDs with `numPartitions` partitions. + * @param reduceFunc associative reduce function + * @param windowDuration width of the window; must be a multiple of this DStream's + * batching interval + * @param slideDuration sliding interval of the window (i.e., the interval after which + * the new DStream will generate RDDs); must be a multiple of this + * DStream's batching interval + * @param numPartitions number of partitions of each RDD in the new DStream. + */ + def reduceByKeyAndWindow( + reduceFunc: (V, V) => V, + windowDuration: Duration, + slideDuration: Duration, + numPartitions: Int + ): DStream[(K, V)] = { + reduceByKeyAndWindow(reduceFunc, windowDuration, slideDuration, defaultPartitioner(numPartitions)) + } + + /** + * Return a new DStream by applying `reduceByKey` over a sliding window. Similar to + * `DStream.reduceByKey()`, but applies it over a sliding window. + * @param reduceFunc associative reduce function + * @param windowDuration width of the window; must be a multiple of this DStream's + * batching interval + * @param slideDuration sliding interval of the window (i.e., the interval after which + * the new DStream will generate RDDs); must be a multiple of this + * DStream's batching interval + * @param partitioner partitioner for controlling the partitioning of each RDD + * in the new DStream. + */ + def reduceByKeyAndWindow( + reduceFunc: (V, V) => V, + windowDuration: Duration, + slideDuration: Duration, + partitioner: Partitioner + ): DStream[(K, V)] = { + val cleanedReduceFunc = ssc.sc.clean(reduceFunc) + self.reduceByKey(cleanedReduceFunc, partitioner) + .window(windowDuration, slideDuration) + .reduceByKey(cleanedReduceFunc, partitioner) + } + + /** + * Return a new DStream by applying incremental `reduceByKey` over a sliding window. + * The reduced value of over a new window is calculated using the old window's reduced value : + * 1. reduce the new values that entered the window (e.g., adding new counts) + * + * 2. "inverse reduce" the old values that left the window (e.g., subtracting old counts) + * + * This is more efficient than reduceByKeyAndWindow without "inverse reduce" function. + * However, it is applicable to only "invertible reduce functions". + * Hash partitioning is used to generate the RDDs with Spark's default number of partitions. + * @param reduceFunc associative reduce function + * @param invReduceFunc inverse reduce function + * @param windowDuration width of the window; must be a multiple of this DStream's + * batching interval + * @param slideDuration sliding interval of the window (i.e., the interval after which + * the new DStream will generate RDDs); must be a multiple of this + * DStream's batching interval + * @param filterFunc Optional function to filter expired key-value pairs; + * only pairs that satisfy the function are retained + */ + def reduceByKeyAndWindow( + reduceFunc: (V, V) => V, + invReduceFunc: (V, V) => V, + windowDuration: Duration, + slideDuration: Duration = self.slideDuration, + numPartitions: Int = ssc.sc.defaultParallelism, + filterFunc: ((K, V)) => Boolean = null + ): DStream[(K, V)] = { + + reduceByKeyAndWindow( + reduceFunc, invReduceFunc, windowDuration, + slideDuration, defaultPartitioner(numPartitions), filterFunc + ) + } + + /** + * Return a new DStream by applying incremental `reduceByKey` over a sliding window. + * The reduced value of over a new window is calculated using the old window's reduced value : + * 1. reduce the new values that entered the window (e.g., adding new counts) + * 2. "inverse reduce" the old values that left the window (e.g., subtracting old counts) + * This is more efficient than reduceByKeyAndWindow without "inverse reduce" function. + * However, it is applicable to only "invertible reduce functions". + * @param reduceFunc associative reduce function + * @param invReduceFunc inverse reduce function + * @param windowDuration width of the window; must be a multiple of this DStream's + * batching interval + * @param slideDuration sliding interval of the window (i.e., the interval after which + * the new DStream will generate RDDs); must be a multiple of this + * DStream's batching interval + * @param partitioner partitioner for controlling the partitioning of each RDD in the new DStream. + * @param filterFunc Optional function to filter expired key-value pairs; + * only pairs that satisfy the function are retained + */ + def reduceByKeyAndWindow( + reduceFunc: (V, V) => V, + invReduceFunc: (V, V) => V, + windowDuration: Duration, + slideDuration: Duration, + partitioner: Partitioner, + filterFunc: ((K, V)) => Boolean + ): DStream[(K, V)] = { + + val cleanedReduceFunc = ssc.sc.clean(reduceFunc) + val cleanedInvReduceFunc = ssc.sc.clean(invReduceFunc) + val cleanedFilterFunc = if (filterFunc != null) Some(ssc.sc.clean(filterFunc)) else None + new ReducedWindowedDStream[K, V]( + self, cleanedReduceFunc, cleanedInvReduceFunc, cleanedFilterFunc, + windowDuration, slideDuration, partitioner + ) + } + + /** + * Return a new "state" DStream where the state for each key is updated by applying + * the given function on the previous state of the key and the new values of each key. + * Hash partitioning is used to generate the RDDs with Spark's default number of partitions. + * @param updateFunc State update function. If `this` function returns None, then + * corresponding state key-value pair will be eliminated. + * @tparam S State type + */ + def updateStateByKey[S: ClassManifest]( + updateFunc: (Seq[V], Option[S]) => Option[S] + ): DStream[(K, S)] = { + updateStateByKey(updateFunc, defaultPartitioner()) + } + + /** + * Return a new "state" DStream where the state for each key is updated by applying + * the given function on the previous state of the key and the new values of each key. + * Hash partitioning is used to generate the RDDs with `numPartitions` partitions. + * @param updateFunc State update function. If `this` function returns None, then + * corresponding state key-value pair will be eliminated. + * @param numPartitions Number of partitions of each RDD in the new DStream. + * @tparam S State type + */ + def updateStateByKey[S: ClassManifest]( + updateFunc: (Seq[V], Option[S]) => Option[S], + numPartitions: Int + ): DStream[(K, S)] = { + updateStateByKey(updateFunc, defaultPartitioner(numPartitions)) + } + + /** + * Create a new "state" DStream where the state for each key is updated by applying + * the given function on the previous state of the key and the new values of the key. + * [[org.apache.spark.Partitioner]] is used to control the partitioning of each RDD. + * @param updateFunc State update function. If `this` function returns None, then + * corresponding state key-value pair will be eliminated. + * @param partitioner Partitioner for controlling the partitioning of each RDD in the new DStream. + * @tparam S State type + */ + def updateStateByKey[S: ClassManifest]( + updateFunc: (Seq[V], Option[S]) => Option[S], + partitioner: Partitioner + ): DStream[(K, S)] = { + val newUpdateFunc = (iterator: Iterator[(K, Seq[V], Option[S])]) => { + iterator.flatMap(t => updateFunc(t._2, t._3).map(s => (t._1, s))) + } + updateStateByKey(newUpdateFunc, partitioner, true) + } + + /** + * Return a new "state" DStream where the state for each key is updated by applying + * the given function on the previous state of the key and the new values of each key. + * [[org.apache.spark.Paxrtitioner]] is used to control the partitioning of each RDD. + * @param updateFunc State update function. If `this` function returns None, then + * corresponding state key-value pair will be eliminated. Note, that + * this function may generate a different a tuple with a different key + * than the input key. It is up to the developer to decide whether to + * remember the partitioner despite the key being changed. + * @param partitioner Partitioner for controlling the partitioning of each RDD in the new DStream. + * @param rememberPartitioner Whether to remember the paritioner object in the generated RDDs. + * @tparam S State type + */ + def updateStateByKey[S: ClassManifest]( + updateFunc: (Iterator[(K, Seq[V], Option[S])]) => Iterator[(K, S)], + partitioner: Partitioner, + rememberPartitioner: Boolean + ): DStream[(K, S)] = { + new StateDStream(self, ssc.sc.clean(updateFunc), partitioner, rememberPartitioner) + } + + + def mapValues[U: ClassManifest](mapValuesFunc: V => U): DStream[(K, U)] = { + new MapValuedDStream[K, V, U](self, mapValuesFunc) + } + + def flatMapValues[U: ClassManifest]( + flatMapValuesFunc: V => TraversableOnce[U] + ): DStream[(K, U)] = { + new FlatMapValuedDStream[K, V, U](self, flatMapValuesFunc) + } + + /** + * Cogroup `this` DStream with `other` DStream. For each key k in corresponding RDDs of `this` + * or `other` DStreams, the generated RDD will contains a tuple with the list of values for that + * key in both RDDs. HashPartitioner is used to partition each generated RDD into default number + * of partitions. + */ + def cogroup[W: ClassManifest](other: DStream[(K, W)]): DStream[(K, (Seq[V], Seq[W]))] = { + cogroup(other, defaultPartitioner()) + } + + /** + * Cogroup `this` DStream with `other` DStream using a partitioner. For each key k in corresponding RDDs of `this` + * or `other` DStreams, the generated RDD will contains a tuple with the list of values for that + * key in both RDDs. Partitioner is used to partition each generated RDD. + */ + def cogroup[W: ClassManifest]( + other: DStream[(K, W)], + partitioner: Partitioner + ): DStream[(K, (Seq[V], Seq[W]))] = { + + val cgd = new CoGroupedDStream[K]( + Seq(self.asInstanceOf[DStream[(K, _)]], other.asInstanceOf[DStream[(K, _)]]), + partitioner + ) + val pdfs = new PairDStreamFunctions[K, Seq[Seq[_]]](cgd)( + classManifest[K], + Manifests.seqSeqManifest + ) + pdfs.mapValues { + case Seq(vs, ws) => + (vs.asInstanceOf[Seq[V]], ws.asInstanceOf[Seq[W]]) + } + } + + /** + * Join `this` DStream with `other` DStream. HashPartitioner is used + * to partition each generated RDD into default number of partitions. + */ + def join[W: ClassManifest](other: DStream[(K, W)]): DStream[(K, (V, W))] = { + join[W](other, defaultPartitioner()) + } + + /** + * Join `this` DStream with `other` DStream, that is, each RDD of the new DStream will + * be generated by joining RDDs from `this` and other DStream. Uses the given + * Partitioner to partition each generated RDD. + */ + def join[W: ClassManifest]( + other: DStream[(K, W)], + partitioner: Partitioner + ): DStream[(K, (V, W))] = { + this.cogroup(other, partitioner) + .flatMapValues{ + case (vs, ws) => + for (v <- vs.iterator; w <- ws.iterator) yield (v, w) + } + } + + /** + * Save each RDD in `this` DStream as a Hadoop file. The file name at each batch interval is generated + * based on `prefix` and `suffix`: "prefix-TIME_IN_MS.suffix" + */ + def saveAsHadoopFiles[F <: OutputFormat[K, V]]( + prefix: String, + suffix: String + )(implicit fm: ClassManifest[F]) { + saveAsHadoopFiles(prefix, suffix, getKeyClass, getValueClass, fm.erasure.asInstanceOf[Class[F]]) + } + + /** + * Save each RDD in `this` DStream as a Hadoop file. The file name at each batch interval is generated + * based on `prefix` and `suffix`: "prefix-TIME_IN_MS.suffix" + */ + def saveAsHadoopFiles( + prefix: String, + suffix: String, + keyClass: Class[_], + valueClass: Class[_], + outputFormatClass: Class[_ <: OutputFormat[_, _]], + conf: JobConf = new JobConf + ) { + val saveFunc = (rdd: RDD[(K, V)], time: Time) => { + val file = rddToFileName(prefix, suffix, time) + rdd.saveAsHadoopFile(file, keyClass, valueClass, outputFormatClass, conf) + } + self.foreach(saveFunc) + } + + /** + * Save each RDD in `this` DStream as a Hadoop file. The file name at each batch interval is + * generated based on `prefix` and `suffix`: "prefix-TIME_IN_MS.suffix". + */ + def saveAsNewAPIHadoopFiles[F <: NewOutputFormat[K, V]]( + prefix: String, + suffix: String + )(implicit fm: ClassManifest[F]) { + saveAsNewAPIHadoopFiles(prefix, suffix, getKeyClass, getValueClass, fm.erasure.asInstanceOf[Class[F]]) + } + + /** + * Save each RDD in `this` DStream as a Hadoop file. The file name at each batch interval is + * generated based on `prefix` and `suffix`: "prefix-TIME_IN_MS.suffix". + */ + def saveAsNewAPIHadoopFiles( + prefix: String, + suffix: String, + keyClass: Class[_], + valueClass: Class[_], + outputFormatClass: Class[_ <: NewOutputFormat[_, _]], + conf: Configuration = new Configuration + ) { + val saveFunc = (rdd: RDD[(K, V)], time: Time) => { + val file = rddToFileName(prefix, suffix, time) + rdd.saveAsNewAPIHadoopFile(file, keyClass, valueClass, outputFormatClass, conf) + } + self.foreach(saveFunc) + } + + private def getKeyClass() = implicitly[ClassManifest[K]].erasure + + private def getValueClass() = implicitly[ClassManifest[V]].erasure +} + + diff --git a/streaming/src/main/scala/org/apache/spark/streaming/Scheduler.scala b/streaming/src/main/scala/org/apache/spark/streaming/Scheduler.scala new file mode 100644 index 0000000000..ed892e33e6 --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/Scheduler.scala @@ -0,0 +1,131 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming + +import util.{ManualClock, RecurringTimer, Clock} +import org.apache.spark.SparkEnv +import org.apache.spark.Logging + +private[streaming] +class Scheduler(ssc: StreamingContext) extends Logging { + + initLogging() + + val concurrentJobs = System.getProperty("spark.streaming.concurrentJobs", "1").toInt + val jobManager = new JobManager(ssc, concurrentJobs) + val checkpointWriter = if (ssc.checkpointDuration != null && ssc.checkpointDir != null) { + new CheckpointWriter(ssc.checkpointDir) + } else { + null + } + + val clockClass = System.getProperty( + "spark.streaming.clock", "org.apache.spark.streaming.util.SystemClock") + val clock = Class.forName(clockClass).newInstance().asInstanceOf[Clock] + val timer = new RecurringTimer(clock, ssc.graph.batchDuration.milliseconds, + longTime => generateJobs(new Time(longTime))) + val graph = ssc.graph + var latestTime: Time = null + + def start() = synchronized { + if (ssc.isCheckpointPresent) { + restart() + } else { + startFirstTime() + } + logInfo("Scheduler started") + } + + def stop() = synchronized { + timer.stop() + jobManager.stop() + if (checkpointWriter != null) checkpointWriter.stop() + ssc.graph.stop() + logInfo("Scheduler stopped") + } + + private def startFirstTime() { + val startTime = new Time(timer.getStartTime()) + graph.start(startTime - graph.batchDuration) + timer.start(startTime.milliseconds) + logInfo("Scheduler's timer started at " + startTime) + } + + private def restart() { + + // If manual clock is being used for testing, then + // either set the manual clock to the last checkpointed time, + // or if the property is defined set it to that time + if (clock.isInstanceOf[ManualClock]) { + val lastTime = ssc.initialCheckpoint.checkpointTime.milliseconds + val jumpTime = System.getProperty("spark.streaming.manualClock.jump", "0").toLong + clock.asInstanceOf[ManualClock].setTime(lastTime + jumpTime) + } + + val batchDuration = ssc.graph.batchDuration + + // Batches when the master was down, that is, + // between the checkpoint and current restart time + val checkpointTime = ssc.initialCheckpoint.checkpointTime + val restartTime = new Time(timer.getRestartTime(graph.zeroTime.milliseconds)) + val downTimes = checkpointTime.until(restartTime, batchDuration) + logInfo("Batches during down time: " + downTimes.mkString(", ")) + + // Batches that were unprocessed before failure + val pendingTimes = ssc.initialCheckpoint.pendingTimes + logInfo("Batches pending processing: " + pendingTimes.mkString(", ")) + // Reschedule jobs for these times + val timesToReschedule = (pendingTimes ++ downTimes).distinct.sorted(Time.ordering) + logInfo("Batches to reschedule: " + timesToReschedule.mkString(", ")) + timesToReschedule.foreach(time => + graph.generateJobs(time).foreach(jobManager.runJob) + ) + + // Restart the timer + timer.start(restartTime.milliseconds) + logInfo("Scheduler's timer restarted at " + restartTime) + } + + /** Generate jobs and perform checkpoint for the given `time`. */ + def generateJobs(time: Time) { + SparkEnv.set(ssc.env) + logInfo("\n-----------------------------------------------------\n") + graph.generateJobs(time).foreach(jobManager.runJob) + latestTime = time + doCheckpoint(time) + } + + /** + * Clear old metadata assuming jobs of `time` have finished processing. + * And also perform checkpoint. + */ + def clearOldMetadata(time: Time) { + ssc.graph.clearOldMetadata(time) + doCheckpoint(time) + } + + /** Perform checkpoint for the give `time`. */ + def doCheckpoint(time: Time) = synchronized { + if (ssc.checkpointDuration != null && (time - graph.zeroTime).isMultipleOf(ssc.checkpointDuration)) { + logInfo("Checkpointing graph for time " + time) + ssc.graph.updateCheckpointData(time) + checkpointWriter.write(new Checkpoint(ssc, time)) + } + } +} + diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala new file mode 100644 index 0000000000..3852ac2dab --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala @@ -0,0 +1,563 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming + +import akka.actor.Props +import akka.actor.SupervisorStrategy +import akka.zeromq.Subscribe + +import org.apache.spark.streaming.dstream._ + +import org.apache.spark._ +import org.apache.spark.streaming.receivers.ActorReceiver +import org.apache.spark.streaming.receivers.ReceiverSupervisorStrategy +import org.apache.spark.streaming.receivers.ZeroMQReceiver +import org.apache.spark.storage.StorageLevel +import org.apache.spark.util.MetadataCleaner +import org.apache.spark.streaming.receivers.ActorReceiver + +import scala.collection.mutable.Queue +import scala.collection.Map + +import java.io.InputStream +import java.util.concurrent.atomic.AtomicInteger +import java.util.UUID + +import org.apache.hadoop.io.LongWritable +import org.apache.hadoop.io.Text +import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat} +import org.apache.hadoop.mapreduce.lib.input.TextInputFormat +import org.apache.hadoop.fs.Path +import twitter4j.Status +import twitter4j.auth.Authorization + + +/** + * A StreamingContext is the main entry point for Spark Streaming functionality. Besides the basic + * information (such as, cluster URL and job name) to internally create a SparkContext, it provides + * methods used to create DStream from various input sources. + */ +class StreamingContext private ( + sc_ : SparkContext, + cp_ : Checkpoint, + batchDur_ : Duration + ) extends Logging { + + /** + * Create a StreamingContext using an existing SparkContext. + * @param sparkContext Existing SparkContext + * @param batchDuration The time interval at which streaming data will be divided into batches + */ + def this(sparkContext: SparkContext, batchDuration: Duration) = { + this(sparkContext, null, batchDuration) + } + + /** + * Create a StreamingContext by providing the details necessary for creating a new SparkContext. + * @param master Cluster URL to connect to (e.g. mesos://host:port, spark://host:port, local[4]). + * @param appName A name for your job, to display on the cluster web UI + * @param batchDuration The time interval at which streaming data will be divided into batches + */ + def this( + master: String, + appName: String, + batchDuration: Duration, + sparkHome: String = null, + jars: Seq[String] = Nil, + environment: Map[String, String] = Map()) = { + this(StreamingContext.createNewSparkContext(master, appName, sparkHome, jars, environment), + null, batchDuration) + } + + + /** + * Re-create a StreamingContext from a checkpoint file. + * @param path Path either to the directory that was specified as the checkpoint directory, or + * to the checkpoint file 'graph' or 'graph.bk'. + */ + def this(path: String) = this(null, CheckpointReader.read(path), null) + + initLogging() + + if (sc_ == null && cp_ == null) { + throw new Exception("Spark Streaming cannot be initialized with " + + "both SparkContext and checkpoint as null") + } + + if (MetadataCleaner.getDelaySeconds < 0) { + throw new SparkException("Spark Streaming cannot be used without setting spark.cleaner.ttl; " + + "set this property before creating a SparkContext (use SPARK_JAVA_OPTS for the shell)") + } + + protected[streaming] val isCheckpointPresent = (cp_ != null) + + protected[streaming] val sc: SparkContext = { + if (isCheckpointPresent) { + new SparkContext(cp_.master, cp_.framework, cp_.sparkHome, cp_.jars, cp_.environment) + } else { + sc_ + } + } + + protected[streaming] val env = SparkEnv.get + + protected[streaming] val graph: DStreamGraph = { + if (isCheckpointPresent) { + cp_.graph.setContext(this) + cp_.graph.restoreCheckpointData() + cp_.graph + } else { + assert(batchDur_ != null, "Batch duration for streaming context cannot be null") + val newGraph = new DStreamGraph() + newGraph.setBatchDuration(batchDur_) + newGraph + } + } + + protected[streaming] val nextNetworkInputStreamId = new AtomicInteger(0) + protected[streaming] var networkInputTracker: NetworkInputTracker = null + + protected[streaming] var checkpointDir: String = { + if (isCheckpointPresent) { + sc.setCheckpointDir(StreamingContext.getSparkCheckpointDir(cp_.checkpointDir), true) + cp_.checkpointDir + } else { + null + } + } + + protected[streaming] var checkpointDuration: Duration = if (isCheckpointPresent) cp_.checkpointDuration else null + protected[streaming] var receiverJobThread: Thread = null + protected[streaming] var scheduler: Scheduler = null + + /** + * Return the associated Spark context + */ + def sparkContext = sc + + /** + * Set each DStreams in this context to remember RDDs it generated in the last given duration. + * DStreams remember RDDs only for a limited duration of time and releases them for garbage + * collection. This method allows the developer to specify how to long to remember the RDDs ( + * if the developer wishes to query old data outside the DStream computation). + * @param duration Minimum duration that each DStream should remember its RDDs + */ + def remember(duration: Duration) { + graph.remember(duration) + } + + /** + * Set the context to periodically checkpoint the DStream operations for master + * fault-tolerance. The graph will be checkpointed every batch interval. + * @param directory HDFS-compatible directory where the checkpoint data will be reliably stored + */ + def checkpoint(directory: String) { + if (directory != null) { + sc.setCheckpointDir(StreamingContext.getSparkCheckpointDir(directory)) + checkpointDir = directory + } else { + checkpointDir = null + } + } + + protected[streaming] def initialCheckpoint: Checkpoint = { + if (isCheckpointPresent) cp_ else null + } + + protected[streaming] def getNewNetworkStreamId() = nextNetworkInputStreamId.getAndIncrement() + + /** + * Create an input stream with any arbitrary user implemented network receiver. + * Find more details at: http://spark-project.org/docs/latest/streaming-custom-receivers.html + * @param receiver Custom implementation of NetworkReceiver + */ + def networkStream[T: ClassManifest]( + receiver: NetworkReceiver[T]): DStream[T] = { + val inputStream = new PluggableInputDStream[T](this, + receiver) + graph.addInputStream(inputStream) + inputStream + } + + /** + * Create an input stream with any arbitrary user implemented actor receiver. + * Find more details at: http://spark-project.org/docs/latest/streaming-custom-receivers.html + * @param props Props object defining creation of the actor + * @param name Name of the actor + * @param storageLevel RDD storage level. Defaults to memory-only. + * + * @note An important point to note: + * Since Actor may exist outside the spark framework, It is thus user's responsibility + * to ensure the type safety, i.e parametrized type of data received and actorStream + * should be same. + */ + def actorStream[T: ClassManifest]( + props: Props, + name: String, + storageLevel: StorageLevel = StorageLevel.MEMORY_ONLY_SER_2, + supervisorStrategy: SupervisorStrategy = ReceiverSupervisorStrategy.defaultStrategy + ): DStream[T] = { + networkStream(new ActorReceiver[T](props, name, storageLevel, supervisorStrategy)) + } + + /** + * Create an input stream that receives messages pushed by a zeromq publisher. + * @param publisherUrl Url of remote zeromq publisher + * @param subscribe topic to subscribe to + * @param bytesToObjects A zeroMQ stream publishes sequence of frames for each topic + * and each frame has sequence of byte thus it needs the converter + * (which might be deserializer of bytes) to translate from sequence + * of sequence of bytes, where sequence refer to a frame + * and sub sequence refer to its payload. + * @param storageLevel RDD storage level. Defaults to memory-only. + */ + def zeroMQStream[T: ClassManifest]( + publisherUrl:String, + subscribe: Subscribe, + bytesToObjects: Seq[Seq[Byte]] ⇒ Iterator[T], + storageLevel: StorageLevel = StorageLevel.MEMORY_ONLY_SER_2, + supervisorStrategy: SupervisorStrategy = ReceiverSupervisorStrategy.defaultStrategy + ): DStream[T] = { + actorStream(Props(new ZeroMQReceiver(publisherUrl,subscribe,bytesToObjects)), + "ZeroMQReceiver", storageLevel, supervisorStrategy) + } + + /** + * Create an input stream that pulls messages from a Kafka Broker. + * @param zkQuorum Zookeper quorum (hostname:port,hostname:port,..). + * @param groupId The group id for this consumer. + * @param topics Map of (topic_name -> numPartitions) to consume. Each partition is consumed + * in its own thread. + * @param storageLevel Storage level to use for storing the received objects + * (default: StorageLevel.MEMORY_AND_DISK_SER_2) + */ + def kafkaStream( + zkQuorum: String, + groupId: String, + topics: Map[String, Int], + storageLevel: StorageLevel = StorageLevel.MEMORY_ONLY_SER_2 + ): DStream[String] = { + val kafkaParams = Map[String, String]( + "zk.connect" -> zkQuorum, "groupid" -> groupId, "zk.connectiontimeout.ms" -> "10000") + kafkaStream[String, kafka.serializer.StringDecoder](kafkaParams, topics, storageLevel) + } + + /** + * Create an input stream that pulls messages from a Kafka Broker. + * @param kafkaParams Map of kafka configuration paramaters. + * See: http://kafka.apache.org/configuration.html + * @param topics Map of (topic_name -> numPartitions) to consume. Each partition is consumed + * in its own thread. + * @param storageLevel Storage level to use for storing the received objects + */ + def kafkaStream[T: ClassManifest, D <: kafka.serializer.Decoder[_]: Manifest]( + kafkaParams: Map[String, String], + topics: Map[String, Int], + storageLevel: StorageLevel + ): DStream[T] = { + val inputStream = new KafkaInputDStream[T, D](this, kafkaParams, topics, storageLevel) + registerInputStream(inputStream) + inputStream + } + + /** + * Create a input stream from TCP source hostname:port. Data is received using + * a TCP socket and the receive bytes is interpreted as UTF8 encoded `\n` delimited + * lines. + * @param hostname Hostname to connect to for receiving data + * @param port Port to connect to for receiving data + * @param storageLevel Storage level to use for storing the received objects + * (default: StorageLevel.MEMORY_AND_DISK_SER_2) + */ + def socketTextStream( + hostname: String, + port: Int, + storageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK_SER_2 + ): DStream[String] = { + socketStream[String](hostname, port, SocketReceiver.bytesToLines, storageLevel) + } + + /** + * Create a input stream from TCP source hostname:port. Data is received using + * a TCP socket and the receive bytes it interepreted as object using the given + * converter. + * @param hostname Hostname to connect to for receiving data + * @param port Port to connect to for receiving data + * @param converter Function to convert the byte stream to objects + * @param storageLevel Storage level to use for storing the received objects + * @tparam T Type of the objects received (after converting bytes to objects) + */ + def socketStream[T: ClassManifest]( + hostname: String, + port: Int, + converter: (InputStream) => Iterator[T], + storageLevel: StorageLevel + ): DStream[T] = { + val inputStream = new SocketInputDStream[T](this, hostname, port, converter, storageLevel) + registerInputStream(inputStream) + inputStream + } + + /** + * Create a input stream from a Flume source. + * @param hostname Hostname of the slave machine to which the flume data will be sent + * @param port Port of the slave machine to which the flume data will be sent + * @param storageLevel Storage level to use for storing the received objects + */ + def flumeStream ( + hostname: String, + port: Int, + storageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK_SER_2 + ): DStream[SparkFlumeEvent] = { + val inputStream = new FlumeInputDStream(this, hostname, port, storageLevel) + registerInputStream(inputStream) + inputStream + } + + /** + * Create a input stream from network source hostname:port, where data is received + * as serialized blocks (serialized using the Spark's serializer) that can be directly + * pushed into the block manager without deserializing them. This is the most efficient + * way to receive data. + * @param hostname Hostname to connect to for receiving data + * @param port Port to connect to for receiving data + * @param storageLevel Storage level to use for storing the received objects + * @tparam T Type of the objects in the received blocks + */ + def rawSocketStream[T: ClassManifest]( + hostname: String, + port: Int, + storageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK_SER_2 + ): DStream[T] = { + val inputStream = new RawInputDStream[T](this, hostname, port, storageLevel) + registerInputStream(inputStream) + inputStream + } + + /** + * Create a input stream that monitors a Hadoop-compatible filesystem + * for new files and reads them using the given key-value types and input format. + * File names starting with . are ignored. + * @param directory HDFS directory to monitor for new file + * @tparam K Key type for reading HDFS file + * @tparam V Value type for reading HDFS file + * @tparam F Input format for reading HDFS file + */ + def fileStream[ + K: ClassManifest, + V: ClassManifest, + F <: NewInputFormat[K, V]: ClassManifest + ] (directory: String): DStream[(K, V)] = { + val inputStream = new FileInputDStream[K, V, F](this, directory) + registerInputStream(inputStream) + inputStream + } + + /** + * Create a input stream that monitors a Hadoop-compatible filesystem + * for new files and reads them using the given key-value types and input format. + * @param directory HDFS directory to monitor for new file + * @param filter Function to filter paths to process + * @param newFilesOnly Should process only new files and ignore existing files in the directory + * @tparam K Key type for reading HDFS file + * @tparam V Value type for reading HDFS file + * @tparam F Input format for reading HDFS file + */ + def fileStream[ + K: ClassManifest, + V: ClassManifest, + F <: NewInputFormat[K, V]: ClassManifest + ] (directory: String, filter: Path => Boolean, newFilesOnly: Boolean): DStream[(K, V)] = { + val inputStream = new FileInputDStream[K, V, F](this, directory, filter, newFilesOnly) + registerInputStream(inputStream) + inputStream + } + + /** + * Create a input stream that monitors a Hadoop-compatible filesystem + * for new files and reads them as text files (using key as LongWritable, value + * as Text and input format as TextInputFormat). File names starting with . are ignored. + * @param directory HDFS directory to monitor for new file + */ + def textFileStream(directory: String): DStream[String] = { + fileStream[LongWritable, Text, TextInputFormat](directory).map(_._2.toString) + } + + /** + * Create a input stream that returns tweets received from Twitter. + * @param twitterAuth Twitter4J authentication, or None to use Twitter4J's default OAuth + * authorization; this uses the system properties twitter4j.oauth.consumerKey, + * .consumerSecret, .accessToken and .accessTokenSecret. + * @param filters Set of filter strings to get only those tweets that match them + * @param storageLevel Storage level to use for storing the received objects + */ + def twitterStream( + twitterAuth: Option[Authorization] = None, + filters: Seq[String] = Nil, + storageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK_SER_2 + ): DStream[Status] = { + val inputStream = new TwitterInputDStream(this, twitterAuth, filters, storageLevel) + registerInputStream(inputStream) + inputStream + } + + /** + * Create an input stream from a queue of RDDs. In each batch, + * it will process either one or all of the RDDs returned by the queue. + * @param queue Queue of RDDs + * @param oneAtATime Whether only one RDD should be consumed from the queue in every interval + * @tparam T Type of objects in the RDD + */ + def queueStream[T: ClassManifest]( + queue: Queue[RDD[T]], + oneAtATime: Boolean = true + ): DStream[T] = { + queueStream(queue, oneAtATime, sc.makeRDD(Seq[T](), 1)) + } + + /** + * Create an input stream from a queue of RDDs. In each batch, + * it will process either one or all of the RDDs returned by the queue. + * @param queue Queue of RDDs + * @param oneAtATime Whether only one RDD should be consumed from the queue in every interval + * @param defaultRDD Default RDD is returned by the DStream when the queue is empty. + * Set as null if no RDD should be returned when empty + * @tparam T Type of objects in the RDD + */ + def queueStream[T: ClassManifest]( + queue: Queue[RDD[T]], + oneAtATime: Boolean, + defaultRDD: RDD[T] + ): DStream[T] = { + val inputStream = new QueueInputDStream(this, queue, oneAtATime, defaultRDD) + registerInputStream(inputStream) + inputStream + } + + /** + * Create a unified DStream from multiple DStreams of the same type and same interval + */ + def union[T: ClassManifest](streams: Seq[DStream[T]]): DStream[T] = { + new UnionDStream[T](streams.toArray) + } + + /** + * Register an input stream that will be started (InputDStream.start() called) to get the + * input data. + */ + def registerInputStream(inputStream: InputDStream[_]) { + graph.addInputStream(inputStream) + } + + /** + * Register an output stream that will be computed every interval + */ + def registerOutputStream(outputStream: DStream[_]) { + graph.addOutputStream(outputStream) + } + + protected def validate() { + assert(graph != null, "Graph is null") + graph.validate() + + assert( + checkpointDir == null || checkpointDuration != null, + "Checkpoint directory has been set, but the graph checkpointing interval has " + + "not been set. Please use StreamingContext.checkpoint() to set the interval." + ) + } + + /** + * Start the execution of the streams. + */ + def start() { + if (checkpointDir != null && checkpointDuration == null && graph != null) { + checkpointDuration = graph.batchDuration + } + + validate() + + val networkInputStreams = graph.getInputStreams().filter(s => s match { + case n: NetworkInputDStream[_] => true + case _ => false + }).map(_.asInstanceOf[NetworkInputDStream[_]]).toArray + + if (networkInputStreams.length > 0) { + // Start the network input tracker (must start before receivers) + networkInputTracker = new NetworkInputTracker(this, networkInputStreams) + networkInputTracker.start() + } + + Thread.sleep(1000) + + // Start the scheduler + scheduler = new Scheduler(this) + scheduler.start() + } + + /** + * Stop the execution of the streams. + */ + def stop() { + try { + if (scheduler != null) scheduler.stop() + if (networkInputTracker != null) networkInputTracker.stop() + if (receiverJobThread != null) receiverJobThread.interrupt() + sc.stop() + logInfo("StreamingContext stopped successfully") + } catch { + case e: Exception => logWarning("Error while stopping", e) + } + } +} + + +object StreamingContext { + + implicit def toPairDStreamFunctions[K: ClassManifest, V: ClassManifest](stream: DStream[(K,V)]) = { + new PairDStreamFunctions[K, V](stream) + } + + protected[streaming] def createNewSparkContext( + master: String, + appName: String, + sparkHome: String, + jars: Seq[String], + environment: Map[String, String]): SparkContext = { + // Set the default cleaner delay to an hour if not already set. + // This should be sufficient for even 1 second interval. + if (MetadataCleaner.getDelaySeconds < 0) { + MetadataCleaner.setDelaySeconds(3600) + } + new SparkContext(master, appName, sparkHome, jars, environment) + } + + protected[streaming] def rddToFileName[T](prefix: String, suffix: String, time: Time): String = { + if (prefix == null) { + time.milliseconds.toString + } else if (suffix == null || suffix.length ==0) { + prefix + "-" + time.milliseconds + } else { + prefix + "-" + time.milliseconds + "." + suffix + } + } + + protected[streaming] def getSparkCheckpointDir(sscCheckpointDir: String): String = { + new Path(sscCheckpointDir, UUID.randomUUID.toString).toString + } +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/Time.scala b/streaming/src/main/scala/org/apache/spark/streaming/Time.scala new file mode 100644 index 0000000000..2678334f53 --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/Time.scala @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming + +/** + * This is a simple class that represents an absolute instant of time. + * Internally, it represents time as the difference, measured in milliseconds, between the current + * time and midnight, January 1, 1970 UTC. This is the same format as what is returned by + * System.currentTimeMillis. + */ +case class Time(private val millis: Long) { + + def milliseconds: Long = millis + + def < (that: Time): Boolean = (this.millis < that.millis) + + def <= (that: Time): Boolean = (this.millis <= that.millis) + + def > (that: Time): Boolean = (this.millis > that.millis) + + def >= (that: Time): Boolean = (this.millis >= that.millis) + + def + (that: Duration): Time = new Time(millis + that.milliseconds) + + def - (that: Time): Duration = new Duration(millis - that.millis) + + def - (that: Duration): Time = new Time(millis - that.milliseconds) + + def floor(that: Duration): Time = { + val t = that.milliseconds + val m = math.floor(this.millis / t).toLong + new Time(m * t) + } + + def isMultipleOf(that: Duration): Boolean = + (this.millis % that.milliseconds == 0) + + def min(that: Time): Time = if (this < that) this else that + + def max(that: Time): Time = if (this > that) this else that + + def until(that: Time, interval: Duration): Seq[Time] = { + (this.milliseconds) until (that.milliseconds) by (interval.milliseconds) map (new Time(_)) + } + + def to(that: Time, interval: Duration): Seq[Time] = { + (this.milliseconds) to (that.milliseconds) by (interval.milliseconds) map (new Time(_)) + } + + + override def toString: String = (millis.toString + " ms") + +} + +object Time { + val ordering = Ordering.by((time: Time) => time.millis) +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStream.scala new file mode 100644 index 0000000000..f8c8d8ece1 --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStream.scala @@ -0,0 +1,102 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.api.java + +import org.apache.spark.streaming.{Duration, Time, DStream} +import org.apache.spark.api.java.function.{Function => JFunction} +import org.apache.spark.api.java.JavaRDD +import org.apache.spark.storage.StorageLevel +import org.apache.spark.RDD + +/** + * A Discretized Stream (DStream), the basic abstraction in Spark Streaming, is a continuous + * sequence of RDDs (of the same type) representing a continuous stream of data (see [[org.apache.spark.RDD]] + * for more details on RDDs). DStreams can either be created from live data (such as, data from + * HDFS, Kafka or Flume) or it can be generated by transformation existing DStreams using operations + * such as `map`, `window` and `reduceByKeyAndWindow`. While a Spark Streaming program is running, each + * DStream periodically generates a RDD, either from live data or by transforming the RDD generated + * by a parent DStream. + * + * This class contains the basic operations available on all DStreams, such as `map`, `filter` and + * `window`. In addition, [[org.apache.spark.streaming.api.java.JavaPairDStream]] contains operations available + * only on DStreams of key-value pairs, such as `groupByKeyAndWindow` and `join`. + * + * DStreams internally is characterized by a few basic properties: + * - A list of other DStreams that the DStream depends on + * - A time interval at which the DStream generates an RDD + * - A function that is used to generate an RDD after each time interval + */ +class JavaDStream[T](val dstream: DStream[T])(implicit val classManifest: ClassManifest[T]) + extends JavaDStreamLike[T, JavaDStream[T], JavaRDD[T]] { + + override def wrapRDD(rdd: RDD[T]): JavaRDD[T] = JavaRDD.fromRDD(rdd) + + /** Return a new DStream containing only the elements that satisfy a predicate. */ + def filter(f: JFunction[T, java.lang.Boolean]): JavaDStream[T] = + dstream.filter((x => f(x).booleanValue())) + + /** Persist RDDs of this DStream with the default storage level (MEMORY_ONLY_SER) */ + def cache(): JavaDStream[T] = dstream.cache() + + /** Persist RDDs of this DStream with the default storage level (MEMORY_ONLY_SER) */ + def persist(): JavaDStream[T] = dstream.persist() + + /** Persist the RDDs of this DStream with the given storage level */ + def persist(storageLevel: StorageLevel): JavaDStream[T] = dstream.persist(storageLevel) + + /** Generate an RDD for the given duration */ + def compute(validTime: Time): JavaRDD[T] = { + dstream.compute(validTime) match { + case Some(rdd) => new JavaRDD(rdd) + case None => null + } + } + + /** + * Return a new DStream in which each RDD contains all the elements in seen in a + * sliding window of time over this DStream. The new DStream generates RDDs with + * the same interval as this DStream. + * @param windowDuration width of the window; must be a multiple of this DStream's interval. + */ + def window(windowDuration: Duration): JavaDStream[T] = + dstream.window(windowDuration) + + /** + * Return a new DStream in which each RDD contains all the elements in seen in a + * sliding window of time over this DStream. + * @param windowDuration width of the window; must be a multiple of this DStream's + * batching interval + * @param slideDuration sliding interval of the window (i.e., the interval after which + * the new DStream will generate RDDs); must be a multiple of this + * DStream's batching interval + */ + def window(windowDuration: Duration, slideDuration: Duration): JavaDStream[T] = + dstream.window(windowDuration, slideDuration) + + /** + * Return a new DStream by unifying data of another DStream with this DStream. + * @param that Another DStream having the same interval (i.e., slideDuration) as this DStream. + */ + def union(that: JavaDStream[T]): JavaDStream[T] = + dstream.union(that.dstream) +} + +object JavaDStream { + implicit def fromDStream[T: ClassManifest](dstream: DStream[T]): JavaDStream[T] = + new JavaDStream[T](dstream) +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala new file mode 100644 index 0000000000..2e6fe9a9c4 --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala @@ -0,0 +1,316 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.api.java + +import java.util.{List => JList} +import java.lang.{Long => JLong} + +import scala.collection.JavaConversions._ + +import org.apache.spark.streaming._ +import org.apache.spark.api.java.{JavaPairRDD, JavaRDDLike, JavaRDD} +import org.apache.spark.api.java.function.{Function2 => JFunction2, Function => JFunction, _} +import java.util +import org.apache.spark.RDD +import JavaDStream._ + +trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T, R]] + extends Serializable { + implicit val classManifest: ClassManifest[T] + + def dstream: DStream[T] + + def wrapRDD(in: RDD[T]): R + + implicit def scalaIntToJavaLong(in: DStream[Long]): JavaDStream[JLong] = { + in.map(new JLong(_)) + } + + /** + * Print the first ten elements of each RDD generated in this DStream. This is an output + * operator, so this DStream will be registered as an output stream and there materialized. + */ + def print() = dstream.print() + + /** + * Return a new DStream in which each RDD has a single element generated by counting each RDD + * of this DStream. + */ + def count(): JavaDStream[JLong] = dstream.count() + + /** + * Return a new DStream in which each RDD contains the counts of each distinct value in + * each RDD of this DStream. Hash partitioning is used to generate the RDDs with + * Spark's default number of partitions. + */ + def countByValue(): JavaPairDStream[T, JLong] = { + JavaPairDStream.scalaToJavaLong(dstream.countByValue()) + } + + /** + * Return a new DStream in which each RDD contains the counts of each distinct value in + * each RDD of this DStream. Hash partitioning is used to generate the RDDs with `numPartitions` + * partitions. + * @param numPartitions number of partitions of each RDD in the new DStream. + */ + def countByValue(numPartitions: Int): JavaPairDStream[T, JLong] = { + JavaPairDStream.scalaToJavaLong(dstream.countByValue(numPartitions)) + } + + + /** + * Return a new DStream in which each RDD has a single element generated by counting the number + * of elements in a window over this DStream. windowDuration and slideDuration are as defined in the + * window() operation. This is equivalent to window(windowDuration, slideDuration).count() + */ + def countByWindow(windowDuration: Duration, slideDuration: Duration) : JavaDStream[JLong] = { + dstream.countByWindow(windowDuration, slideDuration) + } + + /** + * Return a new DStream in which each RDD contains the count of distinct elements in + * RDDs in a sliding window over this DStream. Hash partitioning is used to generate the RDDs with + * Spark's default number of partitions. + * @param windowDuration width of the window; must be a multiple of this DStream's + * batching interval + * @param slideDuration sliding interval of the window (i.e., the interval after which + * the new DStream will generate RDDs); must be a multiple of this + * DStream's batching interval + */ + def countByValueAndWindow(windowDuration: Duration, slideDuration: Duration) + : JavaPairDStream[T, JLong] = { + JavaPairDStream.scalaToJavaLong( + dstream.countByValueAndWindow(windowDuration, slideDuration)) + } + + /** + * Return a new DStream in which each RDD contains the count of distinct elements in + * RDDs in a sliding window over this DStream. Hash partitioning is used to generate the RDDs with `numPartitions` + * partitions. + * @param windowDuration width of the window; must be a multiple of this DStream's + * batching interval + * @param slideDuration sliding interval of the window (i.e., the interval after which + * the new DStream will generate RDDs); must be a multiple of this + * DStream's batching interval + * @param numPartitions number of partitions of each RDD in the new DStream. + */ + def countByValueAndWindow(windowDuration: Duration, slideDuration: Duration, numPartitions: Int) + : JavaPairDStream[T, JLong] = { + JavaPairDStream.scalaToJavaLong( + dstream.countByValueAndWindow(windowDuration, slideDuration, numPartitions)) + } + + /** + * Return a new DStream in which each RDD is generated by applying glom() to each RDD of + * this DStream. Applying glom() to an RDD coalesces all elements within each partition into + * an array. + */ + def glom(): JavaDStream[JList[T]] = + new JavaDStream(dstream.glom().map(x => new java.util.ArrayList[T](x.toSeq))) + + /** Return the StreamingContext associated with this DStream */ + def context(): StreamingContext = dstream.context() + + /** Return a new DStream by applying a function to all elements of this DStream. */ + def map[R](f: JFunction[T, R]): JavaDStream[R] = { + new JavaDStream(dstream.map(f)(f.returnType()))(f.returnType()) + } + + /** Return a new DStream by applying a function to all elements of this DStream. */ + def map[K2, V2](f: PairFunction[T, K2, V2]): JavaPairDStream[K2, V2] = { + def cm = implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[Tuple2[K2, V2]]] + new JavaPairDStream(dstream.map(f)(cm))(f.keyType(), f.valueType()) + } + + /** + * Return a new DStream by applying a function to all elements of this DStream, + * and then flattening the results + */ + def flatMap[U](f: FlatMapFunction[T, U]): JavaDStream[U] = { + import scala.collection.JavaConverters._ + def fn = (x: T) => f.apply(x).asScala + new JavaDStream(dstream.flatMap(fn)(f.elementType()))(f.elementType()) + } + + /** + * Return a new DStream by applying a function to all elements of this DStream, + * and then flattening the results + */ + def flatMap[K2, V2](f: PairFlatMapFunction[T, K2, V2]): JavaPairDStream[K2, V2] = { + import scala.collection.JavaConverters._ + def fn = (x: T) => f.apply(x).asScala + def cm = implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[Tuple2[K2, V2]]] + new JavaPairDStream(dstream.flatMap(fn)(cm))(f.keyType(), f.valueType()) + } + + /** + * Return a new DStream in which each RDD is generated by applying mapPartitions() to each RDDs + * of this DStream. Applying mapPartitions() to an RDD applies a function to each partition + * of the RDD. + */ + def mapPartitions[U](f: FlatMapFunction[java.util.Iterator[T], U]): JavaDStream[U] = { + def fn = (x: Iterator[T]) => asScalaIterator(f.apply(asJavaIterator(x)).iterator()) + new JavaDStream(dstream.mapPartitions(fn)(f.elementType()))(f.elementType()) + } + + /** + * Return a new DStream in which each RDD is generated by applying mapPartitions() to each RDDs + * of this DStream. Applying mapPartitions() to an RDD applies a function to each partition + * of the RDD. + */ + def mapPartitions[K2, V2](f: PairFlatMapFunction[java.util.Iterator[T], K2, V2]) + : JavaPairDStream[K2, V2] = { + def fn = (x: Iterator[T]) => asScalaIterator(f.apply(asJavaIterator(x)).iterator()) + new JavaPairDStream(dstream.mapPartitions(fn))(f.keyType(), f.valueType()) + } + + /** + * Return a new DStream in which each RDD has a single element generated by reducing each RDD + * of this DStream. + */ + def reduce(f: JFunction2[T, T, T]): JavaDStream[T] = dstream.reduce(f) + + /** + * Return a new DStream in which each RDD has a single element generated by reducing all + * elements in a sliding window over this DStream. + * @param reduceFunc associative reduce function + * @param windowDuration width of the window; must be a multiple of this DStream's + * batching interval + * @param slideDuration sliding interval of the window (i.e., the interval after which + * the new DStream will generate RDDs); must be a multiple of this + * DStream's batching interval + */ + def reduceByWindow( + reduceFunc: (T, T) => T, + windowDuration: Duration, + slideDuration: Duration + ): DStream[T] = { + dstream.reduceByWindow(reduceFunc, windowDuration, slideDuration) + } + + + /** + * Return a new DStream in which each RDD has a single element generated by reducing all + * elements in a sliding window over this DStream. However, the reduction is done incrementally + * using the old window's reduced value : + * 1. reduce the new values that entered the window (e.g., adding new counts) + * 2. "inverse reduce" the old values that left the window (e.g., subtracting old counts) + * This is more efficient than reduceByWindow without "inverse reduce" function. + * However, it is applicable to only "invertible reduce functions". + * @param reduceFunc associative reduce function + * @param invReduceFunc inverse reduce function + * @param windowDuration width of the window; must be a multiple of this DStream's + * batching interval + * @param slideDuration sliding interval of the window (i.e., the interval after which + * the new DStream will generate RDDs); must be a multiple of this + * DStream's batching interval + */ + def reduceByWindow( + reduceFunc: JFunction2[T, T, T], + invReduceFunc: JFunction2[T, T, T], + windowDuration: Duration, + slideDuration: Duration + ): JavaDStream[T] = { + dstream.reduceByWindow(reduceFunc, invReduceFunc, windowDuration, slideDuration) + } + + /** + * Return all the RDDs between 'fromDuration' to 'toDuration' (both included) + */ + def slice(fromTime: Time, toTime: Time): JList[R] = { + new util.ArrayList(dstream.slice(fromTime, toTime).map(wrapRDD(_)).toSeq) + } + + /** + * Apply a function to each RDD in this DStream. This is an output operator, so + * this DStream will be registered as an output stream and therefore materialized. + */ + def foreach(foreachFunc: JFunction[R, Void]) { + dstream.foreach(rdd => foreachFunc.call(wrapRDD(rdd))) + } + + /** + * Apply a function to each RDD in this DStream. This is an output operator, so + * this DStream will be registered as an output stream and therefore materialized. + */ + def foreach(foreachFunc: JFunction2[R, Time, Void]) { + dstream.foreach((rdd, time) => foreachFunc.call(wrapRDD(rdd), time)) + } + + /** + * Return a new DStream in which each RDD is generated by applying a function + * on each RDD of this DStream. + */ + def transform[U](transformFunc: JFunction[R, JavaRDD[U]]): JavaDStream[U] = { + implicit val cm: ClassManifest[U] = + implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[U]] + def scalaTransform (in: RDD[T]): RDD[U] = + transformFunc.call(wrapRDD(in)).rdd + dstream.transform(scalaTransform(_)) + } + + /** + * Return a new DStream in which each RDD is generated by applying a function + * on each RDD of this DStream. + */ + def transform[U](transformFunc: JFunction2[R, Time, JavaRDD[U]]): JavaDStream[U] = { + implicit val cm: ClassManifest[U] = + implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[U]] + def scalaTransform (in: RDD[T], time: Time): RDD[U] = + transformFunc.call(wrapRDD(in), time).rdd + dstream.transform(scalaTransform(_, _)) + } + + /** + * Return a new DStream in which each RDD is generated by applying a function + * on each RDD of this DStream. + */ + def transform[K2, V2](transformFunc: JFunction[R, JavaPairRDD[K2, V2]]): + JavaPairDStream[K2, V2] = { + implicit val cmk: ClassManifest[K2] = + implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[K2]] + implicit val cmv: ClassManifest[V2] = + implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[V2]] + def scalaTransform (in: RDD[T]): RDD[(K2, V2)] = + transformFunc.call(wrapRDD(in)).rdd + dstream.transform(scalaTransform(_)) + } + + /** + * Return a new DStream in which each RDD is generated by applying a function + * on each RDD of this DStream. + */ + def transform[K2, V2](transformFunc: JFunction2[R, Time, JavaPairRDD[K2, V2]]): + JavaPairDStream[K2, V2] = { + implicit val cmk: ClassManifest[K2] = + implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[K2]] + implicit val cmv: ClassManifest[V2] = + implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[V2]] + def scalaTransform (in: RDD[T], time: Time): RDD[(K2, V2)] = + transformFunc.call(wrapRDD(in), time).rdd + dstream.transform(scalaTransform(_, _)) + } + + /** + * Enable periodic checkpointing of RDDs of this DStream + * @param interval Time interval after which generated RDD will be checkpointed + */ + def checkpoint(interval: Duration) = { + dstream.checkpoint(interval) + } +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala new file mode 100644 index 0000000000..c203dccd17 --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala @@ -0,0 +1,613 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.api.java + +import java.util.{List => JList} +import java.lang.{Long => JLong} + +import scala.collection.JavaConversions._ + +import org.apache.spark.streaming._ +import org.apache.spark.streaming.StreamingContext._ +import org.apache.spark.api.java.function.{Function => JFunction, Function2 => JFunction2} +import org.apache.spark.{RDD, Partitioner} +import org.apache.hadoop.mapred.{JobConf, OutputFormat} +import org.apache.hadoop.mapreduce.{OutputFormat => NewOutputFormat} +import org.apache.hadoop.conf.Configuration +import org.apache.spark.api.java.{JavaUtils, JavaRDD, JavaPairRDD} +import org.apache.spark.storage.StorageLevel +import com.google.common.base.Optional +import org.apache.spark.RDD + +class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( + implicit val kManifiest: ClassManifest[K], + implicit val vManifest: ClassManifest[V]) + extends JavaDStreamLike[(K, V), JavaPairDStream[K, V], JavaPairRDD[K, V]] { + + override def wrapRDD(rdd: RDD[(K, V)]): JavaPairRDD[K, V] = JavaPairRDD.fromRDD(rdd) + + // ======================================================================= + // Methods common to all DStream's + // ======================================================================= + + /** Return a new DStream containing only the elements that satisfy a predicate. */ + def filter(f: JFunction[(K, V), java.lang.Boolean]): JavaPairDStream[K, V] = + dstream.filter((x => f(x).booleanValue())) + + /** Persist RDDs of this DStream with the default storage level (MEMORY_ONLY_SER) */ + def cache(): JavaPairDStream[K, V] = dstream.cache() + + /** Persist RDDs of this DStream with the default storage level (MEMORY_ONLY_SER) */ + def persist(): JavaPairDStream[K, V] = dstream.persist() + + /** Persist the RDDs of this DStream with the given storage level */ + def persist(storageLevel: StorageLevel): JavaPairDStream[K, V] = dstream.persist(storageLevel) + + /** Method that generates a RDD for the given Duration */ + def compute(validTime: Time): JavaPairRDD[K, V] = { + dstream.compute(validTime) match { + case Some(rdd) => new JavaPairRDD(rdd) + case None => null + } + } + + /** + * Return a new DStream which is computed based on windowed batches of this DStream. + * The new DStream generates RDDs with the same interval as this DStream. + * @param windowDuration width of the window; must be a multiple of this DStream's interval. + * @return + */ + def window(windowDuration: Duration): JavaPairDStream[K, V] = + dstream.window(windowDuration) + + /** + * Return a new DStream which is computed based on windowed batches of this DStream. + * @param windowDuration duration (i.e., width) of the window; + * must be a multiple of this DStream's interval + * @param slideDuration sliding interval of the window (i.e., the interval after which + * the new DStream will generate RDDs); must be a multiple of this + * DStream's interval + */ + def window(windowDuration: Duration, slideDuration: Duration): JavaPairDStream[K, V] = + dstream.window(windowDuration, slideDuration) + + /** + * Return a new DStream by unifying data of another DStream with this DStream. + * @param that Another DStream having the same interval (i.e., slideDuration) as this DStream. + */ + def union(that: JavaPairDStream[K, V]): JavaPairDStream[K, V] = + dstream.union(that.dstream) + + // ======================================================================= + // Methods only for PairDStream's + // ======================================================================= + + /** + * Return a new DStream by applying `groupByKey` to each RDD. Hash partitioning is used to + * generate the RDDs with Spark's default number of partitions. + */ + def groupByKey(): JavaPairDStream[K, JList[V]] = + dstream.groupByKey().mapValues(seqAsJavaList _) + + /** + * Return a new DStream by applying `groupByKey` to each RDD. Hash partitioning is used to + * generate the RDDs with `numPartitions` partitions. + */ + def groupByKey(numPartitions: Int): JavaPairDStream[K, JList[V]] = + dstream.groupByKey(numPartitions).mapValues(seqAsJavaList _) + + /** + * Return a new DStream by applying `groupByKey` on each RDD of `this` DStream. + * Therefore, the values for each key in `this` DStream's RDDs are grouped into a + * single sequence to generate the RDDs of the new DStream. [[org.apache.spark.Partitioner]] + * is used to control the partitioning of each RDD. + */ + def groupByKey(partitioner: Partitioner): JavaPairDStream[K, JList[V]] = + dstream.groupByKey(partitioner).mapValues(seqAsJavaList _) + + /** + * Return a new DStream by applying `reduceByKey` to each RDD. The values for each key are + * merged using the associative reduce function. Hash partitioning is used to generate the RDDs + * with Spark's default number of partitions. + */ + def reduceByKey(func: JFunction2[V, V, V]): JavaPairDStream[K, V] = + dstream.reduceByKey(func) + + /** + * Return a new DStream by applying `reduceByKey` to each RDD. The values for each key are + * merged using the supplied reduce function. Hash partitioning is used to generate the RDDs + * with `numPartitions` partitions. + */ + def reduceByKey(func: JFunction2[V, V, V], numPartitions: Int): JavaPairDStream[K, V] = + dstream.reduceByKey(func, numPartitions) + + /** + * Return a new DStream by applying `reduceByKey` to each RDD. The values for each key are + * merged using the supplied reduce function. [[org.apache.spark.Partitioner]] is used to control the + * partitioning of each RDD. + */ + def reduceByKey(func: JFunction2[V, V, V], partitioner: Partitioner): JavaPairDStream[K, V] = { + dstream.reduceByKey(func, partitioner) + } + + /** + * Combine elements of each key in DStream's RDDs using custom function. This is similar to the + * combineByKey for RDDs. Please refer to combineByKey in [[org.apache.spark.PairRDDFunctions]] for more + * information. + */ + def combineByKey[C](createCombiner: JFunction[V, C], + mergeValue: JFunction2[C, V, C], + mergeCombiners: JFunction2[C, C, C], + partitioner: Partitioner + ): JavaPairDStream[K, C] = { + implicit val cm: ClassManifest[C] = + implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[C]] + dstream.combineByKey(createCombiner, mergeValue, mergeCombiners, partitioner) + } + + /** + * Return a new DStream by applying `groupByKey` over a sliding window. This is similar to + * `DStream.groupByKey()` but applies it over a sliding window. The new DStream generates RDDs + * with the same interval as this DStream. Hash partitioning is used to generate the RDDs with + * Spark's default number of partitions. + * @param windowDuration width of the window; must be a multiple of this DStream's + * batching interval + */ + def groupByKeyAndWindow(windowDuration: Duration): JavaPairDStream[K, JList[V]] = { + dstream.groupByKeyAndWindow(windowDuration).mapValues(seqAsJavaList _) + } + + /** + * Return a new DStream by applying `groupByKey` over a sliding window. Similar to + * `DStream.groupByKey()`, but applies it over a sliding window. Hash partitioning is used to + * generate the RDDs with Spark's default number of partitions. + * @param windowDuration width of the window; must be a multiple of this DStream's + * batching interval + * @param slideDuration sliding interval of the window (i.e., the interval after which + * the new DStream will generate RDDs); must be a multiple of this + * DStream's batching interval + */ + def groupByKeyAndWindow(windowDuration: Duration, slideDuration: Duration) + : JavaPairDStream[K, JList[V]] = { + dstream.groupByKeyAndWindow(windowDuration, slideDuration).mapValues(seqAsJavaList _) + } + + /** + * Return a new DStream by applying `groupByKey` over a sliding window on `this` DStream. + * Similar to `DStream.groupByKey()`, but applies it over a sliding window. + * Hash partitioning is used to generate the RDDs with `numPartitions` partitions. + * @param windowDuration width of the window; must be a multiple of this DStream's + * batching interval + * @param slideDuration sliding interval of the window (i.e., the interval after which + * the new DStream will generate RDDs); must be a multiple of this + * DStream's batching interval + * @param numPartitions Number of partitions of each RDD in the new DStream. + */ + def groupByKeyAndWindow(windowDuration: Duration, slideDuration: Duration, numPartitions: Int) + :JavaPairDStream[K, JList[V]] = { + dstream.groupByKeyAndWindow(windowDuration, slideDuration, numPartitions) + .mapValues(seqAsJavaList _) + } + + /** + * Return a new DStream by applying `groupByKey` over a sliding window on `this` DStream. + * Similar to `DStream.groupByKey()`, but applies it over a sliding window. + * @param windowDuration width of the window; must be a multiple of this DStream's + * batching interval + * @param slideDuration sliding interval of the window (i.e., the interval after which + * the new DStream will generate RDDs); must be a multiple of this + * DStream's batching interval + * @param partitioner Partitioner for controlling the partitioning of each RDD in the new DStream. + */ + def groupByKeyAndWindow( + windowDuration: Duration, + slideDuration: Duration, + partitioner: Partitioner + ):JavaPairDStream[K, JList[V]] = { + dstream.groupByKeyAndWindow(windowDuration, slideDuration, partitioner) + .mapValues(seqAsJavaList _) + } + + /** + * Create a new DStream by applying `reduceByKey` over a sliding window on `this` DStream. + * Similar to `DStream.reduceByKey()`, but applies it over a sliding window. The new DStream + * generates RDDs with the same interval as this DStream. Hash partitioning is used to generate + * the RDDs with Spark's default number of partitions. + * @param reduceFunc associative reduce function + * @param windowDuration width of the window; must be a multiple of this DStream's + * batching interval + */ + def reduceByKeyAndWindow(reduceFunc: Function2[V, V, V], windowDuration: Duration) + :JavaPairDStream[K, V] = { + dstream.reduceByKeyAndWindow(reduceFunc, windowDuration) + } + + /** + * Return a new DStream by applying `reduceByKey` over a sliding window. This is similar to + * `DStream.reduceByKey()` but applies it over a sliding window. Hash partitioning is used to + * generate the RDDs with Spark's default number of partitions. + * @param reduceFunc associative reduce function + * @param windowDuration width of the window; must be a multiple of this DStream's + * batching interval + * @param slideDuration sliding interval of the window (i.e., the interval after which + * the new DStream will generate RDDs); must be a multiple of this + * DStream's batching interval + */ + def reduceByKeyAndWindow( + reduceFunc: Function2[V, V, V], + windowDuration: Duration, + slideDuration: Duration + ):JavaPairDStream[K, V] = { + dstream.reduceByKeyAndWindow(reduceFunc, windowDuration, slideDuration) + } + + /** + * Return a new DStream by applying `reduceByKey` over a sliding window. This is similar to + * `DStream.reduceByKey()` but applies it over a sliding window. Hash partitioning is used to + * generate the RDDs with `numPartitions` partitions. + * @param reduceFunc associative reduce function + * @param windowDuration width of the window; must be a multiple of this DStream's + * batching interval + * @param slideDuration sliding interval of the window (i.e., the interval after which + * the new DStream will generate RDDs); must be a multiple of this + * DStream's batching interval + * @param numPartitions Number of partitions of each RDD in the new DStream. + */ + def reduceByKeyAndWindow( + reduceFunc: Function2[V, V, V], + windowDuration: Duration, + slideDuration: Duration, + numPartitions: Int + ): JavaPairDStream[K, V] = { + dstream.reduceByKeyAndWindow(reduceFunc, windowDuration, slideDuration, numPartitions) + } + + /** + * Return a new DStream by applying `reduceByKey` over a sliding window. Similar to + * `DStream.reduceByKey()`, but applies it over a sliding window. + * @param reduceFunc associative reduce function + * @param windowDuration width of the window; must be a multiple of this DStream's + * batching interval + * @param slideDuration sliding interval of the window (i.e., the interval after which + * the new DStream will generate RDDs); must be a multiple of this + * DStream's batching interval + * @param partitioner Partitioner for controlling the partitioning of each RDD in the new DStream. + */ + def reduceByKeyAndWindow( + reduceFunc: Function2[V, V, V], + windowDuration: Duration, + slideDuration: Duration, + partitioner: Partitioner + ): JavaPairDStream[K, V] = { + dstream.reduceByKeyAndWindow(reduceFunc, windowDuration, slideDuration, partitioner) + } + + /** + * Return a new DStream by reducing over a using incremental computation. + * The reduced value of over a new window is calculated using the old window's reduce value : + * 1. reduce the new values that entered the window (e.g., adding new counts) + * 2. "inverse reduce" the old values that left the window (e.g., subtracting old counts) + * This is more efficient that reduceByKeyAndWindow without "inverse reduce" function. + * However, it is applicable to only "invertible reduce functions". + * Hash partitioning is used to generate the RDDs with Spark's default number of partitions. + * @param reduceFunc associative reduce function + * @param invReduceFunc inverse function + * @param windowDuration width of the window; must be a multiple of this DStream's + * batching interval + * @param slideDuration sliding interval of the window (i.e., the interval after which + * the new DStream will generate RDDs); must be a multiple of this + * DStream's batching interval + */ + def reduceByKeyAndWindow( + reduceFunc: Function2[V, V, V], + invReduceFunc: Function2[V, V, V], + windowDuration: Duration, + slideDuration: Duration + ): JavaPairDStream[K, V] = { + dstream.reduceByKeyAndWindow(reduceFunc, invReduceFunc, windowDuration, slideDuration) + } + + /** + * Return a new DStream by applying incremental `reduceByKey` over a sliding window. + * The reduced value of over a new window is calculated using the old window's reduce value : + * 1. reduce the new values that entered the window (e.g., adding new counts) + * 2. "inverse reduce" the old values that left the window (e.g., subtracting old counts) + * This is more efficient that reduceByKeyAndWindow without "inverse reduce" function. + * However, it is applicable to only "invertible reduce functions". + * Hash partitioning is used to generate the RDDs with `numPartitions` partitions. + * @param reduceFunc associative reduce function + * @param invReduceFunc inverse function + * @param windowDuration width of the window; must be a multiple of this DStream's + * batching interval + * @param slideDuration sliding interval of the window (i.e., the interval after which + * the new DStream will generate RDDs); must be a multiple of this + * DStream's batching interval + * @param numPartitions number of partitions of each RDD in the new DStream. + * @param filterFunc function to filter expired key-value pairs; + * only pairs that satisfy the function are retained + * set this to null if you do not want to filter + */ + def reduceByKeyAndWindow( + reduceFunc: Function2[V, V, V], + invReduceFunc: Function2[V, V, V], + windowDuration: Duration, + slideDuration: Duration, + numPartitions: Int, + filterFunc: JFunction[(K, V), java.lang.Boolean] + ): JavaPairDStream[K, V] = { + dstream.reduceByKeyAndWindow( + reduceFunc, + invReduceFunc, + windowDuration, + slideDuration, + numPartitions, + (p: (K, V)) => filterFunc(p).booleanValue() + ) + } + + /** + * Return a new DStream by applying incremental `reduceByKey` over a sliding window. + * The reduced value of over a new window is calculated using the old window's reduce value : + * 1. reduce the new values that entered the window (e.g., adding new counts) + * 2. "inverse reduce" the old values that left the window (e.g., subtracting old counts) + * This is more efficient that reduceByKeyAndWindow without "inverse reduce" function. + * However, it is applicable to only "invertible reduce functions". + * @param reduceFunc associative reduce function + * @param invReduceFunc inverse function + * @param windowDuration width of the window; must be a multiple of this DStream's + * batching interval + * @param slideDuration sliding interval of the window (i.e., the interval after which + * the new DStream will generate RDDs); must be a multiple of this + * DStream's batching interval + * @param partitioner Partitioner for controlling the partitioning of each RDD in the new DStream. + * @param filterFunc function to filter expired key-value pairs; + * only pairs that satisfy the function are retained + * set this to null if you do not want to filter + */ + def reduceByKeyAndWindow( + reduceFunc: Function2[V, V, V], + invReduceFunc: Function2[V, V, V], + windowDuration: Duration, + slideDuration: Duration, + partitioner: Partitioner, + filterFunc: JFunction[(K, V), java.lang.Boolean] + ): JavaPairDStream[K, V] = { + dstream.reduceByKeyAndWindow( + reduceFunc, + invReduceFunc, + windowDuration, + slideDuration, + partitioner, + (p: (K, V)) => filterFunc(p).booleanValue() + ) + } + + private def convertUpdateStateFunction[S](in: JFunction2[JList[V], Optional[S], Optional[S]]): + (Seq[V], Option[S]) => Option[S] = { + val scalaFunc: (Seq[V], Option[S]) => Option[S] = (values, state) => { + val list: JList[V] = values + val scalaState: Optional[S] = JavaUtils.optionToOptional(state) + val result: Optional[S] = in.apply(list, scalaState) + result.isPresent match { + case true => Some(result.get()) + case _ => None + } + } + scalaFunc + } + + /** + * Create a new "state" DStream where the state for each key is updated by applying + * the given function on the previous state of the key and the new values of each key. + * Hash partitioning is used to generate the RDDs with Spark's default number of partitions. + * @param updateFunc State update function. If `this` function returns None, then + * corresponding state key-value pair will be eliminated. + * @tparam S State type + */ + def updateStateByKey[S](updateFunc: JFunction2[JList[V], Optional[S], Optional[S]]) + : JavaPairDStream[K, S] = { + implicit val cm: ClassManifest[S] = + implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[S]] + dstream.updateStateByKey(convertUpdateStateFunction(updateFunc)) + } + + /** + * Create a new "state" DStream where the state for each key is updated by applying + * the given function on the previous state of the key and the new values of each key. + * Hash partitioning is used to generate the RDDs with `numPartitions` partitions. + * @param updateFunc State update function. If `this` function returns None, then + * corresponding state key-value pair will be eliminated. + * @param numPartitions Number of partitions of each RDD in the new DStream. + * @tparam S State type + */ + def updateStateByKey[S: ClassManifest]( + updateFunc: JFunction2[JList[V], Optional[S], Optional[S]], + numPartitions: Int) + : JavaPairDStream[K, S] = { + dstream.updateStateByKey(convertUpdateStateFunction(updateFunc), numPartitions) + } + + /** + * Create a new "state" DStream where the state for each key is updated by applying + * the given function on the previous state of the key and the new values of the key. + * [[org.apache.spark.Partitioner]] is used to control the partitioning of each RDD. + * @param updateFunc State update function. If `this` function returns None, then + * corresponding state key-value pair will be eliminated. + * @param partitioner Partitioner for controlling the partitioning of each RDD in the new DStream. + * @tparam S State type + */ + def updateStateByKey[S: ClassManifest]( + updateFunc: JFunction2[JList[V], Optional[S], Optional[S]], + partitioner: Partitioner + ): JavaPairDStream[K, S] = { + dstream.updateStateByKey(convertUpdateStateFunction(updateFunc), partitioner) + } + + def mapValues[U](f: JFunction[V, U]): JavaPairDStream[K, U] = { + implicit val cm: ClassManifest[U] = + implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[U]] + dstream.mapValues(f) + } + + def flatMapValues[U](f: JFunction[V, java.lang.Iterable[U]]): JavaPairDStream[K, U] = { + import scala.collection.JavaConverters._ + def fn = (x: V) => f.apply(x).asScala + implicit val cm: ClassManifest[U] = + implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[U]] + dstream.flatMapValues(fn) + } + + /** + * Cogroup `this` DStream with `other` DStream. For each key k in corresponding RDDs of `this` + * or `other` DStreams, the generated RDD will contains a tuple with the list of values for that + * key in both RDDs. HashPartitioner is used to partition each generated RDD into default number + * of partitions. + */ + def cogroup[W](other: JavaPairDStream[K, W]): JavaPairDStream[K, (JList[V], JList[W])] = { + implicit val cm: ClassManifest[W] = + implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[W]] + dstream.cogroup(other.dstream).mapValues(t => (seqAsJavaList(t._1), seqAsJavaList((t._2)))) + } + + /** + * Cogroup `this` DStream with `other` DStream. For each key k in corresponding RDDs of `this` + * or `other` DStreams, the generated RDD will contains a tuple with the list of values for that + * key in both RDDs. Partitioner is used to partition each generated RDD. + */ + def cogroup[W](other: JavaPairDStream[K, W], partitioner: Partitioner) + : JavaPairDStream[K, (JList[V], JList[W])] = { + implicit val cm: ClassManifest[W] = + implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[W]] + dstream.cogroup(other.dstream, partitioner) + .mapValues(t => (seqAsJavaList(t._1), seqAsJavaList((t._2)))) + } + + /** + * Join `this` DStream with `other` DStream. HashPartitioner is used + * to partition each generated RDD into default number of partitions. + */ + def join[W](other: JavaPairDStream[K, W]): JavaPairDStream[K, (V, W)] = { + implicit val cm: ClassManifest[W] = + implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[W]] + dstream.join(other.dstream) + } + + /** + * Join `this` DStream with `other` DStream, that is, each RDD of the new DStream will + * be generated by joining RDDs from `this` and other DStream. Uses the given + * Partitioner to partition each generated RDD. + */ + def join[W](other: JavaPairDStream[K, W], partitioner: Partitioner) + : JavaPairDStream[K, (V, W)] = { + implicit val cm: ClassManifest[W] = + implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[W]] + dstream.join(other.dstream, partitioner) + } + + /** + * Save each RDD in `this` DStream as a Hadoop file. The file name at each batch interval is + * generated based on `prefix` and `suffix`: "prefix-TIME_IN_MS.suffix". + */ + def saveAsHadoopFiles[F <: OutputFormat[K, V]](prefix: String, suffix: String) { + dstream.saveAsHadoopFiles(prefix, suffix) + } + + /** + * Save each RDD in `this` DStream as a Hadoop file. The file name at each batch interval is + * generated based on `prefix` and `suffix`: "prefix-TIME_IN_MS.suffix". + */ + def saveAsHadoopFiles( + prefix: String, + suffix: String, + keyClass: Class[_], + valueClass: Class[_], + outputFormatClass: Class[_ <: OutputFormat[_, _]]) { + dstream.saveAsHadoopFiles(prefix, suffix, keyClass, valueClass, outputFormatClass) + } + + /** + * Save each RDD in `this` DStream as a Hadoop file. The file name at each batch interval is + * generated based on `prefix` and `suffix`: "prefix-TIME_IN_MS.suffix". + */ + def saveAsHadoopFiles( + prefix: String, + suffix: String, + keyClass: Class[_], + valueClass: Class[_], + outputFormatClass: Class[_ <: OutputFormat[_, _]], + conf: JobConf) { + dstream.saveAsHadoopFiles(prefix, suffix, keyClass, valueClass, outputFormatClass, conf) + } + + /** + * Save each RDD in `this` DStream as a Hadoop file. The file name at each batch interval is + * generated based on `prefix` and `suffix`: "prefix-TIME_IN_MS.suffix". + */ + def saveAsNewAPIHadoopFiles[F <: NewOutputFormat[K, V]](prefix: String, suffix: String) { + dstream.saveAsNewAPIHadoopFiles(prefix, suffix) + } + + /** + * Save each RDD in `this` DStream as a Hadoop file. The file name at each batch interval is + * generated based on `prefix` and `suffix`: "prefix-TIME_IN_MS.suffix". + */ + def saveAsNewAPIHadoopFiles( + prefix: String, + suffix: String, + keyClass: Class[_], + valueClass: Class[_], + outputFormatClass: Class[_ <: NewOutputFormat[_, _]]) { + dstream.saveAsNewAPIHadoopFiles(prefix, suffix, keyClass, valueClass, outputFormatClass) + } + + /** + * Save each RDD in `this` DStream as a Hadoop file. The file name at each batch interval is + * generated based on `prefix` and `suffix`: "prefix-TIME_IN_MS.suffix". + */ + def saveAsNewAPIHadoopFiles( + prefix: String, + suffix: String, + keyClass: Class[_], + valueClass: Class[_], + outputFormatClass: Class[_ <: NewOutputFormat[_, _]], + conf: Configuration = new Configuration) { + dstream.saveAsNewAPIHadoopFiles(prefix, suffix, keyClass, valueClass, outputFormatClass, conf) + } + + override val classManifest: ClassManifest[(K, V)] = + implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[Tuple2[K, V]]] +} + +object JavaPairDStream { + implicit def fromPairDStream[K: ClassManifest, V: ClassManifest](dstream: DStream[(K, V)]) + :JavaPairDStream[K, V] = + new JavaPairDStream[K, V](dstream) + + def fromJavaDStream[K, V](dstream: JavaDStream[(K, V)]): JavaPairDStream[K, V] = { + implicit val cmk: ClassManifest[K] = + implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[K]] + implicit val cmv: ClassManifest[V] = + implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[V]] + new JavaPairDStream[K, V](dstream.dstream) + } + + def scalaToJavaLong[K: ClassManifest](dstream: JavaPairDStream[K, Long]) + : JavaPairDStream[K, JLong] = { + StreamingContext.toPairDStreamFunctions(dstream.dstream).mapValues(new JLong(_)) + } +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala new file mode 100644 index 0000000000..f10beb1db3 --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala @@ -0,0 +1,614 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.api.java + +import org.apache.spark.streaming._ +import receivers.{ActorReceiver, ReceiverSupervisorStrategy} +import org.apache.spark.streaming.dstream._ +import org.apache.spark.storage.StorageLevel +import org.apache.spark.api.java.function.{Function => JFunction, Function2 => JFunction2} +import org.apache.spark.api.java.{JavaSparkContext, JavaRDD} +import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat} +import twitter4j.Status +import akka.actor.Props +import akka.actor.SupervisorStrategy +import akka.zeromq.Subscribe +import scala.collection.JavaConversions._ +import java.lang.{Long => JLong, Integer => JInt} +import java.io.InputStream +import java.util.{Map => JMap} +import twitter4j.auth.Authorization +import org.apache.spark.RDD + +/** + * A StreamingContext is the main entry point for Spark Streaming functionality. Besides the basic + * information (such as, cluster URL and job name) to internally create a SparkContext, it provides + * methods used to create DStream from various input sources. + */ +class JavaStreamingContext(val ssc: StreamingContext) { + + // TODOs: + // - Test to/from Hadoop functions + // - Support creating and registering InputStreams + + + /** + * Creates a StreamingContext. + * @param master Name of the Spark Master + * @param appName Name to be used when registering with the scheduler + * @param batchDuration The time interval at which streaming data will be divided into batches + */ + def this(master: String, appName: String, batchDuration: Duration) = + this(new StreamingContext(master, appName, batchDuration, null, Nil, Map())) + + /** + * Creates a StreamingContext. + * @param master Name of the Spark Master + * @param appName Name to be used when registering with the scheduler + * @param batchDuration The time interval at which streaming data will be divided into batches + * @param sparkHome The SPARK_HOME directory on the slave nodes + * @param jarFile JAR file containing job code, to ship to cluster. This can be a path on the local + * file system or an HDFS, HTTP, HTTPS, or FTP URL. + */ + def this( + master: String, + appName: String, + batchDuration: Duration, + sparkHome: String, + jarFile: String) = + this(new StreamingContext(master, appName, batchDuration, sparkHome, Seq(jarFile), Map())) + + /** + * Creates a StreamingContext. + * @param master Name of the Spark Master + * @param appName Name to be used when registering with the scheduler + * @param batchDuration The time interval at which streaming data will be divided into batches + * @param sparkHome The SPARK_HOME directory on the slave nodes + * @param jars Collection of JARs to send to the cluster. These can be paths on the local file + * system or HDFS, HTTP, HTTPS, or FTP URLs. + */ + def this( + master: String, + appName: String, + batchDuration: Duration, + sparkHome: String, + jars: Array[String]) = + this(new StreamingContext(master, appName, batchDuration, sparkHome, jars, Map())) + + /** + * Creates a StreamingContext. + * @param master Name of the Spark Master + * @param appName Name to be used when registering with the scheduler + * @param batchDuration The time interval at which streaming data will be divided into batches + * @param sparkHome The SPARK_HOME directory on the slave nodes + * @param jars Collection of JARs to send to the cluster. These can be paths on the local file + * system or HDFS, HTTP, HTTPS, or FTP URLs. + * @param environment Environment variables to set on worker nodes + */ + def this( + master: String, + appName: String, + batchDuration: Duration, + sparkHome: String, + jars: Array[String], + environment: JMap[String, String]) = + this(new StreamingContext(master, appName, batchDuration, sparkHome, jars, environment)) + + /** + * Creates a StreamingContext using an existing SparkContext. + * @param sparkContext The underlying JavaSparkContext to use + * @param batchDuration The time interval at which streaming data will be divided into batches + */ + def this(sparkContext: JavaSparkContext, batchDuration: Duration) = + this(new StreamingContext(sparkContext.sc, batchDuration)) + + /** + * Re-creates a StreamingContext from a checkpoint file. + * @param path Path either to the directory that was specified as the checkpoint directory, or + * to the checkpoint file 'graph' or 'graph.bk'. + */ + def this(path: String) = this (new StreamingContext(path)) + + /** The underlying SparkContext */ + val sc: JavaSparkContext = new JavaSparkContext(ssc.sc) + + /** + * Create an input stream that pulls messages form a Kafka Broker. + * @param zkQuorum Zookeper quorum (hostname:port,hostname:port,..). + * @param groupId The group id for this consumer. + * @param topics Map of (topic_name -> numPartitions) to consume. Each partition is consumed + * in its own thread. + */ + def kafkaStream( + zkQuorum: String, + groupId: String, + topics: JMap[String, JInt]) + : JavaDStream[String] = { + implicit val cmt: ClassManifest[String] = + implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[String]] + ssc.kafkaStream(zkQuorum, groupId, Map(topics.mapValues(_.intValue()).toSeq: _*), + StorageLevel.MEMORY_ONLY_SER_2) + } + + /** + * Create an input stream that pulls messages form a Kafka Broker. + * @param zkQuorum Zookeper quorum (hostname:port,hostname:port,..). + * @param groupId The group id for this consumer. + * @param topics Map of (topic_name -> numPartitions) to consume. Each partition is consumed + * in its own thread. + * @param storageLevel RDD storage level. Defaults to memory-only + * + */ + def kafkaStream( + zkQuorum: String, + groupId: String, + topics: JMap[String, JInt], + storageLevel: StorageLevel) + : JavaDStream[String] = { + implicit val cmt: ClassManifest[String] = + implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[String]] + ssc.kafkaStream(zkQuorum, groupId, Map(topics.mapValues(_.intValue()).toSeq: _*), + storageLevel) + } + + /** + * Create an input stream that pulls messages form a Kafka Broker. + * @param typeClass Type of RDD + * @param decoderClass Type of kafka decoder + * @param kafkaParams Map of kafka configuration paramaters. + * See: http://kafka.apache.org/configuration.html + * @param topics Map of (topic_name -> numPartitions) to consume. Each partition is consumed + * in its own thread. + * @param storageLevel RDD storage level. Defaults to memory-only + */ + def kafkaStream[T, D <: kafka.serializer.Decoder[_]]( + typeClass: Class[T], + decoderClass: Class[D], + kafkaParams: JMap[String, String], + topics: JMap[String, JInt], + storageLevel: StorageLevel) + : JavaDStream[T] = { + implicit val cmt: ClassManifest[T] = + implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]] + implicit val cmd: Manifest[D] = implicitly[Manifest[AnyRef]].asInstanceOf[Manifest[D]] + ssc.kafkaStream[T, D]( + kafkaParams.toMap, + Map(topics.mapValues(_.intValue()).toSeq: _*), + storageLevel) + } + + /** + * Create a input stream from network source hostname:port. Data is received using + * a TCP socket and the receive bytes is interpreted as UTF8 encoded \n delimited + * lines. + * @param hostname Hostname to connect to for receiving data + * @param port Port to connect to for receiving data + * @param storageLevel Storage level to use for storing the received objects + * (default: StorageLevel.MEMORY_AND_DISK_SER_2) + */ + def socketTextStream(hostname: String, port: Int, storageLevel: StorageLevel) + : JavaDStream[String] = { + ssc.socketTextStream(hostname, port, storageLevel) + } + + /** + * Create a input stream from network source hostname:port. Data is received using + * a TCP socket and the receive bytes is interpreted as UTF8 encoded \n delimited + * lines. + * @param hostname Hostname to connect to for receiving data + * @param port Port to connect to for receiving data + */ + def socketTextStream(hostname: String, port: Int): JavaDStream[String] = { + ssc.socketTextStream(hostname, port) + } + + /** + * Create a input stream from network source hostname:port. Data is received using + * a TCP socket and the receive bytes it interepreted as object using the given + * converter. + * @param hostname Hostname to connect to for receiving data + * @param port Port to connect to for receiving data + * @param converter Function to convert the byte stream to objects + * @param storageLevel Storage level to use for storing the received objects + * @tparam T Type of the objects received (after converting bytes to objects) + */ + def socketStream[T]( + hostname: String, + port: Int, + converter: JFunction[InputStream, java.lang.Iterable[T]], + storageLevel: StorageLevel) + : JavaDStream[T] = { + def fn = (x: InputStream) => converter.apply(x).toIterator + implicit val cmt: ClassManifest[T] = + implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]] + ssc.socketStream(hostname, port, fn, storageLevel) + } + + /** + * Creates a input stream that monitors a Hadoop-compatible filesystem + * for new files and reads them as text files (using key as LongWritable, value + * as Text and input format as TextInputFormat). File names starting with . are ignored. + * @param directory HDFS directory to monitor for new file + */ + def textFileStream(directory: String): JavaDStream[String] = { + ssc.textFileStream(directory) + } + + /** + * Create a input stream from network source hostname:port, where data is received + * as serialized blocks (serialized using the Spark's serializer) that can be directly + * pushed into the block manager without deserializing them. This is the most efficient + * way to receive data. + * @param hostname Hostname to connect to for receiving data + * @param port Port to connect to for receiving data + * @param storageLevel Storage level to use for storing the received objects + * @tparam T Type of the objects in the received blocks + */ + def rawSocketStream[T]( + hostname: String, + port: Int, + storageLevel: StorageLevel): JavaDStream[T] = { + implicit val cmt: ClassManifest[T] = + implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]] + JavaDStream.fromDStream(ssc.rawSocketStream(hostname, port, storageLevel)) + } + + /** + * Create a input stream from network source hostname:port, where data is received + * as serialized blocks (serialized using the Spark's serializer) that can be directly + * pushed into the block manager without deserializing them. This is the most efficient + * way to receive data. + * @param hostname Hostname to connect to for receiving data + * @param port Port to connect to for receiving data + * @tparam T Type of the objects in the received blocks + */ + def rawSocketStream[T](hostname: String, port: Int): JavaDStream[T] = { + implicit val cmt: ClassManifest[T] = + implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]] + JavaDStream.fromDStream(ssc.rawSocketStream(hostname, port)) + } + + /** + * Creates a input stream that monitors a Hadoop-compatible filesystem + * for new files and reads them using the given key-value types and input format. + * File names starting with . are ignored. + * @param directory HDFS directory to monitor for new file + * @tparam K Key type for reading HDFS file + * @tparam V Value type for reading HDFS file + * @tparam F Input format for reading HDFS file + */ + def fileStream[K, V, F <: NewInputFormat[K, V]](directory: String): JavaPairDStream[K, V] = { + implicit val cmk: ClassManifest[K] = + implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[K]] + implicit val cmv: ClassManifest[V] = + implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[V]] + implicit val cmf: ClassManifest[F] = + implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[F]] + ssc.fileStream[K, V, F](directory); + } + + /** + * Creates a input stream from a Flume source. + * @param hostname Hostname of the slave machine to which the flume data will be sent + * @param port Port of the slave machine to which the flume data will be sent + * @param storageLevel Storage level to use for storing the received objects + */ + def flumeStream(hostname: String, port: Int, storageLevel: StorageLevel): + JavaDStream[SparkFlumeEvent] = { + ssc.flumeStream(hostname, port, storageLevel) + } + + + /** + * Creates a input stream from a Flume source. + * @param hostname Hostname of the slave machine to which the flume data will be sent + * @param port Port of the slave machine to which the flume data will be sent + */ + def flumeStream(hostname: String, port: Int): JavaDStream[SparkFlumeEvent] = { + ssc.flumeStream(hostname, port) + } + + /** + * Create a input stream that returns tweets received from Twitter. + * @param twitterAuth Twitter4J Authorization object + * @param filters Set of filter strings to get only those tweets that match them + * @param storageLevel Storage level to use for storing the received objects + */ + def twitterStream( + twitterAuth: Authorization, + filters: Array[String], + storageLevel: StorageLevel + ): JavaDStream[Status] = { + ssc.twitterStream(Some(twitterAuth), filters, storageLevel) + } + + /** + * Create a input stream that returns tweets received from Twitter using Twitter4J's default + * OAuth authentication; this requires the system properties twitter4j.oauth.consumerKey, + * .consumerSecret, .accessToken and .accessTokenSecret to be set. + * @param filters Set of filter strings to get only those tweets that match them + * @param storageLevel Storage level to use for storing the received objects + */ + def twitterStream( + filters: Array[String], + storageLevel: StorageLevel + ): JavaDStream[Status] = { + ssc.twitterStream(None, filters, storageLevel) + } + + /** + * Create a input stream that returns tweets received from Twitter. + * @param twitterAuth Twitter4J Authorization + * @param filters Set of filter strings to get only those tweets that match them + */ + def twitterStream( + twitterAuth: Authorization, + filters: Array[String] + ): JavaDStream[Status] = { + ssc.twitterStream(Some(twitterAuth), filters) + } + + /** + * Create a input stream that returns tweets received from Twitter using Twitter4J's default + * OAuth authentication; this requires the system properties twitter4j.oauth.consumerKey, + * .consumerSecret, .accessToken and .accessTokenSecret to be set. + * @param filters Set of filter strings to get only those tweets that match them + */ + def twitterStream( + filters: Array[String] + ): JavaDStream[Status] = { + ssc.twitterStream(None, filters) + } + + /** + * Create a input stream that returns tweets received from Twitter. + * @param twitterAuth Twitter4J Authorization + */ + def twitterStream( + twitterAuth: Authorization + ): JavaDStream[Status] = { + ssc.twitterStream(Some(twitterAuth)) + } + + /** + * Create a input stream that returns tweets received from Twitter using Twitter4J's default + * OAuth authentication; this requires the system properties twitter4j.oauth.consumerKey, + * .consumerSecret, .accessToken and .accessTokenSecret to be set. + */ + def twitterStream(): JavaDStream[Status] = { + ssc.twitterStream() + } + + /** + * Create an input stream with any arbitrary user implemented actor receiver. + * @param props Props object defining creation of the actor + * @param name Name of the actor + * @param storageLevel Storage level to use for storing the received objects + * + * @note An important point to note: + * Since Actor may exist outside the spark framework, It is thus user's responsibility + * to ensure the type safety, i.e parametrized type of data received and actorStream + * should be same. + */ + def actorStream[T]( + props: Props, + name: String, + storageLevel: StorageLevel, + supervisorStrategy: SupervisorStrategy + ): JavaDStream[T] = { + implicit val cm: ClassManifest[T] = + implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]] + ssc.actorStream[T](props, name, storageLevel, supervisorStrategy) + } + + /** + * Create an input stream with any arbitrary user implemented actor receiver. + * @param props Props object defining creation of the actor + * @param name Name of the actor + * @param storageLevel Storage level to use for storing the received objects + * + * @note An important point to note: + * Since Actor may exist outside the spark framework, It is thus user's responsibility + * to ensure the type safety, i.e parametrized type of data received and actorStream + * should be same. + */ + def actorStream[T]( + props: Props, + name: String, + storageLevel: StorageLevel + ): JavaDStream[T] = { + implicit val cm: ClassManifest[T] = + implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]] + ssc.actorStream[T](props, name, storageLevel) + } + + /** + * Create an input stream with any arbitrary user implemented actor receiver. + * @param props Props object defining creation of the actor + * @param name Name of the actor + * + * @note An important point to note: + * Since Actor may exist outside the spark framework, It is thus user's responsibility + * to ensure the type safety, i.e parametrized type of data received and actorStream + * should be same. + */ + def actorStream[T]( + props: Props, + name: String + ): JavaDStream[T] = { + implicit val cm: ClassManifest[T] = + implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]] + ssc.actorStream[T](props, name) + } + + /** + * Create an input stream that receives messages pushed by a zeromq publisher. + * @param publisherUrl Url of remote zeromq publisher + * @param subscribe topic to subscribe to + * @param bytesToObjects A zeroMQ stream publishes sequence of frames for each topic and each frame has sequence + * of byte thus it needs the converter(which might be deserializer of bytes) + * to translate from sequence of sequence of bytes, where sequence refer to a frame + * and sub sequence refer to its payload. + * @param storageLevel Storage level to use for storing the received objects + */ + def zeroMQStream[T]( + publisherUrl:String, + subscribe: Subscribe, + bytesToObjects: Seq[Seq[Byte]] ⇒ Iterator[T], + storageLevel: StorageLevel, + supervisorStrategy: SupervisorStrategy + ): JavaDStream[T] = { + implicit val cm: ClassManifest[T] = + implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]] + ssc.zeroMQStream[T](publisherUrl, subscribe, bytesToObjects, storageLevel, supervisorStrategy) + } + + /** + * Create an input stream that receives messages pushed by a zeromq publisher. + * @param publisherUrl Url of remote zeromq publisher + * @param subscribe topic to subscribe to + * @param bytesToObjects A zeroMQ stream publishes sequence of frames for each topic and each frame has sequence + * of byte thus it needs the converter(which might be deserializer of bytes) + * to translate from sequence of sequence of bytes, where sequence refer to a frame + * and sub sequence refer to its payload. + * @param storageLevel RDD storage level. Defaults to memory-only. + */ + def zeroMQStream[T]( + publisherUrl:String, + subscribe: Subscribe, + bytesToObjects: JFunction[Array[Array[Byte]], java.lang.Iterable[T]], + storageLevel: StorageLevel + ): JavaDStream[T] = { + implicit val cm: ClassManifest[T] = + implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]] + def fn(x: Seq[Seq[Byte]]) = bytesToObjects.apply(x.map(_.toArray).toArray).toIterator + ssc.zeroMQStream[T](publisherUrl, subscribe, fn, storageLevel) + } + + /** + * Create an input stream that receives messages pushed by a zeromq publisher. + * @param publisherUrl Url of remote zeromq publisher + * @param subscribe topic to subscribe to + * @param bytesToObjects A zeroMQ stream publishes sequence of frames for each topic and each frame has sequence + * of byte thus it needs the converter(which might be deserializer of bytes) + * to translate from sequence of sequence of bytes, where sequence refer to a frame + * and sub sequence refer to its payload. + */ + def zeroMQStream[T]( + publisherUrl:String, + subscribe: Subscribe, + bytesToObjects: JFunction[Array[Array[Byte]], java.lang.Iterable[T]] + ): JavaDStream[T] = { + implicit val cm: ClassManifest[T] = + implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]] + def fn(x: Seq[Seq[Byte]]) = bytesToObjects.apply(x.map(_.toArray).toArray).toIterator + ssc.zeroMQStream[T](publisherUrl, subscribe, fn) + } + + /** + * Registers an output stream that will be computed every interval + */ + def registerOutputStream(outputStream: JavaDStreamLike[_, _, _]) { + ssc.registerOutputStream(outputStream.dstream) + } + + /** + * Creates a input stream from an queue of RDDs. In each batch, + * it will process either one or all of the RDDs returned by the queue. + * + * NOTE: changes to the queue after the stream is created will not be recognized. + * @param queue Queue of RDDs + * @tparam T Type of objects in the RDD + */ + def queueStream[T](queue: java.util.Queue[JavaRDD[T]]): JavaDStream[T] = { + implicit val cm: ClassManifest[T] = + implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]] + val sQueue = new scala.collection.mutable.Queue[RDD[T]] + sQueue.enqueue(queue.map(_.rdd).toSeq: _*) + ssc.queueStream(sQueue) + } + + /** + * Creates a input stream from an queue of RDDs. In each batch, + * it will process either one or all of the RDDs returned by the queue. + * + * NOTE: changes to the queue after the stream is created will not be recognized. + * @param queue Queue of RDDs + * @param oneAtATime Whether only one RDD should be consumed from the queue in every interval + * @tparam T Type of objects in the RDD + */ + def queueStream[T](queue: java.util.Queue[JavaRDD[T]], oneAtATime: Boolean): JavaDStream[T] = { + implicit val cm: ClassManifest[T] = + implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]] + val sQueue = new scala.collection.mutable.Queue[RDD[T]] + sQueue.enqueue(queue.map(_.rdd).toSeq: _*) + ssc.queueStream(sQueue, oneAtATime) + } + + /** + * Creates a input stream from an queue of RDDs. In each batch, + * it will process either one or all of the RDDs returned by the queue. + * + * NOTE: changes to the queue after the stream is created will not be recognized. + * @param queue Queue of RDDs + * @param oneAtATime Whether only one RDD should be consumed from the queue in every interval + * @param defaultRDD Default RDD is returned by the DStream when the queue is empty + * @tparam T Type of objects in the RDD + */ + def queueStream[T]( + queue: java.util.Queue[JavaRDD[T]], + oneAtATime: Boolean, + defaultRDD: JavaRDD[T]): JavaDStream[T] = { + implicit val cm: ClassManifest[T] = + implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]] + val sQueue = new scala.collection.mutable.Queue[RDD[T]] + sQueue.enqueue(queue.map(_.rdd).toSeq: _*) + ssc.queueStream(sQueue, oneAtATime, defaultRDD.rdd) + } + + /** + * Sets the context to periodically checkpoint the DStream operations for master + * fault-tolerance. The graph will be checkpointed every batch interval. + * @param directory HDFS-compatible directory where the checkpoint data will be reliably stored + */ + def checkpoint(directory: String) { + ssc.checkpoint(directory) + } + + /** + * Sets each DStreams in this context to remember RDDs it generated in the last given duration. + * DStreams remember RDDs only for a limited duration of duration and releases them for garbage + * collection. This method allows the developer to specify how to long to remember the RDDs ( + * if the developer wishes to query old data outside the DStream computation). + * @param duration Minimum duration that each DStream should remember its RDDs + */ + def remember(duration: Duration) { + ssc.remember(duration) + } + + /** + * Starts the execution of the streams. + */ + def start() = ssc.start() + + /** + * Sstops the execution of the streams. + */ + def stop() = ssc.stop() + +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/CoGroupedDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/CoGroupedDStream.scala new file mode 100644 index 0000000000..4a9d82211f --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/CoGroupedDStream.scala @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.dstream + +import org.apache.spark.{RDD, Partitioner} +import org.apache.spark.rdd.CoGroupedRDD +import org.apache.spark.streaming.{Time, DStream, Duration} + +private[streaming] +class CoGroupedDStream[K : ClassManifest]( + parents: Seq[DStream[(K, _)]], + partitioner: Partitioner + ) extends DStream[(K, Seq[Seq[_]])](parents.head.ssc) { + + if (parents.length == 0) { + throw new IllegalArgumentException("Empty array of parents") + } + + if (parents.map(_.ssc).distinct.size > 1) { + throw new IllegalArgumentException("Array of parents have different StreamingContexts") + } + + if (parents.map(_.slideDuration).distinct.size > 1) { + throw new IllegalArgumentException("Array of parents have different slide times") + } + + override def dependencies = parents.toList + + override def slideDuration: Duration = parents.head.slideDuration + + override def compute(validTime: Time): Option[RDD[(K, Seq[Seq[_]])]] = { + val part = partitioner + val rdds = parents.flatMap(_.getOrCompute(validTime)) + if (rdds.size > 0) { + val q = new CoGroupedRDD[K](rdds, part) + Some(q) + } else { + None + } + } + +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ConstantInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ConstantInputDStream.scala new file mode 100644 index 0000000000..35cc4cb396 --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ConstantInputDStream.scala @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.dstream + +import org.apache.spark.RDD +import org.apache.spark.streaming.{Time, StreamingContext} + +/** + * An input stream that always returns the same RDD on each timestep. Useful for testing. + */ +class ConstantInputDStream[T: ClassManifest](ssc_ : StreamingContext, rdd: RDD[T]) + extends InputDStream[T](ssc_) { + + override def start() {} + + override def stop() {} + + override def compute(validTime: Time): Option[RDD[T]] = { + Some(rdd) + } +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala new file mode 100644 index 0000000000..1c265ed972 --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala @@ -0,0 +1,199 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.dstream + +import org.apache.spark.RDD +import org.apache.spark.rdd.UnionRDD +import org.apache.spark.streaming.{DStreamCheckpointData, StreamingContext, Time} + +import org.apache.hadoop.fs.{FileSystem, Path, PathFilter} +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat} + +import scala.collection.mutable.{HashSet, HashMap} +import java.io.{ObjectInputStream, IOException} + +private[streaming] +class FileInputDStream[K: ClassManifest, V: ClassManifest, F <: NewInputFormat[K,V] : ClassManifest]( + @transient ssc_ : StreamingContext, + directory: String, + filter: Path => Boolean = FileInputDStream.defaultFilter, + newFilesOnly: Boolean = true) + extends InputDStream[(K, V)](ssc_) { + + protected[streaming] override val checkpointData = new FileInputDStreamCheckpointData + + // Latest file mod time seen till any point of time + private val lastModTimeFiles = new HashSet[String]() + private var lastModTime = 0L + + @transient private var path_ : Path = null + @transient private var fs_ : FileSystem = null + @transient private[streaming] var files = new HashMap[Time, Array[String]] + + override def start() { + if (newFilesOnly) { + lastModTime = graph.zeroTime.milliseconds + } else { + lastModTime = 0 + } + logDebug("LastModTime initialized to " + lastModTime + ", new files only = " + newFilesOnly) + } + + override def stop() { } + + /** + * Finds the files that were modified since the last time this method was called and makes + * a union RDD out of them. Note that this maintains the list of files that were processed + * in the latest modification time in the previous call to this method. This is because the + * modification time returned by the FileStatus API seems to return times only at the + * granularity of seconds. And new files may have the same modification time as the + * latest modification time in the previous call to this method yet was not reported in + * the previous call. + */ + override def compute(validTime: Time): Option[RDD[(K, V)]] = { + assert(validTime.milliseconds >= lastModTime, "Trying to get new files for really old time [" + validTime + " < " + lastModTime) + + // Create the filter for selecting new files + val newFilter = new PathFilter() { + // Latest file mod time seen in this round of fetching files and its corresponding files + var latestModTime = 0L + val latestModTimeFiles = new HashSet[String]() + + def accept(path: Path): Boolean = { + if (!filter(path)) { // Reject file if it does not satisfy filter + logDebug("Rejected by filter " + path) + return false + } else { // Accept file only if + val modTime = fs.getFileStatus(path).getModificationTime() + logDebug("Mod time for " + path + " is " + modTime) + if (modTime < lastModTime) { + logDebug("Mod time less than last mod time") + return false // If the file was created before the last time it was called + } else if (modTime == lastModTime && lastModTimeFiles.contains(path.toString)) { + logDebug("Mod time equal to last mod time, but file considered already") + return false // If the file was created exactly as lastModTime but not reported yet + } else if (modTime > validTime.milliseconds) { + logDebug("Mod time more than valid time") + return false // If the file was created after the time this function call requires + } + if (modTime > latestModTime) { + latestModTime = modTime + latestModTimeFiles.clear() + logDebug("Latest mod time updated to " + latestModTime) + } + latestModTimeFiles += path.toString + logDebug("Accepted " + path) + return true + } + } + } + logDebug("Finding new files at time " + validTime + " for last mod time = " + lastModTime) + val newFiles = fs.listStatus(path, newFilter).map(_.getPath.toString) + logInfo("New files at time " + validTime + ":\n" + newFiles.mkString("\n")) + if (newFiles.length > 0) { + // Update the modification time and the files processed for that modification time + if (lastModTime != newFilter.latestModTime) { + lastModTime = newFilter.latestModTime + lastModTimeFiles.clear() + } + lastModTimeFiles ++= newFilter.latestModTimeFiles + logDebug("Last mod time updated to " + lastModTime) + } + files += ((validTime, newFiles)) + Some(filesToRDD(newFiles)) + } + + /** Clear the old time-to-files mappings along with old RDDs */ + protected[streaming] override def clearOldMetadata(time: Time) { + super.clearOldMetadata(time) + val oldFiles = files.filter(_._1 <= (time - rememberDuration)) + files --= oldFiles.keys + logInfo("Cleared " + oldFiles.size + " old files that were older than " + + (time - rememberDuration) + ": " + oldFiles.keys.mkString(", ")) + logDebug("Cleared files are:\n" + + oldFiles.map(p => (p._1, p._2.mkString(", "))).mkString("\n")) + } + + /** Generate one RDD from an array of files */ + protected[streaming] def filesToRDD(files: Seq[String]): RDD[(K, V)] = { + new UnionRDD( + context.sparkContext, + files.map(file => context.sparkContext.newAPIHadoopFile[K, V, F](file)) + ) + } + + private def path: Path = { + if (path_ == null) path_ = new Path(directory) + path_ + } + + private def fs: FileSystem = { + if (fs_ == null) fs_ = path.getFileSystem(new Configuration()) + fs_ + } + + @throws(classOf[IOException]) + private def readObject(ois: ObjectInputStream) { + logDebug(this.getClass().getSimpleName + ".readObject used") + ois.defaultReadObject() + generatedRDDs = new HashMap[Time, RDD[(K,V)]] () + files = new HashMap[Time, Array[String]] + } + + /** + * A custom version of the DStreamCheckpointData that stores names of + * Hadoop files as checkpoint data. + */ + private[streaming] + class FileInputDStreamCheckpointData extends DStreamCheckpointData(this) { + + def hadoopFiles = data.asInstanceOf[HashMap[Time, Array[String]]] + + override def update() { + hadoopFiles.clear() + hadoopFiles ++= files + } + + override def cleanup() { } + + override def restore() { + hadoopFiles.foreach { + case (t, f) => { + // Restore the metadata in both files and generatedRDDs + logInfo("Restoring files for time " + t + " - " + + f.mkString("[", ", ", "]") ) + files += ((t, f)) + generatedRDDs += ((t, filesToRDD(f))) + } + } + } + + override def toString() = { + "[\n" + hadoopFiles.size + " file sets\n" + + hadoopFiles.map(p => (p._1, p._2.mkString(", "))).mkString("\n") + "\n]" + } + } +} + +private[streaming] +object FileInputDStream { + def defaultFilter(path: Path): Boolean = !path.getName().startsWith(".") +} + + diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FilteredDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FilteredDStream.scala new file mode 100644 index 0000000000..3166c68760 --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FilteredDStream.scala @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.dstream + +import org.apache.spark.streaming.{Duration, DStream, Time} +import org.apache.spark.RDD + +private[streaming] +class FilteredDStream[T: ClassManifest]( + parent: DStream[T], + filterFunc: T => Boolean + ) extends DStream[T](parent.ssc) { + + override def dependencies = List(parent) + + override def slideDuration: Duration = parent.slideDuration + + override def compute(validTime: Time): Option[RDD[T]] = { + parent.getOrCompute(validTime).map(_.filter(filterFunc)) + } +} + + diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FlatMapValuedDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FlatMapValuedDStream.scala new file mode 100644 index 0000000000..21950ad6ac --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FlatMapValuedDStream.scala @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.dstream + +import org.apache.spark.streaming.{Duration, DStream, Time} +import org.apache.spark.RDD +import org.apache.spark.SparkContext._ + +private[streaming] +class FlatMapValuedDStream[K: ClassManifest, V: ClassManifest, U: ClassManifest]( + parent: DStream[(K, V)], + flatMapValueFunc: V => TraversableOnce[U] + ) extends DStream[(K, U)](parent.ssc) { + + override def dependencies = List(parent) + + override def slideDuration: Duration = parent.slideDuration + + override def compute(validTime: Time): Option[RDD[(K, U)]] = { + parent.getOrCompute(validTime).map(_.flatMapValues[U](flatMapValueFunc)) + } +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FlatMappedDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FlatMappedDStream.scala new file mode 100644 index 0000000000..8377cfe60c --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FlatMappedDStream.scala @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.dstream + +import org.apache.spark.streaming.{Duration, DStream, Time} +import org.apache.spark.RDD + +private[streaming] +class FlatMappedDStream[T: ClassManifest, U: ClassManifest]( + parent: DStream[T], + flatMapFunc: T => Traversable[U] + ) extends DStream[U](parent.ssc) { + + override def dependencies = List(parent) + + override def slideDuration: Duration = parent.slideDuration + + override def compute(validTime: Time): Option[RDD[U]] = { + parent.getOrCompute(validTime).map(_.flatMap(flatMapFunc)) + } +} + diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FlumeInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FlumeInputDStream.scala new file mode 100644 index 0000000000..3fb443143c --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FlumeInputDStream.scala @@ -0,0 +1,154 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.dstream + +import org.apache.spark.streaming.StreamingContext + +import org.apache.spark.Utils +import org.apache.spark.storage.StorageLevel + +import org.apache.flume.source.avro.AvroSourceProtocol +import org.apache.flume.source.avro.AvroFlumeEvent +import org.apache.flume.source.avro.Status +import org.apache.avro.ipc.specific.SpecificResponder +import org.apache.avro.ipc.NettyServer + +import scala.collection.JavaConversions._ + +import java.net.InetSocketAddress +import java.io.{ObjectInput, ObjectOutput, Externalizable} +import java.nio.ByteBuffer + +private[streaming] +class FlumeInputDStream[T: ClassManifest]( + @transient ssc_ : StreamingContext, + host: String, + port: Int, + storageLevel: StorageLevel +) extends NetworkInputDStream[SparkFlumeEvent](ssc_) { + + override def getReceiver(): NetworkReceiver[SparkFlumeEvent] = { + new FlumeReceiver(host, port, storageLevel) + } +} + +/** + * A wrapper class for AvroFlumeEvent's with a custom serialization format. + * + * This is necessary because AvroFlumeEvent uses inner data structures + * which are not serializable. + */ +class SparkFlumeEvent() extends Externalizable { + var event : AvroFlumeEvent = new AvroFlumeEvent() + + /* De-serialize from bytes. */ + def readExternal(in: ObjectInput) { + val bodyLength = in.readInt() + val bodyBuff = new Array[Byte](bodyLength) + in.read(bodyBuff) + + val numHeaders = in.readInt() + val headers = new java.util.HashMap[CharSequence, CharSequence] + + for (i <- 0 until numHeaders) { + val keyLength = in.readInt() + val keyBuff = new Array[Byte](keyLength) + in.read(keyBuff) + val key : String = Utils.deserialize(keyBuff) + + val valLength = in.readInt() + val valBuff = new Array[Byte](valLength) + in.read(valBuff) + val value : String = Utils.deserialize(valBuff) + + headers.put(key, value) + } + + event.setBody(ByteBuffer.wrap(bodyBuff)) + event.setHeaders(headers) + } + + /* Serialize to bytes. */ + def writeExternal(out: ObjectOutput) { + val body = event.getBody.array() + out.writeInt(body.length) + out.write(body) + + val numHeaders = event.getHeaders.size() + out.writeInt(numHeaders) + for ((k, v) <- event.getHeaders) { + val keyBuff = Utils.serialize(k.toString) + out.writeInt(keyBuff.length) + out.write(keyBuff) + val valBuff = Utils.serialize(v.toString) + out.writeInt(valBuff.length) + out.write(valBuff) + } + } +} + +private[streaming] object SparkFlumeEvent { + def fromAvroFlumeEvent(in : AvroFlumeEvent) : SparkFlumeEvent = { + val event = new SparkFlumeEvent + event.event = in + event + } +} + +/** A simple server that implements Flume's Avro protocol. */ +private[streaming] +class FlumeEventServer(receiver : FlumeReceiver) extends AvroSourceProtocol { + override def append(event : AvroFlumeEvent) : Status = { + receiver.blockGenerator += SparkFlumeEvent.fromAvroFlumeEvent(event) + Status.OK + } + + override def appendBatch(events : java.util.List[AvroFlumeEvent]) : Status = { + events.foreach (event => + receiver.blockGenerator += SparkFlumeEvent.fromAvroFlumeEvent(event)) + Status.OK + } +} + +/** A NetworkReceiver which listens for events using the + * Flume Avro interface.*/ +private[streaming] +class FlumeReceiver( + host: String, + port: Int, + storageLevel: StorageLevel + ) extends NetworkReceiver[SparkFlumeEvent] { + + lazy val blockGenerator = new BlockGenerator(storageLevel) + + protected override def onStart() { + val responder = new SpecificResponder( + classOf[AvroSourceProtocol], new FlumeEventServer(this)); + val server = new NettyServer(responder, new InetSocketAddress(host, port)); + blockGenerator.start() + server.start() + logInfo("Flume receiver started") + } + + protected override def onStop() { + blockGenerator.stop() + logInfo("Flume receiver stopped") + } + + override def getLocationPreference = Some(host) +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ForEachDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ForEachDStream.scala new file mode 100644 index 0000000000..c1f95650c8 --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ForEachDStream.scala @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.dstream + +import org.apache.spark.RDD +import org.apache.spark.streaming.{Duration, DStream, Job, Time} + +private[streaming] +class ForEachDStream[T: ClassManifest] ( + parent: DStream[T], + foreachFunc: (RDD[T], Time) => Unit + ) extends DStream[Unit](parent.ssc) { + + override def dependencies = List(parent) + + override def slideDuration: Duration = parent.slideDuration + + override def compute(validTime: Time): Option[RDD[Unit]] = None + + override def generateJob(time: Time): Option[Job] = { + parent.getOrCompute(time) match { + case Some(rdd) => + val jobFunc = () => { + foreachFunc(rdd, time) + } + Some(new Job(time, jobFunc)) + case None => None + } + } +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/GlommedDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/GlommedDStream.scala new file mode 100644 index 0000000000..1e4c7e7fde --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/GlommedDStream.scala @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.dstream + +import org.apache.spark.streaming.{Duration, DStream, Time} +import org.apache.spark.RDD + +private[streaming] +class GlommedDStream[T: ClassManifest](parent: DStream[T]) + extends DStream[Array[T]](parent.ssc) { + + override def dependencies = List(parent) + + override def slideDuration: Duration = parent.slideDuration + + override def compute(validTime: Time): Option[RDD[Array[T]]] = { + parent.getOrCompute(validTime).map(_.glom()) + } +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala new file mode 100644 index 0000000000..674b27118c --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.dstream + +import org.apache.spark.streaming.{Time, Duration, StreamingContext, DStream} + +/** + * This is the abstract base class for all input streams. This class provides to methods + * start() and stop() which called by the scheduler to start and stop receiving data/ + * Input streams that can generated RDDs from new data just by running a service on + * the driver node (that is, without running a receiver onworker nodes) can be + * implemented by directly subclassing this InputDStream. For example, + * FileInputDStream, a subclass of InputDStream, monitors a HDFS directory for + * new files and generates RDDs on the new files. For implementing input streams + * that requires running a receiver on the worker nodes, use NetworkInputDStream + * as the parent class. + */ +abstract class InputDStream[T: ClassManifest] (@transient ssc_ : StreamingContext) + extends DStream[T](ssc_) { + + var lastValidTime: Time = null + + /** + * Checks whether the 'time' is valid wrt slideDuration for generating RDD. + * Additionally it also ensures valid times are in strictly increasing order. + * This ensures that InputDStream.compute() is called strictly on increasing + * times. + */ + override protected def isTimeValid(time: Time): Boolean = { + if (!super.isTimeValid(time)) { + false // Time not valid + } else { + // Time is valid, but check it it is more than lastValidTime + if (lastValidTime != null && time < lastValidTime) { + logWarning("isTimeValid called with " + time + " where as last valid time is " + lastValidTime) + } + lastValidTime = time + true + } + } + + override def dependencies = List() + + override def slideDuration: Duration = { + if (ssc == null) throw new Exception("ssc is null") + if (ssc.graph.batchDuration == null) throw new Exception("batchDuration is null") + ssc.graph.batchDuration + } + + /** Method called to start receiving data. Subclasses must implement this method. */ + def start() + + /** Method called to stop receiving data. Subclasses must implement this method. */ + def stop() +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/KafkaInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/KafkaInputDStream.scala new file mode 100644 index 0000000000..51e913675d --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/KafkaInputDStream.scala @@ -0,0 +1,141 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.dstream + +import org.apache.spark.Logging +import org.apache.spark.storage.StorageLevel +import org.apache.spark.streaming.{Time, DStreamCheckpointData, StreamingContext} + +import java.util.Properties +import java.util.concurrent.Executors + +import kafka.consumer._ +import kafka.message.{Message, MessageSet, MessageAndMetadata} +import kafka.serializer.Decoder +import kafka.utils.{Utils, ZKGroupTopicDirs} +import kafka.utils.ZkUtils._ +import kafka.utils.ZKStringSerializer +import org.I0Itec.zkclient._ + +import scala.collection.Map +import scala.collection.mutable.HashMap +import scala.collection.JavaConversions._ + + +/** + * Input stream that pulls messages from a Kafka Broker. + * + * @param kafkaParams Map of kafka configuration paramaters. See: http://kafka.apache.org/configuration.html + * @param topics Map of (topic_name -> numPartitions) to consume. Each partition is consumed + * in its own thread. + * @param storageLevel RDD storage level. + */ +private[streaming] +class KafkaInputDStream[T: ClassManifest, D <: Decoder[_]: Manifest]( + @transient ssc_ : StreamingContext, + kafkaParams: Map[String, String], + topics: Map[String, Int], + storageLevel: StorageLevel + ) extends NetworkInputDStream[T](ssc_ ) with Logging { + + + def getReceiver(): NetworkReceiver[T] = { + new KafkaReceiver[T, D](kafkaParams, topics, storageLevel) + .asInstanceOf[NetworkReceiver[T]] + } +} + +private[streaming] +class KafkaReceiver[T: ClassManifest, D <: Decoder[_]: Manifest]( + kafkaParams: Map[String, String], + topics: Map[String, Int], + storageLevel: StorageLevel + ) extends NetworkReceiver[Any] { + + // Handles pushing data into the BlockManager + lazy protected val blockGenerator = new BlockGenerator(storageLevel) + // Connection to Kafka + var consumerConnector : ConsumerConnector = null + + def onStop() { + blockGenerator.stop() + } + + def onStart() { + + blockGenerator.start() + + // In case we are using multiple Threads to handle Kafka Messages + val executorPool = Executors.newFixedThreadPool(topics.values.reduce(_ + _)) + + logInfo("Starting Kafka Consumer Stream with group: " + kafkaParams("groupid")) + + // Kafka connection properties + val props = new Properties() + kafkaParams.foreach(param => props.put(param._1, param._2)) + + // Create the connection to the cluster + logInfo("Connecting to Zookeper: " + kafkaParams("zk.connect")) + val consumerConfig = new ConsumerConfig(props) + consumerConnector = Consumer.create(consumerConfig) + logInfo("Connected to " + kafkaParams("zk.connect")) + + // When autooffset.reset is defined, it is our responsibility to try and whack the + // consumer group zk node. + if (kafkaParams.contains("autooffset.reset")) { + tryZookeeperConsumerGroupCleanup(kafkaParams("zk.connect"), kafkaParams("groupid")) + } + + // Create Threads for each Topic/Message Stream we are listening + val decoder = manifest[D].erasure.newInstance.asInstanceOf[Decoder[T]] + val topicMessageStreams = consumerConnector.createMessageStreams(topics, decoder) + + // Start the messages handler for each partition + topicMessageStreams.values.foreach { streams => + streams.foreach { stream => executorPool.submit(new MessageHandler(stream)) } + } + } + + // Handles Kafka Messages + private class MessageHandler[T: ClassManifest](stream: KafkaStream[T]) extends Runnable { + def run() { + logInfo("Starting MessageHandler.") + for (msgAndMetadata <- stream) { + blockGenerator += msgAndMetadata.message + } + } + } + + // It is our responsibility to delete the consumer group when specifying autooffset.reset. This is because + // Kafka 0.7.2 only honors this param when the group is not in zookeeper. + // + // The kafka high level consumer doesn't expose setting offsets currently, this is a trick copied from Kafkas' + // ConsoleConsumer. See code related to 'autooffset.reset' when it is set to 'smallest'/'largest': + // https://github.com/apache/kafka/blob/0.7.2/core/src/main/scala/kafka/consumer/ConsoleConsumer.scala + private def tryZookeeperConsumerGroupCleanup(zkUrl: String, groupId: String) { + try { + val dir = "/consumers/" + groupId + logInfo("Cleaning up temporary zookeeper data under " + dir + ".") + val zk = new ZkClient(zkUrl, 30*1000, 30*1000, ZKStringSerializer) + zk.deleteRecursive(dir) + zk.close() + } catch { + case _ => // swallow + } + } +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/MapPartitionedDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/MapPartitionedDStream.scala new file mode 100644 index 0000000000..1d79d707bb --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/MapPartitionedDStream.scala @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.dstream + +import org.apache.spark.streaming.{Duration, DStream, Time} +import org.apache.spark.RDD + +private[streaming] +class MapPartitionedDStream[T: ClassManifest, U: ClassManifest]( + parent: DStream[T], + mapPartFunc: Iterator[T] => Iterator[U], + preservePartitioning: Boolean + ) extends DStream[U](parent.ssc) { + + override def dependencies = List(parent) + + override def slideDuration: Duration = parent.slideDuration + + override def compute(validTime: Time): Option[RDD[U]] = { + parent.getOrCompute(validTime).map(_.mapPartitions[U](mapPartFunc, preservePartitioning)) + } +} + diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/MapValuedDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/MapValuedDStream.scala new file mode 100644 index 0000000000..312e0c0567 --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/MapValuedDStream.scala @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.dstream + +import org.apache.spark.streaming.{Duration, DStream, Time} +import org.apache.spark.RDD +import org.apache.spark.SparkContext._ + +private[streaming] +class MapValuedDStream[K: ClassManifest, V: ClassManifest, U: ClassManifest]( + parent: DStream[(K, V)], + mapValueFunc: V => U + ) extends DStream[(K, U)](parent.ssc) { + + override def dependencies = List(parent) + + override def slideDuration: Duration = parent.slideDuration + + override def compute(validTime: Time): Option[RDD[(K, U)]] = { + parent.getOrCompute(validTime).map(_.mapValues[U](mapValueFunc)) + } +} + diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/MappedDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/MappedDStream.scala new file mode 100644 index 0000000000..af688dde5f --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/MappedDStream.scala @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.dstream + +import org.apache.spark.streaming.{Duration, DStream, Time} +import org.apache.spark.RDD + +private[streaming] +class MappedDStream[T: ClassManifest, U: ClassManifest] ( + parent: DStream[T], + mapFunc: T => U + ) extends DStream[U](parent.ssc) { + + override def dependencies = List(parent) + + override def slideDuration: Duration = parent.slideDuration + + override def compute(validTime: Time): Option[RDD[U]] = { + parent.getOrCompute(validTime).map(_.map[U](mapFunc)) + } +} + diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/NetworkInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/NetworkInputDStream.scala new file mode 100644 index 0000000000..3d68da36a2 --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/NetworkInputDStream.scala @@ -0,0 +1,272 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.dstream + +import org.apache.spark.streaming.{Time, StreamingContext, AddBlocks, RegisterReceiver, DeregisterReceiver} + +import org.apache.spark.{Logging, SparkEnv, RDD} +import org.apache.spark.rdd.BlockRDD +import org.apache.spark.storage.StorageLevel + +import scala.collection.mutable.ArrayBuffer + +import java.nio.ByteBuffer + +import akka.actor.{Props, Actor} +import akka.pattern.ask +import akka.dispatch.Await +import akka.util.duration._ +import org.apache.spark.streaming.util.{RecurringTimer, SystemClock} +import java.util.concurrent.ArrayBlockingQueue + +/** + * Abstract class for defining any InputDStream that has to start a receiver on worker + * nodes to receive external data. Specific implementations of NetworkInputDStream must + * define the getReceiver() function that gets the receiver object of type + * [[org.apache.spark.streaming.dstream.NetworkReceiver]] that will be sent to the workers to receive + * data. + * @param ssc_ Streaming context that will execute this input stream + * @tparam T Class type of the object of this stream + */ +abstract class NetworkInputDStream[T: ClassManifest](@transient ssc_ : StreamingContext) + extends InputDStream[T](ssc_) { + + // This is an unique identifier that is used to match the network receiver with the + // corresponding network input stream. + val id = ssc.getNewNetworkStreamId() + + /** + * Gets the receiver object that will be sent to the worker nodes + * to receive data. This method needs to defined by any specific implementation + * of a NetworkInputDStream. + */ + def getReceiver(): NetworkReceiver[T] + + // Nothing to start or stop as both taken care of by the NetworkInputTracker. + def start() {} + + def stop() {} + + override def compute(validTime: Time): Option[RDD[T]] = { + // If this is called for any time before the start time of the context, + // then this returns an empty RDD. This may happen when recovering from a + // master failure + if (validTime >= graph.startTime) { + val blockIds = ssc.networkInputTracker.getBlockIds(id, validTime) + Some(new BlockRDD[T](ssc.sc, blockIds)) + } else { + Some(new BlockRDD[T](ssc.sc, Array[String]())) + } + } +} + + +private[streaming] sealed trait NetworkReceiverMessage +private[streaming] case class StopReceiver(msg: String) extends NetworkReceiverMessage +private[streaming] case class ReportBlock(blockId: String, metadata: Any) extends NetworkReceiverMessage +private[streaming] case class ReportError(msg: String) extends NetworkReceiverMessage + +/** + * Abstract class of a receiver that can be run on worker nodes to receive external data. See + * [[org.apache.spark.streaming.dstream.NetworkInputDStream]] for an explanation. + */ +abstract class NetworkReceiver[T: ClassManifest]() extends Serializable with Logging { + + initLogging() + + lazy protected val env = SparkEnv.get + + lazy protected val actor = env.actorSystem.actorOf( + Props(new NetworkReceiverActor()), "NetworkReceiver-" + streamId) + + lazy protected val receivingThread = Thread.currentThread() + + protected var streamId: Int = -1 + + /** + * This method will be called to start receiving data. All your receiver + * starting code should be implemented by defining this function. + */ + protected def onStart() + + /** This method will be called to stop receiving data. */ + protected def onStop() + + /** Conveys a placement preference (hostname) for this receiver. */ + def getLocationPreference() : Option[String] = None + + /** + * Starts the receiver. First is accesses all the lazy members to + * materialize them. Then it calls the user-defined onStart() method to start + * other threads, etc required to receiver the data. + */ + def start() { + try { + // Access the lazy vals to materialize them + env + actor + receivingThread + + // Call user-defined onStart() + onStart() + } catch { + case ie: InterruptedException => + logInfo("Receiving thread interrupted") + //println("Receiving thread interrupted") + case e: Exception => + stopOnError(e) + } + } + + /** + * Stops the receiver. First it interrupts the main receiving thread, + * that is, the thread that called receiver.start(). Then it calls the user-defined + * onStop() method to stop other threads and/or do cleanup. + */ + def stop() { + receivingThread.interrupt() + onStop() + //TODO: terminate the actor + } + + /** + * Stops the receiver and reports exception to the tracker. + * This should be called whenever an exception is to be handled on any thread + * of the receiver. + */ + protected def stopOnError(e: Exception) { + logError("Error receiving data", e) + stop() + actor ! ReportError(e.toString) + } + + + /** + * Pushes a block (as an ArrayBuffer filled with data) into the block manager. + */ + def pushBlock(blockId: String, arrayBuffer: ArrayBuffer[T], metadata: Any, level: StorageLevel) { + env.blockManager.put(blockId, arrayBuffer.asInstanceOf[ArrayBuffer[Any]], level) + actor ! ReportBlock(blockId, metadata) + } + + /** + * Pushes a block (as bytes) into the block manager. + */ + def pushBlock(blockId: String, bytes: ByteBuffer, metadata: Any, level: StorageLevel) { + env.blockManager.putBytes(blockId, bytes, level) + actor ! ReportBlock(blockId, metadata) + } + + /** A helper actor that communicates with the NetworkInputTracker */ + private class NetworkReceiverActor extends Actor { + logInfo("Attempting to register with tracker") + val ip = System.getProperty("spark.driver.host", "localhost") + val port = System.getProperty("spark.driver.port", "7077").toInt + val url = "akka://spark@%s:%s/user/NetworkInputTracker".format(ip, port) + val tracker = env.actorSystem.actorFor(url) + val timeout = 5.seconds + + override def preStart() { + val future = tracker.ask(RegisterReceiver(streamId, self))(timeout) + Await.result(future, timeout) + } + + override def receive() = { + case ReportBlock(blockId, metadata) => + tracker ! AddBlocks(streamId, Array(blockId), metadata) + case ReportError(msg) => + tracker ! DeregisterReceiver(streamId, msg) + case StopReceiver(msg) => + stop() + tracker ! DeregisterReceiver(streamId, msg) + } + } + + protected[streaming] def setStreamId(id: Int) { + streamId = id + } + + /** + * Batches objects created by a [[org.apache.spark.streaming.dstream.NetworkReceiver]] and puts them into + * appropriately named blocks at regular intervals. This class starts two threads, + * one to periodically start a new batch and prepare the previous batch of as a block, + * the other to push the blocks into the block manager. + */ + class BlockGenerator(storageLevel: StorageLevel) + extends Serializable with Logging { + + case class Block(id: String, buffer: ArrayBuffer[T], metadata: Any = null) + + val clock = new SystemClock() + val blockInterval = System.getProperty("spark.streaming.blockInterval", "200").toLong + val blockIntervalTimer = new RecurringTimer(clock, blockInterval, updateCurrentBuffer) + val blockStorageLevel = storageLevel + val blocksForPushing = new ArrayBlockingQueue[Block](1000) + val blockPushingThread = new Thread() { override def run() { keepPushingBlocks() } } + + var currentBuffer = new ArrayBuffer[T] + + def start() { + blockIntervalTimer.start() + blockPushingThread.start() + logInfo("Data handler started") + } + + def stop() { + blockIntervalTimer.stop() + blockPushingThread.interrupt() + logInfo("Data handler stopped") + } + + def += (obj: T) { + currentBuffer += obj + } + + private def updateCurrentBuffer(time: Long) { + try { + val newBlockBuffer = currentBuffer + currentBuffer = new ArrayBuffer[T] + if (newBlockBuffer.size > 0) { + val blockId = "input-" + NetworkReceiver.this.streamId + "-" + (time - blockInterval) + val newBlock = new Block(blockId, newBlockBuffer) + blocksForPushing.add(newBlock) + } + } catch { + case ie: InterruptedException => + logInfo("Block interval timer thread interrupted") + case e: Exception => + NetworkReceiver.this.stop() + } + } + + private def keepPushingBlocks() { + logInfo("Block pushing thread started") + try { + while(true) { + val block = blocksForPushing.take() + NetworkReceiver.this.pushBlock(block.id, block.buffer, block.metadata, storageLevel) + } + } catch { + case ie: InterruptedException => + logInfo("Block pushing thread interrupted") + case e: Exception => + NetworkReceiver.this.stop() + } + } + } +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/PluggableInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/PluggableInputDStream.scala new file mode 100644 index 0000000000..15782f5c11 --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/PluggableInputDStream.scala @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.dstream + +import org.apache.spark.streaming.StreamingContext + +private[streaming] +class PluggableInputDStream[T: ClassManifest]( + @transient ssc_ : StreamingContext, + receiver: NetworkReceiver[T]) extends NetworkInputDStream[T](ssc_) { + + def getReceiver(): NetworkReceiver[T] = { + receiver + } +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/QueueInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/QueueInputDStream.scala new file mode 100644 index 0000000000..b43ecaeebe --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/QueueInputDStream.scala @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.dstream + +import org.apache.spark.RDD +import org.apache.spark.rdd.UnionRDD + +import scala.collection.mutable.Queue +import scala.collection.mutable.ArrayBuffer +import org.apache.spark.streaming.{Time, StreamingContext} + +private[streaming] +class QueueInputDStream[T: ClassManifest]( + @transient ssc: StreamingContext, + val queue: Queue[RDD[T]], + oneAtATime: Boolean, + defaultRDD: RDD[T] + ) extends InputDStream[T](ssc) { + + override def start() { } + + override def stop() { } + + override def compute(validTime: Time): Option[RDD[T]] = { + val buffer = new ArrayBuffer[RDD[T]]() + if (oneAtATime && queue.size > 0) { + buffer += queue.dequeue() + } else { + buffer ++= queue + } + if (buffer.size > 0) { + if (oneAtATime) { + Some(buffer.head) + } else { + Some(new UnionRDD(ssc.sc, buffer.toSeq)) + } + } else if (defaultRDD != null) { + Some(defaultRDD) + } else { + None + } + } + +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/RawInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/RawInputDStream.scala new file mode 100644 index 0000000000..c91f12ecd7 --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/RawInputDStream.scala @@ -0,0 +1,108 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.dstream + +import org.apache.spark.Logging +import org.apache.spark.storage.StorageLevel +import org.apache.spark.streaming.StreamingContext + +import java.net.InetSocketAddress +import java.nio.ByteBuffer +import java.nio.channels.{ReadableByteChannel, SocketChannel} +import java.io.EOFException +import java.util.concurrent.ArrayBlockingQueue + + +/** + * An input stream that reads blocks of serialized objects from a given network address. + * The blocks will be inserted directly into the block store. This is the fastest way to get + * data into Spark Streaming, though it requires the sender to batch data and serialize it + * in the format that the system is configured with. + */ +private[streaming] +class RawInputDStream[T: ClassManifest]( + @transient ssc_ : StreamingContext, + host: String, + port: Int, + storageLevel: StorageLevel + ) extends NetworkInputDStream[T](ssc_ ) with Logging { + + def getReceiver(): NetworkReceiver[T] = { + new RawNetworkReceiver(host, port, storageLevel).asInstanceOf[NetworkReceiver[T]] + } +} + +private[streaming] +class RawNetworkReceiver(host: String, port: Int, storageLevel: StorageLevel) + extends NetworkReceiver[Any] { + + var blockPushingThread: Thread = null + + override def getLocationPreference = None + + def onStart() { + // Open a socket to the target address and keep reading from it + logInfo("Connecting to " + host + ":" + port) + val channel = SocketChannel.open() + channel.configureBlocking(true) + channel.connect(new InetSocketAddress(host, port)) + logInfo("Connected to " + host + ":" + port) + + val queue = new ArrayBlockingQueue[ByteBuffer](2) + + blockPushingThread = new Thread { + setDaemon(true) + override def run() { + var nextBlockNumber = 0 + while (true) { + val buffer = queue.take() + val blockId = "input-" + streamId + "-" + nextBlockNumber + nextBlockNumber += 1 + pushBlock(blockId, buffer, null, storageLevel) + } + } + } + blockPushingThread.start() + + val lengthBuffer = ByteBuffer.allocate(4) + while (true) { + lengthBuffer.clear() + readFully(channel, lengthBuffer) + lengthBuffer.flip() + val length = lengthBuffer.getInt() + val dataBuffer = ByteBuffer.allocate(length) + readFully(channel, dataBuffer) + dataBuffer.flip() + logInfo("Read a block with " + length + " bytes") + queue.put(dataBuffer) + } + } + + def onStop() { + if (blockPushingThread != null) blockPushingThread.interrupt() + } + + /** Read a buffer fully from a given Channel */ + private def readFully(channel: ReadableByteChannel, dest: ByteBuffer) { + while (dest.position < dest.limit) { + if (channel.read(dest) == -1) { + throw new EOFException("End of channel") + } + } + } +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReducedWindowedDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReducedWindowedDStream.scala new file mode 100644 index 0000000000..b6c672f899 --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReducedWindowedDStream.scala @@ -0,0 +1,174 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.dstream + +import org.apache.spark.streaming.StreamingContext._ + +import org.apache.spark.RDD +import org.apache.spark.rdd.{CoGroupedRDD, MapPartitionsRDD} +import org.apache.spark.Partitioner +import org.apache.spark.SparkContext._ +import org.apache.spark.storage.StorageLevel + +import scala.collection.mutable.ArrayBuffer +import org.apache.spark.streaming.{Duration, Interval, Time, DStream} + +private[streaming] +class ReducedWindowedDStream[K: ClassManifest, V: ClassManifest]( + parent: DStream[(K, V)], + reduceFunc: (V, V) => V, + invReduceFunc: (V, V) => V, + filterFunc: Option[((K, V)) => Boolean], + _windowDuration: Duration, + _slideDuration: Duration, + partitioner: Partitioner + ) extends DStream[(K,V)](parent.ssc) { + + assert(_windowDuration.isMultipleOf(parent.slideDuration), + "The window duration of ReducedWindowedDStream (" + _slideDuration + ") " + + "must be multiple of the slide duration of parent DStream (" + parent.slideDuration + ")" + ) + + assert(_slideDuration.isMultipleOf(parent.slideDuration), + "The slide duration of ReducedWindowedDStream (" + _slideDuration + ") " + + "must be multiple of the slide duration of parent DStream (" + parent.slideDuration + ")" + ) + + // Reduce each batch of data using reduceByKey which will be further reduced by window + // by ReducedWindowedDStream + val reducedStream = parent.reduceByKey(reduceFunc, partitioner) + + // Persist RDDs to memory by default as these RDDs are going to be reused. + super.persist(StorageLevel.MEMORY_ONLY_SER) + reducedStream.persist(StorageLevel.MEMORY_ONLY_SER) + + def windowDuration: Duration = _windowDuration + + override def dependencies = List(reducedStream) + + override def slideDuration: Duration = _slideDuration + + override val mustCheckpoint = true + + override def parentRememberDuration: Duration = rememberDuration + windowDuration + + override def persist(storageLevel: StorageLevel): DStream[(K,V)] = { + super.persist(storageLevel) + reducedStream.persist(storageLevel) + this + } + + override def checkpoint(interval: Duration): DStream[(K, V)] = { + super.checkpoint(interval) + //reducedStream.checkpoint(interval) + this + } + + override def compute(validTime: Time): Option[RDD[(K, V)]] = { + val reduceF = reduceFunc + val invReduceF = invReduceFunc + + val currentTime = validTime + val currentWindow = new Interval(currentTime - windowDuration + parent.slideDuration, currentTime) + val previousWindow = currentWindow - slideDuration + + logDebug("Window time = " + windowDuration) + logDebug("Slide time = " + slideDuration) + logDebug("ZeroTime = " + zeroTime) + logDebug("Current window = " + currentWindow) + logDebug("Previous window = " + previousWindow) + + // _____________________________ + // | previous window _________|___________________ + // |___________________| current window | --------------> Time + // |_____________________________| + // + // |________ _________| |________ _________| + // | | + // V V + // old RDDs new RDDs + // + + // Get the RDDs of the reduced values in "old time steps" + val oldRDDs = + reducedStream.slice(previousWindow.beginTime, currentWindow.beginTime - parent.slideDuration) + logDebug("# old RDDs = " + oldRDDs.size) + + // Get the RDDs of the reduced values in "new time steps" + val newRDDs = + reducedStream.slice(previousWindow.endTime + parent.slideDuration, currentWindow.endTime) + logDebug("# new RDDs = " + newRDDs.size) + + // Get the RDD of the reduced value of the previous window + val previousWindowRDD = + getOrCompute(previousWindow.endTime).getOrElse(ssc.sc.makeRDD(Seq[(K,V)]())) + + // Make the list of RDDs that needs to cogrouped together for reducing their reduced values + val allRDDs = new ArrayBuffer[RDD[(K, V)]]() += previousWindowRDD ++= oldRDDs ++= newRDDs + + // Cogroup the reduced RDDs and merge the reduced values + val cogroupedRDD = new CoGroupedRDD[K](allRDDs.toSeq.asInstanceOf[Seq[RDD[(K, _)]]], partitioner) + //val mergeValuesFunc = mergeValues(oldRDDs.size, newRDDs.size) _ + + val numOldValues = oldRDDs.size + val numNewValues = newRDDs.size + + val mergeValues = (seqOfValues: Seq[Seq[V]]) => { + if (seqOfValues.size != 1 + numOldValues + numNewValues) { + throw new Exception("Unexpected number of sequences of reduced values") + } + // Getting reduced values "old time steps" that will be removed from current window + val oldValues = (1 to numOldValues).map(i => seqOfValues(i)).filter(!_.isEmpty).map(_.head) + // Getting reduced values "new time steps" + val newValues = + (1 to numNewValues).map(i => seqOfValues(numOldValues + i)).filter(!_.isEmpty).map(_.head) + + if (seqOfValues(0).isEmpty) { + // If previous window's reduce value does not exist, then at least new values should exist + if (newValues.isEmpty) { + throw new Exception("Neither previous window has value for key, nor new values found. " + + "Are you sure your key class hashes consistently?") + } + // Reduce the new values + newValues.reduce(reduceF) // return + } else { + // Get the previous window's reduced value + var tempValue = seqOfValues(0).head + // If old values exists, then inverse reduce then from previous value + if (!oldValues.isEmpty) { + tempValue = invReduceF(tempValue, oldValues.reduce(reduceF)) + } + // If new values exists, then reduce them with previous value + if (!newValues.isEmpty) { + tempValue = reduceF(tempValue, newValues.reduce(reduceF)) + } + tempValue // return + } + } + + val mergedValuesRDD = cogroupedRDD.asInstanceOf[RDD[(K,Seq[Seq[V]])]].mapValues(mergeValues) + + if (filterFunc.isDefined) { + Some(mergedValuesRDD.filter(filterFunc.get)) + } else { + Some(mergedValuesRDD) + } + } +} + + diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ShuffledDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ShuffledDStream.scala new file mode 100644 index 0000000000..3a0bd2acd7 --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ShuffledDStream.scala @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.dstream + +import org.apache.spark.{RDD, Partitioner} +import org.apache.spark.SparkContext._ +import org.apache.spark.streaming.{Duration, DStream, Time} + +private[streaming] +class ShuffledDStream[K: ClassManifest, V: ClassManifest, C: ClassManifest]( + parent: DStream[(K,V)], + createCombiner: V => C, + mergeValue: (C, V) => C, + mergeCombiner: (C, C) => C, + partitioner: Partitioner + ) extends DStream [(K,C)] (parent.ssc) { + + override def dependencies = List(parent) + + override def slideDuration: Duration = parent.slideDuration + + override def compute(validTime: Time): Option[RDD[(K,C)]] = { + parent.getOrCompute(validTime) match { + case Some(rdd) => + Some(rdd.combineByKey[C](createCombiner, mergeValue, mergeCombiner, partitioner)) + case None => None + } + } +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/SocketInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/SocketInputDStream.scala new file mode 100644 index 0000000000..e2539c7396 --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/SocketInputDStream.scala @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.dstream + +import org.apache.spark.streaming.StreamingContext +import org.apache.spark.storage.StorageLevel +import org.apache.spark.util.NextIterator + +import java.io._ +import java.net.Socket + +private[streaming] +class SocketInputDStream[T: ClassManifest]( + @transient ssc_ : StreamingContext, + host: String, + port: Int, + bytesToObjects: InputStream => Iterator[T], + storageLevel: StorageLevel + ) extends NetworkInputDStream[T](ssc_) { + + def getReceiver(): NetworkReceiver[T] = { + new SocketReceiver(host, port, bytesToObjects, storageLevel) + } +} + +private[streaming] +class SocketReceiver[T: ClassManifest]( + host: String, + port: Int, + bytesToObjects: InputStream => Iterator[T], + storageLevel: StorageLevel + ) extends NetworkReceiver[T] { + + lazy protected val blockGenerator = new BlockGenerator(storageLevel) + + override def getLocationPreference = None + + protected def onStart() { + logInfo("Connecting to " + host + ":" + port) + val socket = new Socket(host, port) + logInfo("Connected to " + host + ":" + port) + blockGenerator.start() + val iterator = bytesToObjects(socket.getInputStream()) + while(iterator.hasNext) { + val obj = iterator.next + blockGenerator += obj + } + } + + protected def onStop() { + blockGenerator.stop() + } + +} + +private[streaming] +object SocketReceiver { + + /** + * This methods translates the data from an inputstream (say, from a socket) + * to '\n' delimited strings and returns an iterator to access the strings. + */ + def bytesToLines(inputStream: InputStream): Iterator[String] = { + val dataInputStream = new BufferedReader(new InputStreamReader(inputStream, "UTF-8")) + new NextIterator[String] { + protected override def getNext() = { + val nextValue = dataInputStream.readLine() + if (nextValue == null) { + finished = true + } + nextValue + } + + protected override def close() { + dataInputStream.close() + } + } + } +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/StateDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/StateDStream.scala new file mode 100644 index 0000000000..c1c9f808f0 --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/StateDStream.scala @@ -0,0 +1,109 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.dstream + +import org.apache.spark.RDD +import org.apache.spark.Partitioner +import org.apache.spark.SparkContext._ +import org.apache.spark.storage.StorageLevel +import org.apache.spark.streaming.{Duration, Time, DStream} + +private[streaming] +class StateDStream[K: ClassManifest, V: ClassManifest, S: ClassManifest]( + parent: DStream[(K, V)], + updateFunc: (Iterator[(K, Seq[V], Option[S])]) => Iterator[(K, S)], + partitioner: Partitioner, + preservePartitioning: Boolean + ) extends DStream[(K, S)](parent.ssc) { + + super.persist(StorageLevel.MEMORY_ONLY_SER) + + override def dependencies = List(parent) + + override def slideDuration: Duration = parent.slideDuration + + override val mustCheckpoint = true + + override def compute(validTime: Time): Option[RDD[(K, S)]] = { + + // Try to get the previous state RDD + getOrCompute(validTime - slideDuration) match { + + case Some(prevStateRDD) => { // If previous state RDD exists + + // Try to get the parent RDD + parent.getOrCompute(validTime) match { + case Some(parentRDD) => { // If parent RDD exists, then compute as usual + + // Define the function for the mapPartition operation on cogrouped RDD; + // first map the cogrouped tuple to tuples of required type, + // and then apply the update function + val updateFuncLocal = updateFunc + val finalFunc = (iterator: Iterator[(K, (Seq[V], Seq[S]))]) => { + val i = iterator.map(t => { + (t._1, t._2._1, t._2._2.headOption) + }) + updateFuncLocal(i) + } + val cogroupedRDD = parentRDD.cogroup(prevStateRDD, partitioner) + val stateRDD = cogroupedRDD.mapPartitions(finalFunc, preservePartitioning) + //logDebug("Generating state RDD for time " + validTime) + return Some(stateRDD) + } + case None => { // If parent RDD does not exist + + // Re-apply the update function to the old state RDD + val updateFuncLocal = updateFunc + val finalFunc = (iterator: Iterator[(K, S)]) => { + val i = iterator.map(t => (t._1, Seq[V](), Option(t._2))) + updateFuncLocal(i) + } + val stateRDD = prevStateRDD.mapPartitions(finalFunc, preservePartitioning) + return Some(stateRDD) + } + } + } + + case None => { // If previous session RDD does not exist (first input data) + + // Try to get the parent RDD + parent.getOrCompute(validTime) match { + case Some(parentRDD) => { // If parent RDD exists, then compute as usual + + // Define the function for the mapPartition operation on grouped RDD; + // first map the grouped tuple to tuples of required type, + // and then apply the update function + val updateFuncLocal = updateFunc + val finalFunc = (iterator: Iterator[(K, Seq[V])]) => { + updateFuncLocal(iterator.map(tuple => (tuple._1, tuple._2, None))) + } + + val groupedRDD = parentRDD.groupByKey(partitioner) + val sessionRDD = groupedRDD.mapPartitions(finalFunc, preservePartitioning) + //logDebug("Generating state RDD for time " + validTime + " (first)") + return Some(sessionRDD) + } + case None => { // If parent RDD does not exist, then nothing to do! + //logDebug("Not generating state RDD (no previous state, no parent)") + return None + } + } + } + } + } +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/TransformedDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/TransformedDStream.scala new file mode 100644 index 0000000000..edba2032b4 --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/TransformedDStream.scala @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.dstream + +import org.apache.spark.RDD +import org.apache.spark.streaming.{Duration, DStream, Time} + +private[streaming] +class TransformedDStream[T: ClassManifest, U: ClassManifest] ( + parent: DStream[T], + transformFunc: (RDD[T], Time) => RDD[U] + ) extends DStream[U](parent.ssc) { + + override def dependencies = List(parent) + + override def slideDuration: Duration = parent.slideDuration + + override def compute(validTime: Time): Option[RDD[U]] = { + parent.getOrCompute(validTime).map(transformFunc(_, validTime)) + } +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/TwitterInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/TwitterInputDStream.scala new file mode 100644 index 0000000000..387e15b0e6 --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/TwitterInputDStream.scala @@ -0,0 +1,99 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.dstream + +import org.apache.spark._ +import org.apache.spark.streaming._ +import storage.StorageLevel +import twitter4j._ +import twitter4j.auth.Authorization +import java.util.prefs.Preferences +import twitter4j.conf.ConfigurationBuilder +import twitter4j.conf.PropertyConfiguration +import twitter4j.auth.OAuthAuthorization +import twitter4j.auth.AccessToken + +/* A stream of Twitter statuses, potentially filtered by one or more keywords. +* +* @constructor create a new Twitter stream using the supplied Twitter4J authentication credentials. +* An optional set of string filters can be used to restrict the set of tweets. The Twitter API is +* such that this may return a sampled subset of all tweets during each interval. +* +* If no Authorization object is provided, initializes OAuth authorization using the system +* properties twitter4j.oauth.consumerKey, .consumerSecret, .accessToken and .accessTokenSecret. +*/ +private[streaming] +class TwitterInputDStream( + @transient ssc_ : StreamingContext, + twitterAuth: Option[Authorization], + filters: Seq[String], + storageLevel: StorageLevel + ) extends NetworkInputDStream[Status](ssc_) { + + private def createOAuthAuthorization(): Authorization = { + new OAuthAuthorization(new ConfigurationBuilder().build()) + } + + private val authorization = twitterAuth.getOrElse(createOAuthAuthorization()) + + override def getReceiver(): NetworkReceiver[Status] = { + new TwitterReceiver(authorization, filters, storageLevel) + } +} + +private[streaming] +class TwitterReceiver( + twitterAuth: Authorization, + filters: Seq[String], + storageLevel: StorageLevel + ) extends NetworkReceiver[Status] { + + var twitterStream: TwitterStream = _ + lazy val blockGenerator = new BlockGenerator(storageLevel) + + protected override def onStart() { + blockGenerator.start() + twitterStream = new TwitterStreamFactory().getInstance(twitterAuth) + twitterStream.addListener(new StatusListener { + def onStatus(status: Status) = { + blockGenerator += status + } + // Unimplemented + def onDeletionNotice(statusDeletionNotice: StatusDeletionNotice) {} + def onTrackLimitationNotice(i: Int) {} + def onScrubGeo(l: Long, l1: Long) {} + def onStallWarning(stallWarning: StallWarning) {} + def onException(e: Exception) { stopOnError(e) } + }) + + val query: FilterQuery = new FilterQuery + if (filters.size > 0) { + query.track(filters.toArray) + twitterStream.filter(query) + } else { + twitterStream.sample() + } + logInfo("Twitter receiver started") + } + + protected override def onStop() { + blockGenerator.stop() + twitterStream.shutdown() + logInfo("Twitter receiver stopped") + } +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/UnionDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/UnionDStream.scala new file mode 100644 index 0000000000..97eab97b2f --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/UnionDStream.scala @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.dstream + +import org.apache.spark.streaming.{Duration, DStream, Time} +import org.apache.spark.RDD +import collection.mutable.ArrayBuffer +import org.apache.spark.rdd.UnionRDD + +private[streaming] +class UnionDStream[T: ClassManifest](parents: Array[DStream[T]]) + extends DStream[T](parents.head.ssc) { + + if (parents.length == 0) { + throw new IllegalArgumentException("Empty array of parents") + } + + if (parents.map(_.ssc).distinct.size > 1) { + throw new IllegalArgumentException("Array of parents have different StreamingContexts") + } + + if (parents.map(_.slideDuration).distinct.size > 1) { + throw new IllegalArgumentException("Array of parents have different slide times") + } + + override def dependencies = parents.toList + + override def slideDuration: Duration = parents.head.slideDuration + + override def compute(validTime: Time): Option[RDD[T]] = { + val rdds = new ArrayBuffer[RDD[T]]() + parents.map(_.getOrCompute(validTime)).foreach(_ match { + case Some(rdd) => rdds += rdd + case None => throw new Exception("Could not generate RDD from a parent for unifying at time " + validTime) + }) + if (rdds.size > 0) { + Some(new UnionRDD(ssc.sc, rdds)) + } else { + None + } + } +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/WindowedDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/WindowedDStream.scala new file mode 100644 index 0000000000..dbbea39e81 --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/WindowedDStream.scala @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.dstream + +import org.apache.spark.RDD +import org.apache.spark.rdd.UnionRDD +import org.apache.spark.storage.StorageLevel +import org.apache.spark.streaming.{Duration, Interval, Time, DStream} + +private[streaming] +class WindowedDStream[T: ClassManifest]( + parent: DStream[T], + _windowDuration: Duration, + _slideDuration: Duration) + extends DStream[T](parent.ssc) { + + if (!_windowDuration.isMultipleOf(parent.slideDuration)) + throw new Exception("The window duration of WindowedDStream (" + _slideDuration + ") " + + "must be multiple of the slide duration of parent DStream (" + parent.slideDuration + ")") + + if (!_slideDuration.isMultipleOf(parent.slideDuration)) + throw new Exception("The slide duration of WindowedDStream (" + _slideDuration + ") " + + "must be multiple of the slide duration of parent DStream (" + parent.slideDuration + ")") + + parent.persist(StorageLevel.MEMORY_ONLY_SER) + + def windowDuration: Duration = _windowDuration + + override def dependencies = List(parent) + + override def slideDuration: Duration = _slideDuration + + override def parentRememberDuration: Duration = rememberDuration + windowDuration + + override def compute(validTime: Time): Option[RDD[T]] = { + val currentWindow = new Interval(validTime - windowDuration + parent.slideDuration, validTime) + Some(new UnionRDD(ssc.sc, parent.slice(currentWindow))) + } +} + + + diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receivers/ActorReceiver.scala b/streaming/src/main/scala/org/apache/spark/streaming/receivers/ActorReceiver.scala new file mode 100644 index 0000000000..4b5d8c467e --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/receivers/ActorReceiver.scala @@ -0,0 +1,175 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.receivers + +import akka.actor.{ Actor, PoisonPill, Props, SupervisorStrategy } +import akka.actor.{ actorRef2Scala, ActorRef } +import akka.actor.{ PossiblyHarmful, OneForOneStrategy } + +import org.apache.spark.storage.StorageLevel +import org.apache.spark.streaming.dstream.NetworkReceiver + +import java.util.concurrent.atomic.AtomicInteger + +import scala.collection.mutable.ArrayBuffer + +/** A helper with set of defaults for supervisor strategy **/ +object ReceiverSupervisorStrategy { + + import akka.util.duration._ + import akka.actor.SupervisorStrategy._ + + val defaultStrategy = OneForOneStrategy(maxNrOfRetries = 10, withinTimeRange = + 15 millis) { + case _: RuntimeException ⇒ Restart + case _: Exception ⇒ Escalate + } +} + +/** + * A receiver trait to be mixed in with your Actor to gain access to + * pushBlock API. + * + * Find more details at: http://spark-project.org/docs/latest/streaming-custom-receivers.html + * + * @example {{{ + * class MyActor extends Actor with Receiver{ + * def receive { + * case anything :String ⇒ pushBlock(anything) + * } + * } + * //Can be plugged in actorStream as follows + * ssc.actorStream[String](Props(new MyActor),"MyActorReceiver") + * + * }}} + * + * @note An important point to note: + * Since Actor may exist outside the spark framework, It is thus user's responsibility + * to ensure the type safety, i.e parametrized type of push block and InputDStream + * should be same. + * + */ +trait Receiver { self: Actor ⇒ + def pushBlock[T: ClassManifest](iter: Iterator[T]) { + context.parent ! Data(iter) + } + + def pushBlock[T: ClassManifest](data: T) { + context.parent ! Data(data) + } + +} + +/** + * Statistics for querying the supervisor about state of workers + */ +case class Statistics(numberOfMsgs: Int, + numberOfWorkers: Int, + numberOfHiccups: Int, + otherInfo: String) + +/** Case class to receive data sent by child actors **/ +private[streaming] case class Data[T: ClassManifest](data: T) + +/** + * Provides Actors as receivers for receiving stream. + * + * As Actors can also be used to receive data from almost any stream source. + * A nice set of abstraction(s) for actors as receivers is already provided for + * a few general cases. It is thus exposed as an API where user may come with + * his own Actor to run as receiver for Spark Streaming input source. + * + * This starts a supervisor actor which starts workers and also provides + * [http://doc.akka.io/docs/akka/2.0.5/scala/fault-tolerance.html fault-tolerance]. + * + * Here's a way to start more supervisor/workers as its children. + * + * @example {{{ + * context.parent ! Props(new Supervisor) + * }}} OR {{{ + * context.parent ! Props(new Worker,"Worker") + * }}} + * + * + */ +private[streaming] class ActorReceiver[T: ClassManifest]( + props: Props, + name: String, + storageLevel: StorageLevel, + receiverSupervisorStrategy: SupervisorStrategy) + extends NetworkReceiver[T] { + + protected lazy val blocksGenerator: BlockGenerator = + new BlockGenerator(storageLevel) + + protected lazy val supervisor = env.actorSystem.actorOf(Props(new Supervisor), + "Supervisor" + streamId) + + private class Supervisor extends Actor { + + override val supervisorStrategy = receiverSupervisorStrategy + val worker = context.actorOf(props, name) + logInfo("Started receiver worker at:" + worker.path) + + val n: AtomicInteger = new AtomicInteger(0) + val hiccups: AtomicInteger = new AtomicInteger(0) + + def receive = { + + case Data(iter: Iterator[_]) ⇒ pushBlock(iter.asInstanceOf[Iterator[T]]) + + case Data(msg) ⇒ + blocksGenerator += msg.asInstanceOf[T] + n.incrementAndGet + + case props: Props ⇒ + val worker = context.actorOf(props) + logInfo("Started receiver worker at:" + worker.path) + sender ! worker + + case (props: Props, name: String) ⇒ + val worker = context.actorOf(props, name) + logInfo("Started receiver worker at:" + worker.path) + sender ! worker + + case _: PossiblyHarmful => hiccups.incrementAndGet() + + case _: Statistics ⇒ + val workers = context.children + sender ! Statistics(n.get, workers.size, hiccups.get, workers.mkString("\n")) + + } + } + + protected def pushBlock(iter: Iterator[T]) { + val buffer = new ArrayBuffer[T] + buffer ++= iter + pushBlock("block-" + streamId + "-" + System.nanoTime(), buffer, null, storageLevel) + } + + protected def onStart() = { + blocksGenerator.start() + supervisor + logInfo("Supervision tree for receivers initialized at:" + supervisor.path) + } + + protected def onStop() = { + supervisor ! PoisonPill + } + +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receivers/ZeroMQReceiver.scala b/streaming/src/main/scala/org/apache/spark/streaming/receivers/ZeroMQReceiver.scala new file mode 100644 index 0000000000..043bb8c8bf --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/receivers/ZeroMQReceiver.scala @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.receivers + +import akka.actor.Actor +import akka.zeromq._ + +import org.apache.spark.Logging + +/** + * A receiver to subscribe to ZeroMQ stream. + */ +private[streaming] class ZeroMQReceiver[T: ClassManifest](publisherUrl: String, + subscribe: Subscribe, + bytesToObjects: Seq[Seq[Byte]] ⇒ Iterator[T]) + extends Actor with Receiver with Logging { + + override def preStart() = context.system.newSocket(SocketType.Sub, Listener(self), + Connect(publisherUrl), subscribe) + + def receive: Receive = { + + case Connecting ⇒ logInfo("connecting ...") + + case m: ZMQMessage ⇒ + logDebug("Received message for:" + m.firstFrameAsString) + + //We ignore first frame for processing as it is the topic + val bytes = m.frames.tail.map(_.payload) + pushBlock(bytesToObjects(bytes)) + + case Closed ⇒ logInfo("received closed ") + + } +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/Clock.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/Clock.scala new file mode 100644 index 0000000000..f67bb2f6ac --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/Clock.scala @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.util + +private[streaming] +trait Clock { + def currentTime(): Long + def waitTillTime(targetTime: Long): Long +} + +private[streaming] +class SystemClock() extends Clock { + + val minPollTime = 25L + + def currentTime(): Long = { + System.currentTimeMillis() + } + + def waitTillTime(targetTime: Long): Long = { + var currentTime = 0L + currentTime = System.currentTimeMillis() + + var waitTime = targetTime - currentTime + if (waitTime <= 0) { + return currentTime + } + + val pollTime = { + if (waitTime / 10.0 > minPollTime) { + (waitTime / 10.0).toLong + } else { + minPollTime + } + } + + + while (true) { + currentTime = System.currentTimeMillis() + waitTime = targetTime - currentTime + + if (waitTime <= 0) { + + return currentTime + } + val sleepTime = + if (waitTime < pollTime) { + waitTime + } else { + pollTime + } + Thread.sleep(sleepTime) + } + return -1 + } +} + +private[streaming] +class ManualClock() extends Clock { + + var time = 0L + + def currentTime() = time + + def setTime(timeToSet: Long) = { + this.synchronized { + time = timeToSet + this.notifyAll() + } + } + + def addToTime(timeToAdd: Long) = { + this.synchronized { + time += timeToAdd + this.notifyAll() + } + } + def waitTillTime(targetTime: Long): Long = { + this.synchronized { + while (time < targetTime) { + this.wait(100) + } + } + return currentTime() + } +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/MasterFailureTest.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/MasterFailureTest.scala new file mode 100644 index 0000000000..50d72298e4 --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/MasterFailureTest.scala @@ -0,0 +1,414 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.util + +import org.apache.spark.{Logging, RDD} +import org.apache.spark.streaming._ +import org.apache.spark.streaming.dstream.ForEachDStream +import StreamingContext._ + +import scala.util.Random +import scala.collection.mutable.{SynchronizedBuffer, ArrayBuffer} + +import java.io.{File, ObjectInputStream, IOException} +import java.util.UUID + +import com.google.common.io.Files + +import org.apache.commons.io.FileUtils +import org.apache.hadoop.fs.{FileUtil, FileSystem, Path} +import org.apache.hadoop.conf.Configuration + + +private[streaming] +object MasterFailureTest extends Logging { + initLogging() + + @volatile var killed = false + @volatile var killCount = 0 + + def main(args: Array[String]) { + if (args.size < 2) { + println( + "Usage: MasterFailureTest <# batches> []") + System.exit(1) + } + val directory = args(0) + val numBatches = args(1).toInt + val batchDuration = if (args.size > 2) Milliseconds(args(2).toInt) else Seconds(1) + + println("\n\n========================= MAP TEST =========================\n\n") + testMap(directory, numBatches, batchDuration) + + println("\n\n================= UPDATE-STATE-BY-KEY TEST =================\n\n") + testUpdateStateByKey(directory, numBatches, batchDuration) + + println("\n\nSUCCESS\n\n") + } + + def testMap(directory: String, numBatches: Int, batchDuration: Duration) { + // Input: time=1 ==> [ 1 ] , time=2 ==> [ 2 ] , time=3 ==> [ 3 ] , ... + val input = (1 to numBatches).map(_.toString).toSeq + // Expected output: time=1 ==> [ 1 ] , time=2 ==> [ 2 ] , time=3 ==> [ 3 ] , ... + val expectedOutput = (1 to numBatches) + + val operation = (st: DStream[String]) => st.map(_.toInt) + + // Run streaming operation with multiple master failures + val output = testOperation(directory, batchDuration, input, operation, expectedOutput) + + logInfo("Expected output, size = " + expectedOutput.size) + logInfo(expectedOutput.mkString("[", ",", "]")) + logInfo("Output, size = " + output.size) + logInfo(output.mkString("[", ",", "]")) + + // Verify whether all the values of the expected output is present + // in the output + assert(output.distinct.toSet == expectedOutput.toSet) + } + + + def testUpdateStateByKey(directory: String, numBatches: Int, batchDuration: Duration) { + // Input: time=1 ==> [ a ] , time=2 ==> [ a, a ] , time=3 ==> [ a, a, a ] , ... + val input = (1 to numBatches).map(i => (1 to i).map(_ => "a").mkString(" ")).toSeq + // Expected output: time=1 ==> [ (a, 1) ] , time=2 ==> [ (a, 3) ] , time=3 ==> [ (a,6) ] , ... + val expectedOutput = (1L to numBatches).map(i => (1L to i).reduce(_ + _)).map(j => ("a", j)) + + val operation = (st: DStream[String]) => { + val updateFunc = (values: Seq[Long], state: Option[Long]) => { + Some(values.foldLeft(0L)(_ + _) + state.getOrElse(0L)) + } + st.flatMap(_.split(" ")) + .map(x => (x, 1L)) + .updateStateByKey[Long](updateFunc) + .checkpoint(batchDuration * 5) + } + + // Run streaming operation with multiple master failures + val output = testOperation(directory, batchDuration, input, operation, expectedOutput) + + logInfo("Expected output, size = " + expectedOutput.size + "\n" + expectedOutput) + logInfo("Output, size = " + output.size + "\n" + output) + + // Verify whether all the values in the output are among the expected output values + output.foreach(o => + assert(expectedOutput.contains(o), "Expected value " + o + " not found") + ) + + // Verify whether the last expected output value has been generated, there by + // confirming that none of the inputs have been missed + assert(output.last == expectedOutput.last) + } + + /** + * Tests stream operation with multiple master failures, and verifies whether the + * final set of output values is as expected or not. + */ + def testOperation[T: ClassManifest]( + directory: String, + batchDuration: Duration, + input: Seq[String], + operation: DStream[String] => DStream[T], + expectedOutput: Seq[T] + ): Seq[T] = { + + // Just making sure that the expected output does not have duplicates + assert(expectedOutput.distinct.toSet == expectedOutput.toSet) + + // Setup the stream computation with the given operation + val (ssc, checkpointDir, testDir) = setupStreams(directory, batchDuration, operation) + + // Start generating files in the a different thread + val fileGeneratingThread = new FileGeneratingThread(input, testDir, batchDuration.milliseconds) + fileGeneratingThread.start() + + // Run the streams and repeatedly kill it until the last expected output + // has been generated, or until it has run for twice the expected time + val lastExpectedOutput = expectedOutput.last + val maxTimeToRun = expectedOutput.size * batchDuration.milliseconds * 2 + val mergedOutput = runStreams(ssc, lastExpectedOutput, maxTimeToRun) + + // Delete directories + fileGeneratingThread.join() + val fs = checkpointDir.getFileSystem(new Configuration()) + fs.delete(checkpointDir, true) + fs.delete(testDir, true) + logInfo("Finished test after " + killCount + " failures") + mergedOutput + } + + /** + * Sets up the stream computation with the given operation, directory (local or HDFS), + * and batch duration. Returns the streaming context and the directory to which + * files should be written for testing. + */ + private def setupStreams[T: ClassManifest]( + directory: String, + batchDuration: Duration, + operation: DStream[String] => DStream[T] + ): (StreamingContext, Path, Path) = { + // Reset all state + reset() + + // Create the directories for this test + val uuid = UUID.randomUUID().toString + val rootDir = new Path(directory, uuid) + val fs = rootDir.getFileSystem(new Configuration()) + val checkpointDir = new Path(rootDir, "checkpoint") + val testDir = new Path(rootDir, "test") + fs.mkdirs(checkpointDir) + fs.mkdirs(testDir) + + // Setup the streaming computation with the given operation + System.clearProperty("spark.driver.port") + System.clearProperty("spark.hostPort") + var ssc = new StreamingContext("local[4]", "MasterFailureTest", batchDuration, null, Nil, Map()) + ssc.checkpoint(checkpointDir.toString) + val inputStream = ssc.textFileStream(testDir.toString) + val operatedStream = operation(inputStream) + val outputStream = new TestOutputStream(operatedStream) + ssc.registerOutputStream(outputStream) + (ssc, checkpointDir, testDir) + } + + + /** + * Repeatedly starts and kills the streaming context until timed out or + * the last expected output is generated. Finally, return + */ + private def runStreams[T: ClassManifest]( + ssc_ : StreamingContext, + lastExpectedOutput: T, + maxTimeToRun: Long + ): Seq[T] = { + + var ssc = ssc_ + var totalTimeRan = 0L + var isLastOutputGenerated = false + var isTimedOut = false + val mergedOutput = new ArrayBuffer[T]() + val checkpointDir = ssc.checkpointDir + var batchDuration = ssc.graph.batchDuration + + while(!isLastOutputGenerated && !isTimedOut) { + // Get the output buffer + val outputBuffer = ssc.graph.getOutputStreams.head.asInstanceOf[TestOutputStream[T]].output + def output = outputBuffer.flatMap(x => x) + + // Start the thread to kill the streaming after some time + killed = false + val killingThread = new KillingThread(ssc, batchDuration.milliseconds * 10) + killingThread.start() + + var timeRan = 0L + try { + // Start the streaming computation and let it run while ... + // (i) StreamingContext has not been shut down yet + // (ii) The last expected output has not been generated yet + // (iii) Its not timed out yet + System.clearProperty("spark.streaming.clock") + System.clearProperty("spark.driver.port") + System.clearProperty("spark.hostPort") + ssc.start() + val startTime = System.currentTimeMillis() + while (!killed && !isLastOutputGenerated && !isTimedOut) { + Thread.sleep(100) + timeRan = System.currentTimeMillis() - startTime + isLastOutputGenerated = (!output.isEmpty && output.last == lastExpectedOutput) + isTimedOut = (timeRan + totalTimeRan > maxTimeToRun) + } + } catch { + case e: Exception => logError("Error running streaming context", e) + } + if (killingThread.isAlive) killingThread.interrupt() + ssc.stop() + + logInfo("Has been killed = " + killed) + logInfo("Is last output generated = " + isLastOutputGenerated) + logInfo("Is timed out = " + isTimedOut) + + // Verify whether the output of each batch has only one element or no element + // and then merge the new output with all the earlier output + mergedOutput ++= output + totalTimeRan += timeRan + logInfo("New output = " + output) + logInfo("Merged output = " + mergedOutput) + logInfo("Time ran = " + timeRan) + logInfo("Total time ran = " + totalTimeRan) + + if (!isLastOutputGenerated && !isTimedOut) { + val sleepTime = Random.nextInt(batchDuration.milliseconds.toInt * 10) + logInfo( + "\n-------------------------------------------\n" + + " Restarting stream computation in " + sleepTime + " ms " + + "\n-------------------------------------------\n" + ) + Thread.sleep(sleepTime) + // Recreate the streaming context from checkpoint + ssc = new StreamingContext(checkpointDir) + } + } + mergedOutput + } + + /** + * Verifies the output value are the same as expected. Since failures can lead to + * a batch being processed twice, a batches output may appear more than once + * consecutively. To avoid getting confused with those, we eliminate consecutive + * duplicate batch outputs of values from the `output`. As a result, the + * expected output should not have consecutive batches with the same values as output. + */ + private def verifyOutput[T: ClassManifest](output: Seq[T], expectedOutput: Seq[T]) { + // Verify whether expected outputs do not consecutive batches with same output + for (i <- 0 until expectedOutput.size - 1) { + assert(expectedOutput(i) != expectedOutput(i+1), + "Expected output has consecutive duplicate sequence of values") + } + + // Log the output + println("Expected output, size = " + expectedOutput.size) + println(expectedOutput.mkString("[", ",", "]")) + println("Output, size = " + output.size) + println(output.mkString("[", ",", "]")) + + // Match the output with the expected output + output.foreach(o => + assert(expectedOutput.contains(o), "Expected value " + o + " not found") + ) + } + + /** Resets counter to prepare for the test */ + private def reset() { + killed = false + killCount = 0 + } +} + +/** + * This is a output stream just for testing. All the output is collected into a + * ArrayBuffer. This buffer is wiped clean on being restored from checkpoint. + */ +private[streaming] +class TestOutputStream[T: ClassManifest]( + parent: DStream[T], + val output: ArrayBuffer[Seq[T]] = new ArrayBuffer[Seq[T]] with SynchronizedBuffer[Seq[T]] + ) extends ForEachDStream[T]( + parent, + (rdd: RDD[T], t: Time) => { + val collected = rdd.collect() + output += collected + } + ) { + + // This is to clear the output buffer every it is read from a checkpoint + @throws(classOf[IOException]) + private def readObject(ois: ObjectInputStream) { + ois.defaultReadObject() + output.clear() + } +} + + +/** + * Thread to kill streaming context after a random period of time. + */ +private[streaming] +class KillingThread(ssc: StreamingContext, maxKillWaitTime: Long) extends Thread with Logging { + initLogging() + + override def run() { + try { + // If it is the first killing, then allow the first checkpoint to be created + var minKillWaitTime = if (MasterFailureTest.killCount == 0) 5000 else 2000 + val killWaitTime = minKillWaitTime + math.abs(Random.nextLong % maxKillWaitTime) + logInfo("Kill wait time = " + killWaitTime) + Thread.sleep(killWaitTime) + logInfo( + "\n---------------------------------------\n" + + "Killing streaming context after " + killWaitTime + " ms" + + "\n---------------------------------------\n" + ) + if (ssc != null) { + ssc.stop() + MasterFailureTest.killed = true + MasterFailureTest.killCount += 1 + } + logInfo("Killing thread finished normally") + } catch { + case ie: InterruptedException => logInfo("Killing thread interrupted") + case e: Exception => logWarning("Exception in killing thread", e) + } + + } +} + + +/** + * Thread to generate input files periodically with the desired text. + */ +private[streaming] +class FileGeneratingThread(input: Seq[String], testDir: Path, interval: Long) + extends Thread with Logging { + initLogging() + + override def run() { + val localTestDir = Files.createTempDir() + var fs = testDir.getFileSystem(new Configuration()) + val maxTries = 3 + try { + Thread.sleep(5000) // To make sure that all the streaming context has been set up + for (i <- 0 until input.size) { + // Write the data to a local file and then move it to the target test directory + val localFile = new File(localTestDir, (i+1).toString) + val hadoopFile = new Path(testDir, (i+1).toString) + val tempHadoopFile = new Path(testDir, ".tmp_" + (i+1).toString) + FileUtils.writeStringToFile(localFile, input(i).toString + "\n") + var tries = 0 + var done = false + while (!done && tries < maxTries) { + tries += 1 + try { + // fs.copyFromLocalFile(new Path(localFile.toString), hadoopFile) + fs.copyFromLocalFile(new Path(localFile.toString), tempHadoopFile) + fs.rename(tempHadoopFile, hadoopFile) + done = true + } catch { + case ioe: IOException => { + fs = testDir.getFileSystem(new Configuration()) + logWarning("Attempt " + tries + " at generating file " + hadoopFile + " failed.", ioe) + } + } + } + if (!done) + logError("Could not generate file " + hadoopFile) + else + logInfo("Generated file " + hadoopFile + " at " + System.currentTimeMillis) + Thread.sleep(interval) + localFile.delete() + } + logInfo("File generating thread finished normally") + } catch { + case ie: InterruptedException => logInfo("File generating thread interrupted") + case e: Exception => logWarning("File generating in killing thread", e) + } finally { + fs.close() + } + } +} + + diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/RawTextHelper.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/RawTextHelper.scala new file mode 100644 index 0000000000..4e6ce6eabd --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/RawTextHelper.scala @@ -0,0 +1,115 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.util + +import org.apache.spark.SparkContext +import org.apache.spark.SparkContext._ +import it.unimi.dsi.fastutil.objects.{Object2LongOpenHashMap => OLMap} +import scala.collection.JavaConversions.mapAsScalaMap + +object RawTextHelper { + + /** + * Splits lines and counts the words in them using specialized object-to-long hashmap + * (to avoid boxing-unboxing overhead of Long in java/scala HashMap) + */ + def splitAndCountPartitions(iter: Iterator[String]): Iterator[(String, Long)] = { + val map = new OLMap[String] + var i = 0 + var j = 0 + while (iter.hasNext) { + val s = iter.next() + i = 0 + while (i < s.length) { + j = i + while (j < s.length && s.charAt(j) != ' ') { + j += 1 + } + if (j > i) { + val w = s.substring(i, j) + val c = map.getLong(w) + map.put(w, c + 1) + } + i = j + while (i < s.length && s.charAt(i) == ' ') { + i += 1 + } + } + } + map.toIterator.map{case (k, v) => (k, v)} + } + + /** + * Gets the top k words in terms of word counts. Assumes that each word exists only once + * in the `data` iterator (that is, the counts have been reduced). + */ + def topK(data: Iterator[(String, Long)], k: Int): Iterator[(String, Long)] = { + val taken = new Array[(String, Long)](k) + + var i = 0 + var len = 0 + var done = false + var value: (String, Long) = null + var swap: (String, Long) = null + var count = 0 + + while(data.hasNext) { + value = data.next + if (value != null) { + count += 1 + if (len == 0) { + taken(0) = value + len = 1 + } else if (len < k || value._2 > taken(len - 1)._2) { + if (len < k) { + len += 1 + } + taken(len - 1) = value + i = len - 1 + while(i > 0 && taken(i - 1)._2 < taken(i)._2) { + swap = taken(i) + taken(i) = taken(i-1) + taken(i - 1) = swap + i -= 1 + } + } + } + } + return taken.toIterator + } + + /** + * Warms up the SparkContext in master and slave by running tasks to force JIT kick in + * before real workload starts. + */ + def warmUp(sc: SparkContext) { + for(i <- 0 to 1) { + sc.parallelize(1 to 200000, 1000) + .map(_ % 1331).map(_.toString) + .mapPartitions(splitAndCountPartitions).reduceByKey(_ + _, 10) + .count() + } + } + + def add(v1: Long, v2: Long) = (v1 + v2) + + def subtract(v1: Long, v2: Long) = (v1 - v2) + + def max(v1: Long, v2: Long) = math.max(v1, v2) +} + diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/RawTextSender.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/RawTextSender.scala new file mode 100644 index 0000000000..249f6a22ae --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/RawTextSender.scala @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.util + +import java.nio.ByteBuffer +import org.apache.spark.util.{RateLimitedOutputStream, IntParam} +import java.net.ServerSocket +import org.apache.spark.{Logging, KryoSerializer} +import it.unimi.dsi.fastutil.io.FastByteArrayOutputStream +import scala.io.Source +import java.io.IOException + +/** + * A helper program that sends blocks of Kryo-serialized text strings out on a socket at a + * specified rate. Used to feed data into RawInputDStream. + */ +object RawTextSender extends Logging { + def main(args: Array[String]) { + if (args.length != 4) { + System.err.println("Usage: RawTextSender ") + System.exit(1) + } + // Parse the arguments using a pattern match + val Array(IntParam(port), file, IntParam(blockSize), IntParam(bytesPerSec)) = args + + // Repeat the input data multiple times to fill in a buffer + val lines = Source.fromFile(file).getLines().toArray + val bufferStream = new FastByteArrayOutputStream(blockSize + 1000) + val ser = new KryoSerializer().newInstance() + val serStream = ser.serializeStream(bufferStream) + var i = 0 + while (bufferStream.position < blockSize) { + serStream.writeObject(lines(i)) + i = (i + 1) % lines.length + } + bufferStream.trim() + val array = bufferStream.array + + val countBuf = ByteBuffer.wrap(new Array[Byte](4)) + countBuf.putInt(array.length) + countBuf.flip() + + val serverSocket = new ServerSocket(port) + logInfo("Listening on port " + port) + + while (true) { + val socket = serverSocket.accept() + logInfo("Got a new connection") + val out = new RateLimitedOutputStream(socket.getOutputStream, bytesPerSec) + try { + while (true) { + out.write(countBuf.array) + out.write(array) + } + } catch { + case e: IOException => + logError("Client disconnected") + socket.close() + } + } + } +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/RecurringTimer.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/RecurringTimer.scala new file mode 100644 index 0000000000..d644240405 --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/RecurringTimer.scala @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.util + +private[streaming] +class RecurringTimer(val clock: Clock, val period: Long, val callback: (Long) => Unit) { + + private val minPollTime = 25L + + private val pollTime = { + if (period / 10.0 > minPollTime) { + (period / 10.0).toLong + } else { + minPollTime + } + } + + private val thread = new Thread() { + override def run() { loop } + } + + private var nextTime = 0L + + def getStartTime(): Long = { + (math.floor(clock.currentTime.toDouble / period) + 1).toLong * period + } + + def getRestartTime(originalStartTime: Long): Long = { + val gap = clock.currentTime - originalStartTime + (math.floor(gap.toDouble / period).toLong + 1) * period + originalStartTime + } + + def start(startTime: Long): Long = { + nextTime = startTime + thread.start() + nextTime + } + + def start(): Long = { + start(getStartTime()) + } + + def stop() { + thread.interrupt() + } + + private def loop() { + try { + while (true) { + clock.waitTillTime(nextTime) + callback(nextTime) + nextTime += period + } + + } catch { + case e: InterruptedException => + } + } +} + +private[streaming] +object RecurringTimer { + + def main(args: Array[String]) { + var lastRecurTime = 0L + val period = 1000 + + def onRecur(time: Long) { + val currentTime = System.currentTimeMillis() + println("" + currentTime + ": " + (currentTime - lastRecurTime)) + lastRecurTime = currentTime + } + val timer = new RecurringTimer(new SystemClock(), period, onRecur) + timer.start() + Thread.sleep(30 * 1000) + timer.stop() + } +} + diff --git a/streaming/src/main/scala/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/spark/streaming/Checkpoint.scala deleted file mode 100644 index 070d930b5e..0000000000 --- a/streaming/src/main/scala/spark/streaming/Checkpoint.scala +++ /dev/null @@ -1,190 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.streaming - -import java.io._ -import java.util.concurrent.Executors -import java.util.concurrent.RejectedExecutionException - -import org.apache.hadoop.fs.Path -import org.apache.hadoop.conf.Configuration - -import spark.Logging -import spark.io.CompressionCodec - - -private[streaming] -class Checkpoint(@transient ssc: StreamingContext, val checkpointTime: Time) - extends Logging with Serializable { - val master = ssc.sc.master - val framework = ssc.sc.appName - val sparkHome = ssc.sc.sparkHome - val jars = ssc.sc.jars - val environment = ssc.sc.environment - val graph = ssc.graph - val checkpointDir = ssc.checkpointDir - val checkpointDuration = ssc.checkpointDuration - val pendingTimes = ssc.scheduler.jobManager.getPendingTimes() - - def validate() { - assert(master != null, "Checkpoint.master is null") - assert(framework != null, "Checkpoint.framework is null") - assert(graph != null, "Checkpoint.graph is null") - assert(checkpointTime != null, "Checkpoint.checkpointTime is null") - logInfo("Checkpoint for time " + checkpointTime + " validated") - } -} - - -/** - * Convenience class to speed up the writing of graph checkpoint to file - */ -private[streaming] -class CheckpointWriter(checkpointDir: String) extends Logging { - val file = new Path(checkpointDir, "graph") - // The file to which we actually write - and then "move" to file. - private val writeFile = new Path(file.getParent, file.getName + ".next") - private val bakFile = new Path(file.getParent, file.getName + ".bk") - - private var stopped = false - - val conf = new Configuration() - var fs = file.getFileSystem(conf) - val maxAttempts = 3 - val executor = Executors.newFixedThreadPool(1) - - private val compressionCodec = CompressionCodec.createCodec() - - // Removed code which validates whether there is only one CheckpointWriter per path 'file' since - // I did not notice any errors - reintroduce it ? - - class CheckpointWriteHandler(checkpointTime: Time, bytes: Array[Byte]) extends Runnable { - def run() { - var attempts = 0 - val startTime = System.currentTimeMillis() - while (attempts < maxAttempts) { - attempts += 1 - try { - logDebug("Saving checkpoint for time " + checkpointTime + " to file '" + file + "'") - // This is inherently thread unsafe .. so alleviating it by writing to '.new' and then doing moves : which should be pretty fast. - val fos = fs.create(writeFile) - fos.write(bytes) - fos.close() - if (fs.exists(file) && fs.rename(file, bakFile)) { - logDebug("Moved existing checkpoint file to " + bakFile) - } - // paranoia - fs.delete(file, false) - fs.rename(writeFile, file) - - val finishTime = System.currentTimeMillis(); - logInfo("Checkpoint for time " + checkpointTime + " saved to file '" + file + - "', took " + bytes.length + " bytes and " + (finishTime - startTime) + " milliseconds") - return - } catch { - case ioe: IOException => - logWarning("Error writing checkpoint to file in " + attempts + " attempts", ioe) - } - } - logError("Could not write checkpoint for time " + checkpointTime + " to file '" + file + "'") - } - } - - def write(checkpoint: Checkpoint) { - val bos = new ByteArrayOutputStream() - val zos = compressionCodec.compressedOutputStream(bos) - val oos = new ObjectOutputStream(zos) - oos.writeObject(checkpoint) - oos.close() - bos.close() - try { - executor.execute(new CheckpointWriteHandler(checkpoint.checkpointTime, bos.toByteArray)) - } catch { - case rej: RejectedExecutionException => - logError("Could not submit checkpoint task to the thread pool executor", rej) - } - } - - def stop() { - synchronized { - if (stopped) return ; - stopped = true - } - executor.shutdown() - val startTime = System.currentTimeMillis() - val terminated = executor.awaitTermination(10, java.util.concurrent.TimeUnit.SECONDS) - val endTime = System.currentTimeMillis() - logInfo("CheckpointWriter executor terminated ? " + terminated + ", waited for " + (endTime - startTime) + " ms.") - } -} - - -private[streaming] -object CheckpointReader extends Logging { - - def read(path: String): Checkpoint = { - val fs = new Path(path).getFileSystem(new Configuration()) - val attempts = Seq(new Path(path, "graph"), new Path(path, "graph.bk"), new Path(path), new Path(path + ".bk")) - - val compressionCodec = CompressionCodec.createCodec() - - attempts.foreach(file => { - if (fs.exists(file)) { - logInfo("Attempting to load checkpoint from file '" + file + "'") - try { - val fis = fs.open(file) - // ObjectInputStream uses the last defined user-defined class loader in the stack - // to find classes, which maybe the wrong class loader. Hence, a inherited version - // of ObjectInputStream is used to explicitly use the current thread's default class - // loader to find and load classes. This is a well know Java issue and has popped up - // in other places (e.g., http://jira.codehaus.org/browse/GROOVY-1627) - val zis = compressionCodec.compressedInputStream(fis) - val ois = new ObjectInputStreamWithLoader(zis, Thread.currentThread().getContextClassLoader) - val cp = ois.readObject.asInstanceOf[Checkpoint] - ois.close() - fs.close() - cp.validate() - logInfo("Checkpoint successfully loaded from file '" + file + "'") - logInfo("Checkpoint was generated at time " + cp.checkpointTime) - return cp - } catch { - case e: Exception => - logError("Error loading checkpoint from file '" + file + "'", e) - } - } else { - logWarning("Could not read checkpoint from file '" + file + "' as it does not exist") - } - - }) - throw new Exception("Could not read checkpoint from path '" + path + "'") - } -} - -private[streaming] -class ObjectInputStreamWithLoader(inputStream_ : InputStream, loader: ClassLoader) - extends ObjectInputStream(inputStream_) { - - override def resolveClass(desc: ObjectStreamClass): Class[_] = { - try { - return loader.loadClass(desc.getName()) - } catch { - case e: Exception => - } - return super.resolveClass(desc) - } -} diff --git a/streaming/src/main/scala/spark/streaming/DStream.scala b/streaming/src/main/scala/spark/streaming/DStream.scala deleted file mode 100644 index 684d3abb56..0000000000 --- a/streaming/src/main/scala/spark/streaming/DStream.scala +++ /dev/null @@ -1,700 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.streaming - -import spark.streaming.dstream._ -import StreamingContext._ -//import Time._ - -import spark.{RDD, Logging} -import spark.storage.StorageLevel - -import scala.collection.mutable.ArrayBuffer -import scala.collection.mutable.HashMap - -import java.io.{ObjectInputStream, IOException, ObjectOutputStream} - -import org.apache.hadoop.fs.{FileSystem, Path} -import org.apache.hadoop.conf.Configuration - -/** - * A Discretized Stream (DStream), the basic abstraction in Spark Streaming, is a continuous - * sequence of RDDs (of the same type) representing a continuous stream of data (see [[spark.RDD]] - * for more details on RDDs). DStreams can either be created from live data (such as, data from - * HDFS, Kafka or Flume) or it can be generated by transformation existing DStreams using operations - * such as `map`, `window` and `reduceByKeyAndWindow`. While a Spark Streaming program is running, each - * DStream periodically generates a RDD, either from live data or by transforming the RDD generated - * by a parent DStream. - * - * This class contains the basic operations available on all DStreams, such as `map`, `filter` and - * `window`. In addition, [[spark.streaming.PairDStreamFunctions]] contains operations available - * only on DStreams of key-value pairs, such as `groupByKeyAndWindow` and `join`. These operations - * are automatically available on any DStream of the right type (e.g., DStream[(Int, Int)] through - * implicit conversions when `spark.streaming.StreamingContext._` is imported. - * - * DStreams internally is characterized by a few basic properties: - * - A list of other DStreams that the DStream depends on - * - A time interval at which the DStream generates an RDD - * - A function that is used to generate an RDD after each time interval - */ - -abstract class DStream[T: ClassManifest] ( - @transient protected[streaming] var ssc: StreamingContext - ) extends Serializable with Logging { - - initLogging() - - // ======================================================================= - // Methods that should be implemented by subclasses of DStream - // ======================================================================= - - /** Time interval after which the DStream generates a RDD */ - def slideDuration: Duration - - /** List of parent DStreams on which this DStream depends on */ - def dependencies: List[DStream[_]] - - /** Method that generates a RDD for the given time */ - def compute (validTime: Time): Option[RDD[T]] - - // ======================================================================= - // Methods and fields available on all DStreams - // ======================================================================= - - // RDDs generated, marked as protected[streaming] so that testsuites can access it - @transient - protected[streaming] var generatedRDDs = new HashMap[Time, RDD[T]] () - - // Time zero for the DStream - protected[streaming] var zeroTime: Time = null - - // Duration for which the DStream will remember each RDD created - protected[streaming] var rememberDuration: Duration = null - - // Storage level of the RDDs in the stream - protected[streaming] var storageLevel: StorageLevel = StorageLevel.NONE - - // Checkpoint details - protected[streaming] val mustCheckpoint = false - protected[streaming] var checkpointDuration: Duration = null - protected[streaming] val checkpointData = new DStreamCheckpointData(this) - - // Reference to whole DStream graph - protected[streaming] var graph: DStreamGraph = null - - protected[streaming] def isInitialized = (zeroTime != null) - - // Duration for which the DStream requires its parent DStream to remember each RDD created - protected[streaming] def parentRememberDuration = rememberDuration - - /** Return the StreamingContext associated with this DStream */ - def context = ssc - - /** Persist the RDDs of this DStream with the given storage level */ - def persist(level: StorageLevel): DStream[T] = { - if (this.isInitialized) { - throw new UnsupportedOperationException( - "Cannot change storage level of an DStream after streaming context has started") - } - this.storageLevel = level - this - } - - /** Persist RDDs of this DStream with the default storage level (MEMORY_ONLY_SER) */ - def persist(): DStream[T] = persist(StorageLevel.MEMORY_ONLY_SER) - - /** Persist RDDs of this DStream with the default storage level (MEMORY_ONLY_SER) */ - def cache(): DStream[T] = persist() - - /** - * Enable periodic checkpointing of RDDs of this DStream - * @param interval Time interval after which generated RDD will be checkpointed - */ - def checkpoint(interval: Duration): DStream[T] = { - if (isInitialized) { - throw new UnsupportedOperationException( - "Cannot change checkpoint interval of an DStream after streaming context has started") - } - persist() - checkpointDuration = interval - this - } - - /** - * Initialize the DStream by setting the "zero" time, based on which - * the validity of future times is calculated. This method also recursively initializes - * its parent DStreams. - */ - protected[streaming] def initialize(time: Time) { - if (zeroTime != null && zeroTime != time) { - throw new Exception("ZeroTime is already initialized to " + zeroTime - + ", cannot initialize it again to " + time) - } - zeroTime = time - - // Set the checkpoint interval to be slideDuration or 10 seconds, which ever is larger - if (mustCheckpoint && checkpointDuration == null) { - checkpointDuration = slideDuration * math.ceil(Seconds(10) / slideDuration).toInt - logInfo("Checkpoint interval automatically set to " + checkpointDuration) - } - - // Set the minimum value of the rememberDuration if not already set - var minRememberDuration = slideDuration - if (checkpointDuration != null && minRememberDuration <= checkpointDuration) { - minRememberDuration = checkpointDuration * 2 // times 2 just to be sure that the latest checkpoint is not forgetten - } - if (rememberDuration == null || rememberDuration < minRememberDuration) { - rememberDuration = minRememberDuration - } - - // Initialize the dependencies - dependencies.foreach(_.initialize(zeroTime)) - } - - protected[streaming] def validate() { - assert(rememberDuration != null, "Remember duration is set to null") - - assert( - !mustCheckpoint || checkpointDuration != null, - "The checkpoint interval for " + this.getClass.getSimpleName + " has not been set." + - " Please use DStream.checkpoint() to set the interval." - ) - - assert( - checkpointDuration == null || context.sparkContext.checkpointDir.isDefined, - "The checkpoint directory has not been set. Please use StreamingContext.checkpoint()" + - " or SparkContext.checkpoint() to set the checkpoint directory." - ) - - assert( - checkpointDuration == null || checkpointDuration >= slideDuration, - "The checkpoint interval for " + this.getClass.getSimpleName + " has been set to " + - checkpointDuration + " which is lower than its slide time (" + slideDuration + "). " + - "Please set it to at least " + slideDuration + "." - ) - - assert( - checkpointDuration == null || checkpointDuration.isMultipleOf(slideDuration), - "The checkpoint interval for " + this.getClass.getSimpleName + " has been set to " + - checkpointDuration + " which not a multiple of its slide time (" + slideDuration + "). " + - "Please set it to a multiple " + slideDuration + "." - ) - - assert( - checkpointDuration == null || storageLevel != StorageLevel.NONE, - "" + this.getClass.getSimpleName + " has been marked for checkpointing but the storage " + - "level has not been set to enable persisting. Please use DStream.persist() to set the " + - "storage level to use memory for better checkpointing performance." - ) - - assert( - checkpointDuration == null || rememberDuration > checkpointDuration, - "The remember duration for " + this.getClass.getSimpleName + " has been set to " + - rememberDuration + " which is not more than the checkpoint interval (" + - checkpointDuration + "). Please set it to higher than " + checkpointDuration + "." - ) - - val metadataCleanerDelay = spark.util.MetadataCleaner.getDelaySeconds - logInfo("metadataCleanupDelay = " + metadataCleanerDelay) - assert( - metadataCleanerDelay < 0 || rememberDuration.milliseconds < metadataCleanerDelay * 1000, - "It seems you are doing some DStream window operation or setting a checkpoint interval " + - "which requires " + this.getClass.getSimpleName + " to remember generated RDDs for more " + - "than " + rememberDuration.milliseconds / 1000 + " seconds. But Spark's metadata cleanup" + - "delay is set to " + metadataCleanerDelay + " seconds, which is not sufficient. Please " + - "set the Java property 'spark.cleaner.delay' to more than " + - math.ceil(rememberDuration.milliseconds / 1000.0).toInt + " seconds." - ) - - dependencies.foreach(_.validate()) - - logInfo("Slide time = " + slideDuration) - logInfo("Storage level = " + storageLevel) - logInfo("Checkpoint interval = " + checkpointDuration) - logInfo("Remember duration = " + rememberDuration) - logInfo("Initialized and validated " + this) - } - - protected[streaming] def setContext(s: StreamingContext) { - if (ssc != null && ssc != s) { - throw new Exception("Context is already set in " + this + ", cannot set it again") - } - ssc = s - logInfo("Set context for " + this) - dependencies.foreach(_.setContext(ssc)) - } - - protected[streaming] def setGraph(g: DStreamGraph) { - if (graph != null && graph != g) { - throw new Exception("Graph is already set in " + this + ", cannot set it again") - } - graph = g - dependencies.foreach(_.setGraph(graph)) - } - - protected[streaming] def remember(duration: Duration) { - if (duration != null && duration > rememberDuration) { - rememberDuration = duration - logInfo("Duration for remembering RDDs set to " + rememberDuration + " for " + this) - } - dependencies.foreach(_.remember(parentRememberDuration)) - } - - /** Checks whether the 'time' is valid wrt slideDuration for generating RDD */ - protected def isTimeValid(time: Time): Boolean = { - if (!isInitialized) { - throw new Exception (this + " has not been initialized") - } else if (time <= zeroTime || ! (time - zeroTime).isMultipleOf(slideDuration)) { - logInfo("Time " + time + " is invalid as zeroTime is " + zeroTime + " and slideDuration is " + slideDuration + " and difference is " + (time - zeroTime)) - false - } else { - logInfo("Time " + time + " is valid") - true - } - } - - /** - * Retrieve a precomputed RDD of this DStream, or computes the RDD. This is an internal - * method that should not be called directly. - */ - protected[streaming] def getOrCompute(time: Time): Option[RDD[T]] = { - // If this DStream was not initialized (i.e., zeroTime not set), then do it - // If RDD was already generated, then retrieve it from HashMap - generatedRDDs.get(time) match { - - // If an RDD was already generated and is being reused, then - // probably all RDDs in this DStream will be reused and hence should be cached - case Some(oldRDD) => Some(oldRDD) - - // if RDD was not generated, and if the time is valid - // (based on sliding time of this DStream), then generate the RDD - case None => { - if (isTimeValid(time)) { - compute(time) match { - case Some(newRDD) => - if (storageLevel != StorageLevel.NONE) { - newRDD.persist(storageLevel) - logInfo("Persisting RDD " + newRDD.id + " for time " + time + " to " + storageLevel + " at time " + time) - } - if (checkpointDuration != null && (time - zeroTime).isMultipleOf(checkpointDuration)) { - newRDD.checkpoint() - logInfo("Marking RDD " + newRDD.id + " for time " + time + " for checkpointing at time " + time) - } - generatedRDDs.put(time, newRDD) - Some(newRDD) - case None => - None - } - } else { - None - } - } - } - } - - /** - * Generate a SparkStreaming job for the given time. This is an internal method that - * should not be called directly. This default implementation creates a job - * that materializes the corresponding RDD. Subclasses of DStream may override this - * to generate their own jobs. - */ - protected[streaming] def generateJob(time: Time): Option[Job] = { - getOrCompute(time) match { - case Some(rdd) => { - val jobFunc = () => { - val emptyFunc = { (iterator: Iterator[T]) => {} } - context.sparkContext.runJob(rdd, emptyFunc) - } - Some(new Job(time, jobFunc)) - } - case None => None - } - } - - /** - * Clear metadata that are older than `rememberDuration` of this DStream. - * This is an internal method that should not be called directly. This default - * implementation clears the old generated RDDs. Subclasses of DStream may override - * this to clear their own metadata along with the generated RDDs. - */ - protected[streaming] def clearOldMetadata(time: Time) { - var numForgotten = 0 - val oldRDDs = generatedRDDs.filter(_._1 <= (time - rememberDuration)) - generatedRDDs --= oldRDDs.keys - logInfo("Cleared " + oldRDDs.size + " RDDs that were older than " + - (time - rememberDuration) + ": " + oldRDDs.keys.mkString(", ")) - dependencies.foreach(_.clearOldMetadata(time)) - } - - /* Adds metadata to the Stream while it is running. - * This methd should be overwritten by sublcasses of InputDStream. - */ - protected[streaming] def addMetadata(metadata: Any) { - if (metadata != null) { - logInfo("Dropping Metadata: " + metadata.toString) - } - } - - /** - * Refresh the list of checkpointed RDDs that will be saved along with checkpoint of - * this stream. This is an internal method that should not be called directly. This is - * a default implementation that saves only the file names of the checkpointed RDDs to - * checkpointData. Subclasses of DStream (especially those of InputDStream) may override - * this method to save custom checkpoint data. - */ - protected[streaming] def updateCheckpointData(currentTime: Time) { - logInfo("Updating checkpoint data for time " + currentTime) - checkpointData.update() - dependencies.foreach(_.updateCheckpointData(currentTime)) - checkpointData.cleanup() - logDebug("Updated checkpoint data for time " + currentTime + ": " + checkpointData) - } - - /** - * Restore the RDDs in generatedRDDs from the checkpointData. This is an internal method - * that should not be called directly. This is a default implementation that recreates RDDs - * from the checkpoint file names stored in checkpointData. Subclasses of DStream that - * override the updateCheckpointData() method would also need to override this method. - */ - protected[streaming] def restoreCheckpointData() { - // Create RDDs from the checkpoint data - logInfo("Restoring checkpoint data") - checkpointData.restore() - dependencies.foreach(_.restoreCheckpointData()) - logInfo("Restored checkpoint data") - } - - @throws(classOf[IOException]) - private def writeObject(oos: ObjectOutputStream) { - logDebug(this.getClass().getSimpleName + ".writeObject used") - if (graph != null) { - graph.synchronized { - if (graph.checkpointInProgress) { - oos.defaultWriteObject() - } else { - val msg = "Object of " + this.getClass.getName + " is being serialized " + - " possibly as a part of closure of an RDD operation. This is because " + - " the DStream object is being referred to from within the closure. " + - " Please rewrite the RDD operation inside this DStream to avoid this. " + - " This has been enforced to avoid bloating of Spark tasks " + - " with unnecessary objects." - throw new java.io.NotSerializableException(msg) - } - } - } else { - throw new java.io.NotSerializableException("Graph is unexpectedly null when DStream is being serialized.") - } - } - - @throws(classOf[IOException]) - private def readObject(ois: ObjectInputStream) { - logDebug(this.getClass().getSimpleName + ".readObject used") - ois.defaultReadObject() - generatedRDDs = new HashMap[Time, RDD[T]] () - } - - // ======================================================================= - // DStream operations - // ======================================================================= - - /** Return a new DStream by applying a function to all elements of this DStream. */ - def map[U: ClassManifest](mapFunc: T => U): DStream[U] = { - new MappedDStream(this, context.sparkContext.clean(mapFunc)) - } - - /** - * Return a new DStream by applying a function to all elements of this DStream, - * and then flattening the results - */ - def flatMap[U: ClassManifest](flatMapFunc: T => Traversable[U]): DStream[U] = { - new FlatMappedDStream(this, context.sparkContext.clean(flatMapFunc)) - } - - /** Return a new DStream containing only the elements that satisfy a predicate. */ - def filter(filterFunc: T => Boolean): DStream[T] = new FilteredDStream(this, filterFunc) - - /** - * Return a new DStream in which each RDD is generated by applying glom() to each RDD of - * this DStream. Applying glom() to an RDD coalesces all elements within each partition into - * an array. - */ - def glom(): DStream[Array[T]] = new GlommedDStream(this) - - /** - * Return a new DStream in which each RDD is generated by applying mapPartitions() to each RDDs - * of this DStream. Applying mapPartitions() to an RDD applies a function to each partition - * of the RDD. - */ - def mapPartitions[U: ClassManifest]( - mapPartFunc: Iterator[T] => Iterator[U], - preservePartitioning: Boolean = false - ): DStream[U] = { - new MapPartitionedDStream(this, context.sparkContext.clean(mapPartFunc), preservePartitioning) - } - - /** - * Return a new DStream in which each RDD has a single element generated by reducing each RDD - * of this DStream. - */ - def reduce(reduceFunc: (T, T) => T): DStream[T] = - this.map(x => (null, x)).reduceByKey(reduceFunc, 1).map(_._2) - - /** - * Return a new DStream in which each RDD has a single element generated by counting each RDD - * of this DStream. - */ - def count(): DStream[Long] = { - this.map(_ => (null, 1L)) - .transform(_.union(context.sparkContext.makeRDD(Seq((null, 0L)), 1))) - .reduceByKey(_ + _) - .map(_._2) - } - - /** - * Return a new DStream in which each RDD contains the counts of each distinct value in - * each RDD of this DStream. Hash partitioning is used to generate - * the RDDs with `numPartitions` partitions (Spark's default number of partitions if - * `numPartitions` not specified). - */ - def countByValue(numPartitions: Int = ssc.sc.defaultParallelism): DStream[(T, Long)] = - this.map(x => (x, 1L)).reduceByKey((x: Long, y: Long) => x + y, numPartitions) - - /** - * Apply a function to each RDD in this DStream. This is an output operator, so - * this DStream will be registered as an output stream and therefore materialized. - */ - def foreach(foreachFunc: RDD[T] => Unit) { - this.foreach((r: RDD[T], t: Time) => foreachFunc(r)) - } - - /** - * Apply a function to each RDD in this DStream. This is an output operator, so - * this DStream will be registered as an output stream and therefore materialized. - */ - def foreach(foreachFunc: (RDD[T], Time) => Unit) { - val newStream = new ForEachDStream(this, context.sparkContext.clean(foreachFunc)) - ssc.registerOutputStream(newStream) - newStream - } - - /** - * Return a new DStream in which each RDD is generated by applying a function - * on each RDD of this DStream. - */ - def transform[U: ClassManifest](transformFunc: RDD[T] => RDD[U]): DStream[U] = { - transform((r: RDD[T], t: Time) => transformFunc(r)) - } - - /** - * Return a new DStream in which each RDD is generated by applying a function - * on each RDD of this DStream. - */ - def transform[U: ClassManifest](transformFunc: (RDD[T], Time) => RDD[U]): DStream[U] = { - new TransformedDStream(this, context.sparkContext.clean(transformFunc)) - } - - /** - * Print the first ten elements of each RDD generated in this DStream. This is an output - * operator, so this DStream will be registered as an output stream and there materialized. - */ - def print() { - def foreachFunc = (rdd: RDD[T], time: Time) => { - val first11 = rdd.take(11) - println ("-------------------------------------------") - println ("Time: " + time) - println ("-------------------------------------------") - first11.take(10).foreach(println) - if (first11.size > 10) println("...") - println() - } - val newStream = new ForEachDStream(this, context.sparkContext.clean(foreachFunc)) - ssc.registerOutputStream(newStream) - } - - /** - * Return a new DStream in which each RDD contains all the elements in seen in a - * sliding window of time over this DStream. The new DStream generates RDDs with - * the same interval as this DStream. - * @param windowDuration width of the window; must be a multiple of this DStream's interval. - */ - def window(windowDuration: Duration): DStream[T] = window(windowDuration, this.slideDuration) - - /** - * Return a new DStream in which each RDD contains all the elements in seen in a - * sliding window of time over this DStream. - * @param windowDuration width of the window; must be a multiple of this DStream's - * batching interval - * @param slideDuration sliding interval of the window (i.e., the interval after which - * the new DStream will generate RDDs); must be a multiple of this - * DStream's batching interval - */ - def window(windowDuration: Duration, slideDuration: Duration): DStream[T] = { - new WindowedDStream(this, windowDuration, slideDuration) - } - - /** - * Return a new DStream in which each RDD has a single element generated by reducing all - * elements in a sliding window over this DStream. - * @param reduceFunc associative reduce function - * @param windowDuration width of the window; must be a multiple of this DStream's - * batching interval - * @param slideDuration sliding interval of the window (i.e., the interval after which - * the new DStream will generate RDDs); must be a multiple of this - * DStream's batching interval - */ - def reduceByWindow( - reduceFunc: (T, T) => T, - windowDuration: Duration, - slideDuration: Duration - ): DStream[T] = { - this.reduce(reduceFunc).window(windowDuration, slideDuration).reduce(reduceFunc) - } - - /** - * Return a new DStream in which each RDD has a single element generated by reducing all - * elements in a sliding window over this DStream. However, the reduction is done incrementally - * using the old window's reduced value : - * 1. reduce the new values that entered the window (e.g., adding new counts) - * 2. "inverse reduce" the old values that left the window (e.g., subtracting old counts) - * This is more efficient than reduceByWindow without "inverse reduce" function. - * However, it is applicable to only "invertible reduce functions". - * @param reduceFunc associative reduce function - * @param invReduceFunc inverse reduce function - * @param windowDuration width of the window; must be a multiple of this DStream's - * batching interval - * @param slideDuration sliding interval of the window (i.e., the interval after which - * the new DStream will generate RDDs); must be a multiple of this - * DStream's batching interval - */ - def reduceByWindow( - reduceFunc: (T, T) => T, - invReduceFunc: (T, T) => T, - windowDuration: Duration, - slideDuration: Duration - ): DStream[T] = { - this.map(x => (1, x)) - .reduceByKeyAndWindow(reduceFunc, invReduceFunc, windowDuration, slideDuration, 1) - .map(_._2) - } - - /** - * Return a new DStream in which each RDD has a single element generated by counting the number - * of elements in a sliding window over this DStream. Hash partitioning is used to generate the RDDs with - * Spark's default number of partitions. - * @param windowDuration width of the window; must be a multiple of this DStream's - * batching interval - * @param slideDuration sliding interval of the window (i.e., the interval after which - * the new DStream will generate RDDs); must be a multiple of this - * DStream's batching interval - */ - def countByWindow(windowDuration: Duration, slideDuration: Duration): DStream[Long] = { - this.map(_ => 1L).reduceByWindow(_ + _, _ - _, windowDuration, slideDuration) - } - - /** - * Return a new DStream in which each RDD contains the count of distinct elements in - * RDDs in a sliding window over this DStream. Hash partitioning is used to generate - * the RDDs with `numPartitions` partitions (Spark's default number of partitions if - * `numPartitions` not specified). - * @param windowDuration width of the window; must be a multiple of this DStream's - * batching interval - * @param slideDuration sliding interval of the window (i.e., the interval after which - * the new DStream will generate RDDs); must be a multiple of this - * DStream's batching interval - * @param numPartitions number of partitions of each RDD in the new DStream. - */ - def countByValueAndWindow( - windowDuration: Duration, - slideDuration: Duration, - numPartitions: Int = ssc.sc.defaultParallelism - ): DStream[(T, Long)] = { - - this.map(x => (x, 1L)).reduceByKeyAndWindow( - (x: Long, y: Long) => x + y, - (x: Long, y: Long) => x - y, - windowDuration, - slideDuration, - numPartitions, - (x: (T, Long)) => x._2 != 0L - ) - } - - /** - * Return a new DStream by unifying data of another DStream with this DStream. - * @param that Another DStream having the same slideDuration as this DStream. - */ - def union(that: DStream[T]): DStream[T] = new UnionDStream[T](Array(this, that)) - - /** - * Return all the RDDs defined by the Interval object (both end times included) - */ - protected[streaming] def slice(interval: Interval): Seq[RDD[T]] = { - slice(interval.beginTime, interval.endTime) - } - - /** - * Return all the RDDs between 'fromTime' to 'toTime' (both included) - */ - def slice(fromTime: Time, toTime: Time): Seq[RDD[T]] = { - if (!(fromTime - zeroTime).isMultipleOf(slideDuration)) { - logWarning("fromTime (" + fromTime + ") is not a multiple of slideDuration (" + slideDuration + ")") - } - if (!(toTime - zeroTime).isMultipleOf(slideDuration)) { - logWarning("toTime (" + fromTime + ") is not a multiple of slideDuration (" + slideDuration + ")") - } - val alignedToTime = toTime.floor(slideDuration) - val alignedFromTime = fromTime.floor(slideDuration) - - logInfo("Slicing from " + fromTime + " to " + toTime + - " (aligned to " + alignedFromTime + " and " + alignedToTime + ")") - - alignedFromTime.to(alignedToTime, slideDuration).flatMap(time => { - if (time >= zeroTime) getOrCompute(time) else None - }) - } - - /** - * Save each RDD in this DStream as a Sequence file of serialized objects. - * The file name at each batch interval is generated based on `prefix` and - * `suffix`: "prefix-TIME_IN_MS.suffix". - */ - def saveAsObjectFiles(prefix: String, suffix: String = "") { - val saveFunc = (rdd: RDD[T], time: Time) => { - val file = rddToFileName(prefix, suffix, time) - rdd.saveAsObjectFile(file) - } - this.foreach(saveFunc) - } - - /** - * Save each RDD in this DStream as at text file, using string representation - * of elements. The file name at each batch interval is generated based on - * `prefix` and `suffix`: "prefix-TIME_IN_MS.suffix". - */ - def saveAsTextFiles(prefix: String, suffix: String = "") { - val saveFunc = (rdd: RDD[T], time: Time) => { - val file = rddToFileName(prefix, suffix, time) - rdd.saveAsTextFile(file) - } - this.foreach(saveFunc) - } - - def register() { - ssc.registerOutputStream(this) - } -} diff --git a/streaming/src/main/scala/spark/streaming/DStreamCheckpointData.scala b/streaming/src/main/scala/spark/streaming/DStreamCheckpointData.scala deleted file mode 100644 index 399ca1c63d..0000000000 --- a/streaming/src/main/scala/spark/streaming/DStreamCheckpointData.scala +++ /dev/null @@ -1,110 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.streaming - -import org.apache.hadoop.fs.Path -import org.apache.hadoop.fs.FileSystem -import org.apache.hadoop.conf.Configuration -import collection.mutable.HashMap -import spark.Logging - - - -private[streaming] -class DStreamCheckpointData[T: ClassManifest] (dstream: DStream[T]) - extends Serializable with Logging { - protected val data = new HashMap[Time, AnyRef]() - - @transient private var fileSystem : FileSystem = null - @transient private var lastCheckpointFiles: HashMap[Time, String] = null - - protected[streaming] def checkpointFiles = data.asInstanceOf[HashMap[Time, String]] - - /** - * Updates the checkpoint data of the DStream. This gets called every time - * the graph checkpoint is initiated. Default implementation records the - * checkpoint files to which the generate RDDs of the DStream has been saved. - */ - def update() { - - // Get the checkpointed RDDs from the generated RDDs - val newCheckpointFiles = dstream.generatedRDDs.filter(_._2.getCheckpointFile.isDefined) - .map(x => (x._1, x._2.getCheckpointFile.get)) - - // Make a copy of the existing checkpoint data (checkpointed RDDs) - lastCheckpointFiles = checkpointFiles.clone() - - // If the new checkpoint data has checkpoints then replace existing with the new one - if (newCheckpointFiles.size > 0) { - checkpointFiles.clear() - checkpointFiles ++= newCheckpointFiles - } - - // TODO: remove this, this is just for debugging - newCheckpointFiles.foreach { - case (time, data) => { logInfo("Added checkpointed RDD for time " + time + " to stream checkpoint") } - } - } - - /** - * Cleanup old checkpoint data. This gets called every time the graph - * checkpoint is initiated, but after `update` is called. Default - * implementation, cleans up old checkpoint files. - */ - def cleanup() { - // If there is at least on checkpoint file in the current checkpoint files, - // then delete the old checkpoint files. - if (checkpointFiles.size > 0 && lastCheckpointFiles != null) { - (lastCheckpointFiles -- checkpointFiles.keySet).foreach { - case (time, file) => { - try { - val path = new Path(file) - if (fileSystem == null) { - fileSystem = path.getFileSystem(new Configuration()) - } - fileSystem.delete(path, true) - logInfo("Deleted checkpoint file '" + file + "' for time " + time) - } catch { - case e: Exception => - logWarning("Error deleting old checkpoint file '" + file + "' for time " + time, e) - } - } - } - } - } - - /** - * Restore the checkpoint data. This gets called once when the DStream graph - * (along with its DStreams) are being restored from a graph checkpoint file. - * Default implementation restores the RDDs from their checkpoint files. - */ - def restore() { - // Create RDDs from the checkpoint data - checkpointFiles.foreach { - case(time, file) => { - logInfo("Restoring checkpointed RDD for time " + time + " from file '" + file + "'") - dstream.generatedRDDs += ((time, dstream.context.sparkContext.checkpointFile[T](file))) - } - } - } - - override def toString() = { - "[\n" + checkpointFiles.size + " checkpoint files \n" + checkpointFiles.mkString("\n") + "\n]" - } -} - diff --git a/streaming/src/main/scala/spark/streaming/DStreamGraph.scala b/streaming/src/main/scala/spark/streaming/DStreamGraph.scala deleted file mode 100644 index c09a332d44..0000000000 --- a/streaming/src/main/scala/spark/streaming/DStreamGraph.scala +++ /dev/null @@ -1,167 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.streaming - -import dstream.InputDStream -import java.io.{ObjectInputStream, IOException, ObjectOutputStream} -import collection.mutable.ArrayBuffer -import spark.Logging - -final private[streaming] class DStreamGraph extends Serializable with Logging { - initLogging() - - private val inputStreams = new ArrayBuffer[InputDStream[_]]() - private val outputStreams = new ArrayBuffer[DStream[_]]() - - var rememberDuration: Duration = null - var checkpointInProgress = false - - var zeroTime: Time = null - var startTime: Time = null - var batchDuration: Duration = null - - def start(time: Time) { - this.synchronized { - if (zeroTime != null) { - throw new Exception("DStream graph computation already started") - } - zeroTime = time - startTime = time - outputStreams.foreach(_.initialize(zeroTime)) - outputStreams.foreach(_.remember(rememberDuration)) - outputStreams.foreach(_.validate) - inputStreams.par.foreach(_.start()) - } - } - - def restart(time: Time) { - this.synchronized { startTime = time } - } - - def stop() { - this.synchronized { - inputStreams.par.foreach(_.stop()) - } - } - - def setContext(ssc: StreamingContext) { - this.synchronized { - outputStreams.foreach(_.setContext(ssc)) - } - } - - def setBatchDuration(duration: Duration) { - this.synchronized { - if (batchDuration != null) { - throw new Exception("Batch duration already set as " + batchDuration + - ". cannot set it again.") - } - batchDuration = duration - } - } - - def remember(duration: Duration) { - this.synchronized { - if (rememberDuration != null) { - throw new Exception("Batch duration already set as " + batchDuration + - ". cannot set it again.") - } - rememberDuration = duration - } - } - - def addInputStream(inputStream: InputDStream[_]) { - this.synchronized { - inputStream.setGraph(this) - inputStreams += inputStream - } - } - - def addOutputStream(outputStream: DStream[_]) { - this.synchronized { - outputStream.setGraph(this) - outputStreams += outputStream - } - } - - def getInputStreams() = this.synchronized { inputStreams.toArray } - - def getOutputStreams() = this.synchronized { outputStreams.toArray } - - def generateJobs(time: Time): Seq[Job] = { - this.synchronized { - logInfo("Generating jobs for time " + time) - val jobs = outputStreams.flatMap(outputStream => outputStream.generateJob(time)) - logInfo("Generated " + jobs.length + " jobs for time " + time) - jobs - } - } - - def clearOldMetadata(time: Time) { - this.synchronized { - logInfo("Clearing old metadata for time " + time) - outputStreams.foreach(_.clearOldMetadata(time)) - logInfo("Cleared old metadata for time " + time) - } - } - - def updateCheckpointData(time: Time) { - this.synchronized { - logInfo("Updating checkpoint data for time " + time) - outputStreams.foreach(_.updateCheckpointData(time)) - logInfo("Updated checkpoint data for time " + time) - } - } - - def restoreCheckpointData() { - this.synchronized { - logInfo("Restoring checkpoint data") - outputStreams.foreach(_.restoreCheckpointData()) - logInfo("Restored checkpoint data") - } - } - - def validate() { - this.synchronized { - assert(batchDuration != null, "Batch duration has not been set") - //assert(batchDuration >= Milliseconds(100), "Batch duration of " + batchDuration + " is very low") - assert(getOutputStreams().size > 0, "No output streams registered, so nothing to execute") - } - } - - @throws(classOf[IOException]) - private def writeObject(oos: ObjectOutputStream) { - this.synchronized { - logDebug("DStreamGraph.writeObject used") - checkpointInProgress = true - oos.defaultWriteObject() - checkpointInProgress = false - } - } - - @throws(classOf[IOException]) - private def readObject(ois: ObjectInputStream) { - this.synchronized { - logDebug("DStreamGraph.readObject used") - checkpointInProgress = true - ois.defaultReadObject() - checkpointInProgress = false - } - } -} - diff --git a/streaming/src/main/scala/spark/streaming/Duration.scala b/streaming/src/main/scala/spark/streaming/Duration.scala deleted file mode 100644 index 12a14e233d..0000000000 --- a/streaming/src/main/scala/spark/streaming/Duration.scala +++ /dev/null @@ -1,83 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.streaming - -import spark.Utils - -case class Duration (private val millis: Long) { - - def < (that: Duration): Boolean = (this.millis < that.millis) - - def <= (that: Duration): Boolean = (this.millis <= that.millis) - - def > (that: Duration): Boolean = (this.millis > that.millis) - - def >= (that: Duration): Boolean = (this.millis >= that.millis) - - def + (that: Duration): Duration = new Duration(millis + that.millis) - - def - (that: Duration): Duration = new Duration(millis - that.millis) - - def * (times: Int): Duration = new Duration(millis * times) - - def / (that: Duration): Double = millis.toDouble / that.millis.toDouble - - def isMultipleOf(that: Duration): Boolean = - (this.millis % that.millis == 0) - - def min(that: Duration): Duration = if (this < that) this else that - - def max(that: Duration): Duration = if (this > that) this else that - - def isZero: Boolean = (this.millis == 0) - - override def toString: String = (millis.toString + " ms") - - def toFormattedString: String = millis.toString - - def milliseconds: Long = millis - - def prettyPrint = Utils.msDurationToString(millis) - -} - -/** - * Helper object that creates instance of [[spark.streaming.Duration]] representing - * a given number of milliseconds. - */ -object Milliseconds { - def apply(milliseconds: Long) = new Duration(milliseconds) -} - -/** - * Helper object that creates instance of [[spark.streaming.Duration]] representing - * a given number of seconds. - */ -object Seconds { - def apply(seconds: Long) = new Duration(seconds * 1000) -} - -/** - * Helper object that creates instance of [[spark.streaming.Duration]] representing - * a given number of minutes. - */ -object Minutes { - def apply(minutes: Long) = new Duration(minutes * 60000) -} - - diff --git a/streaming/src/main/scala/spark/streaming/Interval.scala b/streaming/src/main/scala/spark/streaming/Interval.scala deleted file mode 100644 index b30cd969e9..0000000000 --- a/streaming/src/main/scala/spark/streaming/Interval.scala +++ /dev/null @@ -1,59 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.streaming - -private[streaming] -class Interval(val beginTime: Time, val endTime: Time) { - def this(beginMs: Long, endMs: Long) = this(new Time(beginMs), new Time(endMs)) - - def duration(): Duration = endTime - beginTime - - def + (time: Duration): Interval = { - new Interval(beginTime + time, endTime + time) - } - - def - (time: Duration): Interval = { - new Interval(beginTime - time, endTime - time) - } - - def < (that: Interval): Boolean = { - if (this.duration != that.duration) { - throw new Exception("Comparing two intervals with different durations [" + this + ", " + that + "]") - } - this.endTime < that.endTime - } - - def <= (that: Interval) = (this < that || this == that) - - def > (that: Interval) = !(this <= that) - - def >= (that: Interval) = !(this < that) - - override def toString = "[" + beginTime + ", " + endTime + "]" -} - -private[streaming] -object Interval { - def currentInterval(duration: Duration): Interval = { - val time = new Time(System.currentTimeMillis) - val intervalBegin = time.floor(duration) - new Interval(intervalBegin, intervalBegin + duration) - } -} - - diff --git a/streaming/src/main/scala/spark/streaming/Job.scala b/streaming/src/main/scala/spark/streaming/Job.scala deleted file mode 100644 index ceb3f92b65..0000000000 --- a/streaming/src/main/scala/spark/streaming/Job.scala +++ /dev/null @@ -1,41 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.streaming - -import java.util.concurrent.atomic.AtomicLong - -private[streaming] -class Job(val time: Time, func: () => _) { - val id = Job.getNewId() - def run(): Long = { - val startTime = System.currentTimeMillis - func() - val stopTime = System.currentTimeMillis - (stopTime - startTime) - } - - override def toString = "streaming job " + id + " @ " + time -} - -private[streaming] -object Job { - val id = new AtomicLong(0) - - def getNewId() = id.getAndIncrement() -} - diff --git a/streaming/src/main/scala/spark/streaming/JobManager.scala b/streaming/src/main/scala/spark/streaming/JobManager.scala deleted file mode 100644 index a31230689f..0000000000 --- a/streaming/src/main/scala/spark/streaming/JobManager.scala +++ /dev/null @@ -1,88 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.streaming - -import spark.Logging -import spark.SparkEnv -import java.util.concurrent.Executors -import collection.mutable.HashMap -import collection.mutable.ArrayBuffer - - -private[streaming] -class JobManager(ssc: StreamingContext, numThreads: Int = 1) extends Logging { - - class JobHandler(ssc: StreamingContext, job: Job) extends Runnable { - def run() { - SparkEnv.set(ssc.env) - try { - val timeTaken = job.run() - logInfo("Total delay: %.5f s for job %s of time %s (execution: %.5f s)".format( - (System.currentTimeMillis() - job.time.milliseconds) / 1000.0, job.id, job.time.milliseconds, timeTaken / 1000.0)) - } catch { - case e: Exception => - logError("Running " + job + " failed", e) - } - clearJob(job) - } - } - - initLogging() - - val jobExecutor = Executors.newFixedThreadPool(numThreads) - val jobs = new HashMap[Time, ArrayBuffer[Job]] - - def runJob(job: Job) { - jobs.synchronized { - jobs.getOrElseUpdate(job.time, new ArrayBuffer[Job]) += job - } - jobExecutor.execute(new JobHandler(ssc, job)) - logInfo("Added " + job + " to queue") - } - - def stop() { - jobExecutor.shutdown() - } - - private def clearJob(job: Job) { - var timeCleared = false - val time = job.time - jobs.synchronized { - val jobsOfTime = jobs.get(time) - if (jobsOfTime.isDefined) { - jobsOfTime.get -= job - if (jobsOfTime.get.isEmpty) { - jobs -= time - timeCleared = true - } - } else { - throw new Exception("Job finished for time " + job.time + - " but time does not exist in jobs") - } - } - if (timeCleared) { - ssc.scheduler.clearOldMetadata(time) - } - } - - def getPendingTimes(): Array[Time] = { - jobs.synchronized { - jobs.keySet.toArray - } - } -} diff --git a/streaming/src/main/scala/spark/streaming/NetworkInputTracker.scala b/streaming/src/main/scala/spark/streaming/NetworkInputTracker.scala deleted file mode 100644 index d4cf2e568c..0000000000 --- a/streaming/src/main/scala/spark/streaming/NetworkInputTracker.scala +++ /dev/null @@ -1,173 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.streaming - -import spark.streaming.dstream.{NetworkInputDStream, NetworkReceiver} -import spark.streaming.dstream.{StopReceiver, ReportBlock, ReportError} -import spark.Logging -import spark.SparkEnv -import spark.SparkContext._ - -import scala.collection.mutable.HashMap -import scala.collection.mutable.Queue - -import akka.actor._ -import akka.pattern.ask -import akka.util.duration._ -import akka.dispatch._ - -private[streaming] sealed trait NetworkInputTrackerMessage -private[streaming] case class RegisterReceiver(streamId: Int, receiverActor: ActorRef) extends NetworkInputTrackerMessage -private[streaming] case class AddBlocks(streamId: Int, blockIds: Seq[String], metadata: Any) extends NetworkInputTrackerMessage -private[streaming] case class DeregisterReceiver(streamId: Int, msg: String) extends NetworkInputTrackerMessage - -/** - * This class manages the execution of the receivers of NetworkInputDStreams. - */ -private[streaming] -class NetworkInputTracker( - @transient ssc: StreamingContext, - @transient networkInputStreams: Array[NetworkInputDStream[_]]) - extends Logging { - - val networkInputStreamMap = Map(networkInputStreams.map(x => (x.id, x)): _*) - val receiverExecutor = new ReceiverExecutor() - val receiverInfo = new HashMap[Int, ActorRef] - val receivedBlockIds = new HashMap[Int, Queue[String]] - val timeout = 5000.milliseconds - - var currentTime: Time = null - - /** Start the actor and receiver execution thread. */ - def start() { - ssc.env.actorSystem.actorOf(Props(new NetworkInputTrackerActor), "NetworkInputTracker") - receiverExecutor.start() - } - - /** Stop the receiver execution thread. */ - def stop() { - // TODO: stop the actor as well - receiverExecutor.interrupt() - receiverExecutor.stopReceivers() - } - - /** Return all the blocks received from a receiver. */ - def getBlockIds(receiverId: Int, time: Time): Array[String] = synchronized { - val queue = receivedBlockIds.synchronized { - receivedBlockIds.getOrElse(receiverId, new Queue[String]()) - } - val result = queue.synchronized { - queue.dequeueAll(x => true) - } - logInfo("Stream " + receiverId + " received " + result.size + " blocks") - result.toArray - } - - /** Actor to receive messages from the receivers. */ - private class NetworkInputTrackerActor extends Actor { - def receive = { - case RegisterReceiver(streamId, receiverActor) => { - if (!networkInputStreamMap.contains(streamId)) { - throw new Exception("Register received for unexpected id " + streamId) - } - receiverInfo += ((streamId, receiverActor)) - logInfo("Registered receiver for network stream " + streamId + " from " + sender.path.address) - sender ! true - } - case AddBlocks(streamId, blockIds, metadata) => { - val tmp = receivedBlockIds.synchronized { - if (!receivedBlockIds.contains(streamId)) { - receivedBlockIds += ((streamId, new Queue[String])) - } - receivedBlockIds(streamId) - } - tmp.synchronized { - tmp ++= blockIds - } - networkInputStreamMap(streamId).addMetadata(metadata) - } - case DeregisterReceiver(streamId, msg) => { - receiverInfo -= streamId - logError("De-registered receiver for network stream " + streamId - + " with message " + msg) - //TODO: Do something about the corresponding NetworkInputDStream - } - } - } - - /** This thread class runs all the receivers on the cluster. */ - class ReceiverExecutor extends Thread { - val env = ssc.env - - override def run() { - try { - SparkEnv.set(env) - startReceivers() - } catch { - case ie: InterruptedException => logInfo("ReceiverExecutor interrupted") - } finally { - stopReceivers() - } - } - - /** - * Get the receivers from the NetworkInputDStreams, distributes them to the - * worker nodes as a parallel collection, and runs them. - */ - def startReceivers() { - val receivers = networkInputStreams.map(nis => { - val rcvr = nis.getReceiver() - rcvr.setStreamId(nis.id) - rcvr - }) - - // Right now, we only honor preferences if all receivers have them - val hasLocationPreferences = receivers.map(_.getLocationPreference().isDefined).reduce(_ && _) - - // Create the parallel collection of receivers to distributed them on the worker nodes - val tempRDD = - if (hasLocationPreferences) { - val receiversWithPreferences = receivers.map(r => (r, Seq(r.getLocationPreference().toString))) - ssc.sc.makeRDD[NetworkReceiver[_]](receiversWithPreferences) - } - else { - ssc.sc.makeRDD(receivers, receivers.size) - } - - // Function to start the receiver on the worker node - val startReceiver = (iterator: Iterator[NetworkReceiver[_]]) => { - if (!iterator.hasNext) { - throw new Exception("Could not start receiver as details not found.") - } - iterator.next().start() - } - // Run the dummy Spark job to ensure that all slaves have registered. - // This avoids all the receivers to be scheduled on the same node. - ssc.sparkContext.makeRDD(1 to 50, 50).map(x => (x, 1)).reduceByKey(_ + _, 20).collect() - - // Distribute the receivers and start them - ssc.sparkContext.runJob(tempRDD, startReceiver) - } - - /** Stops the receivers. */ - def stopReceivers() { - // Signal the receivers to stop - receiverInfo.values.foreach(_ ! StopReceiver) - } - } -} diff --git a/streaming/src/main/scala/spark/streaming/PairDStreamFunctions.scala b/streaming/src/main/scala/spark/streaming/PairDStreamFunctions.scala deleted file mode 100644 index 47bf07bee1..0000000000 --- a/streaming/src/main/scala/spark/streaming/PairDStreamFunctions.scala +++ /dev/null @@ -1,534 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.streaming - -import spark.streaming.StreamingContext._ -import spark.streaming.dstream.{ReducedWindowedDStream, StateDStream} -import spark.streaming.dstream.{CoGroupedDStream, ShuffledDStream} -import spark.streaming.dstream.{MapValuedDStream, FlatMapValuedDStream} - -import spark.{Manifests, RDD, Partitioner, HashPartitioner} -import spark.SparkContext._ -import spark.storage.StorageLevel - -import scala.collection.mutable.ArrayBuffer - -import org.apache.hadoop.mapred.{JobConf, OutputFormat} -import org.apache.hadoop.mapreduce.{OutputFormat => NewOutputFormat} -import org.apache.hadoop.mapred.OutputFormat -import org.apache.hadoop.security.UserGroupInformation -import org.apache.hadoop.conf.Configuration - -class PairDStreamFunctions[K: ClassManifest, V: ClassManifest](self: DStream[(K,V)]) -extends Serializable { - - private[streaming] def ssc = self.ssc - - private[streaming] def defaultPartitioner(numPartitions: Int = self.ssc.sc.defaultParallelism) = { - new HashPartitioner(numPartitions) - } - - /** - * Return a new DStream by applying `groupByKey` to each RDD. Hash partitioning is used to - * generate the RDDs with Spark's default number of partitions. - */ - def groupByKey(): DStream[(K, Seq[V])] = { - groupByKey(defaultPartitioner()) - } - - /** - * Return a new DStream by applying `groupByKey` to each RDD. Hash partitioning is used to - * generate the RDDs with `numPartitions` partitions. - */ - def groupByKey(numPartitions: Int): DStream[(K, Seq[V])] = { - groupByKey(defaultPartitioner(numPartitions)) - } - - /** - * Return a new DStream by applying `groupByKey` on each RDD. The supplied [[spark.Partitioner]] - * is used to control the partitioning of each RDD. - */ - def groupByKey(partitioner: Partitioner): DStream[(K, Seq[V])] = { - val createCombiner = (v: V) => ArrayBuffer[V](v) - val mergeValue = (c: ArrayBuffer[V], v: V) => (c += v) - val mergeCombiner = (c1: ArrayBuffer[V], c2: ArrayBuffer[V]) => (c1 ++ c2) - combineByKey(createCombiner, mergeValue, mergeCombiner, partitioner) - .asInstanceOf[DStream[(K, Seq[V])]] - } - - /** - * Return a new DStream by applying `reduceByKey` to each RDD. The values for each key are - * merged using the associative reduce function. Hash partitioning is used to generate the RDDs - * with Spark's default number of partitions. - */ - def reduceByKey(reduceFunc: (V, V) => V): DStream[(K, V)] = { - reduceByKey(reduceFunc, defaultPartitioner()) - } - - /** - * Return a new DStream by applying `reduceByKey` to each RDD. The values for each key are - * merged using the supplied reduce function. Hash partitioning is used to generate the RDDs - * with `numPartitions` partitions. - */ - def reduceByKey(reduceFunc: (V, V) => V, numPartitions: Int): DStream[(K, V)] = { - reduceByKey(reduceFunc, defaultPartitioner(numPartitions)) - } - - /** - * Return a new DStream by applying `reduceByKey` to each RDD. The values for each key are - * merged using the supplied reduce function. [[spark.Partitioner]] is used to control the - * partitioning of each RDD. - */ - def reduceByKey(reduceFunc: (V, V) => V, partitioner: Partitioner): DStream[(K, V)] = { - val cleanedReduceFunc = ssc.sc.clean(reduceFunc) - combineByKey((v: V) => v, cleanedReduceFunc, cleanedReduceFunc, partitioner) - } - - /** - * Combine elements of each key in DStream's RDDs using custom functions. This is similar to the - * combineByKey for RDDs. Please refer to combineByKey in [[spark.PairRDDFunctions]] for more - * information. - */ - def combineByKey[C: ClassManifest]( - createCombiner: V => C, - mergeValue: (C, V) => C, - mergeCombiner: (C, C) => C, - partitioner: Partitioner) : DStream[(K, C)] = { - new ShuffledDStream[K, V, C](self, createCombiner, mergeValue, mergeCombiner, partitioner) - } - - /** - * Return a new DStream by applying `groupByKey` over a sliding window. This is similar to - * `DStream.groupByKey()` but applies it over a sliding window. The new DStream generates RDDs - * with the same interval as this DStream. Hash partitioning is used to generate the RDDs with - * Spark's default number of partitions. - * @param windowDuration width of the window; must be a multiple of this DStream's - * batching interval - */ - def groupByKeyAndWindow(windowDuration: Duration): DStream[(K, Seq[V])] = { - groupByKeyAndWindow(windowDuration, self.slideDuration, defaultPartitioner()) - } - - /** - * Return a new DStream by applying `groupByKey` over a sliding window. Similar to - * `DStream.groupByKey()`, but applies it over a sliding window. Hash partitioning is used to - * generate the RDDs with Spark's default number of partitions. - * @param windowDuration width of the window; must be a multiple of this DStream's - * batching interval - * @param slideDuration sliding interval of the window (i.e., the interval after which - * the new DStream will generate RDDs); must be a multiple of this - * DStream's batching interval - */ - def groupByKeyAndWindow(windowDuration: Duration, slideDuration: Duration): DStream[(K, Seq[V])] = { - groupByKeyAndWindow(windowDuration, slideDuration, defaultPartitioner()) - } - - /** - * Return a new DStream by applying `groupByKey` over a sliding window on `this` DStream. - * Similar to `DStream.groupByKey()`, but applies it over a sliding window. - * Hash partitioning is used to generate the RDDs with `numPartitions` partitions. - * @param windowDuration width of the window; must be a multiple of this DStream's - * batching interval - * @param slideDuration sliding interval of the window (i.e., the interval after which - * the new DStream will generate RDDs); must be a multiple of this - * DStream's batching interval - * @param numPartitions number of partitions of each RDD in the new DStream; if not specified - * then Spark's default number of partitions will be used - */ - def groupByKeyAndWindow( - windowDuration: Duration, - slideDuration: Duration, - numPartitions: Int - ): DStream[(K, Seq[V])] = { - groupByKeyAndWindow(windowDuration, slideDuration, defaultPartitioner(numPartitions)) - } - - /** - * Create a new DStream by applying `groupByKey` over a sliding window on `this` DStream. - * Similar to `DStream.groupByKey()`, but applies it over a sliding window. - * @param windowDuration width of the window; must be a multiple of this DStream's - * batching interval - * @param slideDuration sliding interval of the window (i.e., the interval after which - * the new DStream will generate RDDs); must be a multiple of this - * DStream's batching interval - * @param partitioner partitioner for controlling the partitioning of each RDD in the new DStream. - */ - def groupByKeyAndWindow( - windowDuration: Duration, - slideDuration: Duration, - partitioner: Partitioner - ): DStream[(K, Seq[V])] = { - self.window(windowDuration, slideDuration).groupByKey(partitioner) - } - - /** - * Return a new DStream by applying `reduceByKey` over a sliding window on `this` DStream. - * Similar to `DStream.reduceByKey()`, but applies it over a sliding window. The new DStream - * generates RDDs with the same interval as this DStream. Hash partitioning is used to generate - * the RDDs with Spark's default number of partitions. - * @param reduceFunc associative reduce function - * @param windowDuration width of the window; must be a multiple of this DStream's - * batching interval - */ - def reduceByKeyAndWindow( - reduceFunc: (V, V) => V, - windowDuration: Duration - ): DStream[(K, V)] = { - reduceByKeyAndWindow(reduceFunc, windowDuration, self.slideDuration, defaultPartitioner()) - } - - /** - * Return a new DStream by applying `reduceByKey` over a sliding window. This is similar to - * `DStream.reduceByKey()` but applies it over a sliding window. Hash partitioning is used to - * generate the RDDs with Spark's default number of partitions. - * @param reduceFunc associative reduce function - * @param windowDuration width of the window; must be a multiple of this DStream's - * batching interval - * @param slideDuration sliding interval of the window (i.e., the interval after which - * the new DStream will generate RDDs); must be a multiple of this - * DStream's batching interval - */ - def reduceByKeyAndWindow( - reduceFunc: (V, V) => V, - windowDuration: Duration, - slideDuration: Duration - ): DStream[(K, V)] = { - reduceByKeyAndWindow(reduceFunc, windowDuration, slideDuration, defaultPartitioner()) - } - - /** - * Return a new DStream by applying `reduceByKey` over a sliding window. This is similar to - * `DStream.reduceByKey()` but applies it over a sliding window. Hash partitioning is used to - * generate the RDDs with `numPartitions` partitions. - * @param reduceFunc associative reduce function - * @param windowDuration width of the window; must be a multiple of this DStream's - * batching interval - * @param slideDuration sliding interval of the window (i.e., the interval after which - * the new DStream will generate RDDs); must be a multiple of this - * DStream's batching interval - * @param numPartitions number of partitions of each RDD in the new DStream. - */ - def reduceByKeyAndWindow( - reduceFunc: (V, V) => V, - windowDuration: Duration, - slideDuration: Duration, - numPartitions: Int - ): DStream[(K, V)] = { - reduceByKeyAndWindow(reduceFunc, windowDuration, slideDuration, defaultPartitioner(numPartitions)) - } - - /** - * Return a new DStream by applying `reduceByKey` over a sliding window. Similar to - * `DStream.reduceByKey()`, but applies it over a sliding window. - * @param reduceFunc associative reduce function - * @param windowDuration width of the window; must be a multiple of this DStream's - * batching interval - * @param slideDuration sliding interval of the window (i.e., the interval after which - * the new DStream will generate RDDs); must be a multiple of this - * DStream's batching interval - * @param partitioner partitioner for controlling the partitioning of each RDD - * in the new DStream. - */ - def reduceByKeyAndWindow( - reduceFunc: (V, V) => V, - windowDuration: Duration, - slideDuration: Duration, - partitioner: Partitioner - ): DStream[(K, V)] = { - val cleanedReduceFunc = ssc.sc.clean(reduceFunc) - self.reduceByKey(cleanedReduceFunc, partitioner) - .window(windowDuration, slideDuration) - .reduceByKey(cleanedReduceFunc, partitioner) - } - - /** - * Return a new DStream by applying incremental `reduceByKey` over a sliding window. - * The reduced value of over a new window is calculated using the old window's reduced value : - * 1. reduce the new values that entered the window (e.g., adding new counts) - * - * 2. "inverse reduce" the old values that left the window (e.g., subtracting old counts) - * - * This is more efficient than reduceByKeyAndWindow without "inverse reduce" function. - * However, it is applicable to only "invertible reduce functions". - * Hash partitioning is used to generate the RDDs with Spark's default number of partitions. - * @param reduceFunc associative reduce function - * @param invReduceFunc inverse reduce function - * @param windowDuration width of the window; must be a multiple of this DStream's - * batching interval - * @param slideDuration sliding interval of the window (i.e., the interval after which - * the new DStream will generate RDDs); must be a multiple of this - * DStream's batching interval - * @param filterFunc Optional function to filter expired key-value pairs; - * only pairs that satisfy the function are retained - */ - def reduceByKeyAndWindow( - reduceFunc: (V, V) => V, - invReduceFunc: (V, V) => V, - windowDuration: Duration, - slideDuration: Duration = self.slideDuration, - numPartitions: Int = ssc.sc.defaultParallelism, - filterFunc: ((K, V)) => Boolean = null - ): DStream[(K, V)] = { - - reduceByKeyAndWindow( - reduceFunc, invReduceFunc, windowDuration, - slideDuration, defaultPartitioner(numPartitions), filterFunc - ) - } - - /** - * Return a new DStream by applying incremental `reduceByKey` over a sliding window. - * The reduced value of over a new window is calculated using the old window's reduced value : - * 1. reduce the new values that entered the window (e.g., adding new counts) - * 2. "inverse reduce" the old values that left the window (e.g., subtracting old counts) - * This is more efficient than reduceByKeyAndWindow without "inverse reduce" function. - * However, it is applicable to only "invertible reduce functions". - * @param reduceFunc associative reduce function - * @param invReduceFunc inverse reduce function - * @param windowDuration width of the window; must be a multiple of this DStream's - * batching interval - * @param slideDuration sliding interval of the window (i.e., the interval after which - * the new DStream will generate RDDs); must be a multiple of this - * DStream's batching interval - * @param partitioner partitioner for controlling the partitioning of each RDD in the new DStream. - * @param filterFunc Optional function to filter expired key-value pairs; - * only pairs that satisfy the function are retained - */ - def reduceByKeyAndWindow( - reduceFunc: (V, V) => V, - invReduceFunc: (V, V) => V, - windowDuration: Duration, - slideDuration: Duration, - partitioner: Partitioner, - filterFunc: ((K, V)) => Boolean - ): DStream[(K, V)] = { - - val cleanedReduceFunc = ssc.sc.clean(reduceFunc) - val cleanedInvReduceFunc = ssc.sc.clean(invReduceFunc) - val cleanedFilterFunc = if (filterFunc != null) Some(ssc.sc.clean(filterFunc)) else None - new ReducedWindowedDStream[K, V]( - self, cleanedReduceFunc, cleanedInvReduceFunc, cleanedFilterFunc, - windowDuration, slideDuration, partitioner - ) - } - - /** - * Return a new "state" DStream where the state for each key is updated by applying - * the given function on the previous state of the key and the new values of each key. - * Hash partitioning is used to generate the RDDs with Spark's default number of partitions. - * @param updateFunc State update function. If `this` function returns None, then - * corresponding state key-value pair will be eliminated. - * @tparam S State type - */ - def updateStateByKey[S: ClassManifest]( - updateFunc: (Seq[V], Option[S]) => Option[S] - ): DStream[(K, S)] = { - updateStateByKey(updateFunc, defaultPartitioner()) - } - - /** - * Return a new "state" DStream where the state for each key is updated by applying - * the given function on the previous state of the key and the new values of each key. - * Hash partitioning is used to generate the RDDs with `numPartitions` partitions. - * @param updateFunc State update function. If `this` function returns None, then - * corresponding state key-value pair will be eliminated. - * @param numPartitions Number of partitions of each RDD in the new DStream. - * @tparam S State type - */ - def updateStateByKey[S: ClassManifest]( - updateFunc: (Seq[V], Option[S]) => Option[S], - numPartitions: Int - ): DStream[(K, S)] = { - updateStateByKey(updateFunc, defaultPartitioner(numPartitions)) - } - - /** - * Create a new "state" DStream where the state for each key is updated by applying - * the given function on the previous state of the key and the new values of the key. - * [[spark.Partitioner]] is used to control the partitioning of each RDD. - * @param updateFunc State update function. If `this` function returns None, then - * corresponding state key-value pair will be eliminated. - * @param partitioner Partitioner for controlling the partitioning of each RDD in the new DStream. - * @tparam S State type - */ - def updateStateByKey[S: ClassManifest]( - updateFunc: (Seq[V], Option[S]) => Option[S], - partitioner: Partitioner - ): DStream[(K, S)] = { - val newUpdateFunc = (iterator: Iterator[(K, Seq[V], Option[S])]) => { - iterator.flatMap(t => updateFunc(t._2, t._3).map(s => (t._1, s))) - } - updateStateByKey(newUpdateFunc, partitioner, true) - } - - /** - * Return a new "state" DStream where the state for each key is updated by applying - * the given function on the previous state of the key and the new values of each key. - * [[spark.Paxrtitioner]] is used to control the partitioning of each RDD. - * @param updateFunc State update function. If `this` function returns None, then - * corresponding state key-value pair will be eliminated. Note, that - * this function may generate a different a tuple with a different key - * than the input key. It is up to the developer to decide whether to - * remember the partitioner despite the key being changed. - * @param partitioner Partitioner for controlling the partitioning of each RDD in the new DStream. - * @param rememberPartitioner Whether to remember the paritioner object in the generated RDDs. - * @tparam S State type - */ - def updateStateByKey[S: ClassManifest]( - updateFunc: (Iterator[(K, Seq[V], Option[S])]) => Iterator[(K, S)], - partitioner: Partitioner, - rememberPartitioner: Boolean - ): DStream[(K, S)] = { - new StateDStream(self, ssc.sc.clean(updateFunc), partitioner, rememberPartitioner) - } - - - def mapValues[U: ClassManifest](mapValuesFunc: V => U): DStream[(K, U)] = { - new MapValuedDStream[K, V, U](self, mapValuesFunc) - } - - def flatMapValues[U: ClassManifest]( - flatMapValuesFunc: V => TraversableOnce[U] - ): DStream[(K, U)] = { - new FlatMapValuedDStream[K, V, U](self, flatMapValuesFunc) - } - - /** - * Cogroup `this` DStream with `other` DStream. For each key k in corresponding RDDs of `this` - * or `other` DStreams, the generated RDD will contains a tuple with the list of values for that - * key in both RDDs. HashPartitioner is used to partition each generated RDD into default number - * of partitions. - */ - def cogroup[W: ClassManifest](other: DStream[(K, W)]): DStream[(K, (Seq[V], Seq[W]))] = { - cogroup(other, defaultPartitioner()) - } - - /** - * Cogroup `this` DStream with `other` DStream using a partitioner. For each key k in corresponding RDDs of `this` - * or `other` DStreams, the generated RDD will contains a tuple with the list of values for that - * key in both RDDs. Partitioner is used to partition each generated RDD. - */ - def cogroup[W: ClassManifest]( - other: DStream[(K, W)], - partitioner: Partitioner - ): DStream[(K, (Seq[V], Seq[W]))] = { - - val cgd = new CoGroupedDStream[K]( - Seq(self.asInstanceOf[DStream[(K, _)]], other.asInstanceOf[DStream[(K, _)]]), - partitioner - ) - val pdfs = new PairDStreamFunctions[K, Seq[Seq[_]]](cgd)( - classManifest[K], - Manifests.seqSeqManifest - ) - pdfs.mapValues { - case Seq(vs, ws) => - (vs.asInstanceOf[Seq[V]], ws.asInstanceOf[Seq[W]]) - } - } - - /** - * Join `this` DStream with `other` DStream. HashPartitioner is used - * to partition each generated RDD into default number of partitions. - */ - def join[W: ClassManifest](other: DStream[(K, W)]): DStream[(K, (V, W))] = { - join[W](other, defaultPartitioner()) - } - - /** - * Join `this` DStream with `other` DStream, that is, each RDD of the new DStream will - * be generated by joining RDDs from `this` and other DStream. Uses the given - * Partitioner to partition each generated RDD. - */ - def join[W: ClassManifest]( - other: DStream[(K, W)], - partitioner: Partitioner - ): DStream[(K, (V, W))] = { - this.cogroup(other, partitioner) - .flatMapValues{ - case (vs, ws) => - for (v <- vs.iterator; w <- ws.iterator) yield (v, w) - } - } - - /** - * Save each RDD in `this` DStream as a Hadoop file. The file name at each batch interval is generated - * based on `prefix` and `suffix`: "prefix-TIME_IN_MS.suffix" - */ - def saveAsHadoopFiles[F <: OutputFormat[K, V]]( - prefix: String, - suffix: String - )(implicit fm: ClassManifest[F]) { - saveAsHadoopFiles(prefix, suffix, getKeyClass, getValueClass, fm.erasure.asInstanceOf[Class[F]]) - } - - /** - * Save each RDD in `this` DStream as a Hadoop file. The file name at each batch interval is generated - * based on `prefix` and `suffix`: "prefix-TIME_IN_MS.suffix" - */ - def saveAsHadoopFiles( - prefix: String, - suffix: String, - keyClass: Class[_], - valueClass: Class[_], - outputFormatClass: Class[_ <: OutputFormat[_, _]], - conf: JobConf = new JobConf - ) { - val saveFunc = (rdd: RDD[(K, V)], time: Time) => { - val file = rddToFileName(prefix, suffix, time) - rdd.saveAsHadoopFile(file, keyClass, valueClass, outputFormatClass, conf) - } - self.foreach(saveFunc) - } - - /** - * Save each RDD in `this` DStream as a Hadoop file. The file name at each batch interval is - * generated based on `prefix` and `suffix`: "prefix-TIME_IN_MS.suffix". - */ - def saveAsNewAPIHadoopFiles[F <: NewOutputFormat[K, V]]( - prefix: String, - suffix: String - )(implicit fm: ClassManifest[F]) { - saveAsNewAPIHadoopFiles(prefix, suffix, getKeyClass, getValueClass, fm.erasure.asInstanceOf[Class[F]]) - } - - /** - * Save each RDD in `this` DStream as a Hadoop file. The file name at each batch interval is - * generated based on `prefix` and `suffix`: "prefix-TIME_IN_MS.suffix". - */ - def saveAsNewAPIHadoopFiles( - prefix: String, - suffix: String, - keyClass: Class[_], - valueClass: Class[_], - outputFormatClass: Class[_ <: NewOutputFormat[_, _]], - conf: Configuration = new Configuration - ) { - val saveFunc = (rdd: RDD[(K, V)], time: Time) => { - val file = rddToFileName(prefix, suffix, time) - rdd.saveAsNewAPIHadoopFile(file, keyClass, valueClass, outputFormatClass, conf) - } - self.foreach(saveFunc) - } - - private def getKeyClass() = implicitly[ClassManifest[K]].erasure - - private def getValueClass() = implicitly[ClassManifest[V]].erasure -} - - diff --git a/streaming/src/main/scala/spark/streaming/Scheduler.scala b/streaming/src/main/scala/spark/streaming/Scheduler.scala deleted file mode 100644 index 252cc2a303..0000000000 --- a/streaming/src/main/scala/spark/streaming/Scheduler.scala +++ /dev/null @@ -1,130 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.streaming - -import util.{ManualClock, RecurringTimer, Clock} -import spark.SparkEnv -import spark.Logging - -private[streaming] -class Scheduler(ssc: StreamingContext) extends Logging { - - initLogging() - - val concurrentJobs = System.getProperty("spark.streaming.concurrentJobs", "1").toInt - val jobManager = new JobManager(ssc, concurrentJobs) - val checkpointWriter = if (ssc.checkpointDuration != null && ssc.checkpointDir != null) { - new CheckpointWriter(ssc.checkpointDir) - } else { - null - } - - val clockClass = System.getProperty("spark.streaming.clock", "spark.streaming.util.SystemClock") - val clock = Class.forName(clockClass).newInstance().asInstanceOf[Clock] - val timer = new RecurringTimer(clock, ssc.graph.batchDuration.milliseconds, - longTime => generateJobs(new Time(longTime))) - val graph = ssc.graph - var latestTime: Time = null - - def start() = synchronized { - if (ssc.isCheckpointPresent) { - restart() - } else { - startFirstTime() - } - logInfo("Scheduler started") - } - - def stop() = synchronized { - timer.stop() - jobManager.stop() - if (checkpointWriter != null) checkpointWriter.stop() - ssc.graph.stop() - logInfo("Scheduler stopped") - } - - private def startFirstTime() { - val startTime = new Time(timer.getStartTime()) - graph.start(startTime - graph.batchDuration) - timer.start(startTime.milliseconds) - logInfo("Scheduler's timer started at " + startTime) - } - - private def restart() { - - // If manual clock is being used for testing, then - // either set the manual clock to the last checkpointed time, - // or if the property is defined set it to that time - if (clock.isInstanceOf[ManualClock]) { - val lastTime = ssc.initialCheckpoint.checkpointTime.milliseconds - val jumpTime = System.getProperty("spark.streaming.manualClock.jump", "0").toLong - clock.asInstanceOf[ManualClock].setTime(lastTime + jumpTime) - } - - val batchDuration = ssc.graph.batchDuration - - // Batches when the master was down, that is, - // between the checkpoint and current restart time - val checkpointTime = ssc.initialCheckpoint.checkpointTime - val restartTime = new Time(timer.getRestartTime(graph.zeroTime.milliseconds)) - val downTimes = checkpointTime.until(restartTime, batchDuration) - logInfo("Batches during down time: " + downTimes.mkString(", ")) - - // Batches that were unprocessed before failure - val pendingTimes = ssc.initialCheckpoint.pendingTimes - logInfo("Batches pending processing: " + pendingTimes.mkString(", ")) - // Reschedule jobs for these times - val timesToReschedule = (pendingTimes ++ downTimes).distinct.sorted(Time.ordering) - logInfo("Batches to reschedule: " + timesToReschedule.mkString(", ")) - timesToReschedule.foreach(time => - graph.generateJobs(time).foreach(jobManager.runJob) - ) - - // Restart the timer - timer.start(restartTime.milliseconds) - logInfo("Scheduler's timer restarted at " + restartTime) - } - - /** Generate jobs and perform checkpoint for the given `time`. */ - def generateJobs(time: Time) { - SparkEnv.set(ssc.env) - logInfo("\n-----------------------------------------------------\n") - graph.generateJobs(time).foreach(jobManager.runJob) - latestTime = time - doCheckpoint(time) - } - - /** - * Clear old metadata assuming jobs of `time` have finished processing. - * And also perform checkpoint. - */ - def clearOldMetadata(time: Time) { - ssc.graph.clearOldMetadata(time) - doCheckpoint(time) - } - - /** Perform checkpoint for the give `time`. */ - def doCheckpoint(time: Time) = synchronized { - if (ssc.checkpointDuration != null && (time - graph.zeroTime).isMultipleOf(ssc.checkpointDuration)) { - logInfo("Checkpointing graph for time " + time) - ssc.graph.updateCheckpointData(time) - checkpointWriter.write(new Checkpoint(ssc, time)) - } - } -} - diff --git a/streaming/src/main/scala/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/spark/streaming/StreamingContext.scala deleted file mode 100644 index 62c95b573a..0000000000 --- a/streaming/src/main/scala/spark/streaming/StreamingContext.scala +++ /dev/null @@ -1,563 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.streaming - -import akka.actor.Props -import akka.actor.SupervisorStrategy -import akka.zeromq.Subscribe - -import spark.streaming.dstream._ - -import spark._ -import spark.streaming.receivers.ActorReceiver -import spark.streaming.receivers.ReceiverSupervisorStrategy -import spark.streaming.receivers.ZeroMQReceiver -import spark.storage.StorageLevel -import spark.util.MetadataCleaner -import spark.streaming.receivers.ActorReceiver - -import scala.collection.mutable.Queue -import scala.collection.Map - -import java.io.InputStream -import java.util.concurrent.atomic.AtomicInteger -import java.util.UUID - -import org.apache.hadoop.io.LongWritable -import org.apache.hadoop.io.Text -import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat} -import org.apache.hadoop.mapreduce.lib.input.TextInputFormat -import org.apache.hadoop.fs.Path -import twitter4j.Status -import twitter4j.auth.Authorization - - -/** - * A StreamingContext is the main entry point for Spark Streaming functionality. Besides the basic - * information (such as, cluster URL and job name) to internally create a SparkContext, it provides - * methods used to create DStream from various input sources. - */ -class StreamingContext private ( - sc_ : SparkContext, - cp_ : Checkpoint, - batchDur_ : Duration - ) extends Logging { - - /** - * Create a StreamingContext using an existing SparkContext. - * @param sparkContext Existing SparkContext - * @param batchDuration The time interval at which streaming data will be divided into batches - */ - def this(sparkContext: SparkContext, batchDuration: Duration) = { - this(sparkContext, null, batchDuration) - } - - /** - * Create a StreamingContext by providing the details necessary for creating a new SparkContext. - * @param master Cluster URL to connect to (e.g. mesos://host:port, spark://host:port, local[4]). - * @param appName A name for your job, to display on the cluster web UI - * @param batchDuration The time interval at which streaming data will be divided into batches - */ - def this( - master: String, - appName: String, - batchDuration: Duration, - sparkHome: String = null, - jars: Seq[String] = Nil, - environment: Map[String, String] = Map()) = { - this(StreamingContext.createNewSparkContext(master, appName, sparkHome, jars, environment), - null, batchDuration) - } - - - /** - * Re-create a StreamingContext from a checkpoint file. - * @param path Path either to the directory that was specified as the checkpoint directory, or - * to the checkpoint file 'graph' or 'graph.bk'. - */ - def this(path: String) = this(null, CheckpointReader.read(path), null) - - initLogging() - - if (sc_ == null && cp_ == null) { - throw new Exception("Spark Streaming cannot be initialized with " + - "both SparkContext and checkpoint as null") - } - - if (MetadataCleaner.getDelaySeconds < 0) { - throw new SparkException("Spark Streaming cannot be used without setting spark.cleaner.ttl; " - + "set this property before creating a SparkContext (use SPARK_JAVA_OPTS for the shell)") - } - - protected[streaming] val isCheckpointPresent = (cp_ != null) - - protected[streaming] val sc: SparkContext = { - if (isCheckpointPresent) { - new SparkContext(cp_.master, cp_.framework, cp_.sparkHome, cp_.jars, cp_.environment) - } else { - sc_ - } - } - - protected[streaming] val env = SparkEnv.get - - protected[streaming] val graph: DStreamGraph = { - if (isCheckpointPresent) { - cp_.graph.setContext(this) - cp_.graph.restoreCheckpointData() - cp_.graph - } else { - assert(batchDur_ != null, "Batch duration for streaming context cannot be null") - val newGraph = new DStreamGraph() - newGraph.setBatchDuration(batchDur_) - newGraph - } - } - - protected[streaming] val nextNetworkInputStreamId = new AtomicInteger(0) - protected[streaming] var networkInputTracker: NetworkInputTracker = null - - protected[streaming] var checkpointDir: String = { - if (isCheckpointPresent) { - sc.setCheckpointDir(StreamingContext.getSparkCheckpointDir(cp_.checkpointDir), true) - cp_.checkpointDir - } else { - null - } - } - - protected[streaming] var checkpointDuration: Duration = if (isCheckpointPresent) cp_.checkpointDuration else null - protected[streaming] var receiverJobThread: Thread = null - protected[streaming] var scheduler: Scheduler = null - - /** - * Return the associated Spark context - */ - def sparkContext = sc - - /** - * Set each DStreams in this context to remember RDDs it generated in the last given duration. - * DStreams remember RDDs only for a limited duration of time and releases them for garbage - * collection. This method allows the developer to specify how to long to remember the RDDs ( - * if the developer wishes to query old data outside the DStream computation). - * @param duration Minimum duration that each DStream should remember its RDDs - */ - def remember(duration: Duration) { - graph.remember(duration) - } - - /** - * Set the context to periodically checkpoint the DStream operations for master - * fault-tolerance. The graph will be checkpointed every batch interval. - * @param directory HDFS-compatible directory where the checkpoint data will be reliably stored - */ - def checkpoint(directory: String) { - if (directory != null) { - sc.setCheckpointDir(StreamingContext.getSparkCheckpointDir(directory)) - checkpointDir = directory - } else { - checkpointDir = null - } - } - - protected[streaming] def initialCheckpoint: Checkpoint = { - if (isCheckpointPresent) cp_ else null - } - - protected[streaming] def getNewNetworkStreamId() = nextNetworkInputStreamId.getAndIncrement() - - /** - * Create an input stream with any arbitrary user implemented network receiver. - * Find more details at: http://spark-project.org/docs/latest/streaming-custom-receivers.html - * @param receiver Custom implementation of NetworkReceiver - */ - def networkStream[T: ClassManifest]( - receiver: NetworkReceiver[T]): DStream[T] = { - val inputStream = new PluggableInputDStream[T](this, - receiver) - graph.addInputStream(inputStream) - inputStream - } - - /** - * Create an input stream with any arbitrary user implemented actor receiver. - * Find more details at: http://spark-project.org/docs/latest/streaming-custom-receivers.html - * @param props Props object defining creation of the actor - * @param name Name of the actor - * @param storageLevel RDD storage level. Defaults to memory-only. - * - * @note An important point to note: - * Since Actor may exist outside the spark framework, It is thus user's responsibility - * to ensure the type safety, i.e parametrized type of data received and actorStream - * should be same. - */ - def actorStream[T: ClassManifest]( - props: Props, - name: String, - storageLevel: StorageLevel = StorageLevel.MEMORY_ONLY_SER_2, - supervisorStrategy: SupervisorStrategy = ReceiverSupervisorStrategy.defaultStrategy - ): DStream[T] = { - networkStream(new ActorReceiver[T](props, name, storageLevel, supervisorStrategy)) - } - - /** - * Create an input stream that receives messages pushed by a zeromq publisher. - * @param publisherUrl Url of remote zeromq publisher - * @param subscribe topic to subscribe to - * @param bytesToObjects A zeroMQ stream publishes sequence of frames for each topic - * and each frame has sequence of byte thus it needs the converter - * (which might be deserializer of bytes) to translate from sequence - * of sequence of bytes, where sequence refer to a frame - * and sub sequence refer to its payload. - * @param storageLevel RDD storage level. Defaults to memory-only. - */ - def zeroMQStream[T: ClassManifest]( - publisherUrl:String, - subscribe: Subscribe, - bytesToObjects: Seq[Seq[Byte]] ⇒ Iterator[T], - storageLevel: StorageLevel = StorageLevel.MEMORY_ONLY_SER_2, - supervisorStrategy: SupervisorStrategy = ReceiverSupervisorStrategy.defaultStrategy - ): DStream[T] = { - actorStream(Props(new ZeroMQReceiver(publisherUrl,subscribe,bytesToObjects)), - "ZeroMQReceiver", storageLevel, supervisorStrategy) - } - - /** - * Create an input stream that pulls messages from a Kafka Broker. - * @param zkQuorum Zookeper quorum (hostname:port,hostname:port,..). - * @param groupId The group id for this consumer. - * @param topics Map of (topic_name -> numPartitions) to consume. Each partition is consumed - * in its own thread. - * @param storageLevel Storage level to use for storing the received objects - * (default: StorageLevel.MEMORY_AND_DISK_SER_2) - */ - def kafkaStream( - zkQuorum: String, - groupId: String, - topics: Map[String, Int], - storageLevel: StorageLevel = StorageLevel.MEMORY_ONLY_SER_2 - ): DStream[String] = { - val kafkaParams = Map[String, String]( - "zk.connect" -> zkQuorum, "groupid" -> groupId, "zk.connectiontimeout.ms" -> "10000") - kafkaStream[String, kafka.serializer.StringDecoder](kafkaParams, topics, storageLevel) - } - - /** - * Create an input stream that pulls messages from a Kafka Broker. - * @param kafkaParams Map of kafka configuration paramaters. - * See: http://kafka.apache.org/configuration.html - * @param topics Map of (topic_name -> numPartitions) to consume. Each partition is consumed - * in its own thread. - * @param storageLevel Storage level to use for storing the received objects - */ - def kafkaStream[T: ClassManifest, D <: kafka.serializer.Decoder[_]: Manifest]( - kafkaParams: Map[String, String], - topics: Map[String, Int], - storageLevel: StorageLevel - ): DStream[T] = { - val inputStream = new KafkaInputDStream[T, D](this, kafkaParams, topics, storageLevel) - registerInputStream(inputStream) - inputStream - } - - /** - * Create a input stream from TCP source hostname:port. Data is received using - * a TCP socket and the receive bytes is interpreted as UTF8 encoded `\n` delimited - * lines. - * @param hostname Hostname to connect to for receiving data - * @param port Port to connect to for receiving data - * @param storageLevel Storage level to use for storing the received objects - * (default: StorageLevel.MEMORY_AND_DISK_SER_2) - */ - def socketTextStream( - hostname: String, - port: Int, - storageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK_SER_2 - ): DStream[String] = { - socketStream[String](hostname, port, SocketReceiver.bytesToLines, storageLevel) - } - - /** - * Create a input stream from TCP source hostname:port. Data is received using - * a TCP socket and the receive bytes it interepreted as object using the given - * converter. - * @param hostname Hostname to connect to for receiving data - * @param port Port to connect to for receiving data - * @param converter Function to convert the byte stream to objects - * @param storageLevel Storage level to use for storing the received objects - * @tparam T Type of the objects received (after converting bytes to objects) - */ - def socketStream[T: ClassManifest]( - hostname: String, - port: Int, - converter: (InputStream) => Iterator[T], - storageLevel: StorageLevel - ): DStream[T] = { - val inputStream = new SocketInputDStream[T](this, hostname, port, converter, storageLevel) - registerInputStream(inputStream) - inputStream - } - - /** - * Create a input stream from a Flume source. - * @param hostname Hostname of the slave machine to which the flume data will be sent - * @param port Port of the slave machine to which the flume data will be sent - * @param storageLevel Storage level to use for storing the received objects - */ - def flumeStream ( - hostname: String, - port: Int, - storageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK_SER_2 - ): DStream[SparkFlumeEvent] = { - val inputStream = new FlumeInputDStream(this, hostname, port, storageLevel) - registerInputStream(inputStream) - inputStream - } - - /** - * Create a input stream from network source hostname:port, where data is received - * as serialized blocks (serialized using the Spark's serializer) that can be directly - * pushed into the block manager without deserializing them. This is the most efficient - * way to receive data. - * @param hostname Hostname to connect to for receiving data - * @param port Port to connect to for receiving data - * @param storageLevel Storage level to use for storing the received objects - * @tparam T Type of the objects in the received blocks - */ - def rawSocketStream[T: ClassManifest]( - hostname: String, - port: Int, - storageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK_SER_2 - ): DStream[T] = { - val inputStream = new RawInputDStream[T](this, hostname, port, storageLevel) - registerInputStream(inputStream) - inputStream - } - - /** - * Create a input stream that monitors a Hadoop-compatible filesystem - * for new files and reads them using the given key-value types and input format. - * File names starting with . are ignored. - * @param directory HDFS directory to monitor for new file - * @tparam K Key type for reading HDFS file - * @tparam V Value type for reading HDFS file - * @tparam F Input format for reading HDFS file - */ - def fileStream[ - K: ClassManifest, - V: ClassManifest, - F <: NewInputFormat[K, V]: ClassManifest - ] (directory: String): DStream[(K, V)] = { - val inputStream = new FileInputDStream[K, V, F](this, directory) - registerInputStream(inputStream) - inputStream - } - - /** - * Create a input stream that monitors a Hadoop-compatible filesystem - * for new files and reads them using the given key-value types and input format. - * @param directory HDFS directory to monitor for new file - * @param filter Function to filter paths to process - * @param newFilesOnly Should process only new files and ignore existing files in the directory - * @tparam K Key type for reading HDFS file - * @tparam V Value type for reading HDFS file - * @tparam F Input format for reading HDFS file - */ - def fileStream[ - K: ClassManifest, - V: ClassManifest, - F <: NewInputFormat[K, V]: ClassManifest - ] (directory: String, filter: Path => Boolean, newFilesOnly: Boolean): DStream[(K, V)] = { - val inputStream = new FileInputDStream[K, V, F](this, directory, filter, newFilesOnly) - registerInputStream(inputStream) - inputStream - } - - /** - * Create a input stream that monitors a Hadoop-compatible filesystem - * for new files and reads them as text files (using key as LongWritable, value - * as Text and input format as TextInputFormat). File names starting with . are ignored. - * @param directory HDFS directory to monitor for new file - */ - def textFileStream(directory: String): DStream[String] = { - fileStream[LongWritable, Text, TextInputFormat](directory).map(_._2.toString) - } - - /** - * Create a input stream that returns tweets received from Twitter. - * @param twitterAuth Twitter4J authentication, or None to use Twitter4J's default OAuth - * authorization; this uses the system properties twitter4j.oauth.consumerKey, - * .consumerSecret, .accessToken and .accessTokenSecret. - * @param filters Set of filter strings to get only those tweets that match them - * @param storageLevel Storage level to use for storing the received objects - */ - def twitterStream( - twitterAuth: Option[Authorization] = None, - filters: Seq[String] = Nil, - storageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK_SER_2 - ): DStream[Status] = { - val inputStream = new TwitterInputDStream(this, twitterAuth, filters, storageLevel) - registerInputStream(inputStream) - inputStream - } - - /** - * Create an input stream from a queue of RDDs. In each batch, - * it will process either one or all of the RDDs returned by the queue. - * @param queue Queue of RDDs - * @param oneAtATime Whether only one RDD should be consumed from the queue in every interval - * @tparam T Type of objects in the RDD - */ - def queueStream[T: ClassManifest]( - queue: Queue[RDD[T]], - oneAtATime: Boolean = true - ): DStream[T] = { - queueStream(queue, oneAtATime, sc.makeRDD(Seq[T](), 1)) - } - - /** - * Create an input stream from a queue of RDDs. In each batch, - * it will process either one or all of the RDDs returned by the queue. - * @param queue Queue of RDDs - * @param oneAtATime Whether only one RDD should be consumed from the queue in every interval - * @param defaultRDD Default RDD is returned by the DStream when the queue is empty. - * Set as null if no RDD should be returned when empty - * @tparam T Type of objects in the RDD - */ - def queueStream[T: ClassManifest]( - queue: Queue[RDD[T]], - oneAtATime: Boolean, - defaultRDD: RDD[T] - ): DStream[T] = { - val inputStream = new QueueInputDStream(this, queue, oneAtATime, defaultRDD) - registerInputStream(inputStream) - inputStream - } - - /** - * Create a unified DStream from multiple DStreams of the same type and same interval - */ - def union[T: ClassManifest](streams: Seq[DStream[T]]): DStream[T] = { - new UnionDStream[T](streams.toArray) - } - - /** - * Register an input stream that will be started (InputDStream.start() called) to get the - * input data. - */ - def registerInputStream(inputStream: InputDStream[_]) { - graph.addInputStream(inputStream) - } - - /** - * Register an output stream that will be computed every interval - */ - def registerOutputStream(outputStream: DStream[_]) { - graph.addOutputStream(outputStream) - } - - protected def validate() { - assert(graph != null, "Graph is null") - graph.validate() - - assert( - checkpointDir == null || checkpointDuration != null, - "Checkpoint directory has been set, but the graph checkpointing interval has " + - "not been set. Please use StreamingContext.checkpoint() to set the interval." - ) - } - - /** - * Start the execution of the streams. - */ - def start() { - if (checkpointDir != null && checkpointDuration == null && graph != null) { - checkpointDuration = graph.batchDuration - } - - validate() - - val networkInputStreams = graph.getInputStreams().filter(s => s match { - case n: NetworkInputDStream[_] => true - case _ => false - }).map(_.asInstanceOf[NetworkInputDStream[_]]).toArray - - if (networkInputStreams.length > 0) { - // Start the network input tracker (must start before receivers) - networkInputTracker = new NetworkInputTracker(this, networkInputStreams) - networkInputTracker.start() - } - - Thread.sleep(1000) - - // Start the scheduler - scheduler = new Scheduler(this) - scheduler.start() - } - - /** - * Stop the execution of the streams. - */ - def stop() { - try { - if (scheduler != null) scheduler.stop() - if (networkInputTracker != null) networkInputTracker.stop() - if (receiverJobThread != null) receiverJobThread.interrupt() - sc.stop() - logInfo("StreamingContext stopped successfully") - } catch { - case e: Exception => logWarning("Error while stopping", e) - } - } -} - - -object StreamingContext { - - implicit def toPairDStreamFunctions[K: ClassManifest, V: ClassManifest](stream: DStream[(K,V)]) = { - new PairDStreamFunctions[K, V](stream) - } - - protected[streaming] def createNewSparkContext( - master: String, - appName: String, - sparkHome: String, - jars: Seq[String], - environment: Map[String, String]): SparkContext = { - // Set the default cleaner delay to an hour if not already set. - // This should be sufficient for even 1 second interval. - if (MetadataCleaner.getDelaySeconds < 0) { - MetadataCleaner.setDelaySeconds(3600) - } - new SparkContext(master, appName, sparkHome, jars, environment) - } - - protected[streaming] def rddToFileName[T](prefix: String, suffix: String, time: Time): String = { - if (prefix == null) { - time.milliseconds.toString - } else if (suffix == null || suffix.length ==0) { - prefix + "-" + time.milliseconds - } else { - prefix + "-" + time.milliseconds + "." + suffix - } - } - - protected[streaming] def getSparkCheckpointDir(sscCheckpointDir: String): String = { - new Path(sscCheckpointDir, UUID.randomUUID.toString).toString - } -} diff --git a/streaming/src/main/scala/spark/streaming/Time.scala b/streaming/src/main/scala/spark/streaming/Time.scala deleted file mode 100644 index ad5eab9dd2..0000000000 --- a/streaming/src/main/scala/spark/streaming/Time.scala +++ /dev/null @@ -1,72 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.streaming - -/** - * This is a simple class that represents an absolute instant of time. - * Internally, it represents time as the difference, measured in milliseconds, between the current - * time and midnight, January 1, 1970 UTC. This is the same format as what is returned by - * System.currentTimeMillis. - */ -case class Time(private val millis: Long) { - - def milliseconds: Long = millis - - def < (that: Time): Boolean = (this.millis < that.millis) - - def <= (that: Time): Boolean = (this.millis <= that.millis) - - def > (that: Time): Boolean = (this.millis > that.millis) - - def >= (that: Time): Boolean = (this.millis >= that.millis) - - def + (that: Duration): Time = new Time(millis + that.milliseconds) - - def - (that: Time): Duration = new Duration(millis - that.millis) - - def - (that: Duration): Time = new Time(millis - that.milliseconds) - - def floor(that: Duration): Time = { - val t = that.milliseconds - val m = math.floor(this.millis / t).toLong - new Time(m * t) - } - - def isMultipleOf(that: Duration): Boolean = - (this.millis % that.milliseconds == 0) - - def min(that: Time): Time = if (this < that) this else that - - def max(that: Time): Time = if (this > that) this else that - - def until(that: Time, interval: Duration): Seq[Time] = { - (this.milliseconds) until (that.milliseconds) by (interval.milliseconds) map (new Time(_)) - } - - def to(that: Time, interval: Duration): Seq[Time] = { - (this.milliseconds) to (that.milliseconds) by (interval.milliseconds) map (new Time(_)) - } - - - override def toString: String = (millis.toString + " ms") - -} - -object Time { - val ordering = Ordering.by((time: Time) => time.millis) -} diff --git a/streaming/src/main/scala/spark/streaming/api/java/JavaDStream.scala b/streaming/src/main/scala/spark/streaming/api/java/JavaDStream.scala deleted file mode 100644 index 7dcb1d713d..0000000000 --- a/streaming/src/main/scala/spark/streaming/api/java/JavaDStream.scala +++ /dev/null @@ -1,102 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.streaming.api.java - -import spark.streaming.{Duration, Time, DStream} -import spark.api.java.function.{Function => JFunction} -import spark.api.java.JavaRDD -import spark.storage.StorageLevel -import spark.RDD - -/** - * A Discretized Stream (DStream), the basic abstraction in Spark Streaming, is a continuous - * sequence of RDDs (of the same type) representing a continuous stream of data (see [[spark.RDD]] - * for more details on RDDs). DStreams can either be created from live data (such as, data from - * HDFS, Kafka or Flume) or it can be generated by transformation existing DStreams using operations - * such as `map`, `window` and `reduceByKeyAndWindow`. While a Spark Streaming program is running, each - * DStream periodically generates a RDD, either from live data or by transforming the RDD generated - * by a parent DStream. - * - * This class contains the basic operations available on all DStreams, such as `map`, `filter` and - * `window`. In addition, [[spark.streaming.api.java.JavaPairDStream]] contains operations available - * only on DStreams of key-value pairs, such as `groupByKeyAndWindow` and `join`. - * - * DStreams internally is characterized by a few basic properties: - * - A list of other DStreams that the DStream depends on - * - A time interval at which the DStream generates an RDD - * - A function that is used to generate an RDD after each time interval - */ -class JavaDStream[T](val dstream: DStream[T])(implicit val classManifest: ClassManifest[T]) - extends JavaDStreamLike[T, JavaDStream[T], JavaRDD[T]] { - - override def wrapRDD(rdd: RDD[T]): JavaRDD[T] = JavaRDD.fromRDD(rdd) - - /** Return a new DStream containing only the elements that satisfy a predicate. */ - def filter(f: JFunction[T, java.lang.Boolean]): JavaDStream[T] = - dstream.filter((x => f(x).booleanValue())) - - /** Persist RDDs of this DStream with the default storage level (MEMORY_ONLY_SER) */ - def cache(): JavaDStream[T] = dstream.cache() - - /** Persist RDDs of this DStream with the default storage level (MEMORY_ONLY_SER) */ - def persist(): JavaDStream[T] = dstream.persist() - - /** Persist the RDDs of this DStream with the given storage level */ - def persist(storageLevel: StorageLevel): JavaDStream[T] = dstream.persist(storageLevel) - - /** Generate an RDD for the given duration */ - def compute(validTime: Time): JavaRDD[T] = { - dstream.compute(validTime) match { - case Some(rdd) => new JavaRDD(rdd) - case None => null - } - } - - /** - * Return a new DStream in which each RDD contains all the elements in seen in a - * sliding window of time over this DStream. The new DStream generates RDDs with - * the same interval as this DStream. - * @param windowDuration width of the window; must be a multiple of this DStream's interval. - */ - def window(windowDuration: Duration): JavaDStream[T] = - dstream.window(windowDuration) - - /** - * Return a new DStream in which each RDD contains all the elements in seen in a - * sliding window of time over this DStream. - * @param windowDuration width of the window; must be a multiple of this DStream's - * batching interval - * @param slideDuration sliding interval of the window (i.e., the interval after which - * the new DStream will generate RDDs); must be a multiple of this - * DStream's batching interval - */ - def window(windowDuration: Duration, slideDuration: Duration): JavaDStream[T] = - dstream.window(windowDuration, slideDuration) - - /** - * Return a new DStream by unifying data of another DStream with this DStream. - * @param that Another DStream having the same interval (i.e., slideDuration) as this DStream. - */ - def union(that: JavaDStream[T]): JavaDStream[T] = - dstream.union(that.dstream) -} - -object JavaDStream { - implicit def fromDStream[T: ClassManifest](dstream: DStream[T]): JavaDStream[T] = - new JavaDStream[T](dstream) -} diff --git a/streaming/src/main/scala/spark/streaming/api/java/JavaDStreamLike.scala b/streaming/src/main/scala/spark/streaming/api/java/JavaDStreamLike.scala deleted file mode 100644 index 3ab5c1fdde..0000000000 --- a/streaming/src/main/scala/spark/streaming/api/java/JavaDStreamLike.scala +++ /dev/null @@ -1,316 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.streaming.api.java - -import java.util.{List => JList} -import java.lang.{Long => JLong} - -import scala.collection.JavaConversions._ - -import spark.streaming._ -import spark.api.java.{JavaPairRDD, JavaRDDLike, JavaRDD} -import spark.api.java.function.{Function2 => JFunction2, Function => JFunction, _} -import java.util -import spark.RDD -import JavaDStream._ - -trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T, R]] - extends Serializable { - implicit val classManifest: ClassManifest[T] - - def dstream: DStream[T] - - def wrapRDD(in: RDD[T]): R - - implicit def scalaIntToJavaLong(in: DStream[Long]): JavaDStream[JLong] = { - in.map(new JLong(_)) - } - - /** - * Print the first ten elements of each RDD generated in this DStream. This is an output - * operator, so this DStream will be registered as an output stream and there materialized. - */ - def print() = dstream.print() - - /** - * Return a new DStream in which each RDD has a single element generated by counting each RDD - * of this DStream. - */ - def count(): JavaDStream[JLong] = dstream.count() - - /** - * Return a new DStream in which each RDD contains the counts of each distinct value in - * each RDD of this DStream. Hash partitioning is used to generate the RDDs with - * Spark's default number of partitions. - */ - def countByValue(): JavaPairDStream[T, JLong] = { - JavaPairDStream.scalaToJavaLong(dstream.countByValue()) - } - - /** - * Return a new DStream in which each RDD contains the counts of each distinct value in - * each RDD of this DStream. Hash partitioning is used to generate the RDDs with `numPartitions` - * partitions. - * @param numPartitions number of partitions of each RDD in the new DStream. - */ - def countByValue(numPartitions: Int): JavaPairDStream[T, JLong] = { - JavaPairDStream.scalaToJavaLong(dstream.countByValue(numPartitions)) - } - - - /** - * Return a new DStream in which each RDD has a single element generated by counting the number - * of elements in a window over this DStream. windowDuration and slideDuration are as defined in the - * window() operation. This is equivalent to window(windowDuration, slideDuration).count() - */ - def countByWindow(windowDuration: Duration, slideDuration: Duration) : JavaDStream[JLong] = { - dstream.countByWindow(windowDuration, slideDuration) - } - - /** - * Return a new DStream in which each RDD contains the count of distinct elements in - * RDDs in a sliding window over this DStream. Hash partitioning is used to generate the RDDs with - * Spark's default number of partitions. - * @param windowDuration width of the window; must be a multiple of this DStream's - * batching interval - * @param slideDuration sliding interval of the window (i.e., the interval after which - * the new DStream will generate RDDs); must be a multiple of this - * DStream's batching interval - */ - def countByValueAndWindow(windowDuration: Duration, slideDuration: Duration) - : JavaPairDStream[T, JLong] = { - JavaPairDStream.scalaToJavaLong( - dstream.countByValueAndWindow(windowDuration, slideDuration)) - } - - /** - * Return a new DStream in which each RDD contains the count of distinct elements in - * RDDs in a sliding window over this DStream. Hash partitioning is used to generate the RDDs with `numPartitions` - * partitions. - * @param windowDuration width of the window; must be a multiple of this DStream's - * batching interval - * @param slideDuration sliding interval of the window (i.e., the interval after which - * the new DStream will generate RDDs); must be a multiple of this - * DStream's batching interval - * @param numPartitions number of partitions of each RDD in the new DStream. - */ - def countByValueAndWindow(windowDuration: Duration, slideDuration: Duration, numPartitions: Int) - : JavaPairDStream[T, JLong] = { - JavaPairDStream.scalaToJavaLong( - dstream.countByValueAndWindow(windowDuration, slideDuration, numPartitions)) - } - - /** - * Return a new DStream in which each RDD is generated by applying glom() to each RDD of - * this DStream. Applying glom() to an RDD coalesces all elements within each partition into - * an array. - */ - def glom(): JavaDStream[JList[T]] = - new JavaDStream(dstream.glom().map(x => new java.util.ArrayList[T](x.toSeq))) - - /** Return the StreamingContext associated with this DStream */ - def context(): StreamingContext = dstream.context() - - /** Return a new DStream by applying a function to all elements of this DStream. */ - def map[R](f: JFunction[T, R]): JavaDStream[R] = { - new JavaDStream(dstream.map(f)(f.returnType()))(f.returnType()) - } - - /** Return a new DStream by applying a function to all elements of this DStream. */ - def map[K2, V2](f: PairFunction[T, K2, V2]): JavaPairDStream[K2, V2] = { - def cm = implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[Tuple2[K2, V2]]] - new JavaPairDStream(dstream.map(f)(cm))(f.keyType(), f.valueType()) - } - - /** - * Return a new DStream by applying a function to all elements of this DStream, - * and then flattening the results - */ - def flatMap[U](f: FlatMapFunction[T, U]): JavaDStream[U] = { - import scala.collection.JavaConverters._ - def fn = (x: T) => f.apply(x).asScala - new JavaDStream(dstream.flatMap(fn)(f.elementType()))(f.elementType()) - } - - /** - * Return a new DStream by applying a function to all elements of this DStream, - * and then flattening the results - */ - def flatMap[K2, V2](f: PairFlatMapFunction[T, K2, V2]): JavaPairDStream[K2, V2] = { - import scala.collection.JavaConverters._ - def fn = (x: T) => f.apply(x).asScala - def cm = implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[Tuple2[K2, V2]]] - new JavaPairDStream(dstream.flatMap(fn)(cm))(f.keyType(), f.valueType()) - } - - /** - * Return a new DStream in which each RDD is generated by applying mapPartitions() to each RDDs - * of this DStream. Applying mapPartitions() to an RDD applies a function to each partition - * of the RDD. - */ - def mapPartitions[U](f: FlatMapFunction[java.util.Iterator[T], U]): JavaDStream[U] = { - def fn = (x: Iterator[T]) => asScalaIterator(f.apply(asJavaIterator(x)).iterator()) - new JavaDStream(dstream.mapPartitions(fn)(f.elementType()))(f.elementType()) - } - - /** - * Return a new DStream in which each RDD is generated by applying mapPartitions() to each RDDs - * of this DStream. Applying mapPartitions() to an RDD applies a function to each partition - * of the RDD. - */ - def mapPartitions[K2, V2](f: PairFlatMapFunction[java.util.Iterator[T], K2, V2]) - : JavaPairDStream[K2, V2] = { - def fn = (x: Iterator[T]) => asScalaIterator(f.apply(asJavaIterator(x)).iterator()) - new JavaPairDStream(dstream.mapPartitions(fn))(f.keyType(), f.valueType()) - } - - /** - * Return a new DStream in which each RDD has a single element generated by reducing each RDD - * of this DStream. - */ - def reduce(f: JFunction2[T, T, T]): JavaDStream[T] = dstream.reduce(f) - - /** - * Return a new DStream in which each RDD has a single element generated by reducing all - * elements in a sliding window over this DStream. - * @param reduceFunc associative reduce function - * @param windowDuration width of the window; must be a multiple of this DStream's - * batching interval - * @param slideDuration sliding interval of the window (i.e., the interval after which - * the new DStream will generate RDDs); must be a multiple of this - * DStream's batching interval - */ - def reduceByWindow( - reduceFunc: (T, T) => T, - windowDuration: Duration, - slideDuration: Duration - ): DStream[T] = { - dstream.reduceByWindow(reduceFunc, windowDuration, slideDuration) - } - - - /** - * Return a new DStream in which each RDD has a single element generated by reducing all - * elements in a sliding window over this DStream. However, the reduction is done incrementally - * using the old window's reduced value : - * 1. reduce the new values that entered the window (e.g., adding new counts) - * 2. "inverse reduce" the old values that left the window (e.g., subtracting old counts) - * This is more efficient than reduceByWindow without "inverse reduce" function. - * However, it is applicable to only "invertible reduce functions". - * @param reduceFunc associative reduce function - * @param invReduceFunc inverse reduce function - * @param windowDuration width of the window; must be a multiple of this DStream's - * batching interval - * @param slideDuration sliding interval of the window (i.e., the interval after which - * the new DStream will generate RDDs); must be a multiple of this - * DStream's batching interval - */ - def reduceByWindow( - reduceFunc: JFunction2[T, T, T], - invReduceFunc: JFunction2[T, T, T], - windowDuration: Duration, - slideDuration: Duration - ): JavaDStream[T] = { - dstream.reduceByWindow(reduceFunc, invReduceFunc, windowDuration, slideDuration) - } - - /** - * Return all the RDDs between 'fromDuration' to 'toDuration' (both included) - */ - def slice(fromTime: Time, toTime: Time): JList[R] = { - new util.ArrayList(dstream.slice(fromTime, toTime).map(wrapRDD(_)).toSeq) - } - - /** - * Apply a function to each RDD in this DStream. This is an output operator, so - * this DStream will be registered as an output stream and therefore materialized. - */ - def foreach(foreachFunc: JFunction[R, Void]) { - dstream.foreach(rdd => foreachFunc.call(wrapRDD(rdd))) - } - - /** - * Apply a function to each RDD in this DStream. This is an output operator, so - * this DStream will be registered as an output stream and therefore materialized. - */ - def foreach(foreachFunc: JFunction2[R, Time, Void]) { - dstream.foreach((rdd, time) => foreachFunc.call(wrapRDD(rdd), time)) - } - - /** - * Return a new DStream in which each RDD is generated by applying a function - * on each RDD of this DStream. - */ - def transform[U](transformFunc: JFunction[R, JavaRDD[U]]): JavaDStream[U] = { - implicit val cm: ClassManifest[U] = - implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[U]] - def scalaTransform (in: RDD[T]): RDD[U] = - transformFunc.call(wrapRDD(in)).rdd - dstream.transform(scalaTransform(_)) - } - - /** - * Return a new DStream in which each RDD is generated by applying a function - * on each RDD of this DStream. - */ - def transform[U](transformFunc: JFunction2[R, Time, JavaRDD[U]]): JavaDStream[U] = { - implicit val cm: ClassManifest[U] = - implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[U]] - def scalaTransform (in: RDD[T], time: Time): RDD[U] = - transformFunc.call(wrapRDD(in), time).rdd - dstream.transform(scalaTransform(_, _)) - } - - /** - * Return a new DStream in which each RDD is generated by applying a function - * on each RDD of this DStream. - */ - def transform[K2, V2](transformFunc: JFunction[R, JavaPairRDD[K2, V2]]): - JavaPairDStream[K2, V2] = { - implicit val cmk: ClassManifest[K2] = - implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[K2]] - implicit val cmv: ClassManifest[V2] = - implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[V2]] - def scalaTransform (in: RDD[T]): RDD[(K2, V2)] = - transformFunc.call(wrapRDD(in)).rdd - dstream.transform(scalaTransform(_)) - } - - /** - * Return a new DStream in which each RDD is generated by applying a function - * on each RDD of this DStream. - */ - def transform[K2, V2](transformFunc: JFunction2[R, Time, JavaPairRDD[K2, V2]]): - JavaPairDStream[K2, V2] = { - implicit val cmk: ClassManifest[K2] = - implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[K2]] - implicit val cmv: ClassManifest[V2] = - implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[V2]] - def scalaTransform (in: RDD[T], time: Time): RDD[(K2, V2)] = - transformFunc.call(wrapRDD(in), time).rdd - dstream.transform(scalaTransform(_, _)) - } - - /** - * Enable periodic checkpointing of RDDs of this DStream - * @param interval Time interval after which generated RDD will be checkpointed - */ - def checkpoint(interval: Duration) = { - dstream.checkpoint(interval) - } -} diff --git a/streaming/src/main/scala/spark/streaming/api/java/JavaPairDStream.scala b/streaming/src/main/scala/spark/streaming/api/java/JavaPairDStream.scala deleted file mode 100644 index ea08fb3826..0000000000 --- a/streaming/src/main/scala/spark/streaming/api/java/JavaPairDStream.scala +++ /dev/null @@ -1,613 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.streaming.api.java - -import java.util.{List => JList} -import java.lang.{Long => JLong} - -import scala.collection.JavaConversions._ - -import spark.streaming._ -import spark.streaming.StreamingContext._ -import spark.api.java.function.{Function => JFunction, Function2 => JFunction2} -import spark.{RDD, Partitioner} -import org.apache.hadoop.mapred.{JobConf, OutputFormat} -import org.apache.hadoop.mapreduce.{OutputFormat => NewOutputFormat} -import org.apache.hadoop.conf.Configuration -import spark.api.java.{JavaUtils, JavaRDD, JavaPairRDD} -import spark.storage.StorageLevel -import com.google.common.base.Optional -import spark.RDD - -class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( - implicit val kManifiest: ClassManifest[K], - implicit val vManifest: ClassManifest[V]) - extends JavaDStreamLike[(K, V), JavaPairDStream[K, V], JavaPairRDD[K, V]] { - - override def wrapRDD(rdd: RDD[(K, V)]): JavaPairRDD[K, V] = JavaPairRDD.fromRDD(rdd) - - // ======================================================================= - // Methods common to all DStream's - // ======================================================================= - - /** Return a new DStream containing only the elements that satisfy a predicate. */ - def filter(f: JFunction[(K, V), java.lang.Boolean]): JavaPairDStream[K, V] = - dstream.filter((x => f(x).booleanValue())) - - /** Persist RDDs of this DStream with the default storage level (MEMORY_ONLY_SER) */ - def cache(): JavaPairDStream[K, V] = dstream.cache() - - /** Persist RDDs of this DStream with the default storage level (MEMORY_ONLY_SER) */ - def persist(): JavaPairDStream[K, V] = dstream.persist() - - /** Persist the RDDs of this DStream with the given storage level */ - def persist(storageLevel: StorageLevel): JavaPairDStream[K, V] = dstream.persist(storageLevel) - - /** Method that generates a RDD for the given Duration */ - def compute(validTime: Time): JavaPairRDD[K, V] = { - dstream.compute(validTime) match { - case Some(rdd) => new JavaPairRDD(rdd) - case None => null - } - } - - /** - * Return a new DStream which is computed based on windowed batches of this DStream. - * The new DStream generates RDDs with the same interval as this DStream. - * @param windowDuration width of the window; must be a multiple of this DStream's interval. - * @return - */ - def window(windowDuration: Duration): JavaPairDStream[K, V] = - dstream.window(windowDuration) - - /** - * Return a new DStream which is computed based on windowed batches of this DStream. - * @param windowDuration duration (i.e., width) of the window; - * must be a multiple of this DStream's interval - * @param slideDuration sliding interval of the window (i.e., the interval after which - * the new DStream will generate RDDs); must be a multiple of this - * DStream's interval - */ - def window(windowDuration: Duration, slideDuration: Duration): JavaPairDStream[K, V] = - dstream.window(windowDuration, slideDuration) - - /** - * Return a new DStream by unifying data of another DStream with this DStream. - * @param that Another DStream having the same interval (i.e., slideDuration) as this DStream. - */ - def union(that: JavaPairDStream[K, V]): JavaPairDStream[K, V] = - dstream.union(that.dstream) - - // ======================================================================= - // Methods only for PairDStream's - // ======================================================================= - - /** - * Return a new DStream by applying `groupByKey` to each RDD. Hash partitioning is used to - * generate the RDDs with Spark's default number of partitions. - */ - def groupByKey(): JavaPairDStream[K, JList[V]] = - dstream.groupByKey().mapValues(seqAsJavaList _) - - /** - * Return a new DStream by applying `groupByKey` to each RDD. Hash partitioning is used to - * generate the RDDs with `numPartitions` partitions. - */ - def groupByKey(numPartitions: Int): JavaPairDStream[K, JList[V]] = - dstream.groupByKey(numPartitions).mapValues(seqAsJavaList _) - - /** - * Return a new DStream by applying `groupByKey` on each RDD of `this` DStream. - * Therefore, the values for each key in `this` DStream's RDDs are grouped into a - * single sequence to generate the RDDs of the new DStream. [[spark.Partitioner]] - * is used to control the partitioning of each RDD. - */ - def groupByKey(partitioner: Partitioner): JavaPairDStream[K, JList[V]] = - dstream.groupByKey(partitioner).mapValues(seqAsJavaList _) - - /** - * Return a new DStream by applying `reduceByKey` to each RDD. The values for each key are - * merged using the associative reduce function. Hash partitioning is used to generate the RDDs - * with Spark's default number of partitions. - */ - def reduceByKey(func: JFunction2[V, V, V]): JavaPairDStream[K, V] = - dstream.reduceByKey(func) - - /** - * Return a new DStream by applying `reduceByKey` to each RDD. The values for each key are - * merged using the supplied reduce function. Hash partitioning is used to generate the RDDs - * with `numPartitions` partitions. - */ - def reduceByKey(func: JFunction2[V, V, V], numPartitions: Int): JavaPairDStream[K, V] = - dstream.reduceByKey(func, numPartitions) - - /** - * Return a new DStream by applying `reduceByKey` to each RDD. The values for each key are - * merged using the supplied reduce function. [[spark.Partitioner]] is used to control the - * partitioning of each RDD. - */ - def reduceByKey(func: JFunction2[V, V, V], partitioner: Partitioner): JavaPairDStream[K, V] = { - dstream.reduceByKey(func, partitioner) - } - - /** - * Combine elements of each key in DStream's RDDs using custom function. This is similar to the - * combineByKey for RDDs. Please refer to combineByKey in [[spark.PairRDDFunctions]] for more - * information. - */ - def combineByKey[C](createCombiner: JFunction[V, C], - mergeValue: JFunction2[C, V, C], - mergeCombiners: JFunction2[C, C, C], - partitioner: Partitioner - ): JavaPairDStream[K, C] = { - implicit val cm: ClassManifest[C] = - implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[C]] - dstream.combineByKey(createCombiner, mergeValue, mergeCombiners, partitioner) - } - - /** - * Return a new DStream by applying `groupByKey` over a sliding window. This is similar to - * `DStream.groupByKey()` but applies it over a sliding window. The new DStream generates RDDs - * with the same interval as this DStream. Hash partitioning is used to generate the RDDs with - * Spark's default number of partitions. - * @param windowDuration width of the window; must be a multiple of this DStream's - * batching interval - */ - def groupByKeyAndWindow(windowDuration: Duration): JavaPairDStream[K, JList[V]] = { - dstream.groupByKeyAndWindow(windowDuration).mapValues(seqAsJavaList _) - } - - /** - * Return a new DStream by applying `groupByKey` over a sliding window. Similar to - * `DStream.groupByKey()`, but applies it over a sliding window. Hash partitioning is used to - * generate the RDDs with Spark's default number of partitions. - * @param windowDuration width of the window; must be a multiple of this DStream's - * batching interval - * @param slideDuration sliding interval of the window (i.e., the interval after which - * the new DStream will generate RDDs); must be a multiple of this - * DStream's batching interval - */ - def groupByKeyAndWindow(windowDuration: Duration, slideDuration: Duration) - : JavaPairDStream[K, JList[V]] = { - dstream.groupByKeyAndWindow(windowDuration, slideDuration).mapValues(seqAsJavaList _) - } - - /** - * Return a new DStream by applying `groupByKey` over a sliding window on `this` DStream. - * Similar to `DStream.groupByKey()`, but applies it over a sliding window. - * Hash partitioning is used to generate the RDDs with `numPartitions` partitions. - * @param windowDuration width of the window; must be a multiple of this DStream's - * batching interval - * @param slideDuration sliding interval of the window (i.e., the interval after which - * the new DStream will generate RDDs); must be a multiple of this - * DStream's batching interval - * @param numPartitions Number of partitions of each RDD in the new DStream. - */ - def groupByKeyAndWindow(windowDuration: Duration, slideDuration: Duration, numPartitions: Int) - :JavaPairDStream[K, JList[V]] = { - dstream.groupByKeyAndWindow(windowDuration, slideDuration, numPartitions) - .mapValues(seqAsJavaList _) - } - - /** - * Return a new DStream by applying `groupByKey` over a sliding window on `this` DStream. - * Similar to `DStream.groupByKey()`, but applies it over a sliding window. - * @param windowDuration width of the window; must be a multiple of this DStream's - * batching interval - * @param slideDuration sliding interval of the window (i.e., the interval after which - * the new DStream will generate RDDs); must be a multiple of this - * DStream's batching interval - * @param partitioner Partitioner for controlling the partitioning of each RDD in the new DStream. - */ - def groupByKeyAndWindow( - windowDuration: Duration, - slideDuration: Duration, - partitioner: Partitioner - ):JavaPairDStream[K, JList[V]] = { - dstream.groupByKeyAndWindow(windowDuration, slideDuration, partitioner) - .mapValues(seqAsJavaList _) - } - - /** - * Create a new DStream by applying `reduceByKey` over a sliding window on `this` DStream. - * Similar to `DStream.reduceByKey()`, but applies it over a sliding window. The new DStream - * generates RDDs with the same interval as this DStream. Hash partitioning is used to generate - * the RDDs with Spark's default number of partitions. - * @param reduceFunc associative reduce function - * @param windowDuration width of the window; must be a multiple of this DStream's - * batching interval - */ - def reduceByKeyAndWindow(reduceFunc: Function2[V, V, V], windowDuration: Duration) - :JavaPairDStream[K, V] = { - dstream.reduceByKeyAndWindow(reduceFunc, windowDuration) - } - - /** - * Return a new DStream by applying `reduceByKey` over a sliding window. This is similar to - * `DStream.reduceByKey()` but applies it over a sliding window. Hash partitioning is used to - * generate the RDDs with Spark's default number of partitions. - * @param reduceFunc associative reduce function - * @param windowDuration width of the window; must be a multiple of this DStream's - * batching interval - * @param slideDuration sliding interval of the window (i.e., the interval after which - * the new DStream will generate RDDs); must be a multiple of this - * DStream's batching interval - */ - def reduceByKeyAndWindow( - reduceFunc: Function2[V, V, V], - windowDuration: Duration, - slideDuration: Duration - ):JavaPairDStream[K, V] = { - dstream.reduceByKeyAndWindow(reduceFunc, windowDuration, slideDuration) - } - - /** - * Return a new DStream by applying `reduceByKey` over a sliding window. This is similar to - * `DStream.reduceByKey()` but applies it over a sliding window. Hash partitioning is used to - * generate the RDDs with `numPartitions` partitions. - * @param reduceFunc associative reduce function - * @param windowDuration width of the window; must be a multiple of this DStream's - * batching interval - * @param slideDuration sliding interval of the window (i.e., the interval after which - * the new DStream will generate RDDs); must be a multiple of this - * DStream's batching interval - * @param numPartitions Number of partitions of each RDD in the new DStream. - */ - def reduceByKeyAndWindow( - reduceFunc: Function2[V, V, V], - windowDuration: Duration, - slideDuration: Duration, - numPartitions: Int - ): JavaPairDStream[K, V] = { - dstream.reduceByKeyAndWindow(reduceFunc, windowDuration, slideDuration, numPartitions) - } - - /** - * Return a new DStream by applying `reduceByKey` over a sliding window. Similar to - * `DStream.reduceByKey()`, but applies it over a sliding window. - * @param reduceFunc associative reduce function - * @param windowDuration width of the window; must be a multiple of this DStream's - * batching interval - * @param slideDuration sliding interval of the window (i.e., the interval after which - * the new DStream will generate RDDs); must be a multiple of this - * DStream's batching interval - * @param partitioner Partitioner for controlling the partitioning of each RDD in the new DStream. - */ - def reduceByKeyAndWindow( - reduceFunc: Function2[V, V, V], - windowDuration: Duration, - slideDuration: Duration, - partitioner: Partitioner - ): JavaPairDStream[K, V] = { - dstream.reduceByKeyAndWindow(reduceFunc, windowDuration, slideDuration, partitioner) - } - - /** - * Return a new DStream by reducing over a using incremental computation. - * The reduced value of over a new window is calculated using the old window's reduce value : - * 1. reduce the new values that entered the window (e.g., adding new counts) - * 2. "inverse reduce" the old values that left the window (e.g., subtracting old counts) - * This is more efficient that reduceByKeyAndWindow without "inverse reduce" function. - * However, it is applicable to only "invertible reduce functions". - * Hash partitioning is used to generate the RDDs with Spark's default number of partitions. - * @param reduceFunc associative reduce function - * @param invReduceFunc inverse function - * @param windowDuration width of the window; must be a multiple of this DStream's - * batching interval - * @param slideDuration sliding interval of the window (i.e., the interval after which - * the new DStream will generate RDDs); must be a multiple of this - * DStream's batching interval - */ - def reduceByKeyAndWindow( - reduceFunc: Function2[V, V, V], - invReduceFunc: Function2[V, V, V], - windowDuration: Duration, - slideDuration: Duration - ): JavaPairDStream[K, V] = { - dstream.reduceByKeyAndWindow(reduceFunc, invReduceFunc, windowDuration, slideDuration) - } - - /** - * Return a new DStream by applying incremental `reduceByKey` over a sliding window. - * The reduced value of over a new window is calculated using the old window's reduce value : - * 1. reduce the new values that entered the window (e.g., adding new counts) - * 2. "inverse reduce" the old values that left the window (e.g., subtracting old counts) - * This is more efficient that reduceByKeyAndWindow without "inverse reduce" function. - * However, it is applicable to only "invertible reduce functions". - * Hash partitioning is used to generate the RDDs with `numPartitions` partitions. - * @param reduceFunc associative reduce function - * @param invReduceFunc inverse function - * @param windowDuration width of the window; must be a multiple of this DStream's - * batching interval - * @param slideDuration sliding interval of the window (i.e., the interval after which - * the new DStream will generate RDDs); must be a multiple of this - * DStream's batching interval - * @param numPartitions number of partitions of each RDD in the new DStream. - * @param filterFunc function to filter expired key-value pairs; - * only pairs that satisfy the function are retained - * set this to null if you do not want to filter - */ - def reduceByKeyAndWindow( - reduceFunc: Function2[V, V, V], - invReduceFunc: Function2[V, V, V], - windowDuration: Duration, - slideDuration: Duration, - numPartitions: Int, - filterFunc: JFunction[(K, V), java.lang.Boolean] - ): JavaPairDStream[K, V] = { - dstream.reduceByKeyAndWindow( - reduceFunc, - invReduceFunc, - windowDuration, - slideDuration, - numPartitions, - (p: (K, V)) => filterFunc(p).booleanValue() - ) - } - - /** - * Return a new DStream by applying incremental `reduceByKey` over a sliding window. - * The reduced value of over a new window is calculated using the old window's reduce value : - * 1. reduce the new values that entered the window (e.g., adding new counts) - * 2. "inverse reduce" the old values that left the window (e.g., subtracting old counts) - * This is more efficient that reduceByKeyAndWindow without "inverse reduce" function. - * However, it is applicable to only "invertible reduce functions". - * @param reduceFunc associative reduce function - * @param invReduceFunc inverse function - * @param windowDuration width of the window; must be a multiple of this DStream's - * batching interval - * @param slideDuration sliding interval of the window (i.e., the interval after which - * the new DStream will generate RDDs); must be a multiple of this - * DStream's batching interval - * @param partitioner Partitioner for controlling the partitioning of each RDD in the new DStream. - * @param filterFunc function to filter expired key-value pairs; - * only pairs that satisfy the function are retained - * set this to null if you do not want to filter - */ - def reduceByKeyAndWindow( - reduceFunc: Function2[V, V, V], - invReduceFunc: Function2[V, V, V], - windowDuration: Duration, - slideDuration: Duration, - partitioner: Partitioner, - filterFunc: JFunction[(K, V), java.lang.Boolean] - ): JavaPairDStream[K, V] = { - dstream.reduceByKeyAndWindow( - reduceFunc, - invReduceFunc, - windowDuration, - slideDuration, - partitioner, - (p: (K, V)) => filterFunc(p).booleanValue() - ) - } - - private def convertUpdateStateFunction[S](in: JFunction2[JList[V], Optional[S], Optional[S]]): - (Seq[V], Option[S]) => Option[S] = { - val scalaFunc: (Seq[V], Option[S]) => Option[S] = (values, state) => { - val list: JList[V] = values - val scalaState: Optional[S] = JavaUtils.optionToOptional(state) - val result: Optional[S] = in.apply(list, scalaState) - result.isPresent match { - case true => Some(result.get()) - case _ => None - } - } - scalaFunc - } - - /** - * Create a new "state" DStream where the state for each key is updated by applying - * the given function on the previous state of the key and the new values of each key. - * Hash partitioning is used to generate the RDDs with Spark's default number of partitions. - * @param updateFunc State update function. If `this` function returns None, then - * corresponding state key-value pair will be eliminated. - * @tparam S State type - */ - def updateStateByKey[S](updateFunc: JFunction2[JList[V], Optional[S], Optional[S]]) - : JavaPairDStream[K, S] = { - implicit val cm: ClassManifest[S] = - implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[S]] - dstream.updateStateByKey(convertUpdateStateFunction(updateFunc)) - } - - /** - * Create a new "state" DStream where the state for each key is updated by applying - * the given function on the previous state of the key and the new values of each key. - * Hash partitioning is used to generate the RDDs with `numPartitions` partitions. - * @param updateFunc State update function. If `this` function returns None, then - * corresponding state key-value pair will be eliminated. - * @param numPartitions Number of partitions of each RDD in the new DStream. - * @tparam S State type - */ - def updateStateByKey[S: ClassManifest]( - updateFunc: JFunction2[JList[V], Optional[S], Optional[S]], - numPartitions: Int) - : JavaPairDStream[K, S] = { - dstream.updateStateByKey(convertUpdateStateFunction(updateFunc), numPartitions) - } - - /** - * Create a new "state" DStream where the state for each key is updated by applying - * the given function on the previous state of the key and the new values of the key. - * [[spark.Partitioner]] is used to control the partitioning of each RDD. - * @param updateFunc State update function. If `this` function returns None, then - * corresponding state key-value pair will be eliminated. - * @param partitioner Partitioner for controlling the partitioning of each RDD in the new DStream. - * @tparam S State type - */ - def updateStateByKey[S: ClassManifest]( - updateFunc: JFunction2[JList[V], Optional[S], Optional[S]], - partitioner: Partitioner - ): JavaPairDStream[K, S] = { - dstream.updateStateByKey(convertUpdateStateFunction(updateFunc), partitioner) - } - - def mapValues[U](f: JFunction[V, U]): JavaPairDStream[K, U] = { - implicit val cm: ClassManifest[U] = - implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[U]] - dstream.mapValues(f) - } - - def flatMapValues[U](f: JFunction[V, java.lang.Iterable[U]]): JavaPairDStream[K, U] = { - import scala.collection.JavaConverters._ - def fn = (x: V) => f.apply(x).asScala - implicit val cm: ClassManifest[U] = - implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[U]] - dstream.flatMapValues(fn) - } - - /** - * Cogroup `this` DStream with `other` DStream. For each key k in corresponding RDDs of `this` - * or `other` DStreams, the generated RDD will contains a tuple with the list of values for that - * key in both RDDs. HashPartitioner is used to partition each generated RDD into default number - * of partitions. - */ - def cogroup[W](other: JavaPairDStream[K, W]): JavaPairDStream[K, (JList[V], JList[W])] = { - implicit val cm: ClassManifest[W] = - implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[W]] - dstream.cogroup(other.dstream).mapValues(t => (seqAsJavaList(t._1), seqAsJavaList((t._2)))) - } - - /** - * Cogroup `this` DStream with `other` DStream. For each key k in corresponding RDDs of `this` - * or `other` DStreams, the generated RDD will contains a tuple with the list of values for that - * key in both RDDs. Partitioner is used to partition each generated RDD. - */ - def cogroup[W](other: JavaPairDStream[K, W], partitioner: Partitioner) - : JavaPairDStream[K, (JList[V], JList[W])] = { - implicit val cm: ClassManifest[W] = - implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[W]] - dstream.cogroup(other.dstream, partitioner) - .mapValues(t => (seqAsJavaList(t._1), seqAsJavaList((t._2)))) - } - - /** - * Join `this` DStream with `other` DStream. HashPartitioner is used - * to partition each generated RDD into default number of partitions. - */ - def join[W](other: JavaPairDStream[K, W]): JavaPairDStream[K, (V, W)] = { - implicit val cm: ClassManifest[W] = - implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[W]] - dstream.join(other.dstream) - } - - /** - * Join `this` DStream with `other` DStream, that is, each RDD of the new DStream will - * be generated by joining RDDs from `this` and other DStream. Uses the given - * Partitioner to partition each generated RDD. - */ - def join[W](other: JavaPairDStream[K, W], partitioner: Partitioner) - : JavaPairDStream[K, (V, W)] = { - implicit val cm: ClassManifest[W] = - implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[W]] - dstream.join(other.dstream, partitioner) - } - - /** - * Save each RDD in `this` DStream as a Hadoop file. The file name at each batch interval is - * generated based on `prefix` and `suffix`: "prefix-TIME_IN_MS.suffix". - */ - def saveAsHadoopFiles[F <: OutputFormat[K, V]](prefix: String, suffix: String) { - dstream.saveAsHadoopFiles(prefix, suffix) - } - - /** - * Save each RDD in `this` DStream as a Hadoop file. The file name at each batch interval is - * generated based on `prefix` and `suffix`: "prefix-TIME_IN_MS.suffix". - */ - def saveAsHadoopFiles( - prefix: String, - suffix: String, - keyClass: Class[_], - valueClass: Class[_], - outputFormatClass: Class[_ <: OutputFormat[_, _]]) { - dstream.saveAsHadoopFiles(prefix, suffix, keyClass, valueClass, outputFormatClass) - } - - /** - * Save each RDD in `this` DStream as a Hadoop file. The file name at each batch interval is - * generated based on `prefix` and `suffix`: "prefix-TIME_IN_MS.suffix". - */ - def saveAsHadoopFiles( - prefix: String, - suffix: String, - keyClass: Class[_], - valueClass: Class[_], - outputFormatClass: Class[_ <: OutputFormat[_, _]], - conf: JobConf) { - dstream.saveAsHadoopFiles(prefix, suffix, keyClass, valueClass, outputFormatClass, conf) - } - - /** - * Save each RDD in `this` DStream as a Hadoop file. The file name at each batch interval is - * generated based on `prefix` and `suffix`: "prefix-TIME_IN_MS.suffix". - */ - def saveAsNewAPIHadoopFiles[F <: NewOutputFormat[K, V]](prefix: String, suffix: String) { - dstream.saveAsNewAPIHadoopFiles(prefix, suffix) - } - - /** - * Save each RDD in `this` DStream as a Hadoop file. The file name at each batch interval is - * generated based on `prefix` and `suffix`: "prefix-TIME_IN_MS.suffix". - */ - def saveAsNewAPIHadoopFiles( - prefix: String, - suffix: String, - keyClass: Class[_], - valueClass: Class[_], - outputFormatClass: Class[_ <: NewOutputFormat[_, _]]) { - dstream.saveAsNewAPIHadoopFiles(prefix, suffix, keyClass, valueClass, outputFormatClass) - } - - /** - * Save each RDD in `this` DStream as a Hadoop file. The file name at each batch interval is - * generated based on `prefix` and `suffix`: "prefix-TIME_IN_MS.suffix". - */ - def saveAsNewAPIHadoopFiles( - prefix: String, - suffix: String, - keyClass: Class[_], - valueClass: Class[_], - outputFormatClass: Class[_ <: NewOutputFormat[_, _]], - conf: Configuration = new Configuration) { - dstream.saveAsNewAPIHadoopFiles(prefix, suffix, keyClass, valueClass, outputFormatClass, conf) - } - - override val classManifest: ClassManifest[(K, V)] = - implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[Tuple2[K, V]]] -} - -object JavaPairDStream { - implicit def fromPairDStream[K: ClassManifest, V: ClassManifest](dstream: DStream[(K, V)]) - :JavaPairDStream[K, V] = - new JavaPairDStream[K, V](dstream) - - def fromJavaDStream[K, V](dstream: JavaDStream[(K, V)]): JavaPairDStream[K, V] = { - implicit val cmk: ClassManifest[K] = - implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[K]] - implicit val cmv: ClassManifest[V] = - implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[V]] - new JavaPairDStream[K, V](dstream.dstream) - } - - def scalaToJavaLong[K: ClassManifest](dstream: JavaPairDStream[K, Long]) - : JavaPairDStream[K, JLong] = { - StreamingContext.toPairDStreamFunctions(dstream.dstream).mapValues(new JLong(_)) - } -} diff --git a/streaming/src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala b/streaming/src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala deleted file mode 100644 index b7720ad0ea..0000000000 --- a/streaming/src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala +++ /dev/null @@ -1,613 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.streaming.api.java - -import spark.streaming._ -import receivers.{ActorReceiver, ReceiverSupervisorStrategy} -import spark.streaming.dstream._ -import spark.storage.StorageLevel -import spark.api.java.function.{Function => JFunction, Function2 => JFunction2} -import spark.api.java.{JavaSparkContext, JavaRDD} -import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat} -import twitter4j.Status -import akka.actor.Props -import akka.actor.SupervisorStrategy -import akka.zeromq.Subscribe -import scala.collection.JavaConversions._ -import java.lang.{Long => JLong, Integer => JInt} -import java.io.InputStream -import java.util.{Map => JMap} -import twitter4j.auth.Authorization - -/** - * A StreamingContext is the main entry point for Spark Streaming functionality. Besides the basic - * information (such as, cluster URL and job name) to internally create a SparkContext, it provides - * methods used to create DStream from various input sources. - */ -class JavaStreamingContext(val ssc: StreamingContext) { - - // TODOs: - // - Test to/from Hadoop functions - // - Support creating and registering InputStreams - - - /** - * Creates a StreamingContext. - * @param master Name of the Spark Master - * @param appName Name to be used when registering with the scheduler - * @param batchDuration The time interval at which streaming data will be divided into batches - */ - def this(master: String, appName: String, batchDuration: Duration) = - this(new StreamingContext(master, appName, batchDuration, null, Nil, Map())) - - /** - * Creates a StreamingContext. - * @param master Name of the Spark Master - * @param appName Name to be used when registering with the scheduler - * @param batchDuration The time interval at which streaming data will be divided into batches - * @param sparkHome The SPARK_HOME directory on the slave nodes - * @param jarFile JAR file containing job code, to ship to cluster. This can be a path on the local - * file system or an HDFS, HTTP, HTTPS, or FTP URL. - */ - def this( - master: String, - appName: String, - batchDuration: Duration, - sparkHome: String, - jarFile: String) = - this(new StreamingContext(master, appName, batchDuration, sparkHome, Seq(jarFile), Map())) - - /** - * Creates a StreamingContext. - * @param master Name of the Spark Master - * @param appName Name to be used when registering with the scheduler - * @param batchDuration The time interval at which streaming data will be divided into batches - * @param sparkHome The SPARK_HOME directory on the slave nodes - * @param jars Collection of JARs to send to the cluster. These can be paths on the local file - * system or HDFS, HTTP, HTTPS, or FTP URLs. - */ - def this( - master: String, - appName: String, - batchDuration: Duration, - sparkHome: String, - jars: Array[String]) = - this(new StreamingContext(master, appName, batchDuration, sparkHome, jars, Map())) - - /** - * Creates a StreamingContext. - * @param master Name of the Spark Master - * @param appName Name to be used when registering with the scheduler - * @param batchDuration The time interval at which streaming data will be divided into batches - * @param sparkHome The SPARK_HOME directory on the slave nodes - * @param jars Collection of JARs to send to the cluster. These can be paths on the local file - * system or HDFS, HTTP, HTTPS, or FTP URLs. - * @param environment Environment variables to set on worker nodes - */ - def this( - master: String, - appName: String, - batchDuration: Duration, - sparkHome: String, - jars: Array[String], - environment: JMap[String, String]) = - this(new StreamingContext(master, appName, batchDuration, sparkHome, jars, environment)) - - /** - * Creates a StreamingContext using an existing SparkContext. - * @param sparkContext The underlying JavaSparkContext to use - * @param batchDuration The time interval at which streaming data will be divided into batches - */ - def this(sparkContext: JavaSparkContext, batchDuration: Duration) = - this(new StreamingContext(sparkContext.sc, batchDuration)) - - /** - * Re-creates a StreamingContext from a checkpoint file. - * @param path Path either to the directory that was specified as the checkpoint directory, or - * to the checkpoint file 'graph' or 'graph.bk'. - */ - def this(path: String) = this (new StreamingContext(path)) - - /** The underlying SparkContext */ - val sc: JavaSparkContext = new JavaSparkContext(ssc.sc) - - /** - * Create an input stream that pulls messages form a Kafka Broker. - * @param zkQuorum Zookeper quorum (hostname:port,hostname:port,..). - * @param groupId The group id for this consumer. - * @param topics Map of (topic_name -> numPartitions) to consume. Each partition is consumed - * in its own thread. - */ - def kafkaStream( - zkQuorum: String, - groupId: String, - topics: JMap[String, JInt]) - : JavaDStream[String] = { - implicit val cmt: ClassManifest[String] = - implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[String]] - ssc.kafkaStream(zkQuorum, groupId, Map(topics.mapValues(_.intValue()).toSeq: _*), - StorageLevel.MEMORY_ONLY_SER_2) - } - - /** - * Create an input stream that pulls messages form a Kafka Broker. - * @param zkQuorum Zookeper quorum (hostname:port,hostname:port,..). - * @param groupId The group id for this consumer. - * @param topics Map of (topic_name -> numPartitions) to consume. Each partition is consumed - * in its own thread. - * @param storageLevel RDD storage level. Defaults to memory-only - * - */ - def kafkaStream( - zkQuorum: String, - groupId: String, - topics: JMap[String, JInt], - storageLevel: StorageLevel) - : JavaDStream[String] = { - implicit val cmt: ClassManifest[String] = - implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[String]] - ssc.kafkaStream(zkQuorum, groupId, Map(topics.mapValues(_.intValue()).toSeq: _*), - storageLevel) - } - - /** - * Create an input stream that pulls messages form a Kafka Broker. - * @param typeClass Type of RDD - * @param decoderClass Type of kafka decoder - * @param kafkaParams Map of kafka configuration paramaters. - * See: http://kafka.apache.org/configuration.html - * @param topics Map of (topic_name -> numPartitions) to consume. Each partition is consumed - * in its own thread. - * @param storageLevel RDD storage level. Defaults to memory-only - */ - def kafkaStream[T, D <: kafka.serializer.Decoder[_]]( - typeClass: Class[T], - decoderClass: Class[D], - kafkaParams: JMap[String, String], - topics: JMap[String, JInt], - storageLevel: StorageLevel) - : JavaDStream[T] = { - implicit val cmt: ClassManifest[T] = - implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]] - implicit val cmd: Manifest[D] = implicitly[Manifest[AnyRef]].asInstanceOf[Manifest[D]] - ssc.kafkaStream[T, D]( - kafkaParams.toMap, - Map(topics.mapValues(_.intValue()).toSeq: _*), - storageLevel) - } - - /** - * Create a input stream from network source hostname:port. Data is received using - * a TCP socket and the receive bytes is interpreted as UTF8 encoded \n delimited - * lines. - * @param hostname Hostname to connect to for receiving data - * @param port Port to connect to for receiving data - * @param storageLevel Storage level to use for storing the received objects - * (default: StorageLevel.MEMORY_AND_DISK_SER_2) - */ - def socketTextStream(hostname: String, port: Int, storageLevel: StorageLevel) - : JavaDStream[String] = { - ssc.socketTextStream(hostname, port, storageLevel) - } - - /** - * Create a input stream from network source hostname:port. Data is received using - * a TCP socket and the receive bytes is interpreted as UTF8 encoded \n delimited - * lines. - * @param hostname Hostname to connect to for receiving data - * @param port Port to connect to for receiving data - */ - def socketTextStream(hostname: String, port: Int): JavaDStream[String] = { - ssc.socketTextStream(hostname, port) - } - - /** - * Create a input stream from network source hostname:port. Data is received using - * a TCP socket and the receive bytes it interepreted as object using the given - * converter. - * @param hostname Hostname to connect to for receiving data - * @param port Port to connect to for receiving data - * @param converter Function to convert the byte stream to objects - * @param storageLevel Storage level to use for storing the received objects - * @tparam T Type of the objects received (after converting bytes to objects) - */ - def socketStream[T]( - hostname: String, - port: Int, - converter: JFunction[InputStream, java.lang.Iterable[T]], - storageLevel: StorageLevel) - : JavaDStream[T] = { - def fn = (x: InputStream) => converter.apply(x).toIterator - implicit val cmt: ClassManifest[T] = - implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]] - ssc.socketStream(hostname, port, fn, storageLevel) - } - - /** - * Creates a input stream that monitors a Hadoop-compatible filesystem - * for new files and reads them as text files (using key as LongWritable, value - * as Text and input format as TextInputFormat). File names starting with . are ignored. - * @param directory HDFS directory to monitor for new file - */ - def textFileStream(directory: String): JavaDStream[String] = { - ssc.textFileStream(directory) - } - - /** - * Create a input stream from network source hostname:port, where data is received - * as serialized blocks (serialized using the Spark's serializer) that can be directly - * pushed into the block manager without deserializing them. This is the most efficient - * way to receive data. - * @param hostname Hostname to connect to for receiving data - * @param port Port to connect to for receiving data - * @param storageLevel Storage level to use for storing the received objects - * @tparam T Type of the objects in the received blocks - */ - def rawSocketStream[T]( - hostname: String, - port: Int, - storageLevel: StorageLevel): JavaDStream[T] = { - implicit val cmt: ClassManifest[T] = - implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]] - JavaDStream.fromDStream(ssc.rawSocketStream(hostname, port, storageLevel)) - } - - /** - * Create a input stream from network source hostname:port, where data is received - * as serialized blocks (serialized using the Spark's serializer) that can be directly - * pushed into the block manager without deserializing them. This is the most efficient - * way to receive data. - * @param hostname Hostname to connect to for receiving data - * @param port Port to connect to for receiving data - * @tparam T Type of the objects in the received blocks - */ - def rawSocketStream[T](hostname: String, port: Int): JavaDStream[T] = { - implicit val cmt: ClassManifest[T] = - implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]] - JavaDStream.fromDStream(ssc.rawSocketStream(hostname, port)) - } - - /** - * Creates a input stream that monitors a Hadoop-compatible filesystem - * for new files and reads them using the given key-value types and input format. - * File names starting with . are ignored. - * @param directory HDFS directory to monitor for new file - * @tparam K Key type for reading HDFS file - * @tparam V Value type for reading HDFS file - * @tparam F Input format for reading HDFS file - */ - def fileStream[K, V, F <: NewInputFormat[K, V]](directory: String): JavaPairDStream[K, V] = { - implicit val cmk: ClassManifest[K] = - implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[K]] - implicit val cmv: ClassManifest[V] = - implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[V]] - implicit val cmf: ClassManifest[F] = - implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[F]] - ssc.fileStream[K, V, F](directory); - } - - /** - * Creates a input stream from a Flume source. - * @param hostname Hostname of the slave machine to which the flume data will be sent - * @param port Port of the slave machine to which the flume data will be sent - * @param storageLevel Storage level to use for storing the received objects - */ - def flumeStream(hostname: String, port: Int, storageLevel: StorageLevel): - JavaDStream[SparkFlumeEvent] = { - ssc.flumeStream(hostname, port, storageLevel) - } - - - /** - * Creates a input stream from a Flume source. - * @param hostname Hostname of the slave machine to which the flume data will be sent - * @param port Port of the slave machine to which the flume data will be sent - */ - def flumeStream(hostname: String, port: Int): JavaDStream[SparkFlumeEvent] = { - ssc.flumeStream(hostname, port) - } - - /** - * Create a input stream that returns tweets received from Twitter. - * @param twitterAuth Twitter4J Authorization object - * @param filters Set of filter strings to get only those tweets that match them - * @param storageLevel Storage level to use for storing the received objects - */ - def twitterStream( - twitterAuth: Authorization, - filters: Array[String], - storageLevel: StorageLevel - ): JavaDStream[Status] = { - ssc.twitterStream(Some(twitterAuth), filters, storageLevel) - } - - /** - * Create a input stream that returns tweets received from Twitter using Twitter4J's default - * OAuth authentication; this requires the system properties twitter4j.oauth.consumerKey, - * .consumerSecret, .accessToken and .accessTokenSecret to be set. - * @param filters Set of filter strings to get only those tweets that match them - * @param storageLevel Storage level to use for storing the received objects - */ - def twitterStream( - filters: Array[String], - storageLevel: StorageLevel - ): JavaDStream[Status] = { - ssc.twitterStream(None, filters, storageLevel) - } - - /** - * Create a input stream that returns tweets received from Twitter. - * @param twitterAuth Twitter4J Authorization - * @param filters Set of filter strings to get only those tweets that match them - */ - def twitterStream( - twitterAuth: Authorization, - filters: Array[String] - ): JavaDStream[Status] = { - ssc.twitterStream(Some(twitterAuth), filters) - } - - /** - * Create a input stream that returns tweets received from Twitter using Twitter4J's default - * OAuth authentication; this requires the system properties twitter4j.oauth.consumerKey, - * .consumerSecret, .accessToken and .accessTokenSecret to be set. - * @param filters Set of filter strings to get only those tweets that match them - */ - def twitterStream( - filters: Array[String] - ): JavaDStream[Status] = { - ssc.twitterStream(None, filters) - } - - /** - * Create a input stream that returns tweets received from Twitter. - * @param twitterAuth Twitter4J Authorization - */ - def twitterStream( - twitterAuth: Authorization - ): JavaDStream[Status] = { - ssc.twitterStream(Some(twitterAuth)) - } - - /** - * Create a input stream that returns tweets received from Twitter using Twitter4J's default - * OAuth authentication; this requires the system properties twitter4j.oauth.consumerKey, - * .consumerSecret, .accessToken and .accessTokenSecret to be set. - */ - def twitterStream(): JavaDStream[Status] = { - ssc.twitterStream() - } - - /** - * Create an input stream with any arbitrary user implemented actor receiver. - * @param props Props object defining creation of the actor - * @param name Name of the actor - * @param storageLevel Storage level to use for storing the received objects - * - * @note An important point to note: - * Since Actor may exist outside the spark framework, It is thus user's responsibility - * to ensure the type safety, i.e parametrized type of data received and actorStream - * should be same. - */ - def actorStream[T]( - props: Props, - name: String, - storageLevel: StorageLevel, - supervisorStrategy: SupervisorStrategy - ): JavaDStream[T] = { - implicit val cm: ClassManifest[T] = - implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]] - ssc.actorStream[T](props, name, storageLevel, supervisorStrategy) - } - - /** - * Create an input stream with any arbitrary user implemented actor receiver. - * @param props Props object defining creation of the actor - * @param name Name of the actor - * @param storageLevel Storage level to use for storing the received objects - * - * @note An important point to note: - * Since Actor may exist outside the spark framework, It is thus user's responsibility - * to ensure the type safety, i.e parametrized type of data received and actorStream - * should be same. - */ - def actorStream[T]( - props: Props, - name: String, - storageLevel: StorageLevel - ): JavaDStream[T] = { - implicit val cm: ClassManifest[T] = - implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]] - ssc.actorStream[T](props, name, storageLevel) - } - - /** - * Create an input stream with any arbitrary user implemented actor receiver. - * @param props Props object defining creation of the actor - * @param name Name of the actor - * - * @note An important point to note: - * Since Actor may exist outside the spark framework, It is thus user's responsibility - * to ensure the type safety, i.e parametrized type of data received and actorStream - * should be same. - */ - def actorStream[T]( - props: Props, - name: String - ): JavaDStream[T] = { - implicit val cm: ClassManifest[T] = - implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]] - ssc.actorStream[T](props, name) - } - - /** - * Create an input stream that receives messages pushed by a zeromq publisher. - * @param publisherUrl Url of remote zeromq publisher - * @param subscribe topic to subscribe to - * @param bytesToObjects A zeroMQ stream publishes sequence of frames for each topic and each frame has sequence - * of byte thus it needs the converter(which might be deserializer of bytes) - * to translate from sequence of sequence of bytes, where sequence refer to a frame - * and sub sequence refer to its payload. - * @param storageLevel Storage level to use for storing the received objects - */ - def zeroMQStream[T]( - publisherUrl:String, - subscribe: Subscribe, - bytesToObjects: Seq[Seq[Byte]] ⇒ Iterator[T], - storageLevel: StorageLevel, - supervisorStrategy: SupervisorStrategy - ): JavaDStream[T] = { - implicit val cm: ClassManifest[T] = - implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]] - ssc.zeroMQStream[T](publisherUrl, subscribe, bytesToObjects, storageLevel, supervisorStrategy) - } - - /** - * Create an input stream that receives messages pushed by a zeromq publisher. - * @param publisherUrl Url of remote zeromq publisher - * @param subscribe topic to subscribe to - * @param bytesToObjects A zeroMQ stream publishes sequence of frames for each topic and each frame has sequence - * of byte thus it needs the converter(which might be deserializer of bytes) - * to translate from sequence of sequence of bytes, where sequence refer to a frame - * and sub sequence refer to its payload. - * @param storageLevel RDD storage level. Defaults to memory-only. - */ - def zeroMQStream[T]( - publisherUrl:String, - subscribe: Subscribe, - bytesToObjects: JFunction[Array[Array[Byte]], java.lang.Iterable[T]], - storageLevel: StorageLevel - ): JavaDStream[T] = { - implicit val cm: ClassManifest[T] = - implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]] - def fn(x: Seq[Seq[Byte]]) = bytesToObjects.apply(x.map(_.toArray).toArray).toIterator - ssc.zeroMQStream[T](publisherUrl, subscribe, fn, storageLevel) - } - - /** - * Create an input stream that receives messages pushed by a zeromq publisher. - * @param publisherUrl Url of remote zeromq publisher - * @param subscribe topic to subscribe to - * @param bytesToObjects A zeroMQ stream publishes sequence of frames for each topic and each frame has sequence - * of byte thus it needs the converter(which might be deserializer of bytes) - * to translate from sequence of sequence of bytes, where sequence refer to a frame - * and sub sequence refer to its payload. - */ - def zeroMQStream[T]( - publisherUrl:String, - subscribe: Subscribe, - bytesToObjects: JFunction[Array[Array[Byte]], java.lang.Iterable[T]] - ): JavaDStream[T] = { - implicit val cm: ClassManifest[T] = - implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]] - def fn(x: Seq[Seq[Byte]]) = bytesToObjects.apply(x.map(_.toArray).toArray).toIterator - ssc.zeroMQStream[T](publisherUrl, subscribe, fn) - } - - /** - * Registers an output stream that will be computed every interval - */ - def registerOutputStream(outputStream: JavaDStreamLike[_, _, _]) { - ssc.registerOutputStream(outputStream.dstream) - } - - /** - * Creates a input stream from an queue of RDDs. In each batch, - * it will process either one or all of the RDDs returned by the queue. - * - * NOTE: changes to the queue after the stream is created will not be recognized. - * @param queue Queue of RDDs - * @tparam T Type of objects in the RDD - */ - def queueStream[T](queue: java.util.Queue[JavaRDD[T]]): JavaDStream[T] = { - implicit val cm: ClassManifest[T] = - implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]] - val sQueue = new scala.collection.mutable.Queue[spark.RDD[T]] - sQueue.enqueue(queue.map(_.rdd).toSeq: _*) - ssc.queueStream(sQueue) - } - - /** - * Creates a input stream from an queue of RDDs. In each batch, - * it will process either one or all of the RDDs returned by the queue. - * - * NOTE: changes to the queue after the stream is created will not be recognized. - * @param queue Queue of RDDs - * @param oneAtATime Whether only one RDD should be consumed from the queue in every interval - * @tparam T Type of objects in the RDD - */ - def queueStream[T](queue: java.util.Queue[JavaRDD[T]], oneAtATime: Boolean): JavaDStream[T] = { - implicit val cm: ClassManifest[T] = - implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]] - val sQueue = new scala.collection.mutable.Queue[spark.RDD[T]] - sQueue.enqueue(queue.map(_.rdd).toSeq: _*) - ssc.queueStream(sQueue, oneAtATime) - } - - /** - * Creates a input stream from an queue of RDDs. In each batch, - * it will process either one or all of the RDDs returned by the queue. - * - * NOTE: changes to the queue after the stream is created will not be recognized. - * @param queue Queue of RDDs - * @param oneAtATime Whether only one RDD should be consumed from the queue in every interval - * @param defaultRDD Default RDD is returned by the DStream when the queue is empty - * @tparam T Type of objects in the RDD - */ - def queueStream[T]( - queue: java.util.Queue[JavaRDD[T]], - oneAtATime: Boolean, - defaultRDD: JavaRDD[T]): JavaDStream[T] = { - implicit val cm: ClassManifest[T] = - implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]] - val sQueue = new scala.collection.mutable.Queue[spark.RDD[T]] - sQueue.enqueue(queue.map(_.rdd).toSeq: _*) - ssc.queueStream(sQueue, oneAtATime, defaultRDD.rdd) - } - - /** - * Sets the context to periodically checkpoint the DStream operations for master - * fault-tolerance. The graph will be checkpointed every batch interval. - * @param directory HDFS-compatible directory where the checkpoint data will be reliably stored - */ - def checkpoint(directory: String) { - ssc.checkpoint(directory) - } - - /** - * Sets each DStreams in this context to remember RDDs it generated in the last given duration. - * DStreams remember RDDs only for a limited duration of duration and releases them for garbage - * collection. This method allows the developer to specify how to long to remember the RDDs ( - * if the developer wishes to query old data outside the DStream computation). - * @param duration Minimum duration that each DStream should remember its RDDs - */ - def remember(duration: Duration) { - ssc.remember(duration) - } - - /** - * Starts the execution of the streams. - */ - def start() = ssc.start() - - /** - * Sstops the execution of the streams. - */ - def stop() = ssc.stop() - -} diff --git a/streaming/src/main/scala/spark/streaming/dstream/CoGroupedDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/CoGroupedDStream.scala deleted file mode 100644 index 99553d295d..0000000000 --- a/streaming/src/main/scala/spark/streaming/dstream/CoGroupedDStream.scala +++ /dev/null @@ -1,57 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.streaming.dstream - -import spark.{RDD, Partitioner} -import spark.rdd.CoGroupedRDD -import spark.streaming.{Time, DStream, Duration} - -private[streaming] -class CoGroupedDStream[K : ClassManifest]( - parents: Seq[DStream[(K, _)]], - partitioner: Partitioner - ) extends DStream[(K, Seq[Seq[_]])](parents.head.ssc) { - - if (parents.length == 0) { - throw new IllegalArgumentException("Empty array of parents") - } - - if (parents.map(_.ssc).distinct.size > 1) { - throw new IllegalArgumentException("Array of parents have different StreamingContexts") - } - - if (parents.map(_.slideDuration).distinct.size > 1) { - throw new IllegalArgumentException("Array of parents have different slide times") - } - - override def dependencies = parents.toList - - override def slideDuration: Duration = parents.head.slideDuration - - override def compute(validTime: Time): Option[RDD[(K, Seq[Seq[_]])]] = { - val part = partitioner - val rdds = parents.flatMap(_.getOrCompute(validTime)) - if (rdds.size > 0) { - val q = new CoGroupedRDD[K](rdds, part) - Some(q) - } else { - None - } - } - -} diff --git a/streaming/src/main/scala/spark/streaming/dstream/ConstantInputDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/ConstantInputDStream.scala deleted file mode 100644 index 095137092a..0000000000 --- a/streaming/src/main/scala/spark/streaming/dstream/ConstantInputDStream.scala +++ /dev/null @@ -1,36 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.streaming.dstream - -import spark.RDD -import spark.streaming.{Time, StreamingContext} - -/** - * An input stream that always returns the same RDD on each timestep. Useful for testing. - */ -class ConstantInputDStream[T: ClassManifest](ssc_ : StreamingContext, rdd: RDD[T]) - extends InputDStream[T](ssc_) { - - override def start() {} - - override def stop() {} - - override def compute(validTime: Time): Option[RDD[T]] = { - Some(rdd) - } -} diff --git a/streaming/src/main/scala/spark/streaming/dstream/FileInputDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/FileInputDStream.scala deleted file mode 100644 index de0536125d..0000000000 --- a/streaming/src/main/scala/spark/streaming/dstream/FileInputDStream.scala +++ /dev/null @@ -1,199 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.streaming.dstream - -import spark.RDD -import spark.rdd.UnionRDD -import spark.streaming.{DStreamCheckpointData, StreamingContext, Time} - -import org.apache.hadoop.fs.{FileSystem, Path, PathFilter} -import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat} - -import scala.collection.mutable.{HashSet, HashMap} -import java.io.{ObjectInputStream, IOException} - -private[streaming] -class FileInputDStream[K: ClassManifest, V: ClassManifest, F <: NewInputFormat[K,V] : ClassManifest]( - @transient ssc_ : StreamingContext, - directory: String, - filter: Path => Boolean = FileInputDStream.defaultFilter, - newFilesOnly: Boolean = true) - extends InputDStream[(K, V)](ssc_) { - - protected[streaming] override val checkpointData = new FileInputDStreamCheckpointData - - // Latest file mod time seen till any point of time - private val lastModTimeFiles = new HashSet[String]() - private var lastModTime = 0L - - @transient private var path_ : Path = null - @transient private var fs_ : FileSystem = null - @transient private[streaming] var files = new HashMap[Time, Array[String]] - - override def start() { - if (newFilesOnly) { - lastModTime = graph.zeroTime.milliseconds - } else { - lastModTime = 0 - } - logDebug("LastModTime initialized to " + lastModTime + ", new files only = " + newFilesOnly) - } - - override def stop() { } - - /** - * Finds the files that were modified since the last time this method was called and makes - * a union RDD out of them. Note that this maintains the list of files that were processed - * in the latest modification time in the previous call to this method. This is because the - * modification time returned by the FileStatus API seems to return times only at the - * granularity of seconds. And new files may have the same modification time as the - * latest modification time in the previous call to this method yet was not reported in - * the previous call. - */ - override def compute(validTime: Time): Option[RDD[(K, V)]] = { - assert(validTime.milliseconds >= lastModTime, "Trying to get new files for really old time [" + validTime + " < " + lastModTime) - - // Create the filter for selecting new files - val newFilter = new PathFilter() { - // Latest file mod time seen in this round of fetching files and its corresponding files - var latestModTime = 0L - val latestModTimeFiles = new HashSet[String]() - - def accept(path: Path): Boolean = { - if (!filter(path)) { // Reject file if it does not satisfy filter - logDebug("Rejected by filter " + path) - return false - } else { // Accept file only if - val modTime = fs.getFileStatus(path).getModificationTime() - logDebug("Mod time for " + path + " is " + modTime) - if (modTime < lastModTime) { - logDebug("Mod time less than last mod time") - return false // If the file was created before the last time it was called - } else if (modTime == lastModTime && lastModTimeFiles.contains(path.toString)) { - logDebug("Mod time equal to last mod time, but file considered already") - return false // If the file was created exactly as lastModTime but not reported yet - } else if (modTime > validTime.milliseconds) { - logDebug("Mod time more than valid time") - return false // If the file was created after the time this function call requires - } - if (modTime > latestModTime) { - latestModTime = modTime - latestModTimeFiles.clear() - logDebug("Latest mod time updated to " + latestModTime) - } - latestModTimeFiles += path.toString - logDebug("Accepted " + path) - return true - } - } - } - logDebug("Finding new files at time " + validTime + " for last mod time = " + lastModTime) - val newFiles = fs.listStatus(path, newFilter).map(_.getPath.toString) - logInfo("New files at time " + validTime + ":\n" + newFiles.mkString("\n")) - if (newFiles.length > 0) { - // Update the modification time and the files processed for that modification time - if (lastModTime != newFilter.latestModTime) { - lastModTime = newFilter.latestModTime - lastModTimeFiles.clear() - } - lastModTimeFiles ++= newFilter.latestModTimeFiles - logDebug("Last mod time updated to " + lastModTime) - } - files += ((validTime, newFiles)) - Some(filesToRDD(newFiles)) - } - - /** Clear the old time-to-files mappings along with old RDDs */ - protected[streaming] override def clearOldMetadata(time: Time) { - super.clearOldMetadata(time) - val oldFiles = files.filter(_._1 <= (time - rememberDuration)) - files --= oldFiles.keys - logInfo("Cleared " + oldFiles.size + " old files that were older than " + - (time - rememberDuration) + ": " + oldFiles.keys.mkString(", ")) - logDebug("Cleared files are:\n" + - oldFiles.map(p => (p._1, p._2.mkString(", "))).mkString("\n")) - } - - /** Generate one RDD from an array of files */ - protected[streaming] def filesToRDD(files: Seq[String]): RDD[(K, V)] = { - new UnionRDD( - context.sparkContext, - files.map(file => context.sparkContext.newAPIHadoopFile[K, V, F](file)) - ) - } - - private def path: Path = { - if (path_ == null) path_ = new Path(directory) - path_ - } - - private def fs: FileSystem = { - if (fs_ == null) fs_ = path.getFileSystem(new Configuration()) - fs_ - } - - @throws(classOf[IOException]) - private def readObject(ois: ObjectInputStream) { - logDebug(this.getClass().getSimpleName + ".readObject used") - ois.defaultReadObject() - generatedRDDs = new HashMap[Time, RDD[(K,V)]] () - files = new HashMap[Time, Array[String]] - } - - /** - * A custom version of the DStreamCheckpointData that stores names of - * Hadoop files as checkpoint data. - */ - private[streaming] - class FileInputDStreamCheckpointData extends DStreamCheckpointData(this) { - - def hadoopFiles = data.asInstanceOf[HashMap[Time, Array[String]]] - - override def update() { - hadoopFiles.clear() - hadoopFiles ++= files - } - - override def cleanup() { } - - override def restore() { - hadoopFiles.foreach { - case (t, f) => { - // Restore the metadata in both files and generatedRDDs - logInfo("Restoring files for time " + t + " - " + - f.mkString("[", ", ", "]") ) - files += ((t, f)) - generatedRDDs += ((t, filesToRDD(f))) - } - } - } - - override def toString() = { - "[\n" + hadoopFiles.size + " file sets\n" + - hadoopFiles.map(p => (p._1, p._2.mkString(", "))).mkString("\n") + "\n]" - } - } -} - -private[streaming] -object FileInputDStream { - def defaultFilter(path: Path): Boolean = !path.getName().startsWith(".") -} - - diff --git a/streaming/src/main/scala/spark/streaming/dstream/FilteredDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/FilteredDStream.scala deleted file mode 100644 index 9d8c5c3175..0000000000 --- a/streaming/src/main/scala/spark/streaming/dstream/FilteredDStream.scala +++ /dev/null @@ -1,38 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.streaming.dstream - -import spark.streaming.{Duration, DStream, Time} -import spark.RDD - -private[streaming] -class FilteredDStream[T: ClassManifest]( - parent: DStream[T], - filterFunc: T => Boolean - ) extends DStream[T](parent.ssc) { - - override def dependencies = List(parent) - - override def slideDuration: Duration = parent.slideDuration - - override def compute(validTime: Time): Option[RDD[T]] = { - parent.getOrCompute(validTime).map(_.filter(filterFunc)) - } -} - - diff --git a/streaming/src/main/scala/spark/streaming/dstream/FlatMapValuedDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/FlatMapValuedDStream.scala deleted file mode 100644 index 78d7117f0f..0000000000 --- a/streaming/src/main/scala/spark/streaming/dstream/FlatMapValuedDStream.scala +++ /dev/null @@ -1,37 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.streaming.dstream - -import spark.streaming.{Duration, DStream, Time} -import spark.RDD -import spark.SparkContext._ - -private[streaming] -class FlatMapValuedDStream[K: ClassManifest, V: ClassManifest, U: ClassManifest]( - parent: DStream[(K, V)], - flatMapValueFunc: V => TraversableOnce[U] - ) extends DStream[(K, U)](parent.ssc) { - - override def dependencies = List(parent) - - override def slideDuration: Duration = parent.slideDuration - - override def compute(validTime: Time): Option[RDD[(K, U)]] = { - parent.getOrCompute(validTime).map(_.flatMapValues[U](flatMapValueFunc)) - } -} diff --git a/streaming/src/main/scala/spark/streaming/dstream/FlatMappedDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/FlatMappedDStream.scala deleted file mode 100644 index d13bebb10f..0000000000 --- a/streaming/src/main/scala/spark/streaming/dstream/FlatMappedDStream.scala +++ /dev/null @@ -1,37 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.streaming.dstream - -import spark.streaming.{Duration, DStream, Time} -import spark.RDD - -private[streaming] -class FlatMappedDStream[T: ClassManifest, U: ClassManifest]( - parent: DStream[T], - flatMapFunc: T => Traversable[U] - ) extends DStream[U](parent.ssc) { - - override def dependencies = List(parent) - - override def slideDuration: Duration = parent.slideDuration - - override def compute(validTime: Time): Option[RDD[U]] = { - parent.getOrCompute(validTime).map(_.flatMap(flatMapFunc)) - } -} - diff --git a/streaming/src/main/scala/spark/streaming/dstream/FlumeInputDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/FlumeInputDStream.scala deleted file mode 100644 index 4906f503c2..0000000000 --- a/streaming/src/main/scala/spark/streaming/dstream/FlumeInputDStream.scala +++ /dev/null @@ -1,154 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.streaming.dstream - -import spark.streaming.StreamingContext - -import spark.Utils -import spark.storage.StorageLevel - -import org.apache.flume.source.avro.AvroSourceProtocol -import org.apache.flume.source.avro.AvroFlumeEvent -import org.apache.flume.source.avro.Status -import org.apache.avro.ipc.specific.SpecificResponder -import org.apache.avro.ipc.NettyServer - -import scala.collection.JavaConversions._ - -import java.net.InetSocketAddress -import java.io.{ObjectInput, ObjectOutput, Externalizable} -import java.nio.ByteBuffer - -private[streaming] -class FlumeInputDStream[T: ClassManifest]( - @transient ssc_ : StreamingContext, - host: String, - port: Int, - storageLevel: StorageLevel -) extends NetworkInputDStream[SparkFlumeEvent](ssc_) { - - override def getReceiver(): NetworkReceiver[SparkFlumeEvent] = { - new FlumeReceiver(host, port, storageLevel) - } -} - -/** - * A wrapper class for AvroFlumeEvent's with a custom serialization format. - * - * This is necessary because AvroFlumeEvent uses inner data structures - * which are not serializable. - */ -class SparkFlumeEvent() extends Externalizable { - var event : AvroFlumeEvent = new AvroFlumeEvent() - - /* De-serialize from bytes. */ - def readExternal(in: ObjectInput) { - val bodyLength = in.readInt() - val bodyBuff = new Array[Byte](bodyLength) - in.read(bodyBuff) - - val numHeaders = in.readInt() - val headers = new java.util.HashMap[CharSequence, CharSequence] - - for (i <- 0 until numHeaders) { - val keyLength = in.readInt() - val keyBuff = new Array[Byte](keyLength) - in.read(keyBuff) - val key : String = Utils.deserialize(keyBuff) - - val valLength = in.readInt() - val valBuff = new Array[Byte](valLength) - in.read(valBuff) - val value : String = Utils.deserialize(valBuff) - - headers.put(key, value) - } - - event.setBody(ByteBuffer.wrap(bodyBuff)) - event.setHeaders(headers) - } - - /* Serialize to bytes. */ - def writeExternal(out: ObjectOutput) { - val body = event.getBody.array() - out.writeInt(body.length) - out.write(body) - - val numHeaders = event.getHeaders.size() - out.writeInt(numHeaders) - for ((k, v) <- event.getHeaders) { - val keyBuff = Utils.serialize(k.toString) - out.writeInt(keyBuff.length) - out.write(keyBuff) - val valBuff = Utils.serialize(v.toString) - out.writeInt(valBuff.length) - out.write(valBuff) - } - } -} - -private[streaming] object SparkFlumeEvent { - def fromAvroFlumeEvent(in : AvroFlumeEvent) : SparkFlumeEvent = { - val event = new SparkFlumeEvent - event.event = in - event - } -} - -/** A simple server that implements Flume's Avro protocol. */ -private[streaming] -class FlumeEventServer(receiver : FlumeReceiver) extends AvroSourceProtocol { - override def append(event : AvroFlumeEvent) : Status = { - receiver.blockGenerator += SparkFlumeEvent.fromAvroFlumeEvent(event) - Status.OK - } - - override def appendBatch(events : java.util.List[AvroFlumeEvent]) : Status = { - events.foreach (event => - receiver.blockGenerator += SparkFlumeEvent.fromAvroFlumeEvent(event)) - Status.OK - } -} - -/** A NetworkReceiver which listens for events using the - * Flume Avro interface.*/ -private[streaming] -class FlumeReceiver( - host: String, - port: Int, - storageLevel: StorageLevel - ) extends NetworkReceiver[SparkFlumeEvent] { - - lazy val blockGenerator = new BlockGenerator(storageLevel) - - protected override def onStart() { - val responder = new SpecificResponder( - classOf[AvroSourceProtocol], new FlumeEventServer(this)); - val server = new NettyServer(responder, new InetSocketAddress(host, port)); - blockGenerator.start() - server.start() - logInfo("Flume receiver started") - } - - protected override def onStop() { - blockGenerator.stop() - logInfo("Flume receiver stopped") - } - - override def getLocationPreference = Some(host) -} diff --git a/streaming/src/main/scala/spark/streaming/dstream/ForEachDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/ForEachDStream.scala deleted file mode 100644 index 7df537eb56..0000000000 --- a/streaming/src/main/scala/spark/streaming/dstream/ForEachDStream.scala +++ /dev/null @@ -1,45 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.streaming.dstream - -import spark.RDD -import spark.streaming.{Duration, DStream, Job, Time} - -private[streaming] -class ForEachDStream[T: ClassManifest] ( - parent: DStream[T], - foreachFunc: (RDD[T], Time) => Unit - ) extends DStream[Unit](parent.ssc) { - - override def dependencies = List(parent) - - override def slideDuration: Duration = parent.slideDuration - - override def compute(validTime: Time): Option[RDD[Unit]] = None - - override def generateJob(time: Time): Option[Job] = { - parent.getOrCompute(time) match { - case Some(rdd) => - val jobFunc = () => { - foreachFunc(rdd, time) - } - Some(new Job(time, jobFunc)) - case None => None - } - } -} diff --git a/streaming/src/main/scala/spark/streaming/dstream/GlommedDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/GlommedDStream.scala deleted file mode 100644 index 06fda6fe8e..0000000000 --- a/streaming/src/main/scala/spark/streaming/dstream/GlommedDStream.scala +++ /dev/null @@ -1,34 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.streaming.dstream - -import spark.streaming.{Duration, DStream, Time} -import spark.RDD - -private[streaming] -class GlommedDStream[T: ClassManifest](parent: DStream[T]) - extends DStream[Array[T]](parent.ssc) { - - override def dependencies = List(parent) - - override def slideDuration: Duration = parent.slideDuration - - override def compute(validTime: Time): Option[RDD[Array[T]]] = { - parent.getOrCompute(validTime).map(_.glom()) - } -} diff --git a/streaming/src/main/scala/spark/streaming/dstream/InputDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/InputDStream.scala deleted file mode 100644 index 4dbdec459d..0000000000 --- a/streaming/src/main/scala/spark/streaming/dstream/InputDStream.scala +++ /dev/null @@ -1,70 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.streaming.dstream - -import spark.streaming.{Time, Duration, StreamingContext, DStream} - -/** - * This is the abstract base class for all input streams. This class provides to methods - * start() and stop() which called by the scheduler to start and stop receiving data/ - * Input streams that can generated RDDs from new data just by running a service on - * the driver node (that is, without running a receiver onworker nodes) can be - * implemented by directly subclassing this InputDStream. For example, - * FileInputDStream, a subclass of InputDStream, monitors a HDFS directory for - * new files and generates RDDs on the new files. For implementing input streams - * that requires running a receiver on the worker nodes, use NetworkInputDStream - * as the parent class. - */ -abstract class InputDStream[T: ClassManifest] (@transient ssc_ : StreamingContext) - extends DStream[T](ssc_) { - - var lastValidTime: Time = null - - /** - * Checks whether the 'time' is valid wrt slideDuration for generating RDD. - * Additionally it also ensures valid times are in strictly increasing order. - * This ensures that InputDStream.compute() is called strictly on increasing - * times. - */ - override protected def isTimeValid(time: Time): Boolean = { - if (!super.isTimeValid(time)) { - false // Time not valid - } else { - // Time is valid, but check it it is more than lastValidTime - if (lastValidTime != null && time < lastValidTime) { - logWarning("isTimeValid called with " + time + " where as last valid time is " + lastValidTime) - } - lastValidTime = time - true - } - } - - override def dependencies = List() - - override def slideDuration: Duration = { - if (ssc == null) throw new Exception("ssc is null") - if (ssc.graph.batchDuration == null) throw new Exception("batchDuration is null") - ssc.graph.batchDuration - } - - /** Method called to start receiving data. Subclasses must implement this method. */ - def start() - - /** Method called to stop receiving data. Subclasses must implement this method. */ - def stop() -} diff --git a/streaming/src/main/scala/spark/streaming/dstream/KafkaInputDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/KafkaInputDStream.scala deleted file mode 100644 index 6ee588af15..0000000000 --- a/streaming/src/main/scala/spark/streaming/dstream/KafkaInputDStream.scala +++ /dev/null @@ -1,141 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.streaming.dstream - -import spark.Logging -import spark.storage.StorageLevel -import spark.streaming.{Time, DStreamCheckpointData, StreamingContext} - -import java.util.Properties -import java.util.concurrent.Executors - -import kafka.consumer._ -import kafka.message.{Message, MessageSet, MessageAndMetadata} -import kafka.serializer.Decoder -import kafka.utils.{Utils, ZKGroupTopicDirs} -import kafka.utils.ZkUtils._ -import kafka.utils.ZKStringSerializer -import org.I0Itec.zkclient._ - -import scala.collection.Map -import scala.collection.mutable.HashMap -import scala.collection.JavaConversions._ - - -/** - * Input stream that pulls messages from a Kafka Broker. - * - * @param kafkaParams Map of kafka configuration paramaters. See: http://kafka.apache.org/configuration.html - * @param topics Map of (topic_name -> numPartitions) to consume. Each partition is consumed - * in its own thread. - * @param storageLevel RDD storage level. - */ -private[streaming] -class KafkaInputDStream[T: ClassManifest, D <: Decoder[_]: Manifest]( - @transient ssc_ : StreamingContext, - kafkaParams: Map[String, String], - topics: Map[String, Int], - storageLevel: StorageLevel - ) extends NetworkInputDStream[T](ssc_ ) with Logging { - - - def getReceiver(): NetworkReceiver[T] = { - new KafkaReceiver[T, D](kafkaParams, topics, storageLevel) - .asInstanceOf[NetworkReceiver[T]] - } -} - -private[streaming] -class KafkaReceiver[T: ClassManifest, D <: Decoder[_]: Manifest]( - kafkaParams: Map[String, String], - topics: Map[String, Int], - storageLevel: StorageLevel - ) extends NetworkReceiver[Any] { - - // Handles pushing data into the BlockManager - lazy protected val blockGenerator = new BlockGenerator(storageLevel) - // Connection to Kafka - var consumerConnector : ConsumerConnector = null - - def onStop() { - blockGenerator.stop() - } - - def onStart() { - - blockGenerator.start() - - // In case we are using multiple Threads to handle Kafka Messages - val executorPool = Executors.newFixedThreadPool(topics.values.reduce(_ + _)) - - logInfo("Starting Kafka Consumer Stream with group: " + kafkaParams("groupid")) - - // Kafka connection properties - val props = new Properties() - kafkaParams.foreach(param => props.put(param._1, param._2)) - - // Create the connection to the cluster - logInfo("Connecting to Zookeper: " + kafkaParams("zk.connect")) - val consumerConfig = new ConsumerConfig(props) - consumerConnector = Consumer.create(consumerConfig) - logInfo("Connected to " + kafkaParams("zk.connect")) - - // When autooffset.reset is defined, it is our responsibility to try and whack the - // consumer group zk node. - if (kafkaParams.contains("autooffset.reset")) { - tryZookeeperConsumerGroupCleanup(kafkaParams("zk.connect"), kafkaParams("groupid")) - } - - // Create Threads for each Topic/Message Stream we are listening - val decoder = manifest[D].erasure.newInstance.asInstanceOf[Decoder[T]] - val topicMessageStreams = consumerConnector.createMessageStreams(topics, decoder) - - // Start the messages handler for each partition - topicMessageStreams.values.foreach { streams => - streams.foreach { stream => executorPool.submit(new MessageHandler(stream)) } - } - } - - // Handles Kafka Messages - private class MessageHandler[T: ClassManifest](stream: KafkaStream[T]) extends Runnable { - def run() { - logInfo("Starting MessageHandler.") - for (msgAndMetadata <- stream) { - blockGenerator += msgAndMetadata.message - } - } - } - - // It is our responsibility to delete the consumer group when specifying autooffset.reset. This is because - // Kafka 0.7.2 only honors this param when the group is not in zookeeper. - // - // The kafka high level consumer doesn't expose setting offsets currently, this is a trick copied from Kafkas' - // ConsoleConsumer. See code related to 'autooffset.reset' when it is set to 'smallest'/'largest': - // https://github.com/apache/kafka/blob/0.7.2/core/src/main/scala/kafka/consumer/ConsoleConsumer.scala - private def tryZookeeperConsumerGroupCleanup(zkUrl: String, groupId: String) { - try { - val dir = "/consumers/" + groupId - logInfo("Cleaning up temporary zookeeper data under " + dir + ".") - val zk = new ZkClient(zkUrl, 30*1000, 30*1000, ZKStringSerializer) - zk.deleteRecursive(dir) - zk.close() - } catch { - case _ => // swallow - } - } -} diff --git a/streaming/src/main/scala/spark/streaming/dstream/MapPartitionedDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/MapPartitionedDStream.scala deleted file mode 100644 index af41a1b9ac..0000000000 --- a/streaming/src/main/scala/spark/streaming/dstream/MapPartitionedDStream.scala +++ /dev/null @@ -1,38 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.streaming.dstream - -import spark.streaming.{Duration, DStream, Time} -import spark.RDD - -private[streaming] -class MapPartitionedDStream[T: ClassManifest, U: ClassManifest]( - parent: DStream[T], - mapPartFunc: Iterator[T] => Iterator[U], - preservePartitioning: Boolean - ) extends DStream[U](parent.ssc) { - - override def dependencies = List(parent) - - override def slideDuration: Duration = parent.slideDuration - - override def compute(validTime: Time): Option[RDD[U]] = { - parent.getOrCompute(validTime).map(_.mapPartitions[U](mapPartFunc, preservePartitioning)) - } -} - diff --git a/streaming/src/main/scala/spark/streaming/dstream/MapValuedDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/MapValuedDStream.scala deleted file mode 100644 index 8d8a6161c6..0000000000 --- a/streaming/src/main/scala/spark/streaming/dstream/MapValuedDStream.scala +++ /dev/null @@ -1,38 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.streaming.dstream - -import spark.streaming.{Duration, DStream, Time} -import spark.RDD -import spark.SparkContext._ - -private[streaming] -class MapValuedDStream[K: ClassManifest, V: ClassManifest, U: ClassManifest]( - parent: DStream[(K, V)], - mapValueFunc: V => U - ) extends DStream[(K, U)](parent.ssc) { - - override def dependencies = List(parent) - - override def slideDuration: Duration = parent.slideDuration - - override def compute(validTime: Time): Option[RDD[(K, U)]] = { - parent.getOrCompute(validTime).map(_.mapValues[U](mapValueFunc)) - } -} - diff --git a/streaming/src/main/scala/spark/streaming/dstream/MappedDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/MappedDStream.scala deleted file mode 100644 index 3fda84a38a..0000000000 --- a/streaming/src/main/scala/spark/streaming/dstream/MappedDStream.scala +++ /dev/null @@ -1,37 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.streaming.dstream - -import spark.streaming.{Duration, DStream, Time} -import spark.RDD - -private[streaming] -class MappedDStream[T: ClassManifest, U: ClassManifest] ( - parent: DStream[T], - mapFunc: T => U - ) extends DStream[U](parent.ssc) { - - override def dependencies = List(parent) - - override def slideDuration: Duration = parent.slideDuration - - override def compute(validTime: Time): Option[RDD[U]] = { - parent.getOrCompute(validTime).map(_.map[U](mapFunc)) - } -} - diff --git a/streaming/src/main/scala/spark/streaming/dstream/NetworkInputDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/NetworkInputDStream.scala deleted file mode 100644 index 1db0a69a2f..0000000000 --- a/streaming/src/main/scala/spark/streaming/dstream/NetworkInputDStream.scala +++ /dev/null @@ -1,272 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.streaming.dstream - -import spark.streaming.{Time, StreamingContext, AddBlocks, RegisterReceiver, DeregisterReceiver} - -import spark.{Logging, SparkEnv, RDD} -import spark.rdd.BlockRDD -import spark.storage.StorageLevel - -import scala.collection.mutable.ArrayBuffer - -import java.nio.ByteBuffer - -import akka.actor.{Props, Actor} -import akka.pattern.ask -import akka.dispatch.Await -import akka.util.duration._ -import spark.streaming.util.{RecurringTimer, SystemClock} -import java.util.concurrent.ArrayBlockingQueue - -/** - * Abstract class for defining any InputDStream that has to start a receiver on worker - * nodes to receive external data. Specific implementations of NetworkInputDStream must - * define the getReceiver() function that gets the receiver object of type - * [[spark.streaming.dstream.NetworkReceiver]] that will be sent to the workers to receive - * data. - * @param ssc_ Streaming context that will execute this input stream - * @tparam T Class type of the object of this stream - */ -abstract class NetworkInputDStream[T: ClassManifest](@transient ssc_ : StreamingContext) - extends InputDStream[T](ssc_) { - - // This is an unique identifier that is used to match the network receiver with the - // corresponding network input stream. - val id = ssc.getNewNetworkStreamId() - - /** - * Gets the receiver object that will be sent to the worker nodes - * to receive data. This method needs to defined by any specific implementation - * of a NetworkInputDStream. - */ - def getReceiver(): NetworkReceiver[T] - - // Nothing to start or stop as both taken care of by the NetworkInputTracker. - def start() {} - - def stop() {} - - override def compute(validTime: Time): Option[RDD[T]] = { - // If this is called for any time before the start time of the context, - // then this returns an empty RDD. This may happen when recovering from a - // master failure - if (validTime >= graph.startTime) { - val blockIds = ssc.networkInputTracker.getBlockIds(id, validTime) - Some(new BlockRDD[T](ssc.sc, blockIds)) - } else { - Some(new BlockRDD[T](ssc.sc, Array[String]())) - } - } -} - - -private[streaming] sealed trait NetworkReceiverMessage -private[streaming] case class StopReceiver(msg: String) extends NetworkReceiverMessage -private[streaming] case class ReportBlock(blockId: String, metadata: Any) extends NetworkReceiverMessage -private[streaming] case class ReportError(msg: String) extends NetworkReceiverMessage - -/** - * Abstract class of a receiver that can be run on worker nodes to receive external data. See - * [[spark.streaming.dstream.NetworkInputDStream]] for an explanation. - */ -abstract class NetworkReceiver[T: ClassManifest]() extends Serializable with Logging { - - initLogging() - - lazy protected val env = SparkEnv.get - - lazy protected val actor = env.actorSystem.actorOf( - Props(new NetworkReceiverActor()), "NetworkReceiver-" + streamId) - - lazy protected val receivingThread = Thread.currentThread() - - protected var streamId: Int = -1 - - /** - * This method will be called to start receiving data. All your receiver - * starting code should be implemented by defining this function. - */ - protected def onStart() - - /** This method will be called to stop receiving data. */ - protected def onStop() - - /** Conveys a placement preference (hostname) for this receiver. */ - def getLocationPreference() : Option[String] = None - - /** - * Starts the receiver. First is accesses all the lazy members to - * materialize them. Then it calls the user-defined onStart() method to start - * other threads, etc required to receiver the data. - */ - def start() { - try { - // Access the lazy vals to materialize them - env - actor - receivingThread - - // Call user-defined onStart() - onStart() - } catch { - case ie: InterruptedException => - logInfo("Receiving thread interrupted") - //println("Receiving thread interrupted") - case e: Exception => - stopOnError(e) - } - } - - /** - * Stops the receiver. First it interrupts the main receiving thread, - * that is, the thread that called receiver.start(). Then it calls the user-defined - * onStop() method to stop other threads and/or do cleanup. - */ - def stop() { - receivingThread.interrupt() - onStop() - //TODO: terminate the actor - } - - /** - * Stops the receiver and reports exception to the tracker. - * This should be called whenever an exception is to be handled on any thread - * of the receiver. - */ - protected def stopOnError(e: Exception) { - logError("Error receiving data", e) - stop() - actor ! ReportError(e.toString) - } - - - /** - * Pushes a block (as an ArrayBuffer filled with data) into the block manager. - */ - def pushBlock(blockId: String, arrayBuffer: ArrayBuffer[T], metadata: Any, level: StorageLevel) { - env.blockManager.put(blockId, arrayBuffer.asInstanceOf[ArrayBuffer[Any]], level) - actor ! ReportBlock(blockId, metadata) - } - - /** - * Pushes a block (as bytes) into the block manager. - */ - def pushBlock(blockId: String, bytes: ByteBuffer, metadata: Any, level: StorageLevel) { - env.blockManager.putBytes(blockId, bytes, level) - actor ! ReportBlock(blockId, metadata) - } - - /** A helper actor that communicates with the NetworkInputTracker */ - private class NetworkReceiverActor extends Actor { - logInfo("Attempting to register with tracker") - val ip = System.getProperty("spark.driver.host", "localhost") - val port = System.getProperty("spark.driver.port", "7077").toInt - val url = "akka://spark@%s:%s/user/NetworkInputTracker".format(ip, port) - val tracker = env.actorSystem.actorFor(url) - val timeout = 5.seconds - - override def preStart() { - val future = tracker.ask(RegisterReceiver(streamId, self))(timeout) - Await.result(future, timeout) - } - - override def receive() = { - case ReportBlock(blockId, metadata) => - tracker ! AddBlocks(streamId, Array(blockId), metadata) - case ReportError(msg) => - tracker ! DeregisterReceiver(streamId, msg) - case StopReceiver(msg) => - stop() - tracker ! DeregisterReceiver(streamId, msg) - } - } - - protected[streaming] def setStreamId(id: Int) { - streamId = id - } - - /** - * Batches objects created by a [[spark.streaming.dstream.NetworkReceiver]] and puts them into - * appropriately named blocks at regular intervals. This class starts two threads, - * one to periodically start a new batch and prepare the previous batch of as a block, - * the other to push the blocks into the block manager. - */ - class BlockGenerator(storageLevel: StorageLevel) - extends Serializable with Logging { - - case class Block(id: String, buffer: ArrayBuffer[T], metadata: Any = null) - - val clock = new SystemClock() - val blockInterval = System.getProperty("spark.streaming.blockInterval", "200").toLong - val blockIntervalTimer = new RecurringTimer(clock, blockInterval, updateCurrentBuffer) - val blockStorageLevel = storageLevel - val blocksForPushing = new ArrayBlockingQueue[Block](1000) - val blockPushingThread = new Thread() { override def run() { keepPushingBlocks() } } - - var currentBuffer = new ArrayBuffer[T] - - def start() { - blockIntervalTimer.start() - blockPushingThread.start() - logInfo("Data handler started") - } - - def stop() { - blockIntervalTimer.stop() - blockPushingThread.interrupt() - logInfo("Data handler stopped") - } - - def += (obj: T) { - currentBuffer += obj - } - - private def updateCurrentBuffer(time: Long) { - try { - val newBlockBuffer = currentBuffer - currentBuffer = new ArrayBuffer[T] - if (newBlockBuffer.size > 0) { - val blockId = "input-" + NetworkReceiver.this.streamId + "-" + (time - blockInterval) - val newBlock = new Block(blockId, newBlockBuffer) - blocksForPushing.add(newBlock) - } - } catch { - case ie: InterruptedException => - logInfo("Block interval timer thread interrupted") - case e: Exception => - NetworkReceiver.this.stop() - } - } - - private def keepPushingBlocks() { - logInfo("Block pushing thread started") - try { - while(true) { - val block = blocksForPushing.take() - NetworkReceiver.this.pushBlock(block.id, block.buffer, block.metadata, storageLevel) - } - } catch { - case ie: InterruptedException => - logInfo("Block pushing thread interrupted") - case e: Exception => - NetworkReceiver.this.stop() - } - } - } -} diff --git a/streaming/src/main/scala/spark/streaming/dstream/PluggableInputDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/PluggableInputDStream.scala deleted file mode 100644 index 33f7cd063f..0000000000 --- a/streaming/src/main/scala/spark/streaming/dstream/PluggableInputDStream.scala +++ /dev/null @@ -1,30 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.streaming.dstream - -import spark.streaming.StreamingContext - -private[streaming] -class PluggableInputDStream[T: ClassManifest]( - @transient ssc_ : StreamingContext, - receiver: NetworkReceiver[T]) extends NetworkInputDStream[T](ssc_) { - - def getReceiver(): NetworkReceiver[T] = { - receiver - } -} diff --git a/streaming/src/main/scala/spark/streaming/dstream/QueueInputDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/QueueInputDStream.scala deleted file mode 100644 index b269061b73..0000000000 --- a/streaming/src/main/scala/spark/streaming/dstream/QueueInputDStream.scala +++ /dev/null @@ -1,59 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.streaming.dstream - -import spark.RDD -import spark.rdd.UnionRDD - -import scala.collection.mutable.Queue -import scala.collection.mutable.ArrayBuffer -import spark.streaming.{Time, StreamingContext} - -private[streaming] -class QueueInputDStream[T: ClassManifest]( - @transient ssc: StreamingContext, - val queue: Queue[RDD[T]], - oneAtATime: Boolean, - defaultRDD: RDD[T] - ) extends InputDStream[T](ssc) { - - override def start() { } - - override def stop() { } - - override def compute(validTime: Time): Option[RDD[T]] = { - val buffer = new ArrayBuffer[RDD[T]]() - if (oneAtATime && queue.size > 0) { - buffer += queue.dequeue() - } else { - buffer ++= queue - } - if (buffer.size > 0) { - if (oneAtATime) { - Some(buffer.head) - } else { - Some(new UnionRDD(ssc.sc, buffer.toSeq)) - } - } else if (defaultRDD != null) { - Some(defaultRDD) - } else { - None - } - } - -} diff --git a/streaming/src/main/scala/spark/streaming/dstream/RawInputDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/RawInputDStream.scala deleted file mode 100644 index 236f74f575..0000000000 --- a/streaming/src/main/scala/spark/streaming/dstream/RawInputDStream.scala +++ /dev/null @@ -1,108 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.streaming.dstream - -import spark.Logging -import spark.storage.StorageLevel -import spark.streaming.StreamingContext - -import java.net.InetSocketAddress -import java.nio.ByteBuffer -import java.nio.channels.{ReadableByteChannel, SocketChannel} -import java.io.EOFException -import java.util.concurrent.ArrayBlockingQueue - - -/** - * An input stream that reads blocks of serialized objects from a given network address. - * The blocks will be inserted directly into the block store. This is the fastest way to get - * data into Spark Streaming, though it requires the sender to batch data and serialize it - * in the format that the system is configured with. - */ -private[streaming] -class RawInputDStream[T: ClassManifest]( - @transient ssc_ : StreamingContext, - host: String, - port: Int, - storageLevel: StorageLevel - ) extends NetworkInputDStream[T](ssc_ ) with Logging { - - def getReceiver(): NetworkReceiver[T] = { - new RawNetworkReceiver(host, port, storageLevel).asInstanceOf[NetworkReceiver[T]] - } -} - -private[streaming] -class RawNetworkReceiver(host: String, port: Int, storageLevel: StorageLevel) - extends NetworkReceiver[Any] { - - var blockPushingThread: Thread = null - - override def getLocationPreference = None - - def onStart() { - // Open a socket to the target address and keep reading from it - logInfo("Connecting to " + host + ":" + port) - val channel = SocketChannel.open() - channel.configureBlocking(true) - channel.connect(new InetSocketAddress(host, port)) - logInfo("Connected to " + host + ":" + port) - - val queue = new ArrayBlockingQueue[ByteBuffer](2) - - blockPushingThread = new Thread { - setDaemon(true) - override def run() { - var nextBlockNumber = 0 - while (true) { - val buffer = queue.take() - val blockId = "input-" + streamId + "-" + nextBlockNumber - nextBlockNumber += 1 - pushBlock(blockId, buffer, null, storageLevel) - } - } - } - blockPushingThread.start() - - val lengthBuffer = ByteBuffer.allocate(4) - while (true) { - lengthBuffer.clear() - readFully(channel, lengthBuffer) - lengthBuffer.flip() - val length = lengthBuffer.getInt() - val dataBuffer = ByteBuffer.allocate(length) - readFully(channel, dataBuffer) - dataBuffer.flip() - logInfo("Read a block with " + length + " bytes") - queue.put(dataBuffer) - } - } - - def onStop() { - if (blockPushingThread != null) blockPushingThread.interrupt() - } - - /** Read a buffer fully from a given Channel */ - private def readFully(channel: ReadableByteChannel, dest: ByteBuffer) { - while (dest.position < dest.limit) { - if (channel.read(dest) == -1) { - throw new EOFException("End of channel") - } - } - } -} diff --git a/streaming/src/main/scala/spark/streaming/dstream/ReducedWindowedDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/ReducedWindowedDStream.scala deleted file mode 100644 index 96260501ab..0000000000 --- a/streaming/src/main/scala/spark/streaming/dstream/ReducedWindowedDStream.scala +++ /dev/null @@ -1,174 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.streaming.dstream - -import spark.streaming.StreamingContext._ - -import spark.RDD -import spark.rdd.{CoGroupedRDD, MapPartitionsRDD} -import spark.Partitioner -import spark.SparkContext._ -import spark.storage.StorageLevel - -import scala.collection.mutable.ArrayBuffer -import spark.streaming.{Duration, Interval, Time, DStream} - -private[streaming] -class ReducedWindowedDStream[K: ClassManifest, V: ClassManifest]( - parent: DStream[(K, V)], - reduceFunc: (V, V) => V, - invReduceFunc: (V, V) => V, - filterFunc: Option[((K, V)) => Boolean], - _windowDuration: Duration, - _slideDuration: Duration, - partitioner: Partitioner - ) extends DStream[(K,V)](parent.ssc) { - - assert(_windowDuration.isMultipleOf(parent.slideDuration), - "The window duration of ReducedWindowedDStream (" + _slideDuration + ") " + - "must be multiple of the slide duration of parent DStream (" + parent.slideDuration + ")" - ) - - assert(_slideDuration.isMultipleOf(parent.slideDuration), - "The slide duration of ReducedWindowedDStream (" + _slideDuration + ") " + - "must be multiple of the slide duration of parent DStream (" + parent.slideDuration + ")" - ) - - // Reduce each batch of data using reduceByKey which will be further reduced by window - // by ReducedWindowedDStream - val reducedStream = parent.reduceByKey(reduceFunc, partitioner) - - // Persist RDDs to memory by default as these RDDs are going to be reused. - super.persist(StorageLevel.MEMORY_ONLY_SER) - reducedStream.persist(StorageLevel.MEMORY_ONLY_SER) - - def windowDuration: Duration = _windowDuration - - override def dependencies = List(reducedStream) - - override def slideDuration: Duration = _slideDuration - - override val mustCheckpoint = true - - override def parentRememberDuration: Duration = rememberDuration + windowDuration - - override def persist(storageLevel: StorageLevel): DStream[(K,V)] = { - super.persist(storageLevel) - reducedStream.persist(storageLevel) - this - } - - override def checkpoint(interval: Duration): DStream[(K, V)] = { - super.checkpoint(interval) - //reducedStream.checkpoint(interval) - this - } - - override def compute(validTime: Time): Option[RDD[(K, V)]] = { - val reduceF = reduceFunc - val invReduceF = invReduceFunc - - val currentTime = validTime - val currentWindow = new Interval(currentTime - windowDuration + parent.slideDuration, currentTime) - val previousWindow = currentWindow - slideDuration - - logDebug("Window time = " + windowDuration) - logDebug("Slide time = " + slideDuration) - logDebug("ZeroTime = " + zeroTime) - logDebug("Current window = " + currentWindow) - logDebug("Previous window = " + previousWindow) - - // _____________________________ - // | previous window _________|___________________ - // |___________________| current window | --------------> Time - // |_____________________________| - // - // |________ _________| |________ _________| - // | | - // V V - // old RDDs new RDDs - // - - // Get the RDDs of the reduced values in "old time steps" - val oldRDDs = - reducedStream.slice(previousWindow.beginTime, currentWindow.beginTime - parent.slideDuration) - logDebug("# old RDDs = " + oldRDDs.size) - - // Get the RDDs of the reduced values in "new time steps" - val newRDDs = - reducedStream.slice(previousWindow.endTime + parent.slideDuration, currentWindow.endTime) - logDebug("# new RDDs = " + newRDDs.size) - - // Get the RDD of the reduced value of the previous window - val previousWindowRDD = - getOrCompute(previousWindow.endTime).getOrElse(ssc.sc.makeRDD(Seq[(K,V)]())) - - // Make the list of RDDs that needs to cogrouped together for reducing their reduced values - val allRDDs = new ArrayBuffer[RDD[(K, V)]]() += previousWindowRDD ++= oldRDDs ++= newRDDs - - // Cogroup the reduced RDDs and merge the reduced values - val cogroupedRDD = new CoGroupedRDD[K](allRDDs.toSeq.asInstanceOf[Seq[RDD[(K, _)]]], partitioner) - //val mergeValuesFunc = mergeValues(oldRDDs.size, newRDDs.size) _ - - val numOldValues = oldRDDs.size - val numNewValues = newRDDs.size - - val mergeValues = (seqOfValues: Seq[Seq[V]]) => { - if (seqOfValues.size != 1 + numOldValues + numNewValues) { - throw new Exception("Unexpected number of sequences of reduced values") - } - // Getting reduced values "old time steps" that will be removed from current window - val oldValues = (1 to numOldValues).map(i => seqOfValues(i)).filter(!_.isEmpty).map(_.head) - // Getting reduced values "new time steps" - val newValues = - (1 to numNewValues).map(i => seqOfValues(numOldValues + i)).filter(!_.isEmpty).map(_.head) - - if (seqOfValues(0).isEmpty) { - // If previous window's reduce value does not exist, then at least new values should exist - if (newValues.isEmpty) { - throw new Exception("Neither previous window has value for key, nor new values found. " + - "Are you sure your key class hashes consistently?") - } - // Reduce the new values - newValues.reduce(reduceF) // return - } else { - // Get the previous window's reduced value - var tempValue = seqOfValues(0).head - // If old values exists, then inverse reduce then from previous value - if (!oldValues.isEmpty) { - tempValue = invReduceF(tempValue, oldValues.reduce(reduceF)) - } - // If new values exists, then reduce them with previous value - if (!newValues.isEmpty) { - tempValue = reduceF(tempValue, newValues.reduce(reduceF)) - } - tempValue // return - } - } - - val mergedValuesRDD = cogroupedRDD.asInstanceOf[RDD[(K,Seq[Seq[V]])]].mapValues(mergeValues) - - if (filterFunc.isDefined) { - Some(mergedValuesRDD.filter(filterFunc.get)) - } else { - Some(mergedValuesRDD) - } - } -} - - diff --git a/streaming/src/main/scala/spark/streaming/dstream/ShuffledDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/ShuffledDStream.scala deleted file mode 100644 index 83b57b27f7..0000000000 --- a/streaming/src/main/scala/spark/streaming/dstream/ShuffledDStream.scala +++ /dev/null @@ -1,44 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.streaming.dstream - -import spark.{RDD, Partitioner} -import spark.SparkContext._ -import spark.streaming.{Duration, DStream, Time} - -private[streaming] -class ShuffledDStream[K: ClassManifest, V: ClassManifest, C: ClassManifest]( - parent: DStream[(K,V)], - createCombiner: V => C, - mergeValue: (C, V) => C, - mergeCombiner: (C, C) => C, - partitioner: Partitioner - ) extends DStream [(K,C)] (parent.ssc) { - - override def dependencies = List(parent) - - override def slideDuration: Duration = parent.slideDuration - - override def compute(validTime: Time): Option[RDD[(K,C)]] = { - parent.getOrCompute(validTime) match { - case Some(rdd) => - Some(rdd.combineByKey[C](createCombiner, mergeValue, mergeCombiner, partitioner)) - case None => None - } - } -} diff --git a/streaming/src/main/scala/spark/streaming/dstream/SocketInputDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/SocketInputDStream.scala deleted file mode 100644 index 5877b10e0e..0000000000 --- a/streaming/src/main/scala/spark/streaming/dstream/SocketInputDStream.scala +++ /dev/null @@ -1,94 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.streaming.dstream - -import spark.streaming.StreamingContext -import spark.storage.StorageLevel -import spark.util.NextIterator - -import java.io._ -import java.net.Socket - -private[streaming] -class SocketInputDStream[T: ClassManifest]( - @transient ssc_ : StreamingContext, - host: String, - port: Int, - bytesToObjects: InputStream => Iterator[T], - storageLevel: StorageLevel - ) extends NetworkInputDStream[T](ssc_) { - - def getReceiver(): NetworkReceiver[T] = { - new SocketReceiver(host, port, bytesToObjects, storageLevel) - } -} - -private[streaming] -class SocketReceiver[T: ClassManifest]( - host: String, - port: Int, - bytesToObjects: InputStream => Iterator[T], - storageLevel: StorageLevel - ) extends NetworkReceiver[T] { - - lazy protected val blockGenerator = new BlockGenerator(storageLevel) - - override def getLocationPreference = None - - protected def onStart() { - logInfo("Connecting to " + host + ":" + port) - val socket = new Socket(host, port) - logInfo("Connected to " + host + ":" + port) - blockGenerator.start() - val iterator = bytesToObjects(socket.getInputStream()) - while(iterator.hasNext) { - val obj = iterator.next - blockGenerator += obj - } - } - - protected def onStop() { - blockGenerator.stop() - } - -} - -private[streaming] -object SocketReceiver { - - /** - * This methods translates the data from an inputstream (say, from a socket) - * to '\n' delimited strings and returns an iterator to access the strings. - */ - def bytesToLines(inputStream: InputStream): Iterator[String] = { - val dataInputStream = new BufferedReader(new InputStreamReader(inputStream, "UTF-8")) - new NextIterator[String] { - protected override def getNext() = { - val nextValue = dataInputStream.readLine() - if (nextValue == null) { - finished = true - } - nextValue - } - - protected override def close() { - dataInputStream.close() - } - } - } -} diff --git a/streaming/src/main/scala/spark/streaming/dstream/StateDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/StateDStream.scala deleted file mode 100644 index 4b46613d5e..0000000000 --- a/streaming/src/main/scala/spark/streaming/dstream/StateDStream.scala +++ /dev/null @@ -1,109 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.streaming.dstream - -import spark.RDD -import spark.Partitioner -import spark.SparkContext._ -import spark.storage.StorageLevel -import spark.streaming.{Duration, Time, DStream} - -private[streaming] -class StateDStream[K: ClassManifest, V: ClassManifest, S: ClassManifest]( - parent: DStream[(K, V)], - updateFunc: (Iterator[(K, Seq[V], Option[S])]) => Iterator[(K, S)], - partitioner: Partitioner, - preservePartitioning: Boolean - ) extends DStream[(K, S)](parent.ssc) { - - super.persist(StorageLevel.MEMORY_ONLY_SER) - - override def dependencies = List(parent) - - override def slideDuration: Duration = parent.slideDuration - - override val mustCheckpoint = true - - override def compute(validTime: Time): Option[RDD[(K, S)]] = { - - // Try to get the previous state RDD - getOrCompute(validTime - slideDuration) match { - - case Some(prevStateRDD) => { // If previous state RDD exists - - // Try to get the parent RDD - parent.getOrCompute(validTime) match { - case Some(parentRDD) => { // If parent RDD exists, then compute as usual - - // Define the function for the mapPartition operation on cogrouped RDD; - // first map the cogrouped tuple to tuples of required type, - // and then apply the update function - val updateFuncLocal = updateFunc - val finalFunc = (iterator: Iterator[(K, (Seq[V], Seq[S]))]) => { - val i = iterator.map(t => { - (t._1, t._2._1, t._2._2.headOption) - }) - updateFuncLocal(i) - } - val cogroupedRDD = parentRDD.cogroup(prevStateRDD, partitioner) - val stateRDD = cogroupedRDD.mapPartitions(finalFunc, preservePartitioning) - //logDebug("Generating state RDD for time " + validTime) - return Some(stateRDD) - } - case None => { // If parent RDD does not exist - - // Re-apply the update function to the old state RDD - val updateFuncLocal = updateFunc - val finalFunc = (iterator: Iterator[(K, S)]) => { - val i = iterator.map(t => (t._1, Seq[V](), Option(t._2))) - updateFuncLocal(i) - } - val stateRDD = prevStateRDD.mapPartitions(finalFunc, preservePartitioning) - return Some(stateRDD) - } - } - } - - case None => { // If previous session RDD does not exist (first input data) - - // Try to get the parent RDD - parent.getOrCompute(validTime) match { - case Some(parentRDD) => { // If parent RDD exists, then compute as usual - - // Define the function for the mapPartition operation on grouped RDD; - // first map the grouped tuple to tuples of required type, - // and then apply the update function - val updateFuncLocal = updateFunc - val finalFunc = (iterator: Iterator[(K, Seq[V])]) => { - updateFuncLocal(iterator.map(tuple => (tuple._1, tuple._2, None))) - } - - val groupedRDD = parentRDD.groupByKey(partitioner) - val sessionRDD = groupedRDD.mapPartitions(finalFunc, preservePartitioning) - //logDebug("Generating state RDD for time " + validTime + " (first)") - return Some(sessionRDD) - } - case None => { // If parent RDD does not exist, then nothing to do! - //logDebug("Not generating state RDD (no previous state, no parent)") - return None - } - } - } - } - } -} diff --git a/streaming/src/main/scala/spark/streaming/dstream/TransformedDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/TransformedDStream.scala deleted file mode 100644 index e7fbc5bbcf..0000000000 --- a/streaming/src/main/scala/spark/streaming/dstream/TransformedDStream.scala +++ /dev/null @@ -1,36 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.streaming.dstream - -import spark.RDD -import spark.streaming.{Duration, DStream, Time} - -private[streaming] -class TransformedDStream[T: ClassManifest, U: ClassManifest] ( - parent: DStream[T], - transformFunc: (RDD[T], Time) => RDD[U] - ) extends DStream[U](parent.ssc) { - - override def dependencies = List(parent) - - override def slideDuration: Duration = parent.slideDuration - - override def compute(validTime: Time): Option[RDD[U]] = { - parent.getOrCompute(validTime).map(transformFunc(_, validTime)) - } -} diff --git a/streaming/src/main/scala/spark/streaming/dstream/TwitterInputDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/TwitterInputDStream.scala deleted file mode 100644 index f09a8b9f90..0000000000 --- a/streaming/src/main/scala/spark/streaming/dstream/TwitterInputDStream.scala +++ /dev/null @@ -1,99 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.streaming.dstream - -import spark._ -import spark.streaming._ -import storage.StorageLevel -import twitter4j._ -import twitter4j.auth.Authorization -import java.util.prefs.Preferences -import twitter4j.conf.ConfigurationBuilder -import twitter4j.conf.PropertyConfiguration -import twitter4j.auth.OAuthAuthorization -import twitter4j.auth.AccessToken - -/* A stream of Twitter statuses, potentially filtered by one or more keywords. -* -* @constructor create a new Twitter stream using the supplied Twitter4J authentication credentials. -* An optional set of string filters can be used to restrict the set of tweets. The Twitter API is -* such that this may return a sampled subset of all tweets during each interval. -* -* If no Authorization object is provided, initializes OAuth authorization using the system -* properties twitter4j.oauth.consumerKey, .consumerSecret, .accessToken and .accessTokenSecret. -*/ -private[streaming] -class TwitterInputDStream( - @transient ssc_ : StreamingContext, - twitterAuth: Option[Authorization], - filters: Seq[String], - storageLevel: StorageLevel - ) extends NetworkInputDStream[Status](ssc_) { - - private def createOAuthAuthorization(): Authorization = { - new OAuthAuthorization(new ConfigurationBuilder().build()) - } - - private val authorization = twitterAuth.getOrElse(createOAuthAuthorization()) - - override def getReceiver(): NetworkReceiver[Status] = { - new TwitterReceiver(authorization, filters, storageLevel) - } -} - -private[streaming] -class TwitterReceiver( - twitterAuth: Authorization, - filters: Seq[String], - storageLevel: StorageLevel - ) extends NetworkReceiver[Status] { - - var twitterStream: TwitterStream = _ - lazy val blockGenerator = new BlockGenerator(storageLevel) - - protected override def onStart() { - blockGenerator.start() - twitterStream = new TwitterStreamFactory().getInstance(twitterAuth) - twitterStream.addListener(new StatusListener { - def onStatus(status: Status) = { - blockGenerator += status - } - // Unimplemented - def onDeletionNotice(statusDeletionNotice: StatusDeletionNotice) {} - def onTrackLimitationNotice(i: Int) {} - def onScrubGeo(l: Long, l1: Long) {} - def onStallWarning(stallWarning: StallWarning) {} - def onException(e: Exception) { stopOnError(e) } - }) - - val query: FilterQuery = new FilterQuery - if (filters.size > 0) { - query.track(filters.toArray) - twitterStream.filter(query) - } else { - twitterStream.sample() - } - logInfo("Twitter receiver started") - } - - protected override def onStop() { - blockGenerator.stop() - twitterStream.shutdown() - logInfo("Twitter receiver stopped") - } -} diff --git a/streaming/src/main/scala/spark/streaming/dstream/UnionDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/UnionDStream.scala deleted file mode 100644 index 3eaa9a7e7f..0000000000 --- a/streaming/src/main/scala/spark/streaming/dstream/UnionDStream.scala +++ /dev/null @@ -1,57 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.streaming.dstream - -import spark.streaming.{Duration, DStream, Time} -import spark.RDD -import collection.mutable.ArrayBuffer -import spark.rdd.UnionRDD - -private[streaming] -class UnionDStream[T: ClassManifest](parents: Array[DStream[T]]) - extends DStream[T](parents.head.ssc) { - - if (parents.length == 0) { - throw new IllegalArgumentException("Empty array of parents") - } - - if (parents.map(_.ssc).distinct.size > 1) { - throw new IllegalArgumentException("Array of parents have different StreamingContexts") - } - - if (parents.map(_.slideDuration).distinct.size > 1) { - throw new IllegalArgumentException("Array of parents have different slide times") - } - - override def dependencies = parents.toList - - override def slideDuration: Duration = parents.head.slideDuration - - override def compute(validTime: Time): Option[RDD[T]] = { - val rdds = new ArrayBuffer[RDD[T]]() - parents.map(_.getOrCompute(validTime)).foreach(_ match { - case Some(rdd) => rdds += rdd - case None => throw new Exception("Could not generate RDD from a parent for unifying at time " + validTime) - }) - if (rdds.size > 0) { - Some(new UnionRDD(ssc.sc, rdds)) - } else { - None - } - } -} diff --git a/streaming/src/main/scala/spark/streaming/dstream/WindowedDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/WindowedDStream.scala deleted file mode 100644 index fd24d61730..0000000000 --- a/streaming/src/main/scala/spark/streaming/dstream/WindowedDStream.scala +++ /dev/null @@ -1,57 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.streaming.dstream - -import spark.RDD -import spark.rdd.UnionRDD -import spark.storage.StorageLevel -import spark.streaming.{Duration, Interval, Time, DStream} - -private[streaming] -class WindowedDStream[T: ClassManifest]( - parent: DStream[T], - _windowDuration: Duration, - _slideDuration: Duration) - extends DStream[T](parent.ssc) { - - if (!_windowDuration.isMultipleOf(parent.slideDuration)) - throw new Exception("The window duration of WindowedDStream (" + _slideDuration + ") " + - "must be multiple of the slide duration of parent DStream (" + parent.slideDuration + ")") - - if (!_slideDuration.isMultipleOf(parent.slideDuration)) - throw new Exception("The slide duration of WindowedDStream (" + _slideDuration + ") " + - "must be multiple of the slide duration of parent DStream (" + parent.slideDuration + ")") - - parent.persist(StorageLevel.MEMORY_ONLY_SER) - - def windowDuration: Duration = _windowDuration - - override def dependencies = List(parent) - - override def slideDuration: Duration = _slideDuration - - override def parentRememberDuration: Duration = rememberDuration + windowDuration - - override def compute(validTime: Time): Option[RDD[T]] = { - val currentWindow = new Interval(validTime - windowDuration + parent.slideDuration, validTime) - Some(new UnionRDD(ssc.sc, parent.slice(currentWindow))) - } -} - - - diff --git a/streaming/src/main/scala/spark/streaming/receivers/ActorReceiver.scala b/streaming/src/main/scala/spark/streaming/receivers/ActorReceiver.scala deleted file mode 100644 index abeeff11b9..0000000000 --- a/streaming/src/main/scala/spark/streaming/receivers/ActorReceiver.scala +++ /dev/null @@ -1,175 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.streaming.receivers - -import akka.actor.{ Actor, PoisonPill, Props, SupervisorStrategy } -import akka.actor.{ actorRef2Scala, ActorRef } -import akka.actor.{ PossiblyHarmful, OneForOneStrategy } - -import spark.storage.StorageLevel -import spark.streaming.dstream.NetworkReceiver - -import java.util.concurrent.atomic.AtomicInteger - -import scala.collection.mutable.ArrayBuffer - -/** A helper with set of defaults for supervisor strategy **/ -object ReceiverSupervisorStrategy { - - import akka.util.duration._ - import akka.actor.SupervisorStrategy._ - - val defaultStrategy = OneForOneStrategy(maxNrOfRetries = 10, withinTimeRange = - 15 millis) { - case _: RuntimeException ⇒ Restart - case _: Exception ⇒ Escalate - } -} - -/** - * A receiver trait to be mixed in with your Actor to gain access to - * pushBlock API. - * - * Find more details at: http://spark-project.org/docs/latest/streaming-custom-receivers.html - * - * @example {{{ - * class MyActor extends Actor with Receiver{ - * def receive { - * case anything :String ⇒ pushBlock(anything) - * } - * } - * //Can be plugged in actorStream as follows - * ssc.actorStream[String](Props(new MyActor),"MyActorReceiver") - * - * }}} - * - * @note An important point to note: - * Since Actor may exist outside the spark framework, It is thus user's responsibility - * to ensure the type safety, i.e parametrized type of push block and InputDStream - * should be same. - * - */ -trait Receiver { self: Actor ⇒ - def pushBlock[T: ClassManifest](iter: Iterator[T]) { - context.parent ! Data(iter) - } - - def pushBlock[T: ClassManifest](data: T) { - context.parent ! Data(data) - } - -} - -/** - * Statistics for querying the supervisor about state of workers - */ -case class Statistics(numberOfMsgs: Int, - numberOfWorkers: Int, - numberOfHiccups: Int, - otherInfo: String) - -/** Case class to receive data sent by child actors **/ -private[streaming] case class Data[T: ClassManifest](data: T) - -/** - * Provides Actors as receivers for receiving stream. - * - * As Actors can also be used to receive data from almost any stream source. - * A nice set of abstraction(s) for actors as receivers is already provided for - * a few general cases. It is thus exposed as an API where user may come with - * his own Actor to run as receiver for Spark Streaming input source. - * - * This starts a supervisor actor which starts workers and also provides - * [http://doc.akka.io/docs/akka/2.0.5/scala/fault-tolerance.html fault-tolerance]. - * - * Here's a way to start more supervisor/workers as its children. - * - * @example {{{ - * context.parent ! Props(new Supervisor) - * }}} OR {{{ - * context.parent ! Props(new Worker,"Worker") - * }}} - * - * - */ -private[streaming] class ActorReceiver[T: ClassManifest]( - props: Props, - name: String, - storageLevel: StorageLevel, - receiverSupervisorStrategy: SupervisorStrategy) - extends NetworkReceiver[T] { - - protected lazy val blocksGenerator: BlockGenerator = - new BlockGenerator(storageLevel) - - protected lazy val supervisor = env.actorSystem.actorOf(Props(new Supervisor), - "Supervisor" + streamId) - - private class Supervisor extends Actor { - - override val supervisorStrategy = receiverSupervisorStrategy - val worker = context.actorOf(props, name) - logInfo("Started receiver worker at:" + worker.path) - - val n: AtomicInteger = new AtomicInteger(0) - val hiccups: AtomicInteger = new AtomicInteger(0) - - def receive = { - - case Data(iter: Iterator[_]) ⇒ pushBlock(iter.asInstanceOf[Iterator[T]]) - - case Data(msg) ⇒ - blocksGenerator += msg.asInstanceOf[T] - n.incrementAndGet - - case props: Props ⇒ - val worker = context.actorOf(props) - logInfo("Started receiver worker at:" + worker.path) - sender ! worker - - case (props: Props, name: String) ⇒ - val worker = context.actorOf(props, name) - logInfo("Started receiver worker at:" + worker.path) - sender ! worker - - case _: PossiblyHarmful => hiccups.incrementAndGet() - - case _: Statistics ⇒ - val workers = context.children - sender ! Statistics(n.get, workers.size, hiccups.get, workers.mkString("\n")) - - } - } - - protected def pushBlock(iter: Iterator[T]) { - val buffer = new ArrayBuffer[T] - buffer ++= iter - pushBlock("block-" + streamId + "-" + System.nanoTime(), buffer, null, storageLevel) - } - - protected def onStart() = { - blocksGenerator.start() - supervisor - logInfo("Supervision tree for receivers initialized at:" + supervisor.path) - } - - protected def onStop() = { - supervisor ! PoisonPill - } - -} diff --git a/streaming/src/main/scala/spark/streaming/receivers/ZeroMQReceiver.scala b/streaming/src/main/scala/spark/streaming/receivers/ZeroMQReceiver.scala deleted file mode 100644 index 22d554e7e4..0000000000 --- a/streaming/src/main/scala/spark/streaming/receivers/ZeroMQReceiver.scala +++ /dev/null @@ -1,50 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.streaming.receivers - -import akka.actor.Actor -import akka.zeromq._ - -import spark.Logging - -/** - * A receiver to subscribe to ZeroMQ stream. - */ -private[streaming] class ZeroMQReceiver[T: ClassManifest](publisherUrl: String, - subscribe: Subscribe, - bytesToObjects: Seq[Seq[Byte]] ⇒ Iterator[T]) - extends Actor with Receiver with Logging { - - override def preStart() = context.system.newSocket(SocketType.Sub, Listener(self), - Connect(publisherUrl), subscribe) - - def receive: Receive = { - - case Connecting ⇒ logInfo("connecting ...") - - case m: ZMQMessage ⇒ - logDebug("Received message for:" + m.firstFrameAsString) - - //We ignore first frame for processing as it is the topic - val bytes = m.frames.tail.map(_.payload) - pushBlock(bytesToObjects(bytes)) - - case Closed ⇒ logInfo("received closed ") - - } -} diff --git a/streaming/src/main/scala/spark/streaming/util/Clock.scala b/streaming/src/main/scala/spark/streaming/util/Clock.scala deleted file mode 100644 index d9ac722df5..0000000000 --- a/streaming/src/main/scala/spark/streaming/util/Clock.scala +++ /dev/null @@ -1,101 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.streaming.util - -private[streaming] -trait Clock { - def currentTime(): Long - def waitTillTime(targetTime: Long): Long -} - -private[streaming] -class SystemClock() extends Clock { - - val minPollTime = 25L - - def currentTime(): Long = { - System.currentTimeMillis() - } - - def waitTillTime(targetTime: Long): Long = { - var currentTime = 0L - currentTime = System.currentTimeMillis() - - var waitTime = targetTime - currentTime - if (waitTime <= 0) { - return currentTime - } - - val pollTime = { - if (waitTime / 10.0 > minPollTime) { - (waitTime / 10.0).toLong - } else { - minPollTime - } - } - - - while (true) { - currentTime = System.currentTimeMillis() - waitTime = targetTime - currentTime - - if (waitTime <= 0) { - - return currentTime - } - val sleepTime = - if (waitTime < pollTime) { - waitTime - } else { - pollTime - } - Thread.sleep(sleepTime) - } - return -1 - } -} - -private[streaming] -class ManualClock() extends Clock { - - var time = 0L - - def currentTime() = time - - def setTime(timeToSet: Long) = { - this.synchronized { - time = timeToSet - this.notifyAll() - } - } - - def addToTime(timeToAdd: Long) = { - this.synchronized { - time += timeToAdd - this.notifyAll() - } - } - def waitTillTime(targetTime: Long): Long = { - this.synchronized { - while (time < targetTime) { - this.wait(100) - } - } - return currentTime() - } -} diff --git a/streaming/src/main/scala/spark/streaming/util/MasterFailureTest.scala b/streaming/src/main/scala/spark/streaming/util/MasterFailureTest.scala deleted file mode 100644 index 8ce5d8daf5..0000000000 --- a/streaming/src/main/scala/spark/streaming/util/MasterFailureTest.scala +++ /dev/null @@ -1,414 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.streaming.util - -import spark.{Logging, RDD} -import spark.streaming._ -import spark.streaming.dstream.ForEachDStream -import StreamingContext._ - -import scala.util.Random -import scala.collection.mutable.{SynchronizedBuffer, ArrayBuffer} - -import java.io.{File, ObjectInputStream, IOException} -import java.util.UUID - -import com.google.common.io.Files - -import org.apache.commons.io.FileUtils -import org.apache.hadoop.fs.{FileUtil, FileSystem, Path} -import org.apache.hadoop.conf.Configuration - - -private[streaming] -object MasterFailureTest extends Logging { - initLogging() - - @volatile var killed = false - @volatile var killCount = 0 - - def main(args: Array[String]) { - if (args.size < 2) { - println( - "Usage: MasterFailureTest <# batches> []") - System.exit(1) - } - val directory = args(0) - val numBatches = args(1).toInt - val batchDuration = if (args.size > 2) Milliseconds(args(2).toInt) else Seconds(1) - - println("\n\n========================= MAP TEST =========================\n\n") - testMap(directory, numBatches, batchDuration) - - println("\n\n================= UPDATE-STATE-BY-KEY TEST =================\n\n") - testUpdateStateByKey(directory, numBatches, batchDuration) - - println("\n\nSUCCESS\n\n") - } - - def testMap(directory: String, numBatches: Int, batchDuration: Duration) { - // Input: time=1 ==> [ 1 ] , time=2 ==> [ 2 ] , time=3 ==> [ 3 ] , ... - val input = (1 to numBatches).map(_.toString).toSeq - // Expected output: time=1 ==> [ 1 ] , time=2 ==> [ 2 ] , time=3 ==> [ 3 ] , ... - val expectedOutput = (1 to numBatches) - - val operation = (st: DStream[String]) => st.map(_.toInt) - - // Run streaming operation with multiple master failures - val output = testOperation(directory, batchDuration, input, operation, expectedOutput) - - logInfo("Expected output, size = " + expectedOutput.size) - logInfo(expectedOutput.mkString("[", ",", "]")) - logInfo("Output, size = " + output.size) - logInfo(output.mkString("[", ",", "]")) - - // Verify whether all the values of the expected output is present - // in the output - assert(output.distinct.toSet == expectedOutput.toSet) - } - - - def testUpdateStateByKey(directory: String, numBatches: Int, batchDuration: Duration) { - // Input: time=1 ==> [ a ] , time=2 ==> [ a, a ] , time=3 ==> [ a, a, a ] , ... - val input = (1 to numBatches).map(i => (1 to i).map(_ => "a").mkString(" ")).toSeq - // Expected output: time=1 ==> [ (a, 1) ] , time=2 ==> [ (a, 3) ] , time=3 ==> [ (a,6) ] , ... - val expectedOutput = (1L to numBatches).map(i => (1L to i).reduce(_ + _)).map(j => ("a", j)) - - val operation = (st: DStream[String]) => { - val updateFunc = (values: Seq[Long], state: Option[Long]) => { - Some(values.foldLeft(0L)(_ + _) + state.getOrElse(0L)) - } - st.flatMap(_.split(" ")) - .map(x => (x, 1L)) - .updateStateByKey[Long](updateFunc) - .checkpoint(batchDuration * 5) - } - - // Run streaming operation with multiple master failures - val output = testOperation(directory, batchDuration, input, operation, expectedOutput) - - logInfo("Expected output, size = " + expectedOutput.size + "\n" + expectedOutput) - logInfo("Output, size = " + output.size + "\n" + output) - - // Verify whether all the values in the output are among the expected output values - output.foreach(o => - assert(expectedOutput.contains(o), "Expected value " + o + " not found") - ) - - // Verify whether the last expected output value has been generated, there by - // confirming that none of the inputs have been missed - assert(output.last == expectedOutput.last) - } - - /** - * Tests stream operation with multiple master failures, and verifies whether the - * final set of output values is as expected or not. - */ - def testOperation[T: ClassManifest]( - directory: String, - batchDuration: Duration, - input: Seq[String], - operation: DStream[String] => DStream[T], - expectedOutput: Seq[T] - ): Seq[T] = { - - // Just making sure that the expected output does not have duplicates - assert(expectedOutput.distinct.toSet == expectedOutput.toSet) - - // Setup the stream computation with the given operation - val (ssc, checkpointDir, testDir) = setupStreams(directory, batchDuration, operation) - - // Start generating files in the a different thread - val fileGeneratingThread = new FileGeneratingThread(input, testDir, batchDuration.milliseconds) - fileGeneratingThread.start() - - // Run the streams and repeatedly kill it until the last expected output - // has been generated, or until it has run for twice the expected time - val lastExpectedOutput = expectedOutput.last - val maxTimeToRun = expectedOutput.size * batchDuration.milliseconds * 2 - val mergedOutput = runStreams(ssc, lastExpectedOutput, maxTimeToRun) - - // Delete directories - fileGeneratingThread.join() - val fs = checkpointDir.getFileSystem(new Configuration()) - fs.delete(checkpointDir, true) - fs.delete(testDir, true) - logInfo("Finished test after " + killCount + " failures") - mergedOutput - } - - /** - * Sets up the stream computation with the given operation, directory (local or HDFS), - * and batch duration. Returns the streaming context and the directory to which - * files should be written for testing. - */ - private def setupStreams[T: ClassManifest]( - directory: String, - batchDuration: Duration, - operation: DStream[String] => DStream[T] - ): (StreamingContext, Path, Path) = { - // Reset all state - reset() - - // Create the directories for this test - val uuid = UUID.randomUUID().toString - val rootDir = new Path(directory, uuid) - val fs = rootDir.getFileSystem(new Configuration()) - val checkpointDir = new Path(rootDir, "checkpoint") - val testDir = new Path(rootDir, "test") - fs.mkdirs(checkpointDir) - fs.mkdirs(testDir) - - // Setup the streaming computation with the given operation - System.clearProperty("spark.driver.port") - System.clearProperty("spark.hostPort") - var ssc = new StreamingContext("local[4]", "MasterFailureTest", batchDuration, null, Nil, Map()) - ssc.checkpoint(checkpointDir.toString) - val inputStream = ssc.textFileStream(testDir.toString) - val operatedStream = operation(inputStream) - val outputStream = new TestOutputStream(operatedStream) - ssc.registerOutputStream(outputStream) - (ssc, checkpointDir, testDir) - } - - - /** - * Repeatedly starts and kills the streaming context until timed out or - * the last expected output is generated. Finally, return - */ - private def runStreams[T: ClassManifest]( - ssc_ : StreamingContext, - lastExpectedOutput: T, - maxTimeToRun: Long - ): Seq[T] = { - - var ssc = ssc_ - var totalTimeRan = 0L - var isLastOutputGenerated = false - var isTimedOut = false - val mergedOutput = new ArrayBuffer[T]() - val checkpointDir = ssc.checkpointDir - var batchDuration = ssc.graph.batchDuration - - while(!isLastOutputGenerated && !isTimedOut) { - // Get the output buffer - val outputBuffer = ssc.graph.getOutputStreams.head.asInstanceOf[TestOutputStream[T]].output - def output = outputBuffer.flatMap(x => x) - - // Start the thread to kill the streaming after some time - killed = false - val killingThread = new KillingThread(ssc, batchDuration.milliseconds * 10) - killingThread.start() - - var timeRan = 0L - try { - // Start the streaming computation and let it run while ... - // (i) StreamingContext has not been shut down yet - // (ii) The last expected output has not been generated yet - // (iii) Its not timed out yet - System.clearProperty("spark.streaming.clock") - System.clearProperty("spark.driver.port") - System.clearProperty("spark.hostPort") - ssc.start() - val startTime = System.currentTimeMillis() - while (!killed && !isLastOutputGenerated && !isTimedOut) { - Thread.sleep(100) - timeRan = System.currentTimeMillis() - startTime - isLastOutputGenerated = (!output.isEmpty && output.last == lastExpectedOutput) - isTimedOut = (timeRan + totalTimeRan > maxTimeToRun) - } - } catch { - case e: Exception => logError("Error running streaming context", e) - } - if (killingThread.isAlive) killingThread.interrupt() - ssc.stop() - - logInfo("Has been killed = " + killed) - logInfo("Is last output generated = " + isLastOutputGenerated) - logInfo("Is timed out = " + isTimedOut) - - // Verify whether the output of each batch has only one element or no element - // and then merge the new output with all the earlier output - mergedOutput ++= output - totalTimeRan += timeRan - logInfo("New output = " + output) - logInfo("Merged output = " + mergedOutput) - logInfo("Time ran = " + timeRan) - logInfo("Total time ran = " + totalTimeRan) - - if (!isLastOutputGenerated && !isTimedOut) { - val sleepTime = Random.nextInt(batchDuration.milliseconds.toInt * 10) - logInfo( - "\n-------------------------------------------\n" + - " Restarting stream computation in " + sleepTime + " ms " + - "\n-------------------------------------------\n" - ) - Thread.sleep(sleepTime) - // Recreate the streaming context from checkpoint - ssc = new StreamingContext(checkpointDir) - } - } - mergedOutput - } - - /** - * Verifies the output value are the same as expected. Since failures can lead to - * a batch being processed twice, a batches output may appear more than once - * consecutively. To avoid getting confused with those, we eliminate consecutive - * duplicate batch outputs of values from the `output`. As a result, the - * expected output should not have consecutive batches with the same values as output. - */ - private def verifyOutput[T: ClassManifest](output: Seq[T], expectedOutput: Seq[T]) { - // Verify whether expected outputs do not consecutive batches with same output - for (i <- 0 until expectedOutput.size - 1) { - assert(expectedOutput(i) != expectedOutput(i+1), - "Expected output has consecutive duplicate sequence of values") - } - - // Log the output - println("Expected output, size = " + expectedOutput.size) - println(expectedOutput.mkString("[", ",", "]")) - println("Output, size = " + output.size) - println(output.mkString("[", ",", "]")) - - // Match the output with the expected output - output.foreach(o => - assert(expectedOutput.contains(o), "Expected value " + o + " not found") - ) - } - - /** Resets counter to prepare for the test */ - private def reset() { - killed = false - killCount = 0 - } -} - -/** - * This is a output stream just for testing. All the output is collected into a - * ArrayBuffer. This buffer is wiped clean on being restored from checkpoint. - */ -private[streaming] -class TestOutputStream[T: ClassManifest]( - parent: DStream[T], - val output: ArrayBuffer[Seq[T]] = new ArrayBuffer[Seq[T]] with SynchronizedBuffer[Seq[T]] - ) extends ForEachDStream[T]( - parent, - (rdd: RDD[T], t: Time) => { - val collected = rdd.collect() - output += collected - } - ) { - - // This is to clear the output buffer every it is read from a checkpoint - @throws(classOf[IOException]) - private def readObject(ois: ObjectInputStream) { - ois.defaultReadObject() - output.clear() - } -} - - -/** - * Thread to kill streaming context after a random period of time. - */ -private[streaming] -class KillingThread(ssc: StreamingContext, maxKillWaitTime: Long) extends Thread with Logging { - initLogging() - - override def run() { - try { - // If it is the first killing, then allow the first checkpoint to be created - var minKillWaitTime = if (MasterFailureTest.killCount == 0) 5000 else 2000 - val killWaitTime = minKillWaitTime + math.abs(Random.nextLong % maxKillWaitTime) - logInfo("Kill wait time = " + killWaitTime) - Thread.sleep(killWaitTime) - logInfo( - "\n---------------------------------------\n" + - "Killing streaming context after " + killWaitTime + " ms" + - "\n---------------------------------------\n" - ) - if (ssc != null) { - ssc.stop() - MasterFailureTest.killed = true - MasterFailureTest.killCount += 1 - } - logInfo("Killing thread finished normally") - } catch { - case ie: InterruptedException => logInfo("Killing thread interrupted") - case e: Exception => logWarning("Exception in killing thread", e) - } - - } -} - - -/** - * Thread to generate input files periodically with the desired text. - */ -private[streaming] -class FileGeneratingThread(input: Seq[String], testDir: Path, interval: Long) - extends Thread with Logging { - initLogging() - - override def run() { - val localTestDir = Files.createTempDir() - var fs = testDir.getFileSystem(new Configuration()) - val maxTries = 3 - try { - Thread.sleep(5000) // To make sure that all the streaming context has been set up - for (i <- 0 until input.size) { - // Write the data to a local file and then move it to the target test directory - val localFile = new File(localTestDir, (i+1).toString) - val hadoopFile = new Path(testDir, (i+1).toString) - val tempHadoopFile = new Path(testDir, ".tmp_" + (i+1).toString) - FileUtils.writeStringToFile(localFile, input(i).toString + "\n") - var tries = 0 - var done = false - while (!done && tries < maxTries) { - tries += 1 - try { - // fs.copyFromLocalFile(new Path(localFile.toString), hadoopFile) - fs.copyFromLocalFile(new Path(localFile.toString), tempHadoopFile) - fs.rename(tempHadoopFile, hadoopFile) - done = true - } catch { - case ioe: IOException => { - fs = testDir.getFileSystem(new Configuration()) - logWarning("Attempt " + tries + " at generating file " + hadoopFile + " failed.", ioe) - } - } - } - if (!done) - logError("Could not generate file " + hadoopFile) - else - logInfo("Generated file " + hadoopFile + " at " + System.currentTimeMillis) - Thread.sleep(interval) - localFile.delete() - } - logInfo("File generating thread finished normally") - } catch { - case ie: InterruptedException => logInfo("File generating thread interrupted") - case e: Exception => logWarning("File generating in killing thread", e) - } finally { - fs.close() - } - } -} - - diff --git a/streaming/src/main/scala/spark/streaming/util/RawTextHelper.scala b/streaming/src/main/scala/spark/streaming/util/RawTextHelper.scala deleted file mode 100644 index bf04120293..0000000000 --- a/streaming/src/main/scala/spark/streaming/util/RawTextHelper.scala +++ /dev/null @@ -1,115 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.streaming.util - -import spark.SparkContext -import spark.SparkContext._ -import it.unimi.dsi.fastutil.objects.{Object2LongOpenHashMap => OLMap} -import scala.collection.JavaConversions.mapAsScalaMap - -object RawTextHelper { - - /** - * Splits lines and counts the words in them using specialized object-to-long hashmap - * (to avoid boxing-unboxing overhead of Long in java/scala HashMap) - */ - def splitAndCountPartitions(iter: Iterator[String]): Iterator[(String, Long)] = { - val map = new OLMap[String] - var i = 0 - var j = 0 - while (iter.hasNext) { - val s = iter.next() - i = 0 - while (i < s.length) { - j = i - while (j < s.length && s.charAt(j) != ' ') { - j += 1 - } - if (j > i) { - val w = s.substring(i, j) - val c = map.getLong(w) - map.put(w, c + 1) - } - i = j - while (i < s.length && s.charAt(i) == ' ') { - i += 1 - } - } - } - map.toIterator.map{case (k, v) => (k, v)} - } - - /** - * Gets the top k words in terms of word counts. Assumes that each word exists only once - * in the `data` iterator (that is, the counts have been reduced). - */ - def topK(data: Iterator[(String, Long)], k: Int): Iterator[(String, Long)] = { - val taken = new Array[(String, Long)](k) - - var i = 0 - var len = 0 - var done = false - var value: (String, Long) = null - var swap: (String, Long) = null - var count = 0 - - while(data.hasNext) { - value = data.next - if (value != null) { - count += 1 - if (len == 0) { - taken(0) = value - len = 1 - } else if (len < k || value._2 > taken(len - 1)._2) { - if (len < k) { - len += 1 - } - taken(len - 1) = value - i = len - 1 - while(i > 0 && taken(i - 1)._2 < taken(i)._2) { - swap = taken(i) - taken(i) = taken(i-1) - taken(i - 1) = swap - i -= 1 - } - } - } - } - return taken.toIterator - } - - /** - * Warms up the SparkContext in master and slave by running tasks to force JIT kick in - * before real workload starts. - */ - def warmUp(sc: SparkContext) { - for(i <- 0 to 1) { - sc.parallelize(1 to 200000, 1000) - .map(_ % 1331).map(_.toString) - .mapPartitions(splitAndCountPartitions).reduceByKey(_ + _, 10) - .count() - } - } - - def add(v1: Long, v2: Long) = (v1 + v2) - - def subtract(v1: Long, v2: Long) = (v1 - v2) - - def max(v1: Long, v2: Long) = math.max(v1, v2) -} - diff --git a/streaming/src/main/scala/spark/streaming/util/RawTextSender.scala b/streaming/src/main/scala/spark/streaming/util/RawTextSender.scala deleted file mode 100644 index 5cc6ad9dee..0000000000 --- a/streaming/src/main/scala/spark/streaming/util/RawTextSender.scala +++ /dev/null @@ -1,77 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.streaming.util - -import java.nio.ByteBuffer -import spark.util.{RateLimitedOutputStream, IntParam} -import java.net.ServerSocket -import spark.{Logging, KryoSerializer} -import it.unimi.dsi.fastutil.io.FastByteArrayOutputStream -import scala.io.Source -import java.io.IOException - -/** - * A helper program that sends blocks of Kryo-serialized text strings out on a socket at a - * specified rate. Used to feed data into RawInputDStream. - */ -object RawTextSender extends Logging { - def main(args: Array[String]) { - if (args.length != 4) { - System.err.println("Usage: RawTextSender ") - System.exit(1) - } - // Parse the arguments using a pattern match - val Array(IntParam(port), file, IntParam(blockSize), IntParam(bytesPerSec)) = args - - // Repeat the input data multiple times to fill in a buffer - val lines = Source.fromFile(file).getLines().toArray - val bufferStream = new FastByteArrayOutputStream(blockSize + 1000) - val ser = new KryoSerializer().newInstance() - val serStream = ser.serializeStream(bufferStream) - var i = 0 - while (bufferStream.position < blockSize) { - serStream.writeObject(lines(i)) - i = (i + 1) % lines.length - } - bufferStream.trim() - val array = bufferStream.array - - val countBuf = ByteBuffer.wrap(new Array[Byte](4)) - countBuf.putInt(array.length) - countBuf.flip() - - val serverSocket = new ServerSocket(port) - logInfo("Listening on port " + port) - - while (true) { - val socket = serverSocket.accept() - logInfo("Got a new connection") - val out = new RateLimitedOutputStream(socket.getOutputStream, bytesPerSec) - try { - while (true) { - out.write(countBuf.array) - out.write(array) - } - } catch { - case e: IOException => - logError("Client disconnected") - socket.close() - } - } - } -} diff --git a/streaming/src/main/scala/spark/streaming/util/RecurringTimer.scala b/streaming/src/main/scala/spark/streaming/util/RecurringTimer.scala deleted file mode 100644 index 7ecc44236d..0000000000 --- a/streaming/src/main/scala/spark/streaming/util/RecurringTimer.scala +++ /dev/null @@ -1,94 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.streaming.util - -private[streaming] -class RecurringTimer(val clock: Clock, val period: Long, val callback: (Long) => Unit) { - - private val minPollTime = 25L - - private val pollTime = { - if (period / 10.0 > minPollTime) { - (period / 10.0).toLong - } else { - minPollTime - } - } - - private val thread = new Thread() { - override def run() { loop } - } - - private var nextTime = 0L - - def getStartTime(): Long = { - (math.floor(clock.currentTime.toDouble / period) + 1).toLong * period - } - - def getRestartTime(originalStartTime: Long): Long = { - val gap = clock.currentTime - originalStartTime - (math.floor(gap.toDouble / period).toLong + 1) * period + originalStartTime - } - - def start(startTime: Long): Long = { - nextTime = startTime - thread.start() - nextTime - } - - def start(): Long = { - start(getStartTime()) - } - - def stop() { - thread.interrupt() - } - - private def loop() { - try { - while (true) { - clock.waitTillTime(nextTime) - callback(nextTime) - nextTime += period - } - - } catch { - case e: InterruptedException => - } - } -} - -private[streaming] -object RecurringTimer { - - def main(args: Array[String]) { - var lastRecurTime = 0L - val period = 1000 - - def onRecur(time: Long) { - val currentTime = System.currentTimeMillis() - println("" + currentTime + ": " + (currentTime - lastRecurTime)) - lastRecurTime = currentTime - } - val timer = new RecurringTimer(new SystemClock(), period, onRecur) - timer.start() - Thread.sleep(30 * 1000) - timer.stop() - } -} - diff --git a/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java b/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java new file mode 100644 index 0000000000..c0d729ff87 --- /dev/null +++ b/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java @@ -0,0 +1,1304 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming; + +import com.google.common.base.Optional; +import com.google.common.collect.Lists; +import com.google.common.collect.Maps; +import com.google.common.io.Files; +import kafka.serializer.StringDecoder; +import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; +import scala.Tuple2; +import org.apache.spark.HashPartitioner; +import org.apache.spark.api.java.JavaPairRDD; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaRDDLike; +import org.apache.spark.api.java.JavaPairRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.api.java.function.*; +import org.apache.spark.storage.StorageLevel; +import org.apache.spark.streaming.api.java.JavaDStream; +import org.apache.spark.streaming.api.java.JavaPairDStream; +import org.apache.spark.streaming.api.java.JavaStreamingContext; +import org.apache.spark.streaming.JavaTestUtils; +import org.apache.spark.streaming.JavaCheckpointTestUtils; +import org.apache.spark.streaming.InputStreamsSuite; + +import java.io.*; +import java.util.*; + +import akka.actor.Props; +import akka.zeromq.Subscribe; + + + +// The test suite itself is Serializable so that anonymous Function implementations can be +// serialized, as an alternative to converting these anonymous classes to static inner classes; +// see http://stackoverflow.com/questions/758570/. +public class JavaAPISuite implements Serializable { + private transient JavaStreamingContext ssc; + + @Before + public void setUp() { + System.setProperty("spark.streaming.clock", "org.apache.spark.streaming.util.ManualClock"); + ssc = new JavaStreamingContext("local[2]", "test", new Duration(1000)); + ssc.checkpoint("checkpoint"); + } + + @After + public void tearDown() { + ssc.stop(); + ssc = null; + + // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown + System.clearProperty("spark.driver.port"); + } + + @Test + public void testCount() { + List> inputData = Arrays.asList( + Arrays.asList(1,2,3,4), + Arrays.asList(3,4,5), + Arrays.asList(3)); + + List> expected = Arrays.asList( + Arrays.asList(4L), + Arrays.asList(3L), + Arrays.asList(1L)); + + JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); + JavaDStream count = stream.count(); + JavaTestUtils.attachTestOutputStream(count); + List> result = JavaTestUtils.runStreams(ssc, 3, 3); + assertOrderInvariantEquals(expected, result); + } + + @Test + public void testMap() { + List> inputData = Arrays.asList( + Arrays.asList("hello", "world"), + Arrays.asList("goodnight", "moon")); + + List> expected = Arrays.asList( + Arrays.asList(5,5), + Arrays.asList(9,4)); + + JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); + JavaDStream letterCount = stream.map(new Function() { + @Override + public Integer call(String s) throws Exception { + return s.length(); + } + }); + JavaTestUtils.attachTestOutputStream(letterCount); + List> result = JavaTestUtils.runStreams(ssc, 2, 2); + + assertOrderInvariantEquals(expected, result); + } + + @Test + public void testWindow() { + List> inputData = Arrays.asList( + Arrays.asList(1,2,3), + Arrays.asList(4,5,6), + Arrays.asList(7,8,9)); + + List> expected = Arrays.asList( + Arrays.asList(1,2,3), + Arrays.asList(4,5,6,1,2,3), + Arrays.asList(7,8,9,4,5,6), + Arrays.asList(7,8,9)); + + JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); + JavaDStream windowed = stream.window(new Duration(2000)); + JavaTestUtils.attachTestOutputStream(windowed); + List> result = JavaTestUtils.runStreams(ssc, 4, 4); + + assertOrderInvariantEquals(expected, result); + } + + @Test + public void testWindowWithSlideDuration() { + List> inputData = Arrays.asList( + Arrays.asList(1,2,3), + Arrays.asList(4,5,6), + Arrays.asList(7,8,9), + Arrays.asList(10,11,12), + Arrays.asList(13,14,15), + Arrays.asList(16,17,18)); + + List> expected = Arrays.asList( + Arrays.asList(1,2,3,4,5,6), + Arrays.asList(1,2,3,4,5,6,7,8,9,10,11,12), + Arrays.asList(7,8,9,10,11,12,13,14,15,16,17,18), + Arrays.asList(13,14,15,16,17,18)); + + JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); + JavaDStream windowed = stream.window(new Duration(4000), new Duration(2000)); + JavaTestUtils.attachTestOutputStream(windowed); + List> result = JavaTestUtils.runStreams(ssc, 8, 4); + + assertOrderInvariantEquals(expected, result); + } + + @Test + public void testFilter() { + List> inputData = Arrays.asList( + Arrays.asList("giants", "dodgers"), + Arrays.asList("yankees", "red socks")); + + List> expected = Arrays.asList( + Arrays.asList("giants"), + Arrays.asList("yankees")); + + JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); + JavaDStream filtered = stream.filter(new Function() { + @Override + public Boolean call(String s) throws Exception { + return s.contains("a"); + } + }); + JavaTestUtils.attachTestOutputStream(filtered); + List> result = JavaTestUtils.runStreams(ssc, 2, 2); + + assertOrderInvariantEquals(expected, result); + } + + @Test + public void testGlom() { + List> inputData = Arrays.asList( + Arrays.asList("giants", "dodgers"), + Arrays.asList("yankees", "red socks")); + + List>> expected = Arrays.asList( + Arrays.asList(Arrays.asList("giants", "dodgers")), + Arrays.asList(Arrays.asList("yankees", "red socks"))); + + JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); + JavaDStream glommed = stream.glom(); + JavaTestUtils.attachTestOutputStream(glommed); + List>> result = JavaTestUtils.runStreams(ssc, 2, 2); + + Assert.assertEquals(expected, result); + } + + @Test + public void testMapPartitions() { + List> inputData = Arrays.asList( + Arrays.asList("giants", "dodgers"), + Arrays.asList("yankees", "red socks")); + + List> expected = Arrays.asList( + Arrays.asList("GIANTSDODGERS"), + Arrays.asList("YANKEESRED SOCKS")); + + JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); + JavaDStream mapped = stream.mapPartitions(new FlatMapFunction, String>() { + @Override + public Iterable call(Iterator in) { + String out = ""; + while (in.hasNext()) { + out = out + in.next().toUpperCase(); + } + return Lists.newArrayList(out); + } + }); + JavaTestUtils.attachTestOutputStream(mapped); + List>> result = JavaTestUtils.runStreams(ssc, 2, 2); + + Assert.assertEquals(expected, result); + } + + private class IntegerSum extends Function2 { + @Override + public Integer call(Integer i1, Integer i2) throws Exception { + return i1 + i2; + } + } + + private class IntegerDifference extends Function2 { + @Override + public Integer call(Integer i1, Integer i2) throws Exception { + return i1 - i2; + } + } + + @Test + public void testReduce() { + List> inputData = Arrays.asList( + Arrays.asList(1,2,3), + Arrays.asList(4,5,6), + Arrays.asList(7,8,9)); + + List> expected = Arrays.asList( + Arrays.asList(6), + Arrays.asList(15), + Arrays.asList(24)); + + JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); + JavaDStream reduced = stream.reduce(new IntegerSum()); + JavaTestUtils.attachTestOutputStream(reduced); + List> result = JavaTestUtils.runStreams(ssc, 3, 3); + + Assert.assertEquals(expected, result); + } + + @Test + public void testReduceByWindow() { + List> inputData = Arrays.asList( + Arrays.asList(1,2,3), + Arrays.asList(4,5,6), + Arrays.asList(7,8,9)); + + List> expected = Arrays.asList( + Arrays.asList(6), + Arrays.asList(21), + Arrays.asList(39), + Arrays.asList(24)); + + JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); + JavaDStream reducedWindowed = stream.reduceByWindow(new IntegerSum(), + new IntegerDifference(), new Duration(2000), new Duration(1000)); + JavaTestUtils.attachTestOutputStream(reducedWindowed); + List> result = JavaTestUtils.runStreams(ssc, 4, 4); + + Assert.assertEquals(expected, result); + } + + @Test + public void testQueueStream() { + List> expected = Arrays.asList( + Arrays.asList(1,2,3), + Arrays.asList(4,5,6), + Arrays.asList(7,8,9)); + + JavaSparkContext jsc = new JavaSparkContext(ssc.ssc().sc()); + JavaRDD rdd1 = ssc.sc().parallelize(Arrays.asList(1,2,3)); + JavaRDD rdd2 = ssc.sc().parallelize(Arrays.asList(4,5,6)); + JavaRDD rdd3 = ssc.sc().parallelize(Arrays.asList(7,8,9)); + + LinkedList> rdds = Lists.newLinkedList(); + rdds.add(rdd1); + rdds.add(rdd2); + rdds.add(rdd3); + + JavaDStream stream = ssc.queueStream(rdds); + JavaTestUtils.attachTestOutputStream(stream); + List> result = JavaTestUtils.runStreams(ssc, 3, 3); + Assert.assertEquals(expected, result); + } + + @Test + public void testTransform() { + List> inputData = Arrays.asList( + Arrays.asList(1,2,3), + Arrays.asList(4,5,6), + Arrays.asList(7,8,9)); + + List> expected = Arrays.asList( + Arrays.asList(3,4,5), + Arrays.asList(6,7,8), + Arrays.asList(9,10,11)); + + JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); + JavaDStream transformed = + stream.transform(new Function, JavaRDD>() { + @Override + public JavaRDD call(JavaRDD in) throws Exception { + return in.map(new Function() { + @Override + public Integer call(Integer i) throws Exception { + return i + 2; + } + }); + }}); + JavaTestUtils.attachTestOutputStream(transformed); + List> result = JavaTestUtils.runStreams(ssc, 3, 3); + + assertOrderInvariantEquals(expected, result); + } + + @Test + public void testFlatMap() { + List> inputData = Arrays.asList( + Arrays.asList("go", "giants"), + Arrays.asList("boo", "dodgers"), + Arrays.asList("athletics")); + + List> expected = Arrays.asList( + Arrays.asList("g","o","g","i","a","n","t","s"), + Arrays.asList("b", "o", "o", "d","o","d","g","e","r","s"), + Arrays.asList("a","t","h","l","e","t","i","c","s")); + + JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); + JavaDStream flatMapped = stream.flatMap(new FlatMapFunction() { + @Override + public Iterable call(String x) { + return Lists.newArrayList(x.split("(?!^)")); + } + }); + JavaTestUtils.attachTestOutputStream(flatMapped); + List> result = JavaTestUtils.runStreams(ssc, 3, 3); + + assertOrderInvariantEquals(expected, result); + } + + @Test + public void testPairFlatMap() { + List> inputData = Arrays.asList( + Arrays.asList("giants"), + Arrays.asList("dodgers"), + Arrays.asList("athletics")); + + List>> expected = Arrays.asList( + Arrays.asList( + new Tuple2(6, "g"), + new Tuple2(6, "i"), + new Tuple2(6, "a"), + new Tuple2(6, "n"), + new Tuple2(6, "t"), + new Tuple2(6, "s")), + Arrays.asList( + new Tuple2(7, "d"), + new Tuple2(7, "o"), + new Tuple2(7, "d"), + new Tuple2(7, "g"), + new Tuple2(7, "e"), + new Tuple2(7, "r"), + new Tuple2(7, "s")), + Arrays.asList( + new Tuple2(9, "a"), + new Tuple2(9, "t"), + new Tuple2(9, "h"), + new Tuple2(9, "l"), + new Tuple2(9, "e"), + new Tuple2(9, "t"), + new Tuple2(9, "i"), + new Tuple2(9, "c"), + new Tuple2(9, "s"))); + + JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); + JavaPairDStream flatMapped = stream.flatMap(new PairFlatMapFunction() { + @Override + public Iterable> call(String in) throws Exception { + List> out = Lists.newArrayList(); + for (String letter: in.split("(?!^)")) { + out.add(new Tuple2(in.length(), letter)); + } + return out; + } + }); + JavaTestUtils.attachTestOutputStream(flatMapped); + List>> result = JavaTestUtils.runStreams(ssc, 3, 3); + + Assert.assertEquals(expected, result); + } + + @Test + public void testUnion() { + List> inputData1 = Arrays.asList( + Arrays.asList(1,1), + Arrays.asList(2,2), + Arrays.asList(3,3)); + + List> inputData2 = Arrays.asList( + Arrays.asList(4,4), + Arrays.asList(5,5), + Arrays.asList(6,6)); + + List> expected = Arrays.asList( + Arrays.asList(1,1,4,4), + Arrays.asList(2,2,5,5), + Arrays.asList(3,3,6,6)); + + JavaDStream stream1 = JavaTestUtils.attachTestInputStream(ssc, inputData1, 2); + JavaDStream stream2 = JavaTestUtils.attachTestInputStream(ssc, inputData2, 2); + + JavaDStream unioned = stream1.union(stream2); + JavaTestUtils.attachTestOutputStream(unioned); + List> result = JavaTestUtils.runStreams(ssc, 3, 3); + + assertOrderInvariantEquals(expected, result); + } + + /* + * Performs an order-invariant comparison of lists representing two RDD streams. This allows + * us to account for ordering variation within individual RDD's which occurs during windowing. + */ + public static void assertOrderInvariantEquals( + List> expected, List> actual) { + for (List list: expected) { + Collections.sort(list); + } + for (List list: actual) { + Collections.sort(list); + } + Assert.assertEquals(expected, actual); + } + + + // PairDStream Functions + @Test + public void testPairFilter() { + List> inputData = Arrays.asList( + Arrays.asList("giants", "dodgers"), + Arrays.asList("yankees", "red socks")); + + List>> expected = Arrays.asList( + Arrays.asList(new Tuple2("giants", 6)), + Arrays.asList(new Tuple2("yankees", 7))); + + JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); + JavaPairDStream pairStream = stream.map( + new PairFunction() { + @Override + public Tuple2 call(String in) throws Exception { + return new Tuple2(in, in.length()); + } + }); + + JavaPairDStream filtered = pairStream.filter( + new Function, Boolean>() { + @Override + public Boolean call(Tuple2 in) throws Exception { + return in._1().contains("a"); + } + }); + JavaTestUtils.attachTestOutputStream(filtered); + List>> result = JavaTestUtils.runStreams(ssc, 2, 2); + + Assert.assertEquals(expected, result); + } + + List>> stringStringKVStream = Arrays.asList( + Arrays.asList(new Tuple2("california", "dodgers"), + new Tuple2("california", "giants"), + new Tuple2("new york", "yankees"), + new Tuple2("new york", "mets")), + Arrays.asList(new Tuple2("california", "sharks"), + new Tuple2("california", "ducks"), + new Tuple2("new york", "rangers"), + new Tuple2("new york", "islanders"))); + + List>> stringIntKVStream = Arrays.asList( + Arrays.asList( + new Tuple2("california", 1), + new Tuple2("california", 3), + new Tuple2("new york", 4), + new Tuple2("new york", 1)), + Arrays.asList( + new Tuple2("california", 5), + new Tuple2("california", 5), + new Tuple2("new york", 3), + new Tuple2("new york", 1))); + + @Test + public void testPairMap() { // Maps pair -> pair of different type + List>> inputData = stringIntKVStream; + + List>> expected = Arrays.asList( + Arrays.asList( + new Tuple2(1, "california"), + new Tuple2(3, "california"), + new Tuple2(4, "new york"), + new Tuple2(1, "new york")), + Arrays.asList( + new Tuple2(5, "california"), + new Tuple2(5, "california"), + new Tuple2(3, "new york"), + new Tuple2(1, "new york"))); + + JavaDStream> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); + JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); + JavaPairDStream reversed = pairStream.map( + new PairFunction, Integer, String>() { + @Override + public Tuple2 call(Tuple2 in) throws Exception { + return in.swap(); + } + }); + + JavaTestUtils.attachTestOutputStream(reversed); + List>> result = JavaTestUtils.runStreams(ssc, 2, 2); + + Assert.assertEquals(expected, result); + } + + @Test + public void testPairMapPartitions() { // Maps pair -> pair of different type + List>> inputData = stringIntKVStream; + + List>> expected = Arrays.asList( + Arrays.asList( + new Tuple2(1, "california"), + new Tuple2(3, "california"), + new Tuple2(4, "new york"), + new Tuple2(1, "new york")), + Arrays.asList( + new Tuple2(5, "california"), + new Tuple2(5, "california"), + new Tuple2(3, "new york"), + new Tuple2(1, "new york"))); + + JavaDStream> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); + JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); + JavaPairDStream reversed = pairStream.mapPartitions( + new PairFlatMapFunction>, Integer, String>() { + @Override + public Iterable> call(Iterator> in) throws Exception { + LinkedList> out = new LinkedList>(); + while (in.hasNext()) { + Tuple2 next = in.next(); + out.add(next.swap()); + } + return out; + } + }); + + JavaTestUtils.attachTestOutputStream(reversed); + List>> result = JavaTestUtils.runStreams(ssc, 2, 2); + + Assert.assertEquals(expected, result); + } + + @Test + public void testPairMap2() { // Maps pair -> single + List>> inputData = stringIntKVStream; + + List> expected = Arrays.asList( + Arrays.asList(1, 3, 4, 1), + Arrays.asList(5, 5, 3, 1)); + + JavaDStream> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); + JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); + JavaDStream reversed = pairStream.map( + new Function, Integer>() { + @Override + public Integer call(Tuple2 in) throws Exception { + return in._2(); + } + }); + + JavaTestUtils.attachTestOutputStream(reversed); + List>> result = JavaTestUtils.runStreams(ssc, 2, 2); + + Assert.assertEquals(expected, result); + } + + @Test + public void testPairToPairFlatMapWithChangingTypes() { // Maps pair -> pair + List>> inputData = Arrays.asList( + Arrays.asList( + new Tuple2("hi", 1), + new Tuple2("ho", 2)), + Arrays.asList( + new Tuple2("hi", 1), + new Tuple2("ho", 2))); + + List>> expected = Arrays.asList( + Arrays.asList( + new Tuple2(1, "h"), + new Tuple2(1, "i"), + new Tuple2(2, "h"), + new Tuple2(2, "o")), + Arrays.asList( + new Tuple2(1, "h"), + new Tuple2(1, "i"), + new Tuple2(2, "h"), + new Tuple2(2, "o"))); + + JavaDStream> stream = + JavaTestUtils.attachTestInputStream(ssc, inputData, 1); + JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); + JavaPairDStream flatMapped = pairStream.flatMap( + new PairFlatMapFunction, Integer, String>() { + @Override + public Iterable> call(Tuple2 in) throws Exception { + List> out = new LinkedList>(); + for (Character s : in._1().toCharArray()) { + out.add(new Tuple2(in._2(), s.toString())); + } + return out; + } + }); + JavaTestUtils.attachTestOutputStream(flatMapped); + List>> result = JavaTestUtils.runStreams(ssc, 2, 2); + + Assert.assertEquals(expected, result); + } + + @Test + public void testPairGroupByKey() { + List>> inputData = stringStringKVStream; + + List>>> expected = Arrays.asList( + Arrays.asList( + new Tuple2>("california", Arrays.asList("dodgers", "giants")), + new Tuple2>("new york", Arrays.asList("yankees", "mets"))), + Arrays.asList( + new Tuple2>("california", Arrays.asList("sharks", "ducks")), + new Tuple2>("new york", Arrays.asList("rangers", "islanders")))); + + JavaDStream> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); + JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); + + JavaPairDStream> grouped = pairStream.groupByKey(); + JavaTestUtils.attachTestOutputStream(grouped); + List>>> result = JavaTestUtils.runStreams(ssc, 2, 2); + + Assert.assertEquals(expected, result); + } + + @Test + public void testPairReduceByKey() { + List>> inputData = stringIntKVStream; + + List>> expected = Arrays.asList( + Arrays.asList( + new Tuple2("california", 4), + new Tuple2("new york", 5)), + Arrays.asList( + new Tuple2("california", 10), + new Tuple2("new york", 4))); + + JavaDStream> stream = JavaTestUtils.attachTestInputStream( + ssc, inputData, 1); + JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); + + JavaPairDStream reduced = pairStream.reduceByKey(new IntegerSum()); + + JavaTestUtils.attachTestOutputStream(reduced); + List>> result = JavaTestUtils.runStreams(ssc, 2, 2); + + Assert.assertEquals(expected, result); + } + + @Test + public void testCombineByKey() { + List>> inputData = stringIntKVStream; + + List>> expected = Arrays.asList( + Arrays.asList( + new Tuple2("california", 4), + new Tuple2("new york", 5)), + Arrays.asList( + new Tuple2("california", 10), + new Tuple2("new york", 4))); + + JavaDStream> stream = JavaTestUtils.attachTestInputStream( + ssc, inputData, 1); + JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); + + JavaPairDStream combined = pairStream.combineByKey( + new Function() { + @Override + public Integer call(Integer i) throws Exception { + return i; + } + }, new IntegerSum(), new IntegerSum(), new HashPartitioner(2)); + + JavaTestUtils.attachTestOutputStream(combined); + List>> result = JavaTestUtils.runStreams(ssc, 2, 2); + + Assert.assertEquals(expected, result); + } + + @Test + public void testCountByValue() { + List> inputData = Arrays.asList( + Arrays.asList("hello", "world"), + Arrays.asList("hello", "moon"), + Arrays.asList("hello")); + + List>> expected = Arrays.asList( + Arrays.asList( + new Tuple2("hello", 1L), + new Tuple2("world", 1L)), + Arrays.asList( + new Tuple2("hello", 1L), + new Tuple2("moon", 1L)), + Arrays.asList( + new Tuple2("hello", 1L))); + + JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); + JavaPairDStream counted = stream.countByValue(); + JavaTestUtils.attachTestOutputStream(counted); + List>> result = JavaTestUtils.runStreams(ssc, 3, 3); + + Assert.assertEquals(expected, result); + } + + @Test + public void testGroupByKeyAndWindow() { + List>> inputData = stringIntKVStream; + + List>>> expected = Arrays.asList( + Arrays.asList( + new Tuple2>("california", Arrays.asList(1, 3)), + new Tuple2>("new york", Arrays.asList(1, 4)) + ), + Arrays.asList( + new Tuple2>("california", Arrays.asList(1, 3, 5, 5)), + new Tuple2>("new york", Arrays.asList(1, 1, 3, 4)) + ), + Arrays.asList( + new Tuple2>("california", Arrays.asList(5, 5)), + new Tuple2>("new york", Arrays.asList(1, 3)) + ) + ); + + JavaDStream> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); + JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); + + JavaPairDStream> groupWindowed = + pairStream.groupByKeyAndWindow(new Duration(2000), new Duration(1000)); + JavaTestUtils.attachTestOutputStream(groupWindowed); + List>>> result = JavaTestUtils.runStreams(ssc, 3, 3); + + assert(result.size() == expected.size()); + for (int i = 0; i < result.size(); i++) { + assert(convert(result.get(i)).equals(convert(expected.get(i)))); + } + } + + private HashSet>> convert(List>> listOfTuples) { + List>> newListOfTuples = new ArrayList>>(); + for (Tuple2> tuple: listOfTuples) { + newListOfTuples.add(convert(tuple)); + } + return new HashSet>>(newListOfTuples); + } + + private Tuple2> convert(Tuple2> tuple) { + return new Tuple2>(tuple._1(), new HashSet(tuple._2())); + } + + @Test + public void testReduceByKeyAndWindow() { + List>> inputData = stringIntKVStream; + + List>> expected = Arrays.asList( + Arrays.asList(new Tuple2("california", 4), + new Tuple2("new york", 5)), + Arrays.asList(new Tuple2("california", 14), + new Tuple2("new york", 9)), + Arrays.asList(new Tuple2("california", 10), + new Tuple2("new york", 4))); + + JavaDStream> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); + JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); + + JavaPairDStream reduceWindowed = + pairStream.reduceByKeyAndWindow(new IntegerSum(), new Duration(2000), new Duration(1000)); + JavaTestUtils.attachTestOutputStream(reduceWindowed); + List>> result = JavaTestUtils.runStreams(ssc, 3, 3); + + Assert.assertEquals(expected, result); + } + + @Test + public void testUpdateStateByKey() { + List>> inputData = stringIntKVStream; + + List>> expected = Arrays.asList( + Arrays.asList(new Tuple2("california", 4), + new Tuple2("new york", 5)), + Arrays.asList(new Tuple2("california", 14), + new Tuple2("new york", 9)), + Arrays.asList(new Tuple2("california", 14), + new Tuple2("new york", 9))); + + JavaDStream> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); + JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); + + JavaPairDStream updated = pairStream.updateStateByKey( + new Function2, Optional, Optional>() { + @Override + public Optional call(List values, Optional state) { + int out = 0; + if (state.isPresent()) { + out = out + state.get(); + } + for (Integer v: values) { + out = out + v; + } + return Optional.of(out); + } + }); + JavaTestUtils.attachTestOutputStream(updated); + List>> result = JavaTestUtils.runStreams(ssc, 3, 3); + + Assert.assertEquals(expected, result); + } + + @Test + public void testReduceByKeyAndWindowWithInverse() { + List>> inputData = stringIntKVStream; + + List>> expected = Arrays.asList( + Arrays.asList(new Tuple2("california", 4), + new Tuple2("new york", 5)), + Arrays.asList(new Tuple2("california", 14), + new Tuple2("new york", 9)), + Arrays.asList(new Tuple2("california", 10), + new Tuple2("new york", 4))); + + JavaDStream> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); + JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); + + JavaPairDStream reduceWindowed = + pairStream.reduceByKeyAndWindow(new IntegerSum(), new IntegerDifference(), new Duration(2000), new Duration(1000)); + JavaTestUtils.attachTestOutputStream(reduceWindowed); + List>> result = JavaTestUtils.runStreams(ssc, 3, 3); + + Assert.assertEquals(expected, result); + } + + @Test + public void testCountByValueAndWindow() { + List> inputData = Arrays.asList( + Arrays.asList("hello", "world"), + Arrays.asList("hello", "moon"), + Arrays.asList("hello")); + + List>> expected = Arrays.asList( + Arrays.asList( + new Tuple2("hello", 1L), + new Tuple2("world", 1L)), + Arrays.asList( + new Tuple2("hello", 2L), + new Tuple2("world", 1L), + new Tuple2("moon", 1L)), + Arrays.asList( + new Tuple2("hello", 2L), + new Tuple2("moon", 1L))); + + JavaDStream stream = JavaTestUtils.attachTestInputStream( + ssc, inputData, 1); + JavaPairDStream counted = + stream.countByValueAndWindow(new Duration(2000), new Duration(1000)); + JavaTestUtils.attachTestOutputStream(counted); + List>> result = JavaTestUtils.runStreams(ssc, 3, 3); + + Assert.assertEquals(expected, result); + } + + @Test + public void testPairTransform() { + List>> inputData = Arrays.asList( + Arrays.asList( + new Tuple2(3, 5), + new Tuple2(1, 5), + new Tuple2(4, 5), + new Tuple2(2, 5)), + Arrays.asList( + new Tuple2(2, 5), + new Tuple2(3, 5), + new Tuple2(4, 5), + new Tuple2(1, 5))); + + List>> expected = Arrays.asList( + Arrays.asList( + new Tuple2(1, 5), + new Tuple2(2, 5), + new Tuple2(3, 5), + new Tuple2(4, 5)), + Arrays.asList( + new Tuple2(1, 5), + new Tuple2(2, 5), + new Tuple2(3, 5), + new Tuple2(4, 5))); + + JavaDStream> stream = JavaTestUtils.attachTestInputStream( + ssc, inputData, 1); + JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); + + JavaPairDStream sorted = pairStream.transform( + new Function, JavaPairRDD>() { + @Override + public JavaPairRDD call(JavaPairRDD in) throws Exception { + return in.sortByKey(); + } + }); + + JavaTestUtils.attachTestOutputStream(sorted); + List>> result = JavaTestUtils.runStreams(ssc, 2, 2); + + Assert.assertEquals(expected, result); + } + + @Test + public void testPairToNormalRDDTransform() { + List>> inputData = Arrays.asList( + Arrays.asList( + new Tuple2(3, 5), + new Tuple2(1, 5), + new Tuple2(4, 5), + new Tuple2(2, 5)), + Arrays.asList( + new Tuple2(2, 5), + new Tuple2(3, 5), + new Tuple2(4, 5), + new Tuple2(1, 5))); + + List> expected = Arrays.asList( + Arrays.asList(3,1,4,2), + Arrays.asList(2,3,4,1)); + + JavaDStream> stream = JavaTestUtils.attachTestInputStream( + ssc, inputData, 1); + JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); + + JavaDStream firstParts = pairStream.transform( + new Function, JavaRDD>() { + @Override + public JavaRDD call(JavaPairRDD in) throws Exception { + return in.map(new Function, Integer>() { + @Override + public Integer call(Tuple2 in) { + return in._1(); + } + }); + } + }); + + JavaTestUtils.attachTestOutputStream(firstParts); + List> result = JavaTestUtils.runStreams(ssc, 2, 2); + + Assert.assertEquals(expected, result); + } + + public void testMapValues() { + List>> inputData = stringStringKVStream; + + List>> expected = Arrays.asList( + Arrays.asList(new Tuple2("california", "DODGERS"), + new Tuple2("california", "GIANTS"), + new Tuple2("new york", "YANKEES"), + new Tuple2("new york", "METS")), + Arrays.asList(new Tuple2("california", "SHARKS"), + new Tuple2("california", "DUCKS"), + new Tuple2("new york", "RANGERS"), + new Tuple2("new york", "ISLANDERS"))); + + JavaDStream> stream = JavaTestUtils.attachTestInputStream( + ssc, inputData, 1); + JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); + + JavaPairDStream mapped = pairStream.mapValues(new Function() { + @Override + public String call(String s) throws Exception { + return s.toUpperCase(); + } + }); + + JavaTestUtils.attachTestOutputStream(mapped); + List>> result = JavaTestUtils.runStreams(ssc, 2, 2); + + Assert.assertEquals(expected, result); + } + + @Test + public void testFlatMapValues() { + List>> inputData = stringStringKVStream; + + List>> expected = Arrays.asList( + Arrays.asList(new Tuple2("california", "dodgers1"), + new Tuple2("california", "dodgers2"), + new Tuple2("california", "giants1"), + new Tuple2("california", "giants2"), + new Tuple2("new york", "yankees1"), + new Tuple2("new york", "yankees2"), + new Tuple2("new york", "mets1"), + new Tuple2("new york", "mets2")), + Arrays.asList(new Tuple2("california", "sharks1"), + new Tuple2("california", "sharks2"), + new Tuple2("california", "ducks1"), + new Tuple2("california", "ducks2"), + new Tuple2("new york", "rangers1"), + new Tuple2("new york", "rangers2"), + new Tuple2("new york", "islanders1"), + new Tuple2("new york", "islanders2"))); + + JavaDStream> stream = JavaTestUtils.attachTestInputStream( + ssc, inputData, 1); + JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); + + + JavaPairDStream flatMapped = pairStream.flatMapValues( + new Function>() { + @Override + public Iterable call(String in) { + List out = new ArrayList(); + out.add(in + "1"); + out.add(in + "2"); + return out; + } + }); + + JavaTestUtils.attachTestOutputStream(flatMapped); + List>> result = JavaTestUtils.runStreams(ssc, 2, 2); + + Assert.assertEquals(expected, result); + } + + @Test + public void testCoGroup() { + List>> stringStringKVStream1 = Arrays.asList( + Arrays.asList(new Tuple2("california", "dodgers"), + new Tuple2("new york", "yankees")), + Arrays.asList(new Tuple2("california", "sharks"), + new Tuple2("new york", "rangers"))); + + List>> stringStringKVStream2 = Arrays.asList( + Arrays.asList(new Tuple2("california", "giants"), + new Tuple2("new york", "mets")), + Arrays.asList(new Tuple2("california", "ducks"), + new Tuple2("new york", "islanders"))); + + + List, List>>>> expected = Arrays.asList( + Arrays.asList( + new Tuple2, List>>("california", + new Tuple2, List>(Arrays.asList("dodgers"), Arrays.asList("giants"))), + new Tuple2, List>>("new york", + new Tuple2, List>(Arrays.asList("yankees"), Arrays.asList("mets")))), + Arrays.asList( + new Tuple2, List>>("california", + new Tuple2, List>(Arrays.asList("sharks"), Arrays.asList("ducks"))), + new Tuple2, List>>("new york", + new Tuple2, List>(Arrays.asList("rangers"), Arrays.asList("islanders"))))); + + + JavaDStream> stream1 = JavaTestUtils.attachTestInputStream( + ssc, stringStringKVStream1, 1); + JavaPairDStream pairStream1 = JavaPairDStream.fromJavaDStream(stream1); + + JavaDStream> stream2 = JavaTestUtils.attachTestInputStream( + ssc, stringStringKVStream2, 1); + JavaPairDStream pairStream2 = JavaPairDStream.fromJavaDStream(stream2); + + JavaPairDStream, List>> grouped = pairStream1.cogroup(pairStream2); + JavaTestUtils.attachTestOutputStream(grouped); + List>> result = JavaTestUtils.runStreams(ssc, 2, 2); + + Assert.assertEquals(expected, result); + } + + @Test + public void testJoin() { + List>> stringStringKVStream1 = Arrays.asList( + Arrays.asList(new Tuple2("california", "dodgers"), + new Tuple2("new york", "yankees")), + Arrays.asList(new Tuple2("california", "sharks"), + new Tuple2("new york", "rangers"))); + + List>> stringStringKVStream2 = Arrays.asList( + Arrays.asList(new Tuple2("california", "giants"), + new Tuple2("new york", "mets")), + Arrays.asList(new Tuple2("california", "ducks"), + new Tuple2("new york", "islanders"))); + + + List>>> expected = Arrays.asList( + Arrays.asList( + new Tuple2>("california", + new Tuple2("dodgers", "giants")), + new Tuple2>("new york", + new Tuple2("yankees", "mets"))), + Arrays.asList( + new Tuple2>("california", + new Tuple2("sharks", "ducks")), + new Tuple2>("new york", + new Tuple2("rangers", "islanders")))); + + + JavaDStream> stream1 = JavaTestUtils.attachTestInputStream( + ssc, stringStringKVStream1, 1); + JavaPairDStream pairStream1 = JavaPairDStream.fromJavaDStream(stream1); + + JavaDStream> stream2 = JavaTestUtils.attachTestInputStream( + ssc, stringStringKVStream2, 1); + JavaPairDStream pairStream2 = JavaPairDStream.fromJavaDStream(stream2); + + JavaPairDStream> joined = pairStream1.join(pairStream2); + JavaTestUtils.attachTestOutputStream(joined); + List>> result = JavaTestUtils.runStreams(ssc, 2, 2); + + Assert.assertEquals(expected, result); + } + + @Test + public void testCheckpointMasterRecovery() throws InterruptedException { + List> inputData = Arrays.asList( + Arrays.asList("this", "is"), + Arrays.asList("a", "test"), + Arrays.asList("counting", "letters")); + + List> expectedInitial = Arrays.asList( + Arrays.asList(4,2)); + List> expectedFinal = Arrays.asList( + Arrays.asList(1,4), + Arrays.asList(8,7)); + + File tempDir = Files.createTempDir(); + ssc.checkpoint(tempDir.getAbsolutePath()); + + JavaDStream stream = JavaCheckpointTestUtils.attachTestInputStream(ssc, inputData, 1); + JavaDStream letterCount = stream.map(new Function() { + @Override + public Integer call(String s) throws Exception { + return s.length(); + } + }); + JavaCheckpointTestUtils.attachTestOutputStream(letterCount); + List> initialResult = JavaTestUtils.runStreams(ssc, 1, 1); + + assertOrderInvariantEquals(expectedInitial, initialResult); + Thread.sleep(1000); + ssc.stop(); + + ssc = new JavaStreamingContext(tempDir.getAbsolutePath()); + // Tweak to take into consideration that the last batch before failure + // will be re-processed after recovery + List> finalResult = JavaCheckpointTestUtils.runStreams(ssc, 2, 3); + assertOrderInvariantEquals(expectedFinal, finalResult.subList(1, 3)); + } + + + /** TEST DISABLED: Pending a discussion about checkpoint() semantics with TD + @Test + public void testCheckpointofIndividualStream() throws InterruptedException { + List> inputData = Arrays.asList( + Arrays.asList("this", "is"), + Arrays.asList("a", "test"), + Arrays.asList("counting", "letters")); + + List> expected = Arrays.asList( + Arrays.asList(4,2), + Arrays.asList(1,4), + Arrays.asList(8,7)); + + JavaDStream stream = JavaCheckpointTestUtils.attachTestInputStream(ssc, inputData, 1); + JavaDStream letterCount = stream.map(new Function() { + @Override + public Integer call(String s) throws Exception { + return s.length(); + } + }); + JavaCheckpointTestUtils.attachTestOutputStream(letterCount); + + letterCount.checkpoint(new Duration(1000)); + + List> result1 = JavaCheckpointTestUtils.runStreams(ssc, 3, 3); + assertOrderInvariantEquals(expected, result1); + } + */ + + // Input stream tests. These mostly just test that we can instantiate a given InputStream with + // Java arguments and assign it to a JavaDStream without producing type errors. Testing of the + // InputStream functionality is deferred to the existing Scala tests. + @Test + public void testKafkaStream() { + HashMap topics = Maps.newHashMap(); + JavaDStream test1 = ssc.kafkaStream("localhost:12345", "group", topics); + JavaDStream test2 = ssc.kafkaStream("localhost:12345", "group", topics, + StorageLevel.MEMORY_AND_DISK()); + + HashMap kafkaParams = Maps.newHashMap(); + kafkaParams.put("zk.connect","localhost:12345"); + kafkaParams.put("groupid","consumer-group"); + JavaDStream test3 = ssc.kafkaStream(String.class, StringDecoder.class, kafkaParams, topics, + StorageLevel.MEMORY_AND_DISK()); + } + + @Test + public void testSocketTextStream() { + JavaDStream test = ssc.socketTextStream("localhost", 12345); + } + + @Test + public void testSocketString() { + class Converter extends Function> { + public Iterable call(InputStream in) { + BufferedReader reader = new BufferedReader(new InputStreamReader(in)); + List out = new ArrayList(); + try { + while (true) { + String line = reader.readLine(); + if (line == null) { break; } + out.add(line); + } + } catch (IOException e) { } + return out; + } + } + + JavaDStream test = ssc.socketStream( + "localhost", + 12345, + new Converter(), + StorageLevel.MEMORY_ONLY()); + } + + @Test + public void testTextFileStream() { + JavaDStream test = ssc.textFileStream("/tmp/foo"); + } + + @Test + public void testRawSocketStream() { + JavaDStream test = ssc.rawSocketStream("localhost", 12345); + } + + @Test + public void testFlumeStream() { + JavaDStream test = ssc.flumeStream("localhost", 12345, StorageLevel.MEMORY_ONLY()); + } + + @Test + public void testFileStream() { + JavaPairDStream foo = + ssc.fileStream("/tmp/foo"); + } + + @Test + public void testTwitterStream() { + String[] filters = new String[] { "good", "bad", "ugly" }; + JavaDStream test = ssc.twitterStream(filters, StorageLevel.MEMORY_ONLY()); + } + + @Test + public void testActorStream() { + JavaDStream test = ssc.actorStream((Props)null, "TestActor", StorageLevel.MEMORY_ONLY()); + } + + @Test + public void testZeroMQStream() { + JavaDStream test = ssc.zeroMQStream("url", (Subscribe) null, new Function>() { + @Override + public Iterable call(byte[][] b) throws Exception { + return null; + } + }); + } +} diff --git a/streaming/src/test/java/org/apache/spark/streaming/JavaTestUtils.scala b/streaming/src/test/java/org/apache/spark/streaming/JavaTestUtils.scala new file mode 100644 index 0000000000..8a6604904d --- /dev/null +++ b/streaming/src/test/java/org/apache/spark/streaming/JavaTestUtils.scala @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming + +import collection.mutable.{SynchronizedBuffer, ArrayBuffer} +import java.util.{List => JList} +import org.apache.spark.streaming.api.java.{JavaPairDStream, JavaDStreamLike, JavaDStream, JavaStreamingContext} +import org.apache.spark.streaming._ +import java.util.ArrayList +import collection.JavaConversions._ +import org.apache.spark.api.java.JavaRDDLike + +/** Exposes streaming test functionality in a Java-friendly way. */ +trait JavaTestBase extends TestSuiteBase { + + /** + * Create a [[org.apache.spark.streaming.TestInputStream]] and attach it to the supplied context. + * The stream will be derived from the supplied lists of Java objects. + **/ + def attachTestInputStream[T]( + ssc: JavaStreamingContext, + data: JList[JList[T]], + numPartitions: Int) = { + val seqData = data.map(Seq(_:_*)) + + implicit val cm: ClassManifest[T] = + implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]] + val dstream = new TestInputStream[T](ssc.ssc, seqData, numPartitions) + ssc.ssc.registerInputStream(dstream) + new JavaDStream[T](dstream) + } + + /** + * Attach a provided stream to it's associated StreamingContext as a + * [[org.apache.spark.streaming.TestOutputStream]]. + **/ + def attachTestOutputStream[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T, R]]( + dstream: JavaDStreamLike[T, This, R]) = + { + implicit val cm: ClassManifest[T] = + implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]] + val ostream = new TestOutputStream(dstream.dstream, + new ArrayBuffer[Seq[T]] with SynchronizedBuffer[Seq[T]]) + dstream.dstream.ssc.registerOutputStream(ostream) + } + + /** + * Process all registered streams for a numBatches batches, failing if + * numExpectedOutput RDD's are not generated. Generated RDD's are collected + * and returned, represented as a list for each batch interval. + */ + def runStreams[V]( + ssc: JavaStreamingContext, numBatches: Int, numExpectedOutput: Int): JList[JList[V]] = { + implicit val cm: ClassManifest[V] = + implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[V]] + val res = runStreams[V](ssc.ssc, numBatches, numExpectedOutput) + val out = new ArrayList[JList[V]]() + res.map(entry => out.append(new ArrayList[V](entry))) + out + } +} + +object JavaTestUtils extends JavaTestBase { + override def maxWaitTimeMillis = 20000 + +} + +object JavaCheckpointTestUtils extends JavaTestBase { + override def actuallyWait = true +} diff --git a/streaming/src/test/java/spark/streaming/JavaAPISuite.java b/streaming/src/test/java/spark/streaming/JavaAPISuite.java deleted file mode 100644 index 3b93790baa..0000000000 --- a/streaming/src/test/java/spark/streaming/JavaAPISuite.java +++ /dev/null @@ -1,1304 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.streaming; - -import com.google.common.base.Optional; -import com.google.common.collect.Lists; -import com.google.common.collect.Maps; -import com.google.common.io.Files; -import kafka.serializer.StringDecoder; -import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat; -import org.junit.After; -import org.junit.Assert; -import org.junit.Before; -import org.junit.Test; -import scala.Tuple2; -import spark.HashPartitioner; -import spark.api.java.JavaPairRDD; -import spark.api.java.JavaRDD; -import spark.api.java.JavaRDDLike; -import spark.api.java.JavaPairRDD; -import spark.api.java.JavaSparkContext; -import spark.api.java.function.*; -import spark.storage.StorageLevel; -import spark.streaming.api.java.JavaDStream; -import spark.streaming.api.java.JavaPairDStream; -import spark.streaming.api.java.JavaStreamingContext; -import spark.streaming.JavaTestUtils; -import spark.streaming.JavaCheckpointTestUtils; -import spark.streaming.InputStreamsSuite; - -import java.io.*; -import java.util.*; - -import akka.actor.Props; -import akka.zeromq.Subscribe; - - - -// The test suite itself is Serializable so that anonymous Function implementations can be -// serialized, as an alternative to converting these anonymous classes to static inner classes; -// see http://stackoverflow.com/questions/758570/. -public class JavaAPISuite implements Serializable { - private transient JavaStreamingContext ssc; - - @Before - public void setUp() { - System.setProperty("spark.streaming.clock", "spark.streaming.util.ManualClock"); - ssc = new JavaStreamingContext("local[2]", "test", new Duration(1000)); - ssc.checkpoint("checkpoint"); - } - - @After - public void tearDown() { - ssc.stop(); - ssc = null; - - // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown - System.clearProperty("spark.driver.port"); - } - - @Test - public void testCount() { - List> inputData = Arrays.asList( - Arrays.asList(1,2,3,4), - Arrays.asList(3,4,5), - Arrays.asList(3)); - - List> expected = Arrays.asList( - Arrays.asList(4L), - Arrays.asList(3L), - Arrays.asList(1L)); - - JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); - JavaDStream count = stream.count(); - JavaTestUtils.attachTestOutputStream(count); - List> result = JavaTestUtils.runStreams(ssc, 3, 3); - assertOrderInvariantEquals(expected, result); - } - - @Test - public void testMap() { - List> inputData = Arrays.asList( - Arrays.asList("hello", "world"), - Arrays.asList("goodnight", "moon")); - - List> expected = Arrays.asList( - Arrays.asList(5,5), - Arrays.asList(9,4)); - - JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); - JavaDStream letterCount = stream.map(new Function() { - @Override - public Integer call(String s) throws Exception { - return s.length(); - } - }); - JavaTestUtils.attachTestOutputStream(letterCount); - List> result = JavaTestUtils.runStreams(ssc, 2, 2); - - assertOrderInvariantEquals(expected, result); - } - - @Test - public void testWindow() { - List> inputData = Arrays.asList( - Arrays.asList(1,2,3), - Arrays.asList(4,5,6), - Arrays.asList(7,8,9)); - - List> expected = Arrays.asList( - Arrays.asList(1,2,3), - Arrays.asList(4,5,6,1,2,3), - Arrays.asList(7,8,9,4,5,6), - Arrays.asList(7,8,9)); - - JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); - JavaDStream windowed = stream.window(new Duration(2000)); - JavaTestUtils.attachTestOutputStream(windowed); - List> result = JavaTestUtils.runStreams(ssc, 4, 4); - - assertOrderInvariantEquals(expected, result); - } - - @Test - public void testWindowWithSlideDuration() { - List> inputData = Arrays.asList( - Arrays.asList(1,2,3), - Arrays.asList(4,5,6), - Arrays.asList(7,8,9), - Arrays.asList(10,11,12), - Arrays.asList(13,14,15), - Arrays.asList(16,17,18)); - - List> expected = Arrays.asList( - Arrays.asList(1,2,3,4,5,6), - Arrays.asList(1,2,3,4,5,6,7,8,9,10,11,12), - Arrays.asList(7,8,9,10,11,12,13,14,15,16,17,18), - Arrays.asList(13,14,15,16,17,18)); - - JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); - JavaDStream windowed = stream.window(new Duration(4000), new Duration(2000)); - JavaTestUtils.attachTestOutputStream(windowed); - List> result = JavaTestUtils.runStreams(ssc, 8, 4); - - assertOrderInvariantEquals(expected, result); - } - - @Test - public void testFilter() { - List> inputData = Arrays.asList( - Arrays.asList("giants", "dodgers"), - Arrays.asList("yankees", "red socks")); - - List> expected = Arrays.asList( - Arrays.asList("giants"), - Arrays.asList("yankees")); - - JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); - JavaDStream filtered = stream.filter(new Function() { - @Override - public Boolean call(String s) throws Exception { - return s.contains("a"); - } - }); - JavaTestUtils.attachTestOutputStream(filtered); - List> result = JavaTestUtils.runStreams(ssc, 2, 2); - - assertOrderInvariantEquals(expected, result); - } - - @Test - public void testGlom() { - List> inputData = Arrays.asList( - Arrays.asList("giants", "dodgers"), - Arrays.asList("yankees", "red socks")); - - List>> expected = Arrays.asList( - Arrays.asList(Arrays.asList("giants", "dodgers")), - Arrays.asList(Arrays.asList("yankees", "red socks"))); - - JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); - JavaDStream glommed = stream.glom(); - JavaTestUtils.attachTestOutputStream(glommed); - List>> result = JavaTestUtils.runStreams(ssc, 2, 2); - - Assert.assertEquals(expected, result); - } - - @Test - public void testMapPartitions() { - List> inputData = Arrays.asList( - Arrays.asList("giants", "dodgers"), - Arrays.asList("yankees", "red socks")); - - List> expected = Arrays.asList( - Arrays.asList("GIANTSDODGERS"), - Arrays.asList("YANKEESRED SOCKS")); - - JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); - JavaDStream mapped = stream.mapPartitions(new FlatMapFunction, String>() { - @Override - public Iterable call(Iterator in) { - String out = ""; - while (in.hasNext()) { - out = out + in.next().toUpperCase(); - } - return Lists.newArrayList(out); - } - }); - JavaTestUtils.attachTestOutputStream(mapped); - List>> result = JavaTestUtils.runStreams(ssc, 2, 2); - - Assert.assertEquals(expected, result); - } - - private class IntegerSum extends Function2 { - @Override - public Integer call(Integer i1, Integer i2) throws Exception { - return i1 + i2; - } - } - - private class IntegerDifference extends Function2 { - @Override - public Integer call(Integer i1, Integer i2) throws Exception { - return i1 - i2; - } - } - - @Test - public void testReduce() { - List> inputData = Arrays.asList( - Arrays.asList(1,2,3), - Arrays.asList(4,5,6), - Arrays.asList(7,8,9)); - - List> expected = Arrays.asList( - Arrays.asList(6), - Arrays.asList(15), - Arrays.asList(24)); - - JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); - JavaDStream reduced = stream.reduce(new IntegerSum()); - JavaTestUtils.attachTestOutputStream(reduced); - List> result = JavaTestUtils.runStreams(ssc, 3, 3); - - Assert.assertEquals(expected, result); - } - - @Test - public void testReduceByWindow() { - List> inputData = Arrays.asList( - Arrays.asList(1,2,3), - Arrays.asList(4,5,6), - Arrays.asList(7,8,9)); - - List> expected = Arrays.asList( - Arrays.asList(6), - Arrays.asList(21), - Arrays.asList(39), - Arrays.asList(24)); - - JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); - JavaDStream reducedWindowed = stream.reduceByWindow(new IntegerSum(), - new IntegerDifference(), new Duration(2000), new Duration(1000)); - JavaTestUtils.attachTestOutputStream(reducedWindowed); - List> result = JavaTestUtils.runStreams(ssc, 4, 4); - - Assert.assertEquals(expected, result); - } - - @Test - public void testQueueStream() { - List> expected = Arrays.asList( - Arrays.asList(1,2,3), - Arrays.asList(4,5,6), - Arrays.asList(7,8,9)); - - JavaSparkContext jsc = new JavaSparkContext(ssc.ssc().sc()); - JavaRDD rdd1 = ssc.sc().parallelize(Arrays.asList(1,2,3)); - JavaRDD rdd2 = ssc.sc().parallelize(Arrays.asList(4,5,6)); - JavaRDD rdd3 = ssc.sc().parallelize(Arrays.asList(7,8,9)); - - LinkedList> rdds = Lists.newLinkedList(); - rdds.add(rdd1); - rdds.add(rdd2); - rdds.add(rdd3); - - JavaDStream stream = ssc.queueStream(rdds); - JavaTestUtils.attachTestOutputStream(stream); - List> result = JavaTestUtils.runStreams(ssc, 3, 3); - Assert.assertEquals(expected, result); - } - - @Test - public void testTransform() { - List> inputData = Arrays.asList( - Arrays.asList(1,2,3), - Arrays.asList(4,5,6), - Arrays.asList(7,8,9)); - - List> expected = Arrays.asList( - Arrays.asList(3,4,5), - Arrays.asList(6,7,8), - Arrays.asList(9,10,11)); - - JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); - JavaDStream transformed = - stream.transform(new Function, JavaRDD>() { - @Override - public JavaRDD call(JavaRDD in) throws Exception { - return in.map(new Function() { - @Override - public Integer call(Integer i) throws Exception { - return i + 2; - } - }); - }}); - JavaTestUtils.attachTestOutputStream(transformed); - List> result = JavaTestUtils.runStreams(ssc, 3, 3); - - assertOrderInvariantEquals(expected, result); - } - - @Test - public void testFlatMap() { - List> inputData = Arrays.asList( - Arrays.asList("go", "giants"), - Arrays.asList("boo", "dodgers"), - Arrays.asList("athletics")); - - List> expected = Arrays.asList( - Arrays.asList("g","o","g","i","a","n","t","s"), - Arrays.asList("b", "o", "o", "d","o","d","g","e","r","s"), - Arrays.asList("a","t","h","l","e","t","i","c","s")); - - JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); - JavaDStream flatMapped = stream.flatMap(new FlatMapFunction() { - @Override - public Iterable call(String x) { - return Lists.newArrayList(x.split("(?!^)")); - } - }); - JavaTestUtils.attachTestOutputStream(flatMapped); - List> result = JavaTestUtils.runStreams(ssc, 3, 3); - - assertOrderInvariantEquals(expected, result); - } - - @Test - public void testPairFlatMap() { - List> inputData = Arrays.asList( - Arrays.asList("giants"), - Arrays.asList("dodgers"), - Arrays.asList("athletics")); - - List>> expected = Arrays.asList( - Arrays.asList( - new Tuple2(6, "g"), - new Tuple2(6, "i"), - new Tuple2(6, "a"), - new Tuple2(6, "n"), - new Tuple2(6, "t"), - new Tuple2(6, "s")), - Arrays.asList( - new Tuple2(7, "d"), - new Tuple2(7, "o"), - new Tuple2(7, "d"), - new Tuple2(7, "g"), - new Tuple2(7, "e"), - new Tuple2(7, "r"), - new Tuple2(7, "s")), - Arrays.asList( - new Tuple2(9, "a"), - new Tuple2(9, "t"), - new Tuple2(9, "h"), - new Tuple2(9, "l"), - new Tuple2(9, "e"), - new Tuple2(9, "t"), - new Tuple2(9, "i"), - new Tuple2(9, "c"), - new Tuple2(9, "s"))); - - JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); - JavaPairDStream flatMapped = stream.flatMap(new PairFlatMapFunction() { - @Override - public Iterable> call(String in) throws Exception { - List> out = Lists.newArrayList(); - for (String letter: in.split("(?!^)")) { - out.add(new Tuple2(in.length(), letter)); - } - return out; - } - }); - JavaTestUtils.attachTestOutputStream(flatMapped); - List>> result = JavaTestUtils.runStreams(ssc, 3, 3); - - Assert.assertEquals(expected, result); - } - - @Test - public void testUnion() { - List> inputData1 = Arrays.asList( - Arrays.asList(1,1), - Arrays.asList(2,2), - Arrays.asList(3,3)); - - List> inputData2 = Arrays.asList( - Arrays.asList(4,4), - Arrays.asList(5,5), - Arrays.asList(6,6)); - - List> expected = Arrays.asList( - Arrays.asList(1,1,4,4), - Arrays.asList(2,2,5,5), - Arrays.asList(3,3,6,6)); - - JavaDStream stream1 = JavaTestUtils.attachTestInputStream(ssc, inputData1, 2); - JavaDStream stream2 = JavaTestUtils.attachTestInputStream(ssc, inputData2, 2); - - JavaDStream unioned = stream1.union(stream2); - JavaTestUtils.attachTestOutputStream(unioned); - List> result = JavaTestUtils.runStreams(ssc, 3, 3); - - assertOrderInvariantEquals(expected, result); - } - - /* - * Performs an order-invariant comparison of lists representing two RDD streams. This allows - * us to account for ordering variation within individual RDD's which occurs during windowing. - */ - public static void assertOrderInvariantEquals( - List> expected, List> actual) { - for (List list: expected) { - Collections.sort(list); - } - for (List list: actual) { - Collections.sort(list); - } - Assert.assertEquals(expected, actual); - } - - - // PairDStream Functions - @Test - public void testPairFilter() { - List> inputData = Arrays.asList( - Arrays.asList("giants", "dodgers"), - Arrays.asList("yankees", "red socks")); - - List>> expected = Arrays.asList( - Arrays.asList(new Tuple2("giants", 6)), - Arrays.asList(new Tuple2("yankees", 7))); - - JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); - JavaPairDStream pairStream = stream.map( - new PairFunction() { - @Override - public Tuple2 call(String in) throws Exception { - return new Tuple2(in, in.length()); - } - }); - - JavaPairDStream filtered = pairStream.filter( - new Function, Boolean>() { - @Override - public Boolean call(Tuple2 in) throws Exception { - return in._1().contains("a"); - } - }); - JavaTestUtils.attachTestOutputStream(filtered); - List>> result = JavaTestUtils.runStreams(ssc, 2, 2); - - Assert.assertEquals(expected, result); - } - - List>> stringStringKVStream = Arrays.asList( - Arrays.asList(new Tuple2("california", "dodgers"), - new Tuple2("california", "giants"), - new Tuple2("new york", "yankees"), - new Tuple2("new york", "mets")), - Arrays.asList(new Tuple2("california", "sharks"), - new Tuple2("california", "ducks"), - new Tuple2("new york", "rangers"), - new Tuple2("new york", "islanders"))); - - List>> stringIntKVStream = Arrays.asList( - Arrays.asList( - new Tuple2("california", 1), - new Tuple2("california", 3), - new Tuple2("new york", 4), - new Tuple2("new york", 1)), - Arrays.asList( - new Tuple2("california", 5), - new Tuple2("california", 5), - new Tuple2("new york", 3), - new Tuple2("new york", 1))); - - @Test - public void testPairMap() { // Maps pair -> pair of different type - List>> inputData = stringIntKVStream; - - List>> expected = Arrays.asList( - Arrays.asList( - new Tuple2(1, "california"), - new Tuple2(3, "california"), - new Tuple2(4, "new york"), - new Tuple2(1, "new york")), - Arrays.asList( - new Tuple2(5, "california"), - new Tuple2(5, "california"), - new Tuple2(3, "new york"), - new Tuple2(1, "new york"))); - - JavaDStream> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); - JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); - JavaPairDStream reversed = pairStream.map( - new PairFunction, Integer, String>() { - @Override - public Tuple2 call(Tuple2 in) throws Exception { - return in.swap(); - } - }); - - JavaTestUtils.attachTestOutputStream(reversed); - List>> result = JavaTestUtils.runStreams(ssc, 2, 2); - - Assert.assertEquals(expected, result); - } - - @Test - public void testPairMapPartitions() { // Maps pair -> pair of different type - List>> inputData = stringIntKVStream; - - List>> expected = Arrays.asList( - Arrays.asList( - new Tuple2(1, "california"), - new Tuple2(3, "california"), - new Tuple2(4, "new york"), - new Tuple2(1, "new york")), - Arrays.asList( - new Tuple2(5, "california"), - new Tuple2(5, "california"), - new Tuple2(3, "new york"), - new Tuple2(1, "new york"))); - - JavaDStream> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); - JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); - JavaPairDStream reversed = pairStream.mapPartitions( - new PairFlatMapFunction>, Integer, String>() { - @Override - public Iterable> call(Iterator> in) throws Exception { - LinkedList> out = new LinkedList>(); - while (in.hasNext()) { - Tuple2 next = in.next(); - out.add(next.swap()); - } - return out; - } - }); - - JavaTestUtils.attachTestOutputStream(reversed); - List>> result = JavaTestUtils.runStreams(ssc, 2, 2); - - Assert.assertEquals(expected, result); - } - - @Test - public void testPairMap2() { // Maps pair -> single - List>> inputData = stringIntKVStream; - - List> expected = Arrays.asList( - Arrays.asList(1, 3, 4, 1), - Arrays.asList(5, 5, 3, 1)); - - JavaDStream> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); - JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); - JavaDStream reversed = pairStream.map( - new Function, Integer>() { - @Override - public Integer call(Tuple2 in) throws Exception { - return in._2(); - } - }); - - JavaTestUtils.attachTestOutputStream(reversed); - List>> result = JavaTestUtils.runStreams(ssc, 2, 2); - - Assert.assertEquals(expected, result); - } - - @Test - public void testPairToPairFlatMapWithChangingTypes() { // Maps pair -> pair - List>> inputData = Arrays.asList( - Arrays.asList( - new Tuple2("hi", 1), - new Tuple2("ho", 2)), - Arrays.asList( - new Tuple2("hi", 1), - new Tuple2("ho", 2))); - - List>> expected = Arrays.asList( - Arrays.asList( - new Tuple2(1, "h"), - new Tuple2(1, "i"), - new Tuple2(2, "h"), - new Tuple2(2, "o")), - Arrays.asList( - new Tuple2(1, "h"), - new Tuple2(1, "i"), - new Tuple2(2, "h"), - new Tuple2(2, "o"))); - - JavaDStream> stream = - JavaTestUtils.attachTestInputStream(ssc, inputData, 1); - JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); - JavaPairDStream flatMapped = pairStream.flatMap( - new PairFlatMapFunction, Integer, String>() { - @Override - public Iterable> call(Tuple2 in) throws Exception { - List> out = new LinkedList>(); - for (Character s : in._1().toCharArray()) { - out.add(new Tuple2(in._2(), s.toString())); - } - return out; - } - }); - JavaTestUtils.attachTestOutputStream(flatMapped); - List>> result = JavaTestUtils.runStreams(ssc, 2, 2); - - Assert.assertEquals(expected, result); - } - - @Test - public void testPairGroupByKey() { - List>> inputData = stringStringKVStream; - - List>>> expected = Arrays.asList( - Arrays.asList( - new Tuple2>("california", Arrays.asList("dodgers", "giants")), - new Tuple2>("new york", Arrays.asList("yankees", "mets"))), - Arrays.asList( - new Tuple2>("california", Arrays.asList("sharks", "ducks")), - new Tuple2>("new york", Arrays.asList("rangers", "islanders")))); - - JavaDStream> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); - JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); - - JavaPairDStream> grouped = pairStream.groupByKey(); - JavaTestUtils.attachTestOutputStream(grouped); - List>>> result = JavaTestUtils.runStreams(ssc, 2, 2); - - Assert.assertEquals(expected, result); - } - - @Test - public void testPairReduceByKey() { - List>> inputData = stringIntKVStream; - - List>> expected = Arrays.asList( - Arrays.asList( - new Tuple2("california", 4), - new Tuple2("new york", 5)), - Arrays.asList( - new Tuple2("california", 10), - new Tuple2("new york", 4))); - - JavaDStream> stream = JavaTestUtils.attachTestInputStream( - ssc, inputData, 1); - JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); - - JavaPairDStream reduced = pairStream.reduceByKey(new IntegerSum()); - - JavaTestUtils.attachTestOutputStream(reduced); - List>> result = JavaTestUtils.runStreams(ssc, 2, 2); - - Assert.assertEquals(expected, result); - } - - @Test - public void testCombineByKey() { - List>> inputData = stringIntKVStream; - - List>> expected = Arrays.asList( - Arrays.asList( - new Tuple2("california", 4), - new Tuple2("new york", 5)), - Arrays.asList( - new Tuple2("california", 10), - new Tuple2("new york", 4))); - - JavaDStream> stream = JavaTestUtils.attachTestInputStream( - ssc, inputData, 1); - JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); - - JavaPairDStream combined = pairStream.combineByKey( - new Function() { - @Override - public Integer call(Integer i) throws Exception { - return i; - } - }, new IntegerSum(), new IntegerSum(), new HashPartitioner(2)); - - JavaTestUtils.attachTestOutputStream(combined); - List>> result = JavaTestUtils.runStreams(ssc, 2, 2); - - Assert.assertEquals(expected, result); - } - - @Test - public void testCountByValue() { - List> inputData = Arrays.asList( - Arrays.asList("hello", "world"), - Arrays.asList("hello", "moon"), - Arrays.asList("hello")); - - List>> expected = Arrays.asList( - Arrays.asList( - new Tuple2("hello", 1L), - new Tuple2("world", 1L)), - Arrays.asList( - new Tuple2("hello", 1L), - new Tuple2("moon", 1L)), - Arrays.asList( - new Tuple2("hello", 1L))); - - JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); - JavaPairDStream counted = stream.countByValue(); - JavaTestUtils.attachTestOutputStream(counted); - List>> result = JavaTestUtils.runStreams(ssc, 3, 3); - - Assert.assertEquals(expected, result); - } - - @Test - public void testGroupByKeyAndWindow() { - List>> inputData = stringIntKVStream; - - List>>> expected = Arrays.asList( - Arrays.asList( - new Tuple2>("california", Arrays.asList(1, 3)), - new Tuple2>("new york", Arrays.asList(1, 4)) - ), - Arrays.asList( - new Tuple2>("california", Arrays.asList(1, 3, 5, 5)), - new Tuple2>("new york", Arrays.asList(1, 1, 3, 4)) - ), - Arrays.asList( - new Tuple2>("california", Arrays.asList(5, 5)), - new Tuple2>("new york", Arrays.asList(1, 3)) - ) - ); - - JavaDStream> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); - JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); - - JavaPairDStream> groupWindowed = - pairStream.groupByKeyAndWindow(new Duration(2000), new Duration(1000)); - JavaTestUtils.attachTestOutputStream(groupWindowed); - List>>> result = JavaTestUtils.runStreams(ssc, 3, 3); - - assert(result.size() == expected.size()); - for (int i = 0; i < result.size(); i++) { - assert(convert(result.get(i)).equals(convert(expected.get(i)))); - } - } - - private HashSet>> convert(List>> listOfTuples) { - List>> newListOfTuples = new ArrayList>>(); - for (Tuple2> tuple: listOfTuples) { - newListOfTuples.add(convert(tuple)); - } - return new HashSet>>(newListOfTuples); - } - - private Tuple2> convert(Tuple2> tuple) { - return new Tuple2>(tuple._1(), new HashSet(tuple._2())); - } - - @Test - public void testReduceByKeyAndWindow() { - List>> inputData = stringIntKVStream; - - List>> expected = Arrays.asList( - Arrays.asList(new Tuple2("california", 4), - new Tuple2("new york", 5)), - Arrays.asList(new Tuple2("california", 14), - new Tuple2("new york", 9)), - Arrays.asList(new Tuple2("california", 10), - new Tuple2("new york", 4))); - - JavaDStream> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); - JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); - - JavaPairDStream reduceWindowed = - pairStream.reduceByKeyAndWindow(new IntegerSum(), new Duration(2000), new Duration(1000)); - JavaTestUtils.attachTestOutputStream(reduceWindowed); - List>> result = JavaTestUtils.runStreams(ssc, 3, 3); - - Assert.assertEquals(expected, result); - } - - @Test - public void testUpdateStateByKey() { - List>> inputData = stringIntKVStream; - - List>> expected = Arrays.asList( - Arrays.asList(new Tuple2("california", 4), - new Tuple2("new york", 5)), - Arrays.asList(new Tuple2("california", 14), - new Tuple2("new york", 9)), - Arrays.asList(new Tuple2("california", 14), - new Tuple2("new york", 9))); - - JavaDStream> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); - JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); - - JavaPairDStream updated = pairStream.updateStateByKey( - new Function2, Optional, Optional>() { - @Override - public Optional call(List values, Optional state) { - int out = 0; - if (state.isPresent()) { - out = out + state.get(); - } - for (Integer v: values) { - out = out + v; - } - return Optional.of(out); - } - }); - JavaTestUtils.attachTestOutputStream(updated); - List>> result = JavaTestUtils.runStreams(ssc, 3, 3); - - Assert.assertEquals(expected, result); - } - - @Test - public void testReduceByKeyAndWindowWithInverse() { - List>> inputData = stringIntKVStream; - - List>> expected = Arrays.asList( - Arrays.asList(new Tuple2("california", 4), - new Tuple2("new york", 5)), - Arrays.asList(new Tuple2("california", 14), - new Tuple2("new york", 9)), - Arrays.asList(new Tuple2("california", 10), - new Tuple2("new york", 4))); - - JavaDStream> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); - JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); - - JavaPairDStream reduceWindowed = - pairStream.reduceByKeyAndWindow(new IntegerSum(), new IntegerDifference(), new Duration(2000), new Duration(1000)); - JavaTestUtils.attachTestOutputStream(reduceWindowed); - List>> result = JavaTestUtils.runStreams(ssc, 3, 3); - - Assert.assertEquals(expected, result); - } - - @Test - public void testCountByValueAndWindow() { - List> inputData = Arrays.asList( - Arrays.asList("hello", "world"), - Arrays.asList("hello", "moon"), - Arrays.asList("hello")); - - List>> expected = Arrays.asList( - Arrays.asList( - new Tuple2("hello", 1L), - new Tuple2("world", 1L)), - Arrays.asList( - new Tuple2("hello", 2L), - new Tuple2("world", 1L), - new Tuple2("moon", 1L)), - Arrays.asList( - new Tuple2("hello", 2L), - new Tuple2("moon", 1L))); - - JavaDStream stream = JavaTestUtils.attachTestInputStream( - ssc, inputData, 1); - JavaPairDStream counted = - stream.countByValueAndWindow(new Duration(2000), new Duration(1000)); - JavaTestUtils.attachTestOutputStream(counted); - List>> result = JavaTestUtils.runStreams(ssc, 3, 3); - - Assert.assertEquals(expected, result); - } - - @Test - public void testPairTransform() { - List>> inputData = Arrays.asList( - Arrays.asList( - new Tuple2(3, 5), - new Tuple2(1, 5), - new Tuple2(4, 5), - new Tuple2(2, 5)), - Arrays.asList( - new Tuple2(2, 5), - new Tuple2(3, 5), - new Tuple2(4, 5), - new Tuple2(1, 5))); - - List>> expected = Arrays.asList( - Arrays.asList( - new Tuple2(1, 5), - new Tuple2(2, 5), - new Tuple2(3, 5), - new Tuple2(4, 5)), - Arrays.asList( - new Tuple2(1, 5), - new Tuple2(2, 5), - new Tuple2(3, 5), - new Tuple2(4, 5))); - - JavaDStream> stream = JavaTestUtils.attachTestInputStream( - ssc, inputData, 1); - JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); - - JavaPairDStream sorted = pairStream.transform( - new Function, JavaPairRDD>() { - @Override - public JavaPairRDD call(JavaPairRDD in) throws Exception { - return in.sortByKey(); - } - }); - - JavaTestUtils.attachTestOutputStream(sorted); - List>> result = JavaTestUtils.runStreams(ssc, 2, 2); - - Assert.assertEquals(expected, result); - } - - @Test - public void testPairToNormalRDDTransform() { - List>> inputData = Arrays.asList( - Arrays.asList( - new Tuple2(3, 5), - new Tuple2(1, 5), - new Tuple2(4, 5), - new Tuple2(2, 5)), - Arrays.asList( - new Tuple2(2, 5), - new Tuple2(3, 5), - new Tuple2(4, 5), - new Tuple2(1, 5))); - - List> expected = Arrays.asList( - Arrays.asList(3,1,4,2), - Arrays.asList(2,3,4,1)); - - JavaDStream> stream = JavaTestUtils.attachTestInputStream( - ssc, inputData, 1); - JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); - - JavaDStream firstParts = pairStream.transform( - new Function, JavaRDD>() { - @Override - public JavaRDD call(JavaPairRDD in) throws Exception { - return in.map(new Function, Integer>() { - @Override - public Integer call(Tuple2 in) { - return in._1(); - } - }); - } - }); - - JavaTestUtils.attachTestOutputStream(firstParts); - List> result = JavaTestUtils.runStreams(ssc, 2, 2); - - Assert.assertEquals(expected, result); - } - - public void testMapValues() { - List>> inputData = stringStringKVStream; - - List>> expected = Arrays.asList( - Arrays.asList(new Tuple2("california", "DODGERS"), - new Tuple2("california", "GIANTS"), - new Tuple2("new york", "YANKEES"), - new Tuple2("new york", "METS")), - Arrays.asList(new Tuple2("california", "SHARKS"), - new Tuple2("california", "DUCKS"), - new Tuple2("new york", "RANGERS"), - new Tuple2("new york", "ISLANDERS"))); - - JavaDStream> stream = JavaTestUtils.attachTestInputStream( - ssc, inputData, 1); - JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); - - JavaPairDStream mapped = pairStream.mapValues(new Function() { - @Override - public String call(String s) throws Exception { - return s.toUpperCase(); - } - }); - - JavaTestUtils.attachTestOutputStream(mapped); - List>> result = JavaTestUtils.runStreams(ssc, 2, 2); - - Assert.assertEquals(expected, result); - } - - @Test - public void testFlatMapValues() { - List>> inputData = stringStringKVStream; - - List>> expected = Arrays.asList( - Arrays.asList(new Tuple2("california", "dodgers1"), - new Tuple2("california", "dodgers2"), - new Tuple2("california", "giants1"), - new Tuple2("california", "giants2"), - new Tuple2("new york", "yankees1"), - new Tuple2("new york", "yankees2"), - new Tuple2("new york", "mets1"), - new Tuple2("new york", "mets2")), - Arrays.asList(new Tuple2("california", "sharks1"), - new Tuple2("california", "sharks2"), - new Tuple2("california", "ducks1"), - new Tuple2("california", "ducks2"), - new Tuple2("new york", "rangers1"), - new Tuple2("new york", "rangers2"), - new Tuple2("new york", "islanders1"), - new Tuple2("new york", "islanders2"))); - - JavaDStream> stream = JavaTestUtils.attachTestInputStream( - ssc, inputData, 1); - JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); - - - JavaPairDStream flatMapped = pairStream.flatMapValues( - new Function>() { - @Override - public Iterable call(String in) { - List out = new ArrayList(); - out.add(in + "1"); - out.add(in + "2"); - return out; - } - }); - - JavaTestUtils.attachTestOutputStream(flatMapped); - List>> result = JavaTestUtils.runStreams(ssc, 2, 2); - - Assert.assertEquals(expected, result); - } - - @Test - public void testCoGroup() { - List>> stringStringKVStream1 = Arrays.asList( - Arrays.asList(new Tuple2("california", "dodgers"), - new Tuple2("new york", "yankees")), - Arrays.asList(new Tuple2("california", "sharks"), - new Tuple2("new york", "rangers"))); - - List>> stringStringKVStream2 = Arrays.asList( - Arrays.asList(new Tuple2("california", "giants"), - new Tuple2("new york", "mets")), - Arrays.asList(new Tuple2("california", "ducks"), - new Tuple2("new york", "islanders"))); - - - List, List>>>> expected = Arrays.asList( - Arrays.asList( - new Tuple2, List>>("california", - new Tuple2, List>(Arrays.asList("dodgers"), Arrays.asList("giants"))), - new Tuple2, List>>("new york", - new Tuple2, List>(Arrays.asList("yankees"), Arrays.asList("mets")))), - Arrays.asList( - new Tuple2, List>>("california", - new Tuple2, List>(Arrays.asList("sharks"), Arrays.asList("ducks"))), - new Tuple2, List>>("new york", - new Tuple2, List>(Arrays.asList("rangers"), Arrays.asList("islanders"))))); - - - JavaDStream> stream1 = JavaTestUtils.attachTestInputStream( - ssc, stringStringKVStream1, 1); - JavaPairDStream pairStream1 = JavaPairDStream.fromJavaDStream(stream1); - - JavaDStream> stream2 = JavaTestUtils.attachTestInputStream( - ssc, stringStringKVStream2, 1); - JavaPairDStream pairStream2 = JavaPairDStream.fromJavaDStream(stream2); - - JavaPairDStream, List>> grouped = pairStream1.cogroup(pairStream2); - JavaTestUtils.attachTestOutputStream(grouped); - List>> result = JavaTestUtils.runStreams(ssc, 2, 2); - - Assert.assertEquals(expected, result); - } - - @Test - public void testJoin() { - List>> stringStringKVStream1 = Arrays.asList( - Arrays.asList(new Tuple2("california", "dodgers"), - new Tuple2("new york", "yankees")), - Arrays.asList(new Tuple2("california", "sharks"), - new Tuple2("new york", "rangers"))); - - List>> stringStringKVStream2 = Arrays.asList( - Arrays.asList(new Tuple2("california", "giants"), - new Tuple2("new york", "mets")), - Arrays.asList(new Tuple2("california", "ducks"), - new Tuple2("new york", "islanders"))); - - - List>>> expected = Arrays.asList( - Arrays.asList( - new Tuple2>("california", - new Tuple2("dodgers", "giants")), - new Tuple2>("new york", - new Tuple2("yankees", "mets"))), - Arrays.asList( - new Tuple2>("california", - new Tuple2("sharks", "ducks")), - new Tuple2>("new york", - new Tuple2("rangers", "islanders")))); - - - JavaDStream> stream1 = JavaTestUtils.attachTestInputStream( - ssc, stringStringKVStream1, 1); - JavaPairDStream pairStream1 = JavaPairDStream.fromJavaDStream(stream1); - - JavaDStream> stream2 = JavaTestUtils.attachTestInputStream( - ssc, stringStringKVStream2, 1); - JavaPairDStream pairStream2 = JavaPairDStream.fromJavaDStream(stream2); - - JavaPairDStream> joined = pairStream1.join(pairStream2); - JavaTestUtils.attachTestOutputStream(joined); - List>> result = JavaTestUtils.runStreams(ssc, 2, 2); - - Assert.assertEquals(expected, result); - } - - @Test - public void testCheckpointMasterRecovery() throws InterruptedException { - List> inputData = Arrays.asList( - Arrays.asList("this", "is"), - Arrays.asList("a", "test"), - Arrays.asList("counting", "letters")); - - List> expectedInitial = Arrays.asList( - Arrays.asList(4,2)); - List> expectedFinal = Arrays.asList( - Arrays.asList(1,4), - Arrays.asList(8,7)); - - File tempDir = Files.createTempDir(); - ssc.checkpoint(tempDir.getAbsolutePath()); - - JavaDStream stream = JavaCheckpointTestUtils.attachTestInputStream(ssc, inputData, 1); - JavaDStream letterCount = stream.map(new Function() { - @Override - public Integer call(String s) throws Exception { - return s.length(); - } - }); - JavaCheckpointTestUtils.attachTestOutputStream(letterCount); - List> initialResult = JavaTestUtils.runStreams(ssc, 1, 1); - - assertOrderInvariantEquals(expectedInitial, initialResult); - Thread.sleep(1000); - ssc.stop(); - - ssc = new JavaStreamingContext(tempDir.getAbsolutePath()); - // Tweak to take into consideration that the last batch before failure - // will be re-processed after recovery - List> finalResult = JavaCheckpointTestUtils.runStreams(ssc, 2, 3); - assertOrderInvariantEquals(expectedFinal, finalResult.subList(1, 3)); - } - - - /** TEST DISABLED: Pending a discussion about checkpoint() semantics with TD - @Test - public void testCheckpointofIndividualStream() throws InterruptedException { - List> inputData = Arrays.asList( - Arrays.asList("this", "is"), - Arrays.asList("a", "test"), - Arrays.asList("counting", "letters")); - - List> expected = Arrays.asList( - Arrays.asList(4,2), - Arrays.asList(1,4), - Arrays.asList(8,7)); - - JavaDStream stream = JavaCheckpointTestUtils.attachTestInputStream(ssc, inputData, 1); - JavaDStream letterCount = stream.map(new Function() { - @Override - public Integer call(String s) throws Exception { - return s.length(); - } - }); - JavaCheckpointTestUtils.attachTestOutputStream(letterCount); - - letterCount.checkpoint(new Duration(1000)); - - List> result1 = JavaCheckpointTestUtils.runStreams(ssc, 3, 3); - assertOrderInvariantEquals(expected, result1); - } - */ - - // Input stream tests. These mostly just test that we can instantiate a given InputStream with - // Java arguments and assign it to a JavaDStream without producing type errors. Testing of the - // InputStream functionality is deferred to the existing Scala tests. - @Test - public void testKafkaStream() { - HashMap topics = Maps.newHashMap(); - JavaDStream test1 = ssc.kafkaStream("localhost:12345", "group", topics); - JavaDStream test2 = ssc.kafkaStream("localhost:12345", "group", topics, - StorageLevel.MEMORY_AND_DISK()); - - HashMap kafkaParams = Maps.newHashMap(); - kafkaParams.put("zk.connect","localhost:12345"); - kafkaParams.put("groupid","consumer-group"); - JavaDStream test3 = ssc.kafkaStream(String.class, StringDecoder.class, kafkaParams, topics, - StorageLevel.MEMORY_AND_DISK()); - } - - @Test - public void testSocketTextStream() { - JavaDStream test = ssc.socketTextStream("localhost", 12345); - } - - @Test - public void testSocketString() { - class Converter extends Function> { - public Iterable call(InputStream in) { - BufferedReader reader = new BufferedReader(new InputStreamReader(in)); - List out = new ArrayList(); - try { - while (true) { - String line = reader.readLine(); - if (line == null) { break; } - out.add(line); - } - } catch (IOException e) { } - return out; - } - } - - JavaDStream test = ssc.socketStream( - "localhost", - 12345, - new Converter(), - StorageLevel.MEMORY_ONLY()); - } - - @Test - public void testTextFileStream() { - JavaDStream test = ssc.textFileStream("/tmp/foo"); - } - - @Test - public void testRawSocketStream() { - JavaDStream test = ssc.rawSocketStream("localhost", 12345); - } - - @Test - public void testFlumeStream() { - JavaDStream test = ssc.flumeStream("localhost", 12345, StorageLevel.MEMORY_ONLY()); - } - - @Test - public void testFileStream() { - JavaPairDStream foo = - ssc.fileStream("/tmp/foo"); - } - - @Test - public void testTwitterStream() { - String[] filters = new String[] { "good", "bad", "ugly" }; - JavaDStream test = ssc.twitterStream(filters, StorageLevel.MEMORY_ONLY()); - } - - @Test - public void testActorStream() { - JavaDStream test = ssc.actorStream((Props)null, "TestActor", StorageLevel.MEMORY_ONLY()); - } - - @Test - public void testZeroMQStream() { - JavaDStream test = ssc.zeroMQStream("url", (Subscribe) null, new Function>() { - @Override - public Iterable call(byte[][] b) throws Exception { - return null; - } - }); - } -} diff --git a/streaming/src/test/java/spark/streaming/JavaTestUtils.scala b/streaming/src/test/java/spark/streaming/JavaTestUtils.scala deleted file mode 100644 index f9d25db8da..0000000000 --- a/streaming/src/test/java/spark/streaming/JavaTestUtils.scala +++ /dev/null @@ -1,84 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.streaming - -import collection.mutable.{SynchronizedBuffer, ArrayBuffer} -import java.util.{List => JList} -import spark.streaming.api.java.{JavaPairDStream, JavaDStreamLike, JavaDStream, JavaStreamingContext} -import spark.streaming._ -import java.util.ArrayList -import collection.JavaConversions._ - -/** Exposes streaming test functionality in a Java-friendly way. */ -trait JavaTestBase extends TestSuiteBase { - - /** - * Create a [[spark.streaming.TestInputStream]] and attach it to the supplied context. - * The stream will be derived from the supplied lists of Java objects. - **/ - def attachTestInputStream[T]( - ssc: JavaStreamingContext, - data: JList[JList[T]], - numPartitions: Int) = { - val seqData = data.map(Seq(_:_*)) - - implicit val cm: ClassManifest[T] = - implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]] - val dstream = new TestInputStream[T](ssc.ssc, seqData, numPartitions) - ssc.ssc.registerInputStream(dstream) - new JavaDStream[T](dstream) - } - - /** - * Attach a provided stream to it's associated StreamingContext as a - * [[spark.streaming.TestOutputStream]]. - **/ - def attachTestOutputStream[T, This <: spark.streaming.api.java.JavaDStreamLike[T, This, R], - R <: spark.api.java.JavaRDDLike[T, R]]( - dstream: JavaDStreamLike[T, This, R]) = { - implicit val cm: ClassManifest[T] = - implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]] - val ostream = new TestOutputStream(dstream.dstream, - new ArrayBuffer[Seq[T]] with SynchronizedBuffer[Seq[T]]) - dstream.dstream.ssc.registerOutputStream(ostream) - } - - /** - * Process all registered streams for a numBatches batches, failing if - * numExpectedOutput RDD's are not generated. Generated RDD's are collected - * and returned, represented as a list for each batch interval. - */ - def runStreams[V]( - ssc: JavaStreamingContext, numBatches: Int, numExpectedOutput: Int): JList[JList[V]] = { - implicit val cm: ClassManifest[V] = - implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[V]] - val res = runStreams[V](ssc.ssc, numBatches, numExpectedOutput) - val out = new ArrayList[JList[V]]() - res.map(entry => out.append(new ArrayList[V](entry))) - out - } -} - -object JavaTestUtils extends JavaTestBase { - override def maxWaitTimeMillis = 20000 - -} - -object JavaCheckpointTestUtils extends JavaTestBase { - override def actuallyWait = true -} diff --git a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala new file mode 100644 index 0000000000..11586f72b6 --- /dev/null +++ b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala @@ -0,0 +1,322 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming + +import org.apache.spark.streaming.StreamingContext._ +import scala.runtime.RichInt +import util.ManualClock + +class BasicOperationsSuite extends TestSuiteBase { + + override def framework() = "BasicOperationsSuite" + + before { + System.setProperty("spark.streaming.clock", "org.apache.spark.streaming.util.ManualClock") + } + + after { + // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown + System.clearProperty("spark.driver.port") + System.clearProperty("spark.hostPort") + } + + test("map") { + val input = Seq(1 to 4, 5 to 8, 9 to 12) + testOperation( + input, + (r: DStream[Int]) => r.map(_.toString), + input.map(_.map(_.toString)) + ) + } + + test("flatMap") { + val input = Seq(1 to 4, 5 to 8, 9 to 12) + testOperation( + input, + (r: DStream[Int]) => r.flatMap(x => Seq(x, x * 2)), + input.map(_.flatMap(x => Array(x, x * 2))) + ) + } + + test("filter") { + val input = Seq(1 to 4, 5 to 8, 9 to 12) + testOperation( + input, + (r: DStream[Int]) => r.filter(x => (x % 2 == 0)), + input.map(_.filter(x => (x % 2 == 0))) + ) + } + + test("glom") { + assert(numInputPartitions === 2, "Number of input partitions has been changed from 2") + val input = Seq(1 to 4, 5 to 8, 9 to 12) + val output = Seq( + Seq( Seq(1, 2), Seq(3, 4) ), + Seq( Seq(5, 6), Seq(7, 8) ), + Seq( Seq(9, 10), Seq(11, 12) ) + ) + val operation = (r: DStream[Int]) => r.glom().map(_.toSeq) + testOperation(input, operation, output) + } + + test("mapPartitions") { + assert(numInputPartitions === 2, "Number of input partitions has been changed from 2") + val input = Seq(1 to 4, 5 to 8, 9 to 12) + val output = Seq(Seq(3, 7), Seq(11, 15), Seq(19, 23)) + val operation = (r: DStream[Int]) => r.mapPartitions(x => Iterator(x.reduce(_ + _))) + testOperation(input, operation, output, true) + } + + test("groupByKey") { + testOperation( + Seq( Seq("a", "a", "b"), Seq("", ""), Seq() ), + (s: DStream[String]) => s.map(x => (x, 1)).groupByKey(), + Seq( Seq(("a", Seq(1, 1)), ("b", Seq(1))), Seq(("", Seq(1, 1))), Seq() ), + true + ) + } + + test("reduceByKey") { + testOperation( + Seq( Seq("a", "a", "b"), Seq("", ""), Seq() ), + (s: DStream[String]) => s.map(x => (x, 1)).reduceByKey(_ + _), + Seq( Seq(("a", 2), ("b", 1)), Seq(("", 2)), Seq() ), + true + ) + } + + test("reduce") { + testOperation( + Seq(1 to 4, 5 to 8, 9 to 12), + (s: DStream[Int]) => s.reduce(_ + _), + Seq(Seq(10), Seq(26), Seq(42)) + ) + } + + test("count") { + testOperation( + Seq(Seq(), 1 to 1, 1 to 2, 1 to 3, 1 to 4), + (s: DStream[Int]) => s.count(), + Seq(Seq(0L), Seq(1L), Seq(2L), Seq(3L), Seq(4L)) + ) + } + + test("countByValue") { + testOperation( + Seq(1 to 1, Seq(1, 1, 1), 1 to 2, Seq(1, 1, 2, 2)), + (s: DStream[Int]) => s.countByValue(), + Seq(Seq((1, 1L)), Seq((1, 3L)), Seq((1, 1L), (2, 1L)), Seq((2, 2L), (1, 2L))), + true + ) + } + + test("mapValues") { + testOperation( + Seq( Seq("a", "a", "b"), Seq("", ""), Seq() ), + (s: DStream[String]) => s.map(x => (x, 1)).reduceByKey(_ + _).mapValues(_ + 10), + Seq( Seq(("a", 12), ("b", 11)), Seq(("", 12)), Seq() ), + true + ) + } + + test("flatMapValues") { + testOperation( + Seq( Seq("a", "a", "b"), Seq("", ""), Seq() ), + (s: DStream[String]) => s.map(x => (x, 1)).reduceByKey(_ + _).flatMapValues(x => Seq(x, x + 10)), + Seq( Seq(("a", 2), ("a", 12), ("b", 1), ("b", 11)), Seq(("", 2), ("", 12)), Seq() ), + true + ) + } + + test("cogroup") { + val inputData1 = Seq( Seq("a", "a", "b"), Seq("a", ""), Seq(""), Seq() ) + val inputData2 = Seq( Seq("a", "a", "b"), Seq("b", ""), Seq(), Seq() ) + val outputData = Seq( + Seq( ("a", (Seq(1, 1), Seq("x", "x"))), ("b", (Seq(1), Seq("x"))) ), + Seq( ("a", (Seq(1), Seq())), ("b", (Seq(), Seq("x"))), ("", (Seq(1), Seq("x"))) ), + Seq( ("", (Seq(1), Seq())) ), + Seq( ) + ) + val operation = (s1: DStream[String], s2: DStream[String]) => { + s1.map(x => (x,1)).cogroup(s2.map(x => (x, "x"))) + } + testOperation(inputData1, inputData2, operation, outputData, true) + } + + test("join") { + val inputData1 = Seq( Seq("a", "b"), Seq("a", ""), Seq(""), Seq() ) + val inputData2 = Seq( Seq("a", "b"), Seq("b", ""), Seq(), Seq("") ) + val outputData = Seq( + Seq( ("a", (1, "x")), ("b", (1, "x")) ), + Seq( ("", (1, "x")) ), + Seq( ), + Seq( ) + ) + val operation = (s1: DStream[String], s2: DStream[String]) => { + s1.map(x => (x,1)).join(s2.map(x => (x,"x"))) + } + testOperation(inputData1, inputData2, operation, outputData, true) + } + + test("updateStateByKey") { + val inputData = + Seq( + Seq("a"), + Seq("a", "b"), + Seq("a", "b", "c"), + Seq("a", "b"), + Seq("a"), + Seq() + ) + + val outputData = + Seq( + Seq(("a", 1)), + Seq(("a", 2), ("b", 1)), + Seq(("a", 3), ("b", 2), ("c", 1)), + Seq(("a", 4), ("b", 3), ("c", 1)), + Seq(("a", 5), ("b", 3), ("c", 1)), + Seq(("a", 5), ("b", 3), ("c", 1)) + ) + + val updateStateOperation = (s: DStream[String]) => { + val updateFunc = (values: Seq[Int], state: Option[Int]) => { + Some(values.foldLeft(0)(_ + _) + state.getOrElse(0)) + } + s.map(x => (x, 1)).updateStateByKey[Int](updateFunc) + } + + testOperation(inputData, updateStateOperation, outputData, true) + } + + test("updateStateByKey - object lifecycle") { + val inputData = + Seq( + Seq("a","b"), + null, + Seq("a","c","a"), + Seq("c"), + null, + null + ) + + val outputData = + Seq( + Seq(("a", 1), ("b", 1)), + Seq(("a", 1), ("b", 1)), + Seq(("a", 3), ("c", 1)), + Seq(("a", 3), ("c", 2)), + Seq(("c", 2)), + Seq() + ) + + val updateStateOperation = (s: DStream[String]) => { + class StateObject(var counter: Int = 0, var expireCounter: Int = 0) extends Serializable + + // updateFunc clears a state when a StateObject is seen without new values twice in a row + val updateFunc = (values: Seq[Int], state: Option[StateObject]) => { + val stateObj = state.getOrElse(new StateObject) + values.foldLeft(0)(_ + _) match { + case 0 => stateObj.expireCounter += 1 // no new values + case n => { // has new values, increment and reset expireCounter + stateObj.counter += n + stateObj.expireCounter = 0 + } + } + stateObj.expireCounter match { + case 2 => None // seen twice with no new values, give it the boot + case _ => Option(stateObj) + } + } + s.map(x => (x, 1)).updateStateByKey[StateObject](updateFunc).mapValues(_.counter) + } + + testOperation(inputData, updateStateOperation, outputData, true) + } + + test("slice") { + val ssc = new StreamingContext("local[2]", "BasicOperationSuite", Seconds(1)) + val input = Seq(Seq(1), Seq(2), Seq(3), Seq(4)) + val stream = new TestInputStream[Int](ssc, input, 2) + ssc.registerInputStream(stream) + stream.foreach(_ => {}) // Dummy output stream + ssc.start() + Thread.sleep(2000) + def getInputFromSlice(fromMillis: Long, toMillis: Long) = { + stream.slice(new Time(fromMillis), new Time(toMillis)).flatMap(_.collect()).toSet + } + + assert(getInputFromSlice(0, 1000) == Set(1)) + assert(getInputFromSlice(0, 2000) == Set(1, 2)) + assert(getInputFromSlice(1000, 2000) == Set(1, 2)) + assert(getInputFromSlice(2000, 4000) == Set(2, 3, 4)) + ssc.stop() + Thread.sleep(1000) + } + + test("forgetting of RDDs - map and window operations") { + assert(batchDuration === Seconds(1), "Batch duration has changed from 1 second") + + val input = (0 until 10).map(x => Seq(x, x + 1)).toSeq + val rememberDuration = Seconds(3) + + assert(input.size === 10, "Number of inputs have changed") + + def operation(s: DStream[Int]): DStream[(Int, Int)] = { + s.map(x => (x % 10, 1)) + .window(Seconds(2), Seconds(1)) + .window(Seconds(4), Seconds(2)) + } + + val ssc = setupStreams(input, operation _) + ssc.remember(rememberDuration) + runStreams[(Int, Int)](ssc, input.size, input.size / 2) + + val windowedStream2 = ssc.graph.getOutputStreams().head.dependencies.head + val windowedStream1 = windowedStream2.dependencies.head + val mappedStream = windowedStream1.dependencies.head + + val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] + assert(clock.time === Seconds(10).milliseconds) + + // IDEALLY + // WindowedStream2 should remember till 7 seconds: 10, 8, + // WindowedStream1 should remember till 4 seconds: 10, 9, 8, 7, 6, 5 + // MappedStream should remember till 7 seconds: 10, 9, 8, 7, 6, 5, 4, 3, + + // IN THIS TEST + // WindowedStream2 should remember till 7 seconds: 10, 8, + // WindowedStream1 should remember till 4 seconds: 10, 9, 8, 7, 6, 5, 4 + // MappedStream should remember till 7 seconds: 10, 9, 8, 7, 6, 5, 4, 3, 2 + + // WindowedStream2 + assert(windowedStream2.generatedRDDs.contains(Time(10000))) + assert(windowedStream2.generatedRDDs.contains(Time(8000))) + assert(!windowedStream2.generatedRDDs.contains(Time(6000))) + + // WindowedStream1 + assert(windowedStream1.generatedRDDs.contains(Time(10000))) + assert(windowedStream1.generatedRDDs.contains(Time(4000))) + assert(!windowedStream1.generatedRDDs.contains(Time(3000))) + + // MappedStream + assert(mappedStream.generatedRDDs.contains(Time(10000))) + assert(mappedStream.generatedRDDs.contains(Time(2000))) + assert(!mappedStream.generatedRDDs.contains(Time(1000))) + } +} diff --git a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala new file mode 100644 index 0000000000..a327de80b3 --- /dev/null +++ b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala @@ -0,0 +1,372 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming + +import dstream.FileInputDStream +import org.apache.spark.streaming.StreamingContext._ +import java.io.File +import runtime.RichInt +import org.scalatest.BeforeAndAfter +import org.apache.commons.io.FileUtils +import collection.mutable.{SynchronizedBuffer, ArrayBuffer} +import util.{Clock, ManualClock} +import scala.util.Random +import com.google.common.io.Files + + +/** + * This test suites tests the checkpointing functionality of DStreams - + * the checkpointing of a DStream's RDDs as well as the checkpointing of + * the whole DStream graph. + */ +class CheckpointSuite extends TestSuiteBase with BeforeAndAfter { + + System.setProperty("spark.streaming.clock", "org.apache.spark.streaming.util.ManualClock") + + before { + FileUtils.deleteDirectory(new File(checkpointDir)) + } + + after { + if (ssc != null) ssc.stop() + FileUtils.deleteDirectory(new File(checkpointDir)) + + // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown + System.clearProperty("spark.driver.port") + System.clearProperty("spark.hostPort") + } + + var ssc: StreamingContext = null + + override def framework = "CheckpointSuite" + + override def batchDuration = Milliseconds(500) + + override def actuallyWait = true + + test("basic rdd checkpoints + dstream graph checkpoint recovery") { + + assert(batchDuration === Milliseconds(500), "batchDuration for this test must be 1 second") + + System.setProperty("spark.streaming.clock", "org.apache.spark.streaming.util.ManualClock") + + val stateStreamCheckpointInterval = Seconds(1) + + // this ensure checkpointing occurs at least once + val firstNumBatches = (stateStreamCheckpointInterval / batchDuration).toLong * 2 + val secondNumBatches = firstNumBatches + + // Setup the streams + val input = (1 to 10).map(_ => Seq("a")).toSeq + val operation = (st: DStream[String]) => { + val updateFunc = (values: Seq[Int], state: Option[RichInt]) => { + Some(new RichInt(values.foldLeft(0)(_ + _) + state.map(_.self).getOrElse(0))) + } + st.map(x => (x, 1)) + .updateStateByKey[RichInt](updateFunc) + .checkpoint(stateStreamCheckpointInterval) + .map(t => (t._1, t._2.self)) + } + var ssc = setupStreams(input, operation) + var stateStream = ssc.graph.getOutputStreams().head.dependencies.head.dependencies.head + + // Run till a time such that at least one RDD in the stream should have been checkpointed, + // then check whether some RDD has been checkpointed or not + ssc.start() + advanceTimeWithRealDelay(ssc, firstNumBatches) + logInfo("Checkpoint data of state stream = \n" + stateStream.checkpointData) + assert(!stateStream.checkpointData.checkpointFiles.isEmpty, "No checkpointed RDDs in state stream before first failure") + stateStream.checkpointData.checkpointFiles.foreach { + case (time, data) => { + val file = new File(data.toString) + assert(file.exists(), "Checkpoint file '" + file +"' for time " + time + " for state stream before first failure does not exist") + } + } + + // Run till a further time such that previous checkpoint files in the stream would be deleted + // and check whether the earlier checkpoint files are deleted + val checkpointFiles = stateStream.checkpointData.checkpointFiles.map(x => new File(x._2)) + advanceTimeWithRealDelay(ssc, secondNumBatches) + checkpointFiles.foreach(file => assert(!file.exists, "Checkpoint file '" + file + "' was not deleted")) + ssc.stop() + + // Restart stream computation using the checkpoint file and check whether + // checkpointed RDDs have been restored or not + ssc = new StreamingContext(checkpointDir) + stateStream = ssc.graph.getOutputStreams().head.dependencies.head.dependencies.head + logInfo("Restored data of state stream = \n[" + stateStream.generatedRDDs.mkString("\n") + "]") + assert(!stateStream.generatedRDDs.isEmpty, "No restored RDDs in state stream after recovery from first failure") + + + // Run one batch to generate a new checkpoint file and check whether some RDD + // is present in the checkpoint data or not + ssc.start() + advanceTimeWithRealDelay(ssc, 1) + assert(!stateStream.checkpointData.checkpointFiles.isEmpty, "No checkpointed RDDs in state stream before second failure") + stateStream.checkpointData.checkpointFiles.foreach { + case (time, data) => { + val file = new File(data.toString) + assert(file.exists(), + "Checkpoint file '" + file +"' for time " + time + " for state stream before seconds failure does not exist") + } + } + ssc.stop() + + // Restart stream computation from the new checkpoint file to see whether that file has + // correct checkpoint data + ssc = new StreamingContext(checkpointDir) + stateStream = ssc.graph.getOutputStreams().head.dependencies.head.dependencies.head + logInfo("Restored data of state stream = \n[" + stateStream.generatedRDDs.mkString("\n") + "]") + assert(!stateStream.generatedRDDs.isEmpty, "No restored RDDs in state stream after recovery from second failure") + + // Adjust manual clock time as if it is being restarted after a delay + System.setProperty("spark.streaming.manualClock.jump", (batchDuration.milliseconds * 7).toString) + ssc.start() + advanceTimeWithRealDelay(ssc, 4) + ssc.stop() + System.clearProperty("spark.streaming.manualClock.jump") + ssc = null + } + + // This tests whether the systm can recover from a master failure with simple + // non-stateful operations. This assumes as reliable, replayable input + // source - TestInputDStream. + test("recovery with map and reduceByKey operations") { + testCheckpointedOperation( + Seq( Seq("a", "a", "b"), Seq("", ""), Seq(), Seq("a", "a", "b"), Seq("", ""), Seq() ), + (s: DStream[String]) => s.map(x => (x, 1)).reduceByKey(_ + _), + Seq( Seq(("a", 2), ("b", 1)), Seq(("", 2)), Seq(), Seq(("a", 2), ("b", 1)), Seq(("", 2)), Seq() ), + 3 + ) + } + + + // This tests whether the ReduceWindowedDStream's RDD checkpoints works correctly such + // that the system can recover from a master failure. This assumes as reliable, + // replayable input source - TestInputDStream. + test("recovery with invertible reduceByKeyAndWindow operation") { + val n = 10 + val w = 4 + val input = (1 to n).map(_ => Seq("a")).toSeq + val output = Seq(Seq(("a", 1)), Seq(("a", 2)), Seq(("a", 3))) ++ (1 to (n - w + 1)).map(x => Seq(("a", 4))) + val operation = (st: DStream[String]) => { + st.map(x => (x, 1)) + .reduceByKeyAndWindow(_ + _, _ - _, batchDuration * w, batchDuration) + .checkpoint(batchDuration * 2) + } + testCheckpointedOperation(input, operation, output, 7) + } + + + // This tests whether the StateDStream's RDD checkpoints works correctly such + // that the system can recover from a master failure. This assumes as reliable, + // replayable input source - TestInputDStream. + test("recovery with updateStateByKey operation") { + val input = (1 to 10).map(_ => Seq("a")).toSeq + val output = (1 to 10).map(x => Seq(("a", x))).toSeq + val operation = (st: DStream[String]) => { + val updateFunc = (values: Seq[Int], state: Option[RichInt]) => { + Some(new RichInt(values.foldLeft(0)(_ + _) + state.map(_.self).getOrElse(0))) + } + st.map(x => (x, 1)) + .updateStateByKey[RichInt](updateFunc) + .checkpoint(batchDuration * 2) + .map(t => (t._1, t._2.self)) + } + testCheckpointedOperation(input, operation, output, 7) + } + + // This tests whether file input stream remembers what files were seen before + // the master failure and uses them again to process a large window operation. + // It also tests whether batches, whose processing was incomplete due to the + // failure, are re-processed or not. + test("recovery with file input stream") { + // Disable manual clock as FileInputDStream does not work with manual clock + val clockProperty = System.getProperty("spark.streaming.clock") + System.clearProperty("spark.streaming.clock") + + // Set up the streaming context and input streams + val testDir = Files.createTempDir() + var ssc = new StreamingContext(master, framework, Seconds(1)) + ssc.checkpoint(checkpointDir) + val fileStream = ssc.textFileStream(testDir.toString) + // Making value 3 take large time to process, to ensure that the master + // shuts down in the middle of processing the 3rd batch + val mappedStream = fileStream.map(s => { + val i = s.toInt + if (i == 3) Thread.sleep(2000) + i + }) + + // Reducing over a large window to ensure that recovery from master failure + // requires reprocessing of all the files seen before the failure + val reducedStream = mappedStream.reduceByWindow(_ + _, Seconds(30), Seconds(1)) + val outputBuffer = new ArrayBuffer[Seq[Int]] + var outputStream = new TestOutputStream(reducedStream, outputBuffer) + ssc.registerOutputStream(outputStream) + ssc.start() + + // Create files and advance manual clock to process them + //var clock = ssc.scheduler.clock.asInstanceOf[ManualClock] + Thread.sleep(1000) + for (i <- Seq(1, 2, 3)) { + FileUtils.writeStringToFile(new File(testDir, i.toString), i.toString + "\n") + // wait to make sure that the file is written such that it gets shown in the file listings + Thread.sleep(1000) + } + logInfo("Output = " + outputStream.output.mkString(",")) + assert(outputStream.output.size > 0, "No files processed before restart") + ssc.stop() + + // Verify whether files created have been recorded correctly or not + var fileInputDStream = ssc.graph.getInputStreams().head.asInstanceOf[FileInputDStream[_, _, _]] + def recordedFiles = fileInputDStream.files.values.flatMap(x => x) + assert(!recordedFiles.filter(_.endsWith("1")).isEmpty) + assert(!recordedFiles.filter(_.endsWith("2")).isEmpty) + assert(!recordedFiles.filter(_.endsWith("3")).isEmpty) + + // Create files while the master is down + for (i <- Seq(4, 5, 6)) { + FileUtils.writeStringToFile(new File(testDir, i.toString), i.toString + "\n") + Thread.sleep(1000) + } + + // Recover context from checkpoint file and verify whether the files that were + // recorded before failure were saved and successfully recovered + logInfo("*********** RESTARTING ************") + ssc = new StreamingContext(checkpointDir) + fileInputDStream = ssc.graph.getInputStreams().head.asInstanceOf[FileInputDStream[_, _, _]] + assert(!recordedFiles.filter(_.endsWith("1")).isEmpty) + assert(!recordedFiles.filter(_.endsWith("2")).isEmpty) + assert(!recordedFiles.filter(_.endsWith("3")).isEmpty) + + // Restart stream computation + ssc.start() + for (i <- Seq(7, 8, 9)) { + FileUtils.writeStringToFile(new File(testDir, i.toString), i.toString + "\n") + Thread.sleep(1000) + } + Thread.sleep(1000) + logInfo("Output = " + outputStream.output.mkString("[", ", ", "]")) + assert(outputStream.output.size > 0, "No files processed after restart") + ssc.stop() + + // Verify whether files created while the driver was down have been recorded or not + assert(!recordedFiles.filter(_.endsWith("4")).isEmpty) + assert(!recordedFiles.filter(_.endsWith("5")).isEmpty) + assert(!recordedFiles.filter(_.endsWith("6")).isEmpty) + + // Verify whether new files created after recover have been recorded or not + assert(!recordedFiles.filter(_.endsWith("7")).isEmpty) + assert(!recordedFiles.filter(_.endsWith("8")).isEmpty) + assert(!recordedFiles.filter(_.endsWith("9")).isEmpty) + + // Append the new output to the old buffer + outputStream = ssc.graph.getOutputStreams().head.asInstanceOf[TestOutputStream[Int]] + outputBuffer ++= outputStream.output + + val expectedOutput = Seq(1, 3, 6, 10, 15, 21, 28, 36, 45) + logInfo("--------------------------------") + logInfo("output, size = " + outputBuffer.size) + outputBuffer.foreach(x => logInfo("[" + x.mkString(",") + "]")) + logInfo("expected output, size = " + expectedOutput.size) + expectedOutput.foreach(x => logInfo("[" + x + "]")) + logInfo("--------------------------------") + + // Verify whether all the elements received are as expected + val output = outputBuffer.flatMap(x => x) + assert(output.contains(6)) // To ensure that the 3rd input (i.e., 3) was processed + output.foreach(o => // To ensure all the inputs are correctly added cumulatively + assert(expectedOutput.contains(o), "Expected value " + o + " not found") + ) + // To ensure that all the inputs were received correctly + assert(expectedOutput.last === output.last) + + // Enable manual clock back again for other tests + if (clockProperty != null) + System.setProperty("spark.streaming.clock", clockProperty) + } + + + /** + * Tests a streaming operation under checkpointing, by restarting the operation + * from checkpoint file and verifying whether the final output is correct. + * The output is assumed to have come from a reliable queue which an replay + * data as required. + * + * NOTE: This takes into consideration that the last batch processed before + * master failure will be re-processed after restart/recovery. + */ + def testCheckpointedOperation[U: ClassManifest, V: ClassManifest]( + input: Seq[Seq[U]], + operation: DStream[U] => DStream[V], + expectedOutput: Seq[Seq[V]], + initialNumBatches: Int + ) { + + // Current code assumes that: + // number of inputs = number of outputs = number of batches to be run + val totalNumBatches = input.size + val nextNumBatches = totalNumBatches - initialNumBatches + val initialNumExpectedOutputs = initialNumBatches + val nextNumExpectedOutputs = expectedOutput.size - initialNumExpectedOutputs + 1 + // because the last batch will be processed again + + // Do the computation for initial number of batches, create checkpoint file and quit + ssc = setupStreams[U, V](input, operation) + ssc.start() + val output = advanceTimeWithRealDelay[V](ssc, initialNumBatches) + ssc.stop() + verifyOutput[V](output, expectedOutput.take(initialNumBatches), true) + Thread.sleep(1000) + + // Restart and complete the computation from checkpoint file + logInfo( + "\n-------------------------------------------\n" + + " Restarting stream computation " + + "\n-------------------------------------------\n" + ) + ssc = new StreamingContext(checkpointDir) + System.clearProperty("spark.driver.port") + System.clearProperty("spark.hostPort") + ssc.start() + val outputNew = advanceTimeWithRealDelay[V](ssc, nextNumBatches) + // the first element will be re-processed data of the last batch before restart + verifyOutput[V](outputNew, expectedOutput.takeRight(nextNumExpectedOutputs), true) + ssc.stop() + ssc = null + } + + /** + * Advances the manual clock on the streaming scheduler by given number of batches. + * It also waits for the expected amount of time for each batch. + */ + def advanceTimeWithRealDelay[V: ClassManifest](ssc: StreamingContext, numBatches: Long): Seq[Seq[V]] = { + val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] + logInfo("Manual clock before advancing = " + clock.time) + for (i <- 1 to numBatches.toInt) { + clock.addToTime(batchDuration.milliseconds) + Thread.sleep(batchDuration.milliseconds) + } + logInfo("Manual clock after advancing = " + clock.time) + Thread.sleep(batchDuration.milliseconds) + + val outputStream = ssc.graph.getOutputStreams.head.asInstanceOf[TestOutputStream[V]] + outputStream.output + } +} diff --git a/streaming/src/test/scala/org/apache/spark/streaming/FailureSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/FailureSuite.scala new file mode 100644 index 0000000000..6337c5359c --- /dev/null +++ b/streaming/src/test/scala/org/apache/spark/streaming/FailureSuite.scala @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming + +import org.apache.spark.Logging +import org.apache.spark.streaming.util.MasterFailureTest +import StreamingContext._ + +import org.scalatest.{FunSuite, BeforeAndAfter} +import com.google.common.io.Files +import java.io.File +import org.apache.commons.io.FileUtils +import collection.mutable.ArrayBuffer + + +/** + * This testsuite tests master failures at random times while the stream is running using + * the real clock. + */ +class FailureSuite extends FunSuite with BeforeAndAfter with Logging { + + var directory = "FailureSuite" + val numBatches = 30 + val batchDuration = Milliseconds(1000) + + before { + FileUtils.deleteDirectory(new File(directory)) + } + + after { + FileUtils.deleteDirectory(new File(directory)) + } + + test("multiple failures with map") { + MasterFailureTest.testMap(directory, numBatches, batchDuration) + } + + test("multiple failures with updateStateByKey") { + MasterFailureTest.testUpdateStateByKey(directory, numBatches, batchDuration) + } +} + diff --git a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala new file mode 100644 index 0000000000..42e3e51e3f --- /dev/null +++ b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala @@ -0,0 +1,349 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming + +import akka.actor.Actor +import akka.actor.IO +import akka.actor.IOManager +import akka.actor.Props +import akka.util.ByteString + +import dstream.SparkFlumeEvent +import java.net.{InetSocketAddress, SocketException, Socket, ServerSocket} +import java.io.{File, BufferedWriter, OutputStreamWriter} +import java.util.concurrent.{TimeUnit, ArrayBlockingQueue} +import collection.mutable.{SynchronizedBuffer, ArrayBuffer} +import util.ManualClock +import org.apache.spark.storage.StorageLevel +import org.apache.spark.streaming.receivers.Receiver +import org.apache.spark.Logging +import scala.util.Random +import org.apache.commons.io.FileUtils +import org.scalatest.BeforeAndAfter +import org.apache.flume.source.avro.AvroSourceProtocol +import org.apache.flume.source.avro.AvroFlumeEvent +import org.apache.flume.source.avro.Status +import org.apache.avro.ipc.{specific, NettyTransceiver} +import org.apache.avro.ipc.specific.SpecificRequestor +import java.nio.ByteBuffer +import collection.JavaConversions._ +import java.nio.charset.Charset +import com.google.common.io.Files + +class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { + + val testPort = 9999 + + override def checkpointDir = "checkpoint" + + before { + System.setProperty("spark.streaming.clock", "org.apache.spark.streaming.util.ManualClock") + } + + after { + // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown + System.clearProperty("spark.driver.port") + System.clearProperty("spark.hostPort") + } + + + test("socket input stream") { + // Start the server + val testServer = new TestServer() + testServer.start() + + // Set up the streaming context and input streams + val ssc = new StreamingContext(master, framework, batchDuration) + val networkStream = ssc.socketTextStream("localhost", testServer.port, StorageLevel.MEMORY_AND_DISK) + val outputBuffer = new ArrayBuffer[Seq[String]] with SynchronizedBuffer[Seq[String ]] + val outputStream = new TestOutputStream(networkStream, outputBuffer) + def output = outputBuffer.flatMap(x => x) + ssc.registerOutputStream(outputStream) + ssc.start() + + // Feed data to the server to send to the network receiver + val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] + val input = Seq(1, 2, 3, 4, 5) + val expectedOutput = input.map(_.toString) + Thread.sleep(1000) + for (i <- 0 until input.size) { + testServer.send(input(i).toString + "\n") + Thread.sleep(500) + clock.addToTime(batchDuration.milliseconds) + } + Thread.sleep(1000) + logInfo("Stopping server") + testServer.stop() + logInfo("Stopping context") + ssc.stop() + + // Verify whether data received was as expected + logInfo("--------------------------------") + logInfo("output.size = " + outputBuffer.size) + logInfo("output") + outputBuffer.foreach(x => logInfo("[" + x.mkString(",") + "]")) + logInfo("expected output.size = " + expectedOutput.size) + logInfo("expected output") + expectedOutput.foreach(x => logInfo("[" + x.mkString(",") + "]")) + logInfo("--------------------------------") + + // Verify whether all the elements received are as expected + // (whether the elements were received one in each interval is not verified) + assert(output.size === expectedOutput.size) + for (i <- 0 until output.size) { + assert(output(i) === expectedOutput(i)) + } + } + + + test("flume input stream") { + // Set up the streaming context and input streams + val ssc = new StreamingContext(master, framework, batchDuration) + val flumeStream = ssc.flumeStream("localhost", testPort, StorageLevel.MEMORY_AND_DISK) + val outputBuffer = new ArrayBuffer[Seq[SparkFlumeEvent]] + with SynchronizedBuffer[Seq[SparkFlumeEvent]] + val outputStream = new TestOutputStream(flumeStream, outputBuffer) + ssc.registerOutputStream(outputStream) + ssc.start() + + val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] + val input = Seq(1, 2, 3, 4, 5) + Thread.sleep(1000) + val transceiver = new NettyTransceiver(new InetSocketAddress("localhost", testPort)); + val client = SpecificRequestor.getClient( + classOf[AvroSourceProtocol], transceiver); + + for (i <- 0 until input.size) { + val event = new AvroFlumeEvent + event.setBody(ByteBuffer.wrap(input(i).toString.getBytes())) + event.setHeaders(Map[CharSequence, CharSequence]("test" -> "header")) + client.append(event) + Thread.sleep(500) + clock.addToTime(batchDuration.milliseconds) + } + + val startTime = System.currentTimeMillis() + while (outputBuffer.size < input.size && System.currentTimeMillis() - startTime < maxWaitTimeMillis) { + logInfo("output.size = " + outputBuffer.size + ", input.size = " + input.size) + Thread.sleep(100) + } + Thread.sleep(1000) + val timeTaken = System.currentTimeMillis() - startTime + assert(timeTaken < maxWaitTimeMillis, "Operation timed out after " + timeTaken + " ms") + logInfo("Stopping context") + ssc.stop() + + val decoder = Charset.forName("UTF-8").newDecoder() + + assert(outputBuffer.size === input.length) + for (i <- 0 until outputBuffer.size) { + assert(outputBuffer(i).size === 1) + val str = decoder.decode(outputBuffer(i).head.event.getBody) + assert(str.toString === input(i).toString) + assert(outputBuffer(i).head.event.getHeaders.get("test") === "header") + } + } + + + test("file input stream") { + // Disable manual clock as FileInputDStream does not work with manual clock + System.clearProperty("spark.streaming.clock") + + // Set up the streaming context and input streams + val testDir = Files.createTempDir() + val ssc = new StreamingContext(master, framework, batchDuration) + val fileStream = ssc.textFileStream(testDir.toString) + val outputBuffer = new ArrayBuffer[Seq[String]] with SynchronizedBuffer[Seq[String]] + def output = outputBuffer.flatMap(x => x) + val outputStream = new TestOutputStream(fileStream, outputBuffer) + ssc.registerOutputStream(outputStream) + ssc.start() + + // Create files in the temporary directory so that Spark Streaming can read data from it + val input = Seq(1, 2, 3, 4, 5) + val expectedOutput = input.map(_.toString) + Thread.sleep(1000) + for (i <- 0 until input.size) { + val file = new File(testDir, i.toString) + FileUtils.writeStringToFile(file, input(i).toString + "\n") + logInfo("Created file " + file) + Thread.sleep(batchDuration.milliseconds) + Thread.sleep(1000) + } + val startTime = System.currentTimeMillis() + Thread.sleep(1000) + val timeTaken = System.currentTimeMillis() - startTime + assert(timeTaken < maxWaitTimeMillis, "Operation timed out after " + timeTaken + " ms") + logInfo("Stopping context") + ssc.stop() + + // Verify whether data received by Spark Streaming was as expected + logInfo("--------------------------------") + logInfo("output, size = " + outputBuffer.size) + outputBuffer.foreach(x => logInfo("[" + x.mkString(",") + "]")) + logInfo("expected output, size = " + expectedOutput.size) + expectedOutput.foreach(x => logInfo("[" + x.mkString(",") + "]")) + logInfo("--------------------------------") + + // Verify whether all the elements received are as expected + // (whether the elements were received one in each interval is not verified) + assert(output.toList === expectedOutput.toList) + + FileUtils.deleteDirectory(testDir) + + // Enable manual clock back again for other tests + System.setProperty("spark.streaming.clock", "org.apache.spark.streaming.util.ManualClock") + } + + + test("actor input stream") { + // Start the server + val testServer = new TestServer() + val port = testServer.port + testServer.start() + + // Set up the streaming context and input streams + val ssc = new StreamingContext(master, framework, batchDuration) + val networkStream = ssc.actorStream[String](Props(new TestActor(port)), "TestActor", + StorageLevel.MEMORY_AND_DISK) //Had to pass the local value of port to prevent from closing over entire scope + val outputBuffer = new ArrayBuffer[Seq[String]] with SynchronizedBuffer[Seq[String]] + val outputStream = new TestOutputStream(networkStream, outputBuffer) + def output = outputBuffer.flatMap(x => x) + ssc.registerOutputStream(outputStream) + ssc.start() + + // Feed data to the server to send to the network receiver + val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] + val input = 1 to 9 + val expectedOutput = input.map(x => x.toString) + Thread.sleep(1000) + for (i <- 0 until input.size) { + testServer.send(input(i).toString) + Thread.sleep(500) + clock.addToTime(batchDuration.milliseconds) + } + Thread.sleep(1000) + logInfo("Stopping server") + testServer.stop() + logInfo("Stopping context") + ssc.stop() + + // Verify whether data received was as expected + logInfo("--------------------------------") + logInfo("output.size = " + outputBuffer.size) + logInfo("output") + outputBuffer.foreach(x => logInfo("[" + x.mkString(",") + "]")) + logInfo("expected output.size = " + expectedOutput.size) + logInfo("expected output") + expectedOutput.foreach(x => logInfo("[" + x.mkString(",") + "]")) + logInfo("--------------------------------") + + // Verify whether all the elements received are as expected + // (whether the elements were received one in each interval is not verified) + assert(output.size === expectedOutput.size) + for (i <- 0 until output.size) { + assert(output(i) === expectedOutput(i)) + } + } + + test("kafka input stream") { + val ssc = new StreamingContext(master, framework, batchDuration) + val topics = Map("my-topic" -> 1) + val test1 = ssc.kafkaStream("localhost:12345", "group", topics) + val test2 = ssc.kafkaStream("localhost:12345", "group", topics, StorageLevel.MEMORY_AND_DISK) + + // Test specifying decoder + val kafkaParams = Map("zk.connect"->"localhost:12345","groupid"->"consumer-group") + val test3 = ssc.kafkaStream[String, kafka.serializer.StringDecoder](kafkaParams, topics, StorageLevel.MEMORY_AND_DISK) + } +} + + +/** This is server to test the network input stream */ +class TestServer() extends Logging { + + val queue = new ArrayBlockingQueue[String](100) + + val serverSocket = new ServerSocket(0) + + val servingThread = new Thread() { + override def run() { + try { + while(true) { + logInfo("Accepting connections on port " + port) + val clientSocket = serverSocket.accept() + logInfo("New connection") + try { + clientSocket.setTcpNoDelay(true) + val outputStream = new BufferedWriter(new OutputStreamWriter(clientSocket.getOutputStream)) + + while(clientSocket.isConnected) { + val msg = queue.poll(100, TimeUnit.MILLISECONDS) + if (msg != null) { + outputStream.write(msg) + outputStream.flush() + logInfo("Message '" + msg + "' sent") + } + } + } catch { + case e: SocketException => logError("TestServer error", e) + } finally { + logInfo("Connection closed") + if (!clientSocket.isClosed) clientSocket.close() + } + } + } catch { + case ie: InterruptedException => + + } finally { + serverSocket.close() + } + } + } + + def start() { servingThread.start() } + + def send(msg: String) { queue.add(msg) } + + def stop() { servingThread.interrupt() } + + def port = serverSocket.getLocalPort +} + +object TestServer { + def main(args: Array[String]) { + val s = new TestServer() + s.start() + while(true) { + Thread.sleep(1000) + s.send("hello") + } + } +} + +class TestActor(port: Int) extends Actor with Receiver { + + def bytesToString(byteString: ByteString) = byteString.utf8String + + override def preStart = IOManager(context.system).connect(new InetSocketAddress(port)) + + def receive = { + case IO.Read(socket, bytes) => + pushBlock(bytesToString(bytes)) + } +} diff --git a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala new file mode 100644 index 0000000000..31c2fa0208 --- /dev/null +++ b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala @@ -0,0 +1,314 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming + +import org.apache.spark.streaming.dstream.{InputDStream, ForEachDStream} +import org.apache.spark.streaming.util.ManualClock + +import org.apache.spark.{RDD, Logging} + +import collection.mutable.ArrayBuffer +import collection.mutable.SynchronizedBuffer + +import java.io.{ObjectInputStream, IOException} + +import org.scalatest.{BeforeAndAfter, FunSuite} + +/** + * This is a input stream just for the testsuites. This is equivalent to a checkpointable, + * replayable, reliable message queue like Kafka. It requires a sequence as input, and + * returns the i_th element at the i_th batch unde manual clock. + */ +class TestInputStream[T: ClassManifest](ssc_ : StreamingContext, input: Seq[Seq[T]], numPartitions: Int) + extends InputDStream[T](ssc_) { + + def start() {} + + def stop() {} + + def compute(validTime: Time): Option[RDD[T]] = { + logInfo("Computing RDD for time " + validTime) + val index = ((validTime - zeroTime) / slideDuration - 1).toInt + val selectedInput = if (index < input.size) input(index) else Seq[T]() + + // lets us test cases where RDDs are not created + if (selectedInput == null) + return None + + val rdd = ssc.sc.makeRDD(selectedInput, numPartitions) + logInfo("Created RDD " + rdd.id + " with " + selectedInput) + Some(rdd) + } +} + +/** + * This is a output stream just for the testsuites. All the output is collected into a + * ArrayBuffer. This buffer is wiped clean on being restored from checkpoint. + */ +class TestOutputStream[T: ClassManifest](parent: DStream[T], val output: ArrayBuffer[Seq[T]]) + extends ForEachDStream[T](parent, (rdd: RDD[T], t: Time) => { + val collected = rdd.collect() + output += collected + }) { + + // This is to clear the output buffer every it is read from a checkpoint + @throws(classOf[IOException]) + private def readObject(ois: ObjectInputStream) { + ois.defaultReadObject() + output.clear() + } +} + +/** + * This is the base trait for Spark Streaming testsuites. This provides basic functionality + * to run user-defined set of input on user-defined stream operations, and verify the output. + */ +trait TestSuiteBase extends FunSuite with BeforeAndAfter with Logging { + + // Name of the framework for Spark context + def framework = "TestSuiteBase" + + // Master for Spark context + def master = "local[2]" + + // Batch duration + def batchDuration = Seconds(1) + + // Directory where the checkpoint data will be saved + def checkpointDir = "checkpoint" + + // Number of partitions of the input parallel collections created for testing + def numInputPartitions = 2 + + // Maximum time to wait before the test times out + def maxWaitTimeMillis = 10000 + + // Whether to actually wait in real time before changing manual clock + def actuallyWait = false + + /** + * Set up required DStreams to test the DStream operation using the two sequences + * of input collections. + */ + def setupStreams[U: ClassManifest, V: ClassManifest]( + input: Seq[Seq[U]], + operation: DStream[U] => DStream[V] + ): StreamingContext = { + + // Create StreamingContext + val ssc = new StreamingContext(master, framework, batchDuration) + if (checkpointDir != null) { + ssc.checkpoint(checkpointDir) + } + + // Setup the stream computation + val inputStream = new TestInputStream(ssc, input, numInputPartitions) + val operatedStream = operation(inputStream) + val outputStream = new TestOutputStream(operatedStream, new ArrayBuffer[Seq[V]] with SynchronizedBuffer[Seq[V]]) + ssc.registerInputStream(inputStream) + ssc.registerOutputStream(outputStream) + ssc + } + + /** + * Set up required DStreams to test the binary operation using the sequence + * of input collections. + */ + def setupStreams[U: ClassManifest, V: ClassManifest, W: ClassManifest]( + input1: Seq[Seq[U]], + input2: Seq[Seq[V]], + operation: (DStream[U], DStream[V]) => DStream[W] + ): StreamingContext = { + + // Create StreamingContext + val ssc = new StreamingContext(master, framework, batchDuration) + if (checkpointDir != null) { + ssc.checkpoint(checkpointDir) + } + + // Setup the stream computation + val inputStream1 = new TestInputStream(ssc, input1, numInputPartitions) + val inputStream2 = new TestInputStream(ssc, input2, numInputPartitions) + val operatedStream = operation(inputStream1, inputStream2) + val outputStream = new TestOutputStream(operatedStream, new ArrayBuffer[Seq[W]] with SynchronizedBuffer[Seq[W]]) + ssc.registerInputStream(inputStream1) + ssc.registerInputStream(inputStream2) + ssc.registerOutputStream(outputStream) + ssc + } + + /** + * Runs the streams set up in `ssc` on manual clock for `numBatches` batches and + * returns the collected output. It will wait until `numExpectedOutput` number of + * output data has been collected or timeout (set by `maxWaitTimeMillis`) is reached. + */ + def runStreams[V: ClassManifest]( + ssc: StreamingContext, + numBatches: Int, + numExpectedOutput: Int + ): Seq[Seq[V]] = { + assert(numBatches > 0, "Number of batches to run stream computation is zero") + assert(numExpectedOutput > 0, "Number of expected outputs after " + numBatches + " is zero") + logInfo("numBatches = " + numBatches + ", numExpectedOutput = " + numExpectedOutput) + + // Get the output buffer + val outputStream = ssc.graph.getOutputStreams.head.asInstanceOf[TestOutputStream[V]] + val output = outputStream.output + + try { + // Start computation + ssc.start() + + // Advance manual clock + val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] + logInfo("Manual clock before advancing = " + clock.time) + if (actuallyWait) { + for (i <- 1 to numBatches) { + logInfo("Actually waiting for " + batchDuration) + clock.addToTime(batchDuration.milliseconds) + Thread.sleep(batchDuration.milliseconds) + } + } else { + clock.addToTime(numBatches * batchDuration.milliseconds) + } + logInfo("Manual clock after advancing = " + clock.time) + + // Wait until expected number of output items have been generated + val startTime = System.currentTimeMillis() + while (output.size < numExpectedOutput && System.currentTimeMillis() - startTime < maxWaitTimeMillis) { + logInfo("output.size = " + output.size + ", numExpectedOutput = " + numExpectedOutput) + Thread.sleep(100) + } + val timeTaken = System.currentTimeMillis() - startTime + + assert(timeTaken < maxWaitTimeMillis, "Operation timed out after " + timeTaken + " ms") + assert(output.size === numExpectedOutput, "Unexpected number of outputs generated") + + Thread.sleep(500) // Give some time for the forgetting old RDDs to complete + } catch { + case e: Exception => e.printStackTrace(); throw e; + } finally { + ssc.stop() + } + output + } + + /** + * Verify whether the output values after running a DStream operation + * is same as the expected output values, by comparing the output + * collections either as lists (order matters) or sets (order does not matter) + */ + def verifyOutput[V: ClassManifest]( + output: Seq[Seq[V]], + expectedOutput: Seq[Seq[V]], + useSet: Boolean + ) { + logInfo("--------------------------------") + logInfo("output.size = " + output.size) + logInfo("output") + output.foreach(x => logInfo("[" + x.mkString(",") + "]")) + logInfo("expected output.size = " + expectedOutput.size) + logInfo("expected output") + expectedOutput.foreach(x => logInfo("[" + x.mkString(",") + "]")) + logInfo("--------------------------------") + + // Match the output with the expected output + assert(output.size === expectedOutput.size, "Number of outputs do not match") + for (i <- 0 until output.size) { + if (useSet) { + assert(output(i).toSet === expectedOutput(i).toSet) + } else { + assert(output(i).toList === expectedOutput(i).toList) + } + } + logInfo("Output verified successfully") + } + + /** + * Test unary DStream operation with a list of inputs, with number of + * batches to run same as the number of expected output values + */ + def testOperation[U: ClassManifest, V: ClassManifest]( + input: Seq[Seq[U]], + operation: DStream[U] => DStream[V], + expectedOutput: Seq[Seq[V]], + useSet: Boolean = false + ) { + testOperation[U, V](input, operation, expectedOutput, -1, useSet) + } + + /** + * Test unary DStream operation with a list of inputs + * @param input Sequence of input collections + * @param operation Binary DStream operation to be applied to the 2 inputs + * @param expectedOutput Sequence of expected output collections + * @param numBatches Number of batches to run the operation for + * @param useSet Compare the output values with the expected output values + * as sets (order matters) or as lists (order does not matter) + */ + def testOperation[U: ClassManifest, V: ClassManifest]( + input: Seq[Seq[U]], + operation: DStream[U] => DStream[V], + expectedOutput: Seq[Seq[V]], + numBatches: Int, + useSet: Boolean + ) { + val numBatches_ = if (numBatches > 0) numBatches else expectedOutput.size + val ssc = setupStreams[U, V](input, operation) + val output = runStreams[V](ssc, numBatches_, expectedOutput.size) + verifyOutput[V](output, expectedOutput, useSet) + } + + /** + * Test binary DStream operation with two lists of inputs, with number of + * batches to run same as the number of expected output values + */ + def testOperation[U: ClassManifest, V: ClassManifest, W: ClassManifest]( + input1: Seq[Seq[U]], + input2: Seq[Seq[V]], + operation: (DStream[U], DStream[V]) => DStream[W], + expectedOutput: Seq[Seq[W]], + useSet: Boolean + ) { + testOperation[U, V, W](input1, input2, operation, expectedOutput, -1, useSet) + } + + /** + * Test binary DStream operation with two lists of inputs + * @param input1 First sequence of input collections + * @param input2 Second sequence of input collections + * @param operation Binary DStream operation to be applied to the 2 inputs + * @param expectedOutput Sequence of expected output collections + * @param numBatches Number of batches to run the operation for + * @param useSet Compare the output values with the expected output values + * as sets (order matters) or as lists (order does not matter) + */ + def testOperation[U: ClassManifest, V: ClassManifest, W: ClassManifest]( + input1: Seq[Seq[U]], + input2: Seq[Seq[V]], + operation: (DStream[U], DStream[V]) => DStream[W], + expectedOutput: Seq[Seq[W]], + numBatches: Int, + useSet: Boolean + ) { + val numBatches_ = if (numBatches > 0) numBatches else expectedOutput.size + val ssc = setupStreams[U, V, W](input1, input2, operation) + val output = runStreams[W](ssc, numBatches_, expectedOutput.size) + verifyOutput[W](output, expectedOutput, useSet) + } +} diff --git a/streaming/src/test/scala/org/apache/spark/streaming/WindowOperationsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/WindowOperationsSuite.scala new file mode 100644 index 0000000000..f50e05c0d8 --- /dev/null +++ b/streaming/src/test/scala/org/apache/spark/streaming/WindowOperationsSuite.scala @@ -0,0 +1,340 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming + +import org.apache.spark.streaming.StreamingContext._ +import collection.mutable.ArrayBuffer + +class WindowOperationsSuite extends TestSuiteBase { + + System.setProperty("spark.streaming.clock", "org.apache.spark.streaming.util.ManualClock") + + override def framework = "WindowOperationsSuite" + + override def maxWaitTimeMillis = 20000 + + override def batchDuration = Seconds(1) + + after { + // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown + System.clearProperty("spark.driver.port") + System.clearProperty("spark.hostPort") + } + + val largerSlideInput = Seq( + Seq(("a", 1)), + Seq(("a", 2)), // 1st window from here + Seq(("a", 3)), + Seq(("a", 4)), // 2nd window from here + Seq(("a", 5)), + Seq(("a", 6)), // 3rd window from here + Seq(), + Seq() // 4th window from here + ) + + val largerSlideReduceOutput = Seq( + Seq(("a", 3)), + Seq(("a", 10)), + Seq(("a", 18)), + Seq(("a", 11)) + ) + + + val bigInput = Seq( + Seq(("a", 1)), + Seq(("a", 1), ("b", 1)), + Seq(("a", 1), ("b", 1), ("c", 1)), + Seq(("a", 1), ("b", 1)), + Seq(("a", 1)), + Seq(), + Seq(("a", 1)), + Seq(("a", 1), ("b", 1)), + Seq(("a", 1), ("b", 1), ("c", 1)), + Seq(("a", 1), ("b", 1)), + Seq(("a", 1)), + Seq() + ) + + val bigGroupByOutput = Seq( + Seq(("a", Seq(1))), + Seq(("a", Seq(1, 1)), ("b", Seq(1))), + Seq(("a", Seq(1, 1)), ("b", Seq(1, 1)), ("c", Seq(1))), + Seq(("a", Seq(1, 1)), ("b", Seq(1, 1)), ("c", Seq(1))), + Seq(("a", Seq(1, 1)), ("b", Seq(1))), + Seq(("a", Seq(1))), + Seq(("a", Seq(1))), + Seq(("a", Seq(1, 1)), ("b", Seq(1))), + Seq(("a", Seq(1, 1)), ("b", Seq(1, 1)), ("c", Seq(1))), + Seq(("a", Seq(1, 1)), ("b", Seq(1, 1)), ("c", Seq(1))), + Seq(("a", Seq(1, 1)), ("b", Seq(1))), + Seq(("a", Seq(1))) + ) + + + val bigReduceOutput = Seq( + Seq(("a", 1)), + Seq(("a", 2), ("b", 1)), + Seq(("a", 2), ("b", 2), ("c", 1)), + Seq(("a", 2), ("b", 2), ("c", 1)), + Seq(("a", 2), ("b", 1)), + Seq(("a", 1)), + Seq(("a", 1)), + Seq(("a", 2), ("b", 1)), + Seq(("a", 2), ("b", 2), ("c", 1)), + Seq(("a", 2), ("b", 2), ("c", 1)), + Seq(("a", 2), ("b", 1)), + Seq(("a", 1)) + ) + + /* + The output of the reduceByKeyAndWindow with inverse function but without a filter + function will be different from the naive reduceByKeyAndWindow, as no keys get + eliminated from the ReducedWindowedDStream even if the value of a key becomes 0. + */ + + val bigReduceInvOutput = Seq( + Seq(("a", 1)), + Seq(("a", 2), ("b", 1)), + Seq(("a", 2), ("b", 2), ("c", 1)), + Seq(("a", 2), ("b", 2), ("c", 1)), + Seq(("a", 2), ("b", 1), ("c", 0)), + Seq(("a", 1), ("b", 0), ("c", 0)), + Seq(("a", 1), ("b", 0), ("c", 0)), + Seq(("a", 2), ("b", 1), ("c", 0)), + Seq(("a", 2), ("b", 2), ("c", 1)), + Seq(("a", 2), ("b", 2), ("c", 1)), + Seq(("a", 2), ("b", 1), ("c", 0)), + Seq(("a", 1), ("b", 0), ("c", 0)) + ) + + // Testing window operation + + testWindow( + "basic window", + Seq( Seq(0), Seq(1), Seq(2), Seq(3), Seq(4), Seq(5)), + Seq( Seq(0), Seq(0, 1), Seq(1, 2), Seq(2, 3), Seq(3, 4), Seq(4, 5)) + ) + + testWindow( + "tumbling window", + Seq( Seq(0), Seq(1), Seq(2), Seq(3), Seq(4), Seq(5)), + Seq( Seq(0, 1), Seq(2, 3), Seq(4, 5)), + Seconds(2), + Seconds(2) + ) + + testWindow( + "larger window", + Seq( Seq(0), Seq(1), Seq(2), Seq(3), Seq(4), Seq(5)), + Seq( Seq(0, 1), Seq(0, 1, 2, 3), Seq(2, 3, 4, 5), Seq(4, 5)), + Seconds(4), + Seconds(2) + ) + + testWindow( + "non-overlapping window", + Seq( Seq(0), Seq(1), Seq(2), Seq(3), Seq(4), Seq(5)), + Seq( Seq(1, 2), Seq(4, 5)), + Seconds(2), + Seconds(3) + ) + + // Testing naive reduceByKeyAndWindow (without invertible function) + + testReduceByKeyAndWindow( + "basic reduction", + Seq( Seq(("a", 1), ("a", 3)) ), + Seq( Seq(("a", 4)) ) + ) + + testReduceByKeyAndWindow( + "key already in window and new value added into window", + Seq( Seq(("a", 1)), Seq(("a", 1)) ), + Seq( Seq(("a", 1)), Seq(("a", 2)) ) + ) + + testReduceByKeyAndWindow( + "new key added into window", + Seq( Seq(("a", 1)), Seq(("a", 1), ("b", 1)) ), + Seq( Seq(("a", 1)), Seq(("a", 2), ("b", 1)) ) + ) + + testReduceByKeyAndWindow( + "key removed from window", + Seq( Seq(("a", 1)), Seq(("a", 1)), Seq(), Seq() ), + Seq( Seq(("a", 1)), Seq(("a", 2)), Seq(("a", 1)), Seq() ) + ) + + testReduceByKeyAndWindow( + "larger slide time", + largerSlideInput, + largerSlideReduceOutput, + Seconds(4), + Seconds(2) + ) + + testReduceByKeyAndWindow("big test", bigInput, bigReduceOutput) + + // Testing reduceByKeyAndWindow (with invertible reduce function) + + testReduceByKeyAndWindowWithInverse( + "basic reduction", + Seq(Seq(("a", 1), ("a", 3)) ), + Seq(Seq(("a", 4)) ) + ) + + testReduceByKeyAndWindowWithInverse( + "key already in window and new value added into window", + Seq( Seq(("a", 1)), Seq(("a", 1)) ), + Seq( Seq(("a", 1)), Seq(("a", 2)) ) + ) + + testReduceByKeyAndWindowWithInverse( + "new key added into window", + Seq( Seq(("a", 1)), Seq(("a", 1), ("b", 1)) ), + Seq( Seq(("a", 1)), Seq(("a", 2), ("b", 1)) ) + ) + + testReduceByKeyAndWindowWithInverse( + "key removed from window", + Seq( Seq(("a", 1)), Seq(("a", 1)), Seq(), Seq() ), + Seq( Seq(("a", 1)), Seq(("a", 2)), Seq(("a", 1)), Seq(("a", 0)) ) + ) + + testReduceByKeyAndWindowWithInverse( + "larger slide time", + largerSlideInput, + largerSlideReduceOutput, + Seconds(4), + Seconds(2) + ) + + testReduceByKeyAndWindowWithInverse("big test", bigInput, bigReduceInvOutput) + + testReduceByKeyAndWindowWithFilteredInverse("big test", bigInput, bigReduceOutput) + + test("groupByKeyAndWindow") { + val input = bigInput + val expectedOutput = bigGroupByOutput.map(_.map(x => (x._1, x._2.toSet))) + val windowDuration = Seconds(2) + val slideDuration = Seconds(1) + val numBatches = expectedOutput.size * (slideDuration / batchDuration).toInt + val operation = (s: DStream[(String, Int)]) => { + s.groupByKeyAndWindow(windowDuration, slideDuration) + .map(x => (x._1, x._2.toSet)) + .persist() + } + testOperation(input, operation, expectedOutput, numBatches, true) + } + + test("countByWindow") { + val input = Seq(Seq(1), Seq(1), Seq(1, 2), Seq(0), Seq(), Seq() ) + val expectedOutput = Seq( Seq(1), Seq(2), Seq(3), Seq(3), Seq(1), Seq(0)) + val windowDuration = Seconds(2) + val slideDuration = Seconds(1) + val numBatches = expectedOutput.size * (slideDuration / batchDuration).toInt + val operation = (s: DStream[Int]) => { + s.countByWindow(windowDuration, slideDuration).map(_.toInt) + } + testOperation(input, operation, expectedOutput, numBatches, true) + } + + test("countByValueAndWindow") { + val input = Seq(Seq("a"), Seq("b", "b"), Seq("a", "b")) + val expectedOutput = Seq( Seq(("a", 1)), Seq(("a", 1), ("b", 2)), Seq(("a", 1), ("b", 3))) + val windowDuration = Seconds(2) + val slideDuration = Seconds(1) + val numBatches = expectedOutput.size * (slideDuration / batchDuration).toInt + val operation = (s: DStream[String]) => { + s.countByValueAndWindow(windowDuration, slideDuration).map(x => (x._1, x._2.toInt)) + } + testOperation(input, operation, expectedOutput, numBatches, true) + } + + + // Helper functions + + def testWindow( + name: String, + input: Seq[Seq[Int]], + expectedOutput: Seq[Seq[Int]], + windowDuration: Duration = Seconds(2), + slideDuration: Duration = Seconds(1) + ) { + test("window - " + name) { + val numBatches = expectedOutput.size * (slideDuration / batchDuration).toInt + val operation = (s: DStream[Int]) => s.window(windowDuration, slideDuration) + testOperation(input, operation, expectedOutput, numBatches, true) + } + } + + def testReduceByKeyAndWindow( + name: String, + input: Seq[Seq[(String, Int)]], + expectedOutput: Seq[Seq[(String, Int)]], + windowDuration: Duration = Seconds(2), + slideDuration: Duration = Seconds(1) + ) { + test("reduceByKeyAndWindow - " + name) { + logInfo("reduceByKeyAndWindow - " + name) + val numBatches = expectedOutput.size * (slideDuration / batchDuration).toInt + val operation = (s: DStream[(String, Int)]) => { + s.reduceByKeyAndWindow((x: Int, y: Int) => x + y, windowDuration, slideDuration) + } + testOperation(input, operation, expectedOutput, numBatches, true) + } + } + + def testReduceByKeyAndWindowWithInverse( + name: String, + input: Seq[Seq[(String, Int)]], + expectedOutput: Seq[Seq[(String, Int)]], + windowDuration: Duration = Seconds(2), + slideDuration: Duration = Seconds(1) + ) { + test("reduceByKeyAndWindow with inverse function - " + name) { + logInfo("reduceByKeyAndWindow with inverse function - " + name) + val numBatches = expectedOutput.size * (slideDuration / batchDuration).toInt + val operation = (s: DStream[(String, Int)]) => { + s.reduceByKeyAndWindow(_ + _, _ - _, windowDuration, slideDuration) + .checkpoint(Seconds(100)) // Large value to avoid effect of RDD checkpointing + } + testOperation(input, operation, expectedOutput, numBatches, true) + } + } + + def testReduceByKeyAndWindowWithFilteredInverse( + name: String, + input: Seq[Seq[(String, Int)]], + expectedOutput: Seq[Seq[(String, Int)]], + windowDuration: Duration = Seconds(2), + slideDuration: Duration = Seconds(1) + ) { + test("reduceByKeyAndWindow with inverse and filter functions - " + name) { + logInfo("reduceByKeyAndWindow with inverse and filter functions - " + name) + val numBatches = expectedOutput.size * (slideDuration / batchDuration).toInt + val filterFunc = (p: (String, Int)) => p._2 != 0 + val operation = (s: DStream[(String, Int)]) => { + s.reduceByKeyAndWindow(_ + _, _ - _, windowDuration, slideDuration, filterFunc = filterFunc) + .persist() + .checkpoint(Seconds(100)) // Large value to avoid effect of RDD checkpointing + } + testOperation(input, operation, expectedOutput, numBatches, true) + } + } +} diff --git a/streaming/src/test/scala/spark/streaming/BasicOperationsSuite.scala b/streaming/src/test/scala/spark/streaming/BasicOperationsSuite.scala deleted file mode 100644 index 67e3e0cd30..0000000000 --- a/streaming/src/test/scala/spark/streaming/BasicOperationsSuite.scala +++ /dev/null @@ -1,322 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.streaming - -import spark.streaming.StreamingContext._ -import scala.runtime.RichInt -import util.ManualClock - -class BasicOperationsSuite extends TestSuiteBase { - - override def framework() = "BasicOperationsSuite" - - before { - System.setProperty("spark.streaming.clock", "spark.streaming.util.ManualClock") - } - - after { - // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown - System.clearProperty("spark.driver.port") - System.clearProperty("spark.hostPort") - } - - test("map") { - val input = Seq(1 to 4, 5 to 8, 9 to 12) - testOperation( - input, - (r: DStream[Int]) => r.map(_.toString), - input.map(_.map(_.toString)) - ) - } - - test("flatMap") { - val input = Seq(1 to 4, 5 to 8, 9 to 12) - testOperation( - input, - (r: DStream[Int]) => r.flatMap(x => Seq(x, x * 2)), - input.map(_.flatMap(x => Array(x, x * 2))) - ) - } - - test("filter") { - val input = Seq(1 to 4, 5 to 8, 9 to 12) - testOperation( - input, - (r: DStream[Int]) => r.filter(x => (x % 2 == 0)), - input.map(_.filter(x => (x % 2 == 0))) - ) - } - - test("glom") { - assert(numInputPartitions === 2, "Number of input partitions has been changed from 2") - val input = Seq(1 to 4, 5 to 8, 9 to 12) - val output = Seq( - Seq( Seq(1, 2), Seq(3, 4) ), - Seq( Seq(5, 6), Seq(7, 8) ), - Seq( Seq(9, 10), Seq(11, 12) ) - ) - val operation = (r: DStream[Int]) => r.glom().map(_.toSeq) - testOperation(input, operation, output) - } - - test("mapPartitions") { - assert(numInputPartitions === 2, "Number of input partitions has been changed from 2") - val input = Seq(1 to 4, 5 to 8, 9 to 12) - val output = Seq(Seq(3, 7), Seq(11, 15), Seq(19, 23)) - val operation = (r: DStream[Int]) => r.mapPartitions(x => Iterator(x.reduce(_ + _))) - testOperation(input, operation, output, true) - } - - test("groupByKey") { - testOperation( - Seq( Seq("a", "a", "b"), Seq("", ""), Seq() ), - (s: DStream[String]) => s.map(x => (x, 1)).groupByKey(), - Seq( Seq(("a", Seq(1, 1)), ("b", Seq(1))), Seq(("", Seq(1, 1))), Seq() ), - true - ) - } - - test("reduceByKey") { - testOperation( - Seq( Seq("a", "a", "b"), Seq("", ""), Seq() ), - (s: DStream[String]) => s.map(x => (x, 1)).reduceByKey(_ + _), - Seq( Seq(("a", 2), ("b", 1)), Seq(("", 2)), Seq() ), - true - ) - } - - test("reduce") { - testOperation( - Seq(1 to 4, 5 to 8, 9 to 12), - (s: DStream[Int]) => s.reduce(_ + _), - Seq(Seq(10), Seq(26), Seq(42)) - ) - } - - test("count") { - testOperation( - Seq(Seq(), 1 to 1, 1 to 2, 1 to 3, 1 to 4), - (s: DStream[Int]) => s.count(), - Seq(Seq(0L), Seq(1L), Seq(2L), Seq(3L), Seq(4L)) - ) - } - - test("countByValue") { - testOperation( - Seq(1 to 1, Seq(1, 1, 1), 1 to 2, Seq(1, 1, 2, 2)), - (s: DStream[Int]) => s.countByValue(), - Seq(Seq((1, 1L)), Seq((1, 3L)), Seq((1, 1L), (2, 1L)), Seq((2, 2L), (1, 2L))), - true - ) - } - - test("mapValues") { - testOperation( - Seq( Seq("a", "a", "b"), Seq("", ""), Seq() ), - (s: DStream[String]) => s.map(x => (x, 1)).reduceByKey(_ + _).mapValues(_ + 10), - Seq( Seq(("a", 12), ("b", 11)), Seq(("", 12)), Seq() ), - true - ) - } - - test("flatMapValues") { - testOperation( - Seq( Seq("a", "a", "b"), Seq("", ""), Seq() ), - (s: DStream[String]) => s.map(x => (x, 1)).reduceByKey(_ + _).flatMapValues(x => Seq(x, x + 10)), - Seq( Seq(("a", 2), ("a", 12), ("b", 1), ("b", 11)), Seq(("", 2), ("", 12)), Seq() ), - true - ) - } - - test("cogroup") { - val inputData1 = Seq( Seq("a", "a", "b"), Seq("a", ""), Seq(""), Seq() ) - val inputData2 = Seq( Seq("a", "a", "b"), Seq("b", ""), Seq(), Seq() ) - val outputData = Seq( - Seq( ("a", (Seq(1, 1), Seq("x", "x"))), ("b", (Seq(1), Seq("x"))) ), - Seq( ("a", (Seq(1), Seq())), ("b", (Seq(), Seq("x"))), ("", (Seq(1), Seq("x"))) ), - Seq( ("", (Seq(1), Seq())) ), - Seq( ) - ) - val operation = (s1: DStream[String], s2: DStream[String]) => { - s1.map(x => (x,1)).cogroup(s2.map(x => (x, "x"))) - } - testOperation(inputData1, inputData2, operation, outputData, true) - } - - test("join") { - val inputData1 = Seq( Seq("a", "b"), Seq("a", ""), Seq(""), Seq() ) - val inputData2 = Seq( Seq("a", "b"), Seq("b", ""), Seq(), Seq("") ) - val outputData = Seq( - Seq( ("a", (1, "x")), ("b", (1, "x")) ), - Seq( ("", (1, "x")) ), - Seq( ), - Seq( ) - ) - val operation = (s1: DStream[String], s2: DStream[String]) => { - s1.map(x => (x,1)).join(s2.map(x => (x,"x"))) - } - testOperation(inputData1, inputData2, operation, outputData, true) - } - - test("updateStateByKey") { - val inputData = - Seq( - Seq("a"), - Seq("a", "b"), - Seq("a", "b", "c"), - Seq("a", "b"), - Seq("a"), - Seq() - ) - - val outputData = - Seq( - Seq(("a", 1)), - Seq(("a", 2), ("b", 1)), - Seq(("a", 3), ("b", 2), ("c", 1)), - Seq(("a", 4), ("b", 3), ("c", 1)), - Seq(("a", 5), ("b", 3), ("c", 1)), - Seq(("a", 5), ("b", 3), ("c", 1)) - ) - - val updateStateOperation = (s: DStream[String]) => { - val updateFunc = (values: Seq[Int], state: Option[Int]) => { - Some(values.foldLeft(0)(_ + _) + state.getOrElse(0)) - } - s.map(x => (x, 1)).updateStateByKey[Int](updateFunc) - } - - testOperation(inputData, updateStateOperation, outputData, true) - } - - test("updateStateByKey - object lifecycle") { - val inputData = - Seq( - Seq("a","b"), - null, - Seq("a","c","a"), - Seq("c"), - null, - null - ) - - val outputData = - Seq( - Seq(("a", 1), ("b", 1)), - Seq(("a", 1), ("b", 1)), - Seq(("a", 3), ("c", 1)), - Seq(("a", 3), ("c", 2)), - Seq(("c", 2)), - Seq() - ) - - val updateStateOperation = (s: DStream[String]) => { - class StateObject(var counter: Int = 0, var expireCounter: Int = 0) extends Serializable - - // updateFunc clears a state when a StateObject is seen without new values twice in a row - val updateFunc = (values: Seq[Int], state: Option[StateObject]) => { - val stateObj = state.getOrElse(new StateObject) - values.foldLeft(0)(_ + _) match { - case 0 => stateObj.expireCounter += 1 // no new values - case n => { // has new values, increment and reset expireCounter - stateObj.counter += n - stateObj.expireCounter = 0 - } - } - stateObj.expireCounter match { - case 2 => None // seen twice with no new values, give it the boot - case _ => Option(stateObj) - } - } - s.map(x => (x, 1)).updateStateByKey[StateObject](updateFunc).mapValues(_.counter) - } - - testOperation(inputData, updateStateOperation, outputData, true) - } - - test("slice") { - val ssc = new StreamingContext("local[2]", "BasicOperationSuite", Seconds(1)) - val input = Seq(Seq(1), Seq(2), Seq(3), Seq(4)) - val stream = new TestInputStream[Int](ssc, input, 2) - ssc.registerInputStream(stream) - stream.foreach(_ => {}) // Dummy output stream - ssc.start() - Thread.sleep(2000) - def getInputFromSlice(fromMillis: Long, toMillis: Long) = { - stream.slice(new Time(fromMillis), new Time(toMillis)).flatMap(_.collect()).toSet - } - - assert(getInputFromSlice(0, 1000) == Set(1)) - assert(getInputFromSlice(0, 2000) == Set(1, 2)) - assert(getInputFromSlice(1000, 2000) == Set(1, 2)) - assert(getInputFromSlice(2000, 4000) == Set(2, 3, 4)) - ssc.stop() - Thread.sleep(1000) - } - - test("forgetting of RDDs - map and window operations") { - assert(batchDuration === Seconds(1), "Batch duration has changed from 1 second") - - val input = (0 until 10).map(x => Seq(x, x + 1)).toSeq - val rememberDuration = Seconds(3) - - assert(input.size === 10, "Number of inputs have changed") - - def operation(s: DStream[Int]): DStream[(Int, Int)] = { - s.map(x => (x % 10, 1)) - .window(Seconds(2), Seconds(1)) - .window(Seconds(4), Seconds(2)) - } - - val ssc = setupStreams(input, operation _) - ssc.remember(rememberDuration) - runStreams[(Int, Int)](ssc, input.size, input.size / 2) - - val windowedStream2 = ssc.graph.getOutputStreams().head.dependencies.head - val windowedStream1 = windowedStream2.dependencies.head - val mappedStream = windowedStream1.dependencies.head - - val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] - assert(clock.time === Seconds(10).milliseconds) - - // IDEALLY - // WindowedStream2 should remember till 7 seconds: 10, 8, - // WindowedStream1 should remember till 4 seconds: 10, 9, 8, 7, 6, 5 - // MappedStream should remember till 7 seconds: 10, 9, 8, 7, 6, 5, 4, 3, - - // IN THIS TEST - // WindowedStream2 should remember till 7 seconds: 10, 8, - // WindowedStream1 should remember till 4 seconds: 10, 9, 8, 7, 6, 5, 4 - // MappedStream should remember till 7 seconds: 10, 9, 8, 7, 6, 5, 4, 3, 2 - - // WindowedStream2 - assert(windowedStream2.generatedRDDs.contains(Time(10000))) - assert(windowedStream2.generatedRDDs.contains(Time(8000))) - assert(!windowedStream2.generatedRDDs.contains(Time(6000))) - - // WindowedStream1 - assert(windowedStream1.generatedRDDs.contains(Time(10000))) - assert(windowedStream1.generatedRDDs.contains(Time(4000))) - assert(!windowedStream1.generatedRDDs.contains(Time(3000))) - - // MappedStream - assert(mappedStream.generatedRDDs.contains(Time(10000))) - assert(mappedStream.generatedRDDs.contains(Time(2000))) - assert(!mappedStream.generatedRDDs.contains(Time(1000))) - } -} diff --git a/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala b/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala deleted file mode 100644 index 8c639648f0..0000000000 --- a/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala +++ /dev/null @@ -1,372 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.streaming - -import dstream.FileInputDStream -import spark.streaming.StreamingContext._ -import java.io.File -import runtime.RichInt -import org.scalatest.BeforeAndAfter -import org.apache.commons.io.FileUtils -import collection.mutable.{SynchronizedBuffer, ArrayBuffer} -import util.{Clock, ManualClock} -import scala.util.Random -import com.google.common.io.Files - - -/** - * This test suites tests the checkpointing functionality of DStreams - - * the checkpointing of a DStream's RDDs as well as the checkpointing of - * the whole DStream graph. - */ -class CheckpointSuite extends TestSuiteBase with BeforeAndAfter { - - System.setProperty("spark.streaming.clock", "spark.streaming.util.ManualClock") - - before { - FileUtils.deleteDirectory(new File(checkpointDir)) - } - - after { - if (ssc != null) ssc.stop() - FileUtils.deleteDirectory(new File(checkpointDir)) - - // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown - System.clearProperty("spark.driver.port") - System.clearProperty("spark.hostPort") - } - - var ssc: StreamingContext = null - - override def framework = "CheckpointSuite" - - override def batchDuration = Milliseconds(500) - - override def actuallyWait = true - - test("basic rdd checkpoints + dstream graph checkpoint recovery") { - - assert(batchDuration === Milliseconds(500), "batchDuration for this test must be 1 second") - - System.setProperty("spark.streaming.clock", "spark.streaming.util.ManualClock") - - val stateStreamCheckpointInterval = Seconds(1) - - // this ensure checkpointing occurs at least once - val firstNumBatches = (stateStreamCheckpointInterval / batchDuration).toLong * 2 - val secondNumBatches = firstNumBatches - - // Setup the streams - val input = (1 to 10).map(_ => Seq("a")).toSeq - val operation = (st: DStream[String]) => { - val updateFunc = (values: Seq[Int], state: Option[RichInt]) => { - Some(new RichInt(values.foldLeft(0)(_ + _) + state.map(_.self).getOrElse(0))) - } - st.map(x => (x, 1)) - .updateStateByKey[RichInt](updateFunc) - .checkpoint(stateStreamCheckpointInterval) - .map(t => (t._1, t._2.self)) - } - var ssc = setupStreams(input, operation) - var stateStream = ssc.graph.getOutputStreams().head.dependencies.head.dependencies.head - - // Run till a time such that at least one RDD in the stream should have been checkpointed, - // then check whether some RDD has been checkpointed or not - ssc.start() - advanceTimeWithRealDelay(ssc, firstNumBatches) - logInfo("Checkpoint data of state stream = \n" + stateStream.checkpointData) - assert(!stateStream.checkpointData.checkpointFiles.isEmpty, "No checkpointed RDDs in state stream before first failure") - stateStream.checkpointData.checkpointFiles.foreach { - case (time, data) => { - val file = new File(data.toString) - assert(file.exists(), "Checkpoint file '" + file +"' for time " + time + " for state stream before first failure does not exist") - } - } - - // Run till a further time such that previous checkpoint files in the stream would be deleted - // and check whether the earlier checkpoint files are deleted - val checkpointFiles = stateStream.checkpointData.checkpointFiles.map(x => new File(x._2)) - advanceTimeWithRealDelay(ssc, secondNumBatches) - checkpointFiles.foreach(file => assert(!file.exists, "Checkpoint file '" + file + "' was not deleted")) - ssc.stop() - - // Restart stream computation using the checkpoint file and check whether - // checkpointed RDDs have been restored or not - ssc = new StreamingContext(checkpointDir) - stateStream = ssc.graph.getOutputStreams().head.dependencies.head.dependencies.head - logInfo("Restored data of state stream = \n[" + stateStream.generatedRDDs.mkString("\n") + "]") - assert(!stateStream.generatedRDDs.isEmpty, "No restored RDDs in state stream after recovery from first failure") - - - // Run one batch to generate a new checkpoint file and check whether some RDD - // is present in the checkpoint data or not - ssc.start() - advanceTimeWithRealDelay(ssc, 1) - assert(!stateStream.checkpointData.checkpointFiles.isEmpty, "No checkpointed RDDs in state stream before second failure") - stateStream.checkpointData.checkpointFiles.foreach { - case (time, data) => { - val file = new File(data.toString) - assert(file.exists(), - "Checkpoint file '" + file +"' for time " + time + " for state stream before seconds failure does not exist") - } - } - ssc.stop() - - // Restart stream computation from the new checkpoint file to see whether that file has - // correct checkpoint data - ssc = new StreamingContext(checkpointDir) - stateStream = ssc.graph.getOutputStreams().head.dependencies.head.dependencies.head - logInfo("Restored data of state stream = \n[" + stateStream.generatedRDDs.mkString("\n") + "]") - assert(!stateStream.generatedRDDs.isEmpty, "No restored RDDs in state stream after recovery from second failure") - - // Adjust manual clock time as if it is being restarted after a delay - System.setProperty("spark.streaming.manualClock.jump", (batchDuration.milliseconds * 7).toString) - ssc.start() - advanceTimeWithRealDelay(ssc, 4) - ssc.stop() - System.clearProperty("spark.streaming.manualClock.jump") - ssc = null - } - - // This tests whether the systm can recover from a master failure with simple - // non-stateful operations. This assumes as reliable, replayable input - // source - TestInputDStream. - test("recovery with map and reduceByKey operations") { - testCheckpointedOperation( - Seq( Seq("a", "a", "b"), Seq("", ""), Seq(), Seq("a", "a", "b"), Seq("", ""), Seq() ), - (s: DStream[String]) => s.map(x => (x, 1)).reduceByKey(_ + _), - Seq( Seq(("a", 2), ("b", 1)), Seq(("", 2)), Seq(), Seq(("a", 2), ("b", 1)), Seq(("", 2)), Seq() ), - 3 - ) - } - - - // This tests whether the ReduceWindowedDStream's RDD checkpoints works correctly such - // that the system can recover from a master failure. This assumes as reliable, - // replayable input source - TestInputDStream. - test("recovery with invertible reduceByKeyAndWindow operation") { - val n = 10 - val w = 4 - val input = (1 to n).map(_ => Seq("a")).toSeq - val output = Seq(Seq(("a", 1)), Seq(("a", 2)), Seq(("a", 3))) ++ (1 to (n - w + 1)).map(x => Seq(("a", 4))) - val operation = (st: DStream[String]) => { - st.map(x => (x, 1)) - .reduceByKeyAndWindow(_ + _, _ - _, batchDuration * w, batchDuration) - .checkpoint(batchDuration * 2) - } - testCheckpointedOperation(input, operation, output, 7) - } - - - // This tests whether the StateDStream's RDD checkpoints works correctly such - // that the system can recover from a master failure. This assumes as reliable, - // replayable input source - TestInputDStream. - test("recovery with updateStateByKey operation") { - val input = (1 to 10).map(_ => Seq("a")).toSeq - val output = (1 to 10).map(x => Seq(("a", x))).toSeq - val operation = (st: DStream[String]) => { - val updateFunc = (values: Seq[Int], state: Option[RichInt]) => { - Some(new RichInt(values.foldLeft(0)(_ + _) + state.map(_.self).getOrElse(0))) - } - st.map(x => (x, 1)) - .updateStateByKey[RichInt](updateFunc) - .checkpoint(batchDuration * 2) - .map(t => (t._1, t._2.self)) - } - testCheckpointedOperation(input, operation, output, 7) - } - - // This tests whether file input stream remembers what files were seen before - // the master failure and uses them again to process a large window operation. - // It also tests whether batches, whose processing was incomplete due to the - // failure, are re-processed or not. - test("recovery with file input stream") { - // Disable manual clock as FileInputDStream does not work with manual clock - val clockProperty = System.getProperty("spark.streaming.clock") - System.clearProperty("spark.streaming.clock") - - // Set up the streaming context and input streams - val testDir = Files.createTempDir() - var ssc = new StreamingContext(master, framework, Seconds(1)) - ssc.checkpoint(checkpointDir) - val fileStream = ssc.textFileStream(testDir.toString) - // Making value 3 take large time to process, to ensure that the master - // shuts down in the middle of processing the 3rd batch - val mappedStream = fileStream.map(s => { - val i = s.toInt - if (i == 3) Thread.sleep(2000) - i - }) - - // Reducing over a large window to ensure that recovery from master failure - // requires reprocessing of all the files seen before the failure - val reducedStream = mappedStream.reduceByWindow(_ + _, Seconds(30), Seconds(1)) - val outputBuffer = new ArrayBuffer[Seq[Int]] - var outputStream = new TestOutputStream(reducedStream, outputBuffer) - ssc.registerOutputStream(outputStream) - ssc.start() - - // Create files and advance manual clock to process them - //var clock = ssc.scheduler.clock.asInstanceOf[ManualClock] - Thread.sleep(1000) - for (i <- Seq(1, 2, 3)) { - FileUtils.writeStringToFile(new File(testDir, i.toString), i.toString + "\n") - // wait to make sure that the file is written such that it gets shown in the file listings - Thread.sleep(1000) - } - logInfo("Output = " + outputStream.output.mkString(",")) - assert(outputStream.output.size > 0, "No files processed before restart") - ssc.stop() - - // Verify whether files created have been recorded correctly or not - var fileInputDStream = ssc.graph.getInputStreams().head.asInstanceOf[FileInputDStream[_, _, _]] - def recordedFiles = fileInputDStream.files.values.flatMap(x => x) - assert(!recordedFiles.filter(_.endsWith("1")).isEmpty) - assert(!recordedFiles.filter(_.endsWith("2")).isEmpty) - assert(!recordedFiles.filter(_.endsWith("3")).isEmpty) - - // Create files while the master is down - for (i <- Seq(4, 5, 6)) { - FileUtils.writeStringToFile(new File(testDir, i.toString), i.toString + "\n") - Thread.sleep(1000) - } - - // Recover context from checkpoint file and verify whether the files that were - // recorded before failure were saved and successfully recovered - logInfo("*********** RESTARTING ************") - ssc = new StreamingContext(checkpointDir) - fileInputDStream = ssc.graph.getInputStreams().head.asInstanceOf[FileInputDStream[_, _, _]] - assert(!recordedFiles.filter(_.endsWith("1")).isEmpty) - assert(!recordedFiles.filter(_.endsWith("2")).isEmpty) - assert(!recordedFiles.filter(_.endsWith("3")).isEmpty) - - // Restart stream computation - ssc.start() - for (i <- Seq(7, 8, 9)) { - FileUtils.writeStringToFile(new File(testDir, i.toString), i.toString + "\n") - Thread.sleep(1000) - } - Thread.sleep(1000) - logInfo("Output = " + outputStream.output.mkString("[", ", ", "]")) - assert(outputStream.output.size > 0, "No files processed after restart") - ssc.stop() - - // Verify whether files created while the driver was down have been recorded or not - assert(!recordedFiles.filter(_.endsWith("4")).isEmpty) - assert(!recordedFiles.filter(_.endsWith("5")).isEmpty) - assert(!recordedFiles.filter(_.endsWith("6")).isEmpty) - - // Verify whether new files created after recover have been recorded or not - assert(!recordedFiles.filter(_.endsWith("7")).isEmpty) - assert(!recordedFiles.filter(_.endsWith("8")).isEmpty) - assert(!recordedFiles.filter(_.endsWith("9")).isEmpty) - - // Append the new output to the old buffer - outputStream = ssc.graph.getOutputStreams().head.asInstanceOf[TestOutputStream[Int]] - outputBuffer ++= outputStream.output - - val expectedOutput = Seq(1, 3, 6, 10, 15, 21, 28, 36, 45) - logInfo("--------------------------------") - logInfo("output, size = " + outputBuffer.size) - outputBuffer.foreach(x => logInfo("[" + x.mkString(",") + "]")) - logInfo("expected output, size = " + expectedOutput.size) - expectedOutput.foreach(x => logInfo("[" + x + "]")) - logInfo("--------------------------------") - - // Verify whether all the elements received are as expected - val output = outputBuffer.flatMap(x => x) - assert(output.contains(6)) // To ensure that the 3rd input (i.e., 3) was processed - output.foreach(o => // To ensure all the inputs are correctly added cumulatively - assert(expectedOutput.contains(o), "Expected value " + o + " not found") - ) - // To ensure that all the inputs were received correctly - assert(expectedOutput.last === output.last) - - // Enable manual clock back again for other tests - if (clockProperty != null) - System.setProperty("spark.streaming.clock", clockProperty) - } - - - /** - * Tests a streaming operation under checkpointing, by restarting the operation - * from checkpoint file and verifying whether the final output is correct. - * The output is assumed to have come from a reliable queue which an replay - * data as required. - * - * NOTE: This takes into consideration that the last batch processed before - * master failure will be re-processed after restart/recovery. - */ - def testCheckpointedOperation[U: ClassManifest, V: ClassManifest]( - input: Seq[Seq[U]], - operation: DStream[U] => DStream[V], - expectedOutput: Seq[Seq[V]], - initialNumBatches: Int - ) { - - // Current code assumes that: - // number of inputs = number of outputs = number of batches to be run - val totalNumBatches = input.size - val nextNumBatches = totalNumBatches - initialNumBatches - val initialNumExpectedOutputs = initialNumBatches - val nextNumExpectedOutputs = expectedOutput.size - initialNumExpectedOutputs + 1 - // because the last batch will be processed again - - // Do the computation for initial number of batches, create checkpoint file and quit - ssc = setupStreams[U, V](input, operation) - ssc.start() - val output = advanceTimeWithRealDelay[V](ssc, initialNumBatches) - ssc.stop() - verifyOutput[V](output, expectedOutput.take(initialNumBatches), true) - Thread.sleep(1000) - - // Restart and complete the computation from checkpoint file - logInfo( - "\n-------------------------------------------\n" + - " Restarting stream computation " + - "\n-------------------------------------------\n" - ) - ssc = new StreamingContext(checkpointDir) - System.clearProperty("spark.driver.port") - System.clearProperty("spark.hostPort") - ssc.start() - val outputNew = advanceTimeWithRealDelay[V](ssc, nextNumBatches) - // the first element will be re-processed data of the last batch before restart - verifyOutput[V](outputNew, expectedOutput.takeRight(nextNumExpectedOutputs), true) - ssc.stop() - ssc = null - } - - /** - * Advances the manual clock on the streaming scheduler by given number of batches. - * It also waits for the expected amount of time for each batch. - */ - def advanceTimeWithRealDelay[V: ClassManifest](ssc: StreamingContext, numBatches: Long): Seq[Seq[V]] = { - val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] - logInfo("Manual clock before advancing = " + clock.time) - for (i <- 1 to numBatches.toInt) { - clock.addToTime(batchDuration.milliseconds) - Thread.sleep(batchDuration.milliseconds) - } - logInfo("Manual clock after advancing = " + clock.time) - Thread.sleep(batchDuration.milliseconds) - - val outputStream = ssc.graph.getOutputStreams.head.asInstanceOf[TestOutputStream[V]] - outputStream.output - } -} diff --git a/streaming/src/test/scala/spark/streaming/FailureSuite.scala b/streaming/src/test/scala/spark/streaming/FailureSuite.scala deleted file mode 100644 index 7fc649fe27..0000000000 --- a/streaming/src/test/scala/spark/streaming/FailureSuite.scala +++ /dev/null @@ -1,57 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.streaming - -import spark.Logging -import spark.streaming.util.MasterFailureTest -import StreamingContext._ - -import org.scalatest.{FunSuite, BeforeAndAfter} -import com.google.common.io.Files -import java.io.File -import org.apache.commons.io.FileUtils -import collection.mutable.ArrayBuffer - - -/** - * This testsuite tests master failures at random times while the stream is running using - * the real clock. - */ -class FailureSuite extends FunSuite with BeforeAndAfter with Logging { - - var directory = "FailureSuite" - val numBatches = 30 - val batchDuration = Milliseconds(1000) - - before { - FileUtils.deleteDirectory(new File(directory)) - } - - after { - FileUtils.deleteDirectory(new File(directory)) - } - - test("multiple failures with map") { - MasterFailureTest.testMap(directory, numBatches, batchDuration) - } - - test("multiple failures with updateStateByKey") { - MasterFailureTest.testUpdateStateByKey(directory, numBatches, batchDuration) - } -} - diff --git a/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala b/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala deleted file mode 100644 index 1c5419b16d..0000000000 --- a/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala +++ /dev/null @@ -1,349 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.streaming - -import akka.actor.Actor -import akka.actor.IO -import akka.actor.IOManager -import akka.actor.Props -import akka.util.ByteString - -import dstream.SparkFlumeEvent -import java.net.{InetSocketAddress, SocketException, Socket, ServerSocket} -import java.io.{File, BufferedWriter, OutputStreamWriter} -import java.util.concurrent.{TimeUnit, ArrayBlockingQueue} -import collection.mutable.{SynchronizedBuffer, ArrayBuffer} -import util.ManualClock -import spark.storage.StorageLevel -import spark.streaming.receivers.Receiver -import spark.Logging -import scala.util.Random -import org.apache.commons.io.FileUtils -import org.scalatest.BeforeAndAfter -import org.apache.flume.source.avro.AvroSourceProtocol -import org.apache.flume.source.avro.AvroFlumeEvent -import org.apache.flume.source.avro.Status -import org.apache.avro.ipc.{specific, NettyTransceiver} -import org.apache.avro.ipc.specific.SpecificRequestor -import java.nio.ByteBuffer -import collection.JavaConversions._ -import java.nio.charset.Charset -import com.google.common.io.Files - -class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { - - val testPort = 9999 - - override def checkpointDir = "checkpoint" - - before { - System.setProperty("spark.streaming.clock", "spark.streaming.util.ManualClock") - } - - after { - // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown - System.clearProperty("spark.driver.port") - System.clearProperty("spark.hostPort") - } - - - test("socket input stream") { - // Start the server - val testServer = new TestServer() - testServer.start() - - // Set up the streaming context and input streams - val ssc = new StreamingContext(master, framework, batchDuration) - val networkStream = ssc.socketTextStream("localhost", testServer.port, StorageLevel.MEMORY_AND_DISK) - val outputBuffer = new ArrayBuffer[Seq[String]] with SynchronizedBuffer[Seq[String ]] - val outputStream = new TestOutputStream(networkStream, outputBuffer) - def output = outputBuffer.flatMap(x => x) - ssc.registerOutputStream(outputStream) - ssc.start() - - // Feed data to the server to send to the network receiver - val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] - val input = Seq(1, 2, 3, 4, 5) - val expectedOutput = input.map(_.toString) - Thread.sleep(1000) - for (i <- 0 until input.size) { - testServer.send(input(i).toString + "\n") - Thread.sleep(500) - clock.addToTime(batchDuration.milliseconds) - } - Thread.sleep(1000) - logInfo("Stopping server") - testServer.stop() - logInfo("Stopping context") - ssc.stop() - - // Verify whether data received was as expected - logInfo("--------------------------------") - logInfo("output.size = " + outputBuffer.size) - logInfo("output") - outputBuffer.foreach(x => logInfo("[" + x.mkString(",") + "]")) - logInfo("expected output.size = " + expectedOutput.size) - logInfo("expected output") - expectedOutput.foreach(x => logInfo("[" + x.mkString(",") + "]")) - logInfo("--------------------------------") - - // Verify whether all the elements received are as expected - // (whether the elements were received one in each interval is not verified) - assert(output.size === expectedOutput.size) - for (i <- 0 until output.size) { - assert(output(i) === expectedOutput(i)) - } - } - - - test("flume input stream") { - // Set up the streaming context and input streams - val ssc = new StreamingContext(master, framework, batchDuration) - val flumeStream = ssc.flumeStream("localhost", testPort, StorageLevel.MEMORY_AND_DISK) - val outputBuffer = new ArrayBuffer[Seq[SparkFlumeEvent]] - with SynchronizedBuffer[Seq[SparkFlumeEvent]] - val outputStream = new TestOutputStream(flumeStream, outputBuffer) - ssc.registerOutputStream(outputStream) - ssc.start() - - val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] - val input = Seq(1, 2, 3, 4, 5) - Thread.sleep(1000) - val transceiver = new NettyTransceiver(new InetSocketAddress("localhost", testPort)); - val client = SpecificRequestor.getClient( - classOf[AvroSourceProtocol], transceiver); - - for (i <- 0 until input.size) { - val event = new AvroFlumeEvent - event.setBody(ByteBuffer.wrap(input(i).toString.getBytes())) - event.setHeaders(Map[CharSequence, CharSequence]("test" -> "header")) - client.append(event) - Thread.sleep(500) - clock.addToTime(batchDuration.milliseconds) - } - - val startTime = System.currentTimeMillis() - while (outputBuffer.size < input.size && System.currentTimeMillis() - startTime < maxWaitTimeMillis) { - logInfo("output.size = " + outputBuffer.size + ", input.size = " + input.size) - Thread.sleep(100) - } - Thread.sleep(1000) - val timeTaken = System.currentTimeMillis() - startTime - assert(timeTaken < maxWaitTimeMillis, "Operation timed out after " + timeTaken + " ms") - logInfo("Stopping context") - ssc.stop() - - val decoder = Charset.forName("UTF-8").newDecoder() - - assert(outputBuffer.size === input.length) - for (i <- 0 until outputBuffer.size) { - assert(outputBuffer(i).size === 1) - val str = decoder.decode(outputBuffer(i).head.event.getBody) - assert(str.toString === input(i).toString) - assert(outputBuffer(i).head.event.getHeaders.get("test") === "header") - } - } - - - test("file input stream") { - // Disable manual clock as FileInputDStream does not work with manual clock - System.clearProperty("spark.streaming.clock") - - // Set up the streaming context and input streams - val testDir = Files.createTempDir() - val ssc = new StreamingContext(master, framework, batchDuration) - val fileStream = ssc.textFileStream(testDir.toString) - val outputBuffer = new ArrayBuffer[Seq[String]] with SynchronizedBuffer[Seq[String]] - def output = outputBuffer.flatMap(x => x) - val outputStream = new TestOutputStream(fileStream, outputBuffer) - ssc.registerOutputStream(outputStream) - ssc.start() - - // Create files in the temporary directory so that Spark Streaming can read data from it - val input = Seq(1, 2, 3, 4, 5) - val expectedOutput = input.map(_.toString) - Thread.sleep(1000) - for (i <- 0 until input.size) { - val file = new File(testDir, i.toString) - FileUtils.writeStringToFile(file, input(i).toString + "\n") - logInfo("Created file " + file) - Thread.sleep(batchDuration.milliseconds) - Thread.sleep(1000) - } - val startTime = System.currentTimeMillis() - Thread.sleep(1000) - val timeTaken = System.currentTimeMillis() - startTime - assert(timeTaken < maxWaitTimeMillis, "Operation timed out after " + timeTaken + " ms") - logInfo("Stopping context") - ssc.stop() - - // Verify whether data received by Spark Streaming was as expected - logInfo("--------------------------------") - logInfo("output, size = " + outputBuffer.size) - outputBuffer.foreach(x => logInfo("[" + x.mkString(",") + "]")) - logInfo("expected output, size = " + expectedOutput.size) - expectedOutput.foreach(x => logInfo("[" + x.mkString(",") + "]")) - logInfo("--------------------------------") - - // Verify whether all the elements received are as expected - // (whether the elements were received one in each interval is not verified) - assert(output.toList === expectedOutput.toList) - - FileUtils.deleteDirectory(testDir) - - // Enable manual clock back again for other tests - System.setProperty("spark.streaming.clock", "spark.streaming.util.ManualClock") - } - - - test("actor input stream") { - // Start the server - val testServer = new TestServer() - val port = testServer.port - testServer.start() - - // Set up the streaming context and input streams - val ssc = new StreamingContext(master, framework, batchDuration) - val networkStream = ssc.actorStream[String](Props(new TestActor(port)), "TestActor", - StorageLevel.MEMORY_AND_DISK) //Had to pass the local value of port to prevent from closing over entire scope - val outputBuffer = new ArrayBuffer[Seq[String]] with SynchronizedBuffer[Seq[String]] - val outputStream = new TestOutputStream(networkStream, outputBuffer) - def output = outputBuffer.flatMap(x => x) - ssc.registerOutputStream(outputStream) - ssc.start() - - // Feed data to the server to send to the network receiver - val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] - val input = 1 to 9 - val expectedOutput = input.map(x => x.toString) - Thread.sleep(1000) - for (i <- 0 until input.size) { - testServer.send(input(i).toString) - Thread.sleep(500) - clock.addToTime(batchDuration.milliseconds) - } - Thread.sleep(1000) - logInfo("Stopping server") - testServer.stop() - logInfo("Stopping context") - ssc.stop() - - // Verify whether data received was as expected - logInfo("--------------------------------") - logInfo("output.size = " + outputBuffer.size) - logInfo("output") - outputBuffer.foreach(x => logInfo("[" + x.mkString(",") + "]")) - logInfo("expected output.size = " + expectedOutput.size) - logInfo("expected output") - expectedOutput.foreach(x => logInfo("[" + x.mkString(",") + "]")) - logInfo("--------------------------------") - - // Verify whether all the elements received are as expected - // (whether the elements were received one in each interval is not verified) - assert(output.size === expectedOutput.size) - for (i <- 0 until output.size) { - assert(output(i) === expectedOutput(i)) - } - } - - test("kafka input stream") { - val ssc = new StreamingContext(master, framework, batchDuration) - val topics = Map("my-topic" -> 1) - val test1 = ssc.kafkaStream("localhost:12345", "group", topics) - val test2 = ssc.kafkaStream("localhost:12345", "group", topics, StorageLevel.MEMORY_AND_DISK) - - // Test specifying decoder - val kafkaParams = Map("zk.connect"->"localhost:12345","groupid"->"consumer-group") - val test3 = ssc.kafkaStream[String, kafka.serializer.StringDecoder](kafkaParams, topics, StorageLevel.MEMORY_AND_DISK) - } -} - - -/** This is server to test the network input stream */ -class TestServer() extends Logging { - - val queue = new ArrayBlockingQueue[String](100) - - val serverSocket = new ServerSocket(0) - - val servingThread = new Thread() { - override def run() { - try { - while(true) { - logInfo("Accepting connections on port " + port) - val clientSocket = serverSocket.accept() - logInfo("New connection") - try { - clientSocket.setTcpNoDelay(true) - val outputStream = new BufferedWriter(new OutputStreamWriter(clientSocket.getOutputStream)) - - while(clientSocket.isConnected) { - val msg = queue.poll(100, TimeUnit.MILLISECONDS) - if (msg != null) { - outputStream.write(msg) - outputStream.flush() - logInfo("Message '" + msg + "' sent") - } - } - } catch { - case e: SocketException => logError("TestServer error", e) - } finally { - logInfo("Connection closed") - if (!clientSocket.isClosed) clientSocket.close() - } - } - } catch { - case ie: InterruptedException => - - } finally { - serverSocket.close() - } - } - } - - def start() { servingThread.start() } - - def send(msg: String) { queue.add(msg) } - - def stop() { servingThread.interrupt() } - - def port = serverSocket.getLocalPort -} - -object TestServer { - def main(args: Array[String]) { - val s = new TestServer() - s.start() - while(true) { - Thread.sleep(1000) - s.send("hello") - } - } -} - -class TestActor(port: Int) extends Actor with Receiver { - - def bytesToString(byteString: ByteString) = byteString.utf8String - - override def preStart = IOManager(context.system).connect(new InetSocketAddress(port)) - - def receive = { - case IO.Read(socket, bytes) => - pushBlock(bytesToString(bytes)) - } -} diff --git a/streaming/src/test/scala/spark/streaming/TestSuiteBase.scala b/streaming/src/test/scala/spark/streaming/TestSuiteBase.scala deleted file mode 100644 index cb34b5a7cc..0000000000 --- a/streaming/src/test/scala/spark/streaming/TestSuiteBase.scala +++ /dev/null @@ -1,314 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.streaming - -import spark.streaming.dstream.{InputDStream, ForEachDStream} -import spark.streaming.util.ManualClock - -import spark.{RDD, Logging} - -import collection.mutable.ArrayBuffer -import collection.mutable.SynchronizedBuffer - -import java.io.{ObjectInputStream, IOException} - -import org.scalatest.{BeforeAndAfter, FunSuite} - -/** - * This is a input stream just for the testsuites. This is equivalent to a checkpointable, - * replayable, reliable message queue like Kafka. It requires a sequence as input, and - * returns the i_th element at the i_th batch unde manual clock. - */ -class TestInputStream[T: ClassManifest](ssc_ : StreamingContext, input: Seq[Seq[T]], numPartitions: Int) - extends InputDStream[T](ssc_) { - - def start() {} - - def stop() {} - - def compute(validTime: Time): Option[RDD[T]] = { - logInfo("Computing RDD for time " + validTime) - val index = ((validTime - zeroTime) / slideDuration - 1).toInt - val selectedInput = if (index < input.size) input(index) else Seq[T]() - - // lets us test cases where RDDs are not created - if (selectedInput == null) - return None - - val rdd = ssc.sc.makeRDD(selectedInput, numPartitions) - logInfo("Created RDD " + rdd.id + " with " + selectedInput) - Some(rdd) - } -} - -/** - * This is a output stream just for the testsuites. All the output is collected into a - * ArrayBuffer. This buffer is wiped clean on being restored from checkpoint. - */ -class TestOutputStream[T: ClassManifest](parent: DStream[T], val output: ArrayBuffer[Seq[T]]) - extends ForEachDStream[T](parent, (rdd: RDD[T], t: Time) => { - val collected = rdd.collect() - output += collected - }) { - - // This is to clear the output buffer every it is read from a checkpoint - @throws(classOf[IOException]) - private def readObject(ois: ObjectInputStream) { - ois.defaultReadObject() - output.clear() - } -} - -/** - * This is the base trait for Spark Streaming testsuites. This provides basic functionality - * to run user-defined set of input on user-defined stream operations, and verify the output. - */ -trait TestSuiteBase extends FunSuite with BeforeAndAfter with Logging { - - // Name of the framework for Spark context - def framework = "TestSuiteBase" - - // Master for Spark context - def master = "local[2]" - - // Batch duration - def batchDuration = Seconds(1) - - // Directory where the checkpoint data will be saved - def checkpointDir = "checkpoint" - - // Number of partitions of the input parallel collections created for testing - def numInputPartitions = 2 - - // Maximum time to wait before the test times out - def maxWaitTimeMillis = 10000 - - // Whether to actually wait in real time before changing manual clock - def actuallyWait = false - - /** - * Set up required DStreams to test the DStream operation using the two sequences - * of input collections. - */ - def setupStreams[U: ClassManifest, V: ClassManifest]( - input: Seq[Seq[U]], - operation: DStream[U] => DStream[V] - ): StreamingContext = { - - // Create StreamingContext - val ssc = new StreamingContext(master, framework, batchDuration) - if (checkpointDir != null) { - ssc.checkpoint(checkpointDir) - } - - // Setup the stream computation - val inputStream = new TestInputStream(ssc, input, numInputPartitions) - val operatedStream = operation(inputStream) - val outputStream = new TestOutputStream(operatedStream, new ArrayBuffer[Seq[V]] with SynchronizedBuffer[Seq[V]]) - ssc.registerInputStream(inputStream) - ssc.registerOutputStream(outputStream) - ssc - } - - /** - * Set up required DStreams to test the binary operation using the sequence - * of input collections. - */ - def setupStreams[U: ClassManifest, V: ClassManifest, W: ClassManifest]( - input1: Seq[Seq[U]], - input2: Seq[Seq[V]], - operation: (DStream[U], DStream[V]) => DStream[W] - ): StreamingContext = { - - // Create StreamingContext - val ssc = new StreamingContext(master, framework, batchDuration) - if (checkpointDir != null) { - ssc.checkpoint(checkpointDir) - } - - // Setup the stream computation - val inputStream1 = new TestInputStream(ssc, input1, numInputPartitions) - val inputStream2 = new TestInputStream(ssc, input2, numInputPartitions) - val operatedStream = operation(inputStream1, inputStream2) - val outputStream = new TestOutputStream(operatedStream, new ArrayBuffer[Seq[W]] with SynchronizedBuffer[Seq[W]]) - ssc.registerInputStream(inputStream1) - ssc.registerInputStream(inputStream2) - ssc.registerOutputStream(outputStream) - ssc - } - - /** - * Runs the streams set up in `ssc` on manual clock for `numBatches` batches and - * returns the collected output. It will wait until `numExpectedOutput` number of - * output data has been collected or timeout (set by `maxWaitTimeMillis`) is reached. - */ - def runStreams[V: ClassManifest]( - ssc: StreamingContext, - numBatches: Int, - numExpectedOutput: Int - ): Seq[Seq[V]] = { - assert(numBatches > 0, "Number of batches to run stream computation is zero") - assert(numExpectedOutput > 0, "Number of expected outputs after " + numBatches + " is zero") - logInfo("numBatches = " + numBatches + ", numExpectedOutput = " + numExpectedOutput) - - // Get the output buffer - val outputStream = ssc.graph.getOutputStreams.head.asInstanceOf[TestOutputStream[V]] - val output = outputStream.output - - try { - // Start computation - ssc.start() - - // Advance manual clock - val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] - logInfo("Manual clock before advancing = " + clock.time) - if (actuallyWait) { - for (i <- 1 to numBatches) { - logInfo("Actually waiting for " + batchDuration) - clock.addToTime(batchDuration.milliseconds) - Thread.sleep(batchDuration.milliseconds) - } - } else { - clock.addToTime(numBatches * batchDuration.milliseconds) - } - logInfo("Manual clock after advancing = " + clock.time) - - // Wait until expected number of output items have been generated - val startTime = System.currentTimeMillis() - while (output.size < numExpectedOutput && System.currentTimeMillis() - startTime < maxWaitTimeMillis) { - logInfo("output.size = " + output.size + ", numExpectedOutput = " + numExpectedOutput) - Thread.sleep(100) - } - val timeTaken = System.currentTimeMillis() - startTime - - assert(timeTaken < maxWaitTimeMillis, "Operation timed out after " + timeTaken + " ms") - assert(output.size === numExpectedOutput, "Unexpected number of outputs generated") - - Thread.sleep(500) // Give some time for the forgetting old RDDs to complete - } catch { - case e: Exception => e.printStackTrace(); throw e; - } finally { - ssc.stop() - } - output - } - - /** - * Verify whether the output values after running a DStream operation - * is same as the expected output values, by comparing the output - * collections either as lists (order matters) or sets (order does not matter) - */ - def verifyOutput[V: ClassManifest]( - output: Seq[Seq[V]], - expectedOutput: Seq[Seq[V]], - useSet: Boolean - ) { - logInfo("--------------------------------") - logInfo("output.size = " + output.size) - logInfo("output") - output.foreach(x => logInfo("[" + x.mkString(",") + "]")) - logInfo("expected output.size = " + expectedOutput.size) - logInfo("expected output") - expectedOutput.foreach(x => logInfo("[" + x.mkString(",") + "]")) - logInfo("--------------------------------") - - // Match the output with the expected output - assert(output.size === expectedOutput.size, "Number of outputs do not match") - for (i <- 0 until output.size) { - if (useSet) { - assert(output(i).toSet === expectedOutput(i).toSet) - } else { - assert(output(i).toList === expectedOutput(i).toList) - } - } - logInfo("Output verified successfully") - } - - /** - * Test unary DStream operation with a list of inputs, with number of - * batches to run same as the number of expected output values - */ - def testOperation[U: ClassManifest, V: ClassManifest]( - input: Seq[Seq[U]], - operation: DStream[U] => DStream[V], - expectedOutput: Seq[Seq[V]], - useSet: Boolean = false - ) { - testOperation[U, V](input, operation, expectedOutput, -1, useSet) - } - - /** - * Test unary DStream operation with a list of inputs - * @param input Sequence of input collections - * @param operation Binary DStream operation to be applied to the 2 inputs - * @param expectedOutput Sequence of expected output collections - * @param numBatches Number of batches to run the operation for - * @param useSet Compare the output values with the expected output values - * as sets (order matters) or as lists (order does not matter) - */ - def testOperation[U: ClassManifest, V: ClassManifest]( - input: Seq[Seq[U]], - operation: DStream[U] => DStream[V], - expectedOutput: Seq[Seq[V]], - numBatches: Int, - useSet: Boolean - ) { - val numBatches_ = if (numBatches > 0) numBatches else expectedOutput.size - val ssc = setupStreams[U, V](input, operation) - val output = runStreams[V](ssc, numBatches_, expectedOutput.size) - verifyOutput[V](output, expectedOutput, useSet) - } - - /** - * Test binary DStream operation with two lists of inputs, with number of - * batches to run same as the number of expected output values - */ - def testOperation[U: ClassManifest, V: ClassManifest, W: ClassManifest]( - input1: Seq[Seq[U]], - input2: Seq[Seq[V]], - operation: (DStream[U], DStream[V]) => DStream[W], - expectedOutput: Seq[Seq[W]], - useSet: Boolean - ) { - testOperation[U, V, W](input1, input2, operation, expectedOutput, -1, useSet) - } - - /** - * Test binary DStream operation with two lists of inputs - * @param input1 First sequence of input collections - * @param input2 Second sequence of input collections - * @param operation Binary DStream operation to be applied to the 2 inputs - * @param expectedOutput Sequence of expected output collections - * @param numBatches Number of batches to run the operation for - * @param useSet Compare the output values with the expected output values - * as sets (order matters) or as lists (order does not matter) - */ - def testOperation[U: ClassManifest, V: ClassManifest, W: ClassManifest]( - input1: Seq[Seq[U]], - input2: Seq[Seq[V]], - operation: (DStream[U], DStream[V]) => DStream[W], - expectedOutput: Seq[Seq[W]], - numBatches: Int, - useSet: Boolean - ) { - val numBatches_ = if (numBatches > 0) numBatches else expectedOutput.size - val ssc = setupStreams[U, V, W](input1, input2, operation) - val output = runStreams[W](ssc, numBatches_, expectedOutput.size) - verifyOutput[W](output, expectedOutput, useSet) - } -} diff --git a/streaming/src/test/scala/spark/streaming/WindowOperationsSuite.scala b/streaming/src/test/scala/spark/streaming/WindowOperationsSuite.scala deleted file mode 100644 index 894b765fc6..0000000000 --- a/streaming/src/test/scala/spark/streaming/WindowOperationsSuite.scala +++ /dev/null @@ -1,340 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.streaming - -import spark.streaming.StreamingContext._ -import collection.mutable.ArrayBuffer - -class WindowOperationsSuite extends TestSuiteBase { - - System.setProperty("spark.streaming.clock", "spark.streaming.util.ManualClock") - - override def framework = "WindowOperationsSuite" - - override def maxWaitTimeMillis = 20000 - - override def batchDuration = Seconds(1) - - after { - // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown - System.clearProperty("spark.driver.port") - System.clearProperty("spark.hostPort") - } - - val largerSlideInput = Seq( - Seq(("a", 1)), - Seq(("a", 2)), // 1st window from here - Seq(("a", 3)), - Seq(("a", 4)), // 2nd window from here - Seq(("a", 5)), - Seq(("a", 6)), // 3rd window from here - Seq(), - Seq() // 4th window from here - ) - - val largerSlideReduceOutput = Seq( - Seq(("a", 3)), - Seq(("a", 10)), - Seq(("a", 18)), - Seq(("a", 11)) - ) - - - val bigInput = Seq( - Seq(("a", 1)), - Seq(("a", 1), ("b", 1)), - Seq(("a", 1), ("b", 1), ("c", 1)), - Seq(("a", 1), ("b", 1)), - Seq(("a", 1)), - Seq(), - Seq(("a", 1)), - Seq(("a", 1), ("b", 1)), - Seq(("a", 1), ("b", 1), ("c", 1)), - Seq(("a", 1), ("b", 1)), - Seq(("a", 1)), - Seq() - ) - - val bigGroupByOutput = Seq( - Seq(("a", Seq(1))), - Seq(("a", Seq(1, 1)), ("b", Seq(1))), - Seq(("a", Seq(1, 1)), ("b", Seq(1, 1)), ("c", Seq(1))), - Seq(("a", Seq(1, 1)), ("b", Seq(1, 1)), ("c", Seq(1))), - Seq(("a", Seq(1, 1)), ("b", Seq(1))), - Seq(("a", Seq(1))), - Seq(("a", Seq(1))), - Seq(("a", Seq(1, 1)), ("b", Seq(1))), - Seq(("a", Seq(1, 1)), ("b", Seq(1, 1)), ("c", Seq(1))), - Seq(("a", Seq(1, 1)), ("b", Seq(1, 1)), ("c", Seq(1))), - Seq(("a", Seq(1, 1)), ("b", Seq(1))), - Seq(("a", Seq(1))) - ) - - - val bigReduceOutput = Seq( - Seq(("a", 1)), - Seq(("a", 2), ("b", 1)), - Seq(("a", 2), ("b", 2), ("c", 1)), - Seq(("a", 2), ("b", 2), ("c", 1)), - Seq(("a", 2), ("b", 1)), - Seq(("a", 1)), - Seq(("a", 1)), - Seq(("a", 2), ("b", 1)), - Seq(("a", 2), ("b", 2), ("c", 1)), - Seq(("a", 2), ("b", 2), ("c", 1)), - Seq(("a", 2), ("b", 1)), - Seq(("a", 1)) - ) - - /* - The output of the reduceByKeyAndWindow with inverse function but without a filter - function will be different from the naive reduceByKeyAndWindow, as no keys get - eliminated from the ReducedWindowedDStream even if the value of a key becomes 0. - */ - - val bigReduceInvOutput = Seq( - Seq(("a", 1)), - Seq(("a", 2), ("b", 1)), - Seq(("a", 2), ("b", 2), ("c", 1)), - Seq(("a", 2), ("b", 2), ("c", 1)), - Seq(("a", 2), ("b", 1), ("c", 0)), - Seq(("a", 1), ("b", 0), ("c", 0)), - Seq(("a", 1), ("b", 0), ("c", 0)), - Seq(("a", 2), ("b", 1), ("c", 0)), - Seq(("a", 2), ("b", 2), ("c", 1)), - Seq(("a", 2), ("b", 2), ("c", 1)), - Seq(("a", 2), ("b", 1), ("c", 0)), - Seq(("a", 1), ("b", 0), ("c", 0)) - ) - - // Testing window operation - - testWindow( - "basic window", - Seq( Seq(0), Seq(1), Seq(2), Seq(3), Seq(4), Seq(5)), - Seq( Seq(0), Seq(0, 1), Seq(1, 2), Seq(2, 3), Seq(3, 4), Seq(4, 5)) - ) - - testWindow( - "tumbling window", - Seq( Seq(0), Seq(1), Seq(2), Seq(3), Seq(4), Seq(5)), - Seq( Seq(0, 1), Seq(2, 3), Seq(4, 5)), - Seconds(2), - Seconds(2) - ) - - testWindow( - "larger window", - Seq( Seq(0), Seq(1), Seq(2), Seq(3), Seq(4), Seq(5)), - Seq( Seq(0, 1), Seq(0, 1, 2, 3), Seq(2, 3, 4, 5), Seq(4, 5)), - Seconds(4), - Seconds(2) - ) - - testWindow( - "non-overlapping window", - Seq( Seq(0), Seq(1), Seq(2), Seq(3), Seq(4), Seq(5)), - Seq( Seq(1, 2), Seq(4, 5)), - Seconds(2), - Seconds(3) - ) - - // Testing naive reduceByKeyAndWindow (without invertible function) - - testReduceByKeyAndWindow( - "basic reduction", - Seq( Seq(("a", 1), ("a", 3)) ), - Seq( Seq(("a", 4)) ) - ) - - testReduceByKeyAndWindow( - "key already in window and new value added into window", - Seq( Seq(("a", 1)), Seq(("a", 1)) ), - Seq( Seq(("a", 1)), Seq(("a", 2)) ) - ) - - testReduceByKeyAndWindow( - "new key added into window", - Seq( Seq(("a", 1)), Seq(("a", 1), ("b", 1)) ), - Seq( Seq(("a", 1)), Seq(("a", 2), ("b", 1)) ) - ) - - testReduceByKeyAndWindow( - "key removed from window", - Seq( Seq(("a", 1)), Seq(("a", 1)), Seq(), Seq() ), - Seq( Seq(("a", 1)), Seq(("a", 2)), Seq(("a", 1)), Seq() ) - ) - - testReduceByKeyAndWindow( - "larger slide time", - largerSlideInput, - largerSlideReduceOutput, - Seconds(4), - Seconds(2) - ) - - testReduceByKeyAndWindow("big test", bigInput, bigReduceOutput) - - // Testing reduceByKeyAndWindow (with invertible reduce function) - - testReduceByKeyAndWindowWithInverse( - "basic reduction", - Seq(Seq(("a", 1), ("a", 3)) ), - Seq(Seq(("a", 4)) ) - ) - - testReduceByKeyAndWindowWithInverse( - "key already in window and new value added into window", - Seq( Seq(("a", 1)), Seq(("a", 1)) ), - Seq( Seq(("a", 1)), Seq(("a", 2)) ) - ) - - testReduceByKeyAndWindowWithInverse( - "new key added into window", - Seq( Seq(("a", 1)), Seq(("a", 1), ("b", 1)) ), - Seq( Seq(("a", 1)), Seq(("a", 2), ("b", 1)) ) - ) - - testReduceByKeyAndWindowWithInverse( - "key removed from window", - Seq( Seq(("a", 1)), Seq(("a", 1)), Seq(), Seq() ), - Seq( Seq(("a", 1)), Seq(("a", 2)), Seq(("a", 1)), Seq(("a", 0)) ) - ) - - testReduceByKeyAndWindowWithInverse( - "larger slide time", - largerSlideInput, - largerSlideReduceOutput, - Seconds(4), - Seconds(2) - ) - - testReduceByKeyAndWindowWithInverse("big test", bigInput, bigReduceInvOutput) - - testReduceByKeyAndWindowWithFilteredInverse("big test", bigInput, bigReduceOutput) - - test("groupByKeyAndWindow") { - val input = bigInput - val expectedOutput = bigGroupByOutput.map(_.map(x => (x._1, x._2.toSet))) - val windowDuration = Seconds(2) - val slideDuration = Seconds(1) - val numBatches = expectedOutput.size * (slideDuration / batchDuration).toInt - val operation = (s: DStream[(String, Int)]) => { - s.groupByKeyAndWindow(windowDuration, slideDuration) - .map(x => (x._1, x._2.toSet)) - .persist() - } - testOperation(input, operation, expectedOutput, numBatches, true) - } - - test("countByWindow") { - val input = Seq(Seq(1), Seq(1), Seq(1, 2), Seq(0), Seq(), Seq() ) - val expectedOutput = Seq( Seq(1), Seq(2), Seq(3), Seq(3), Seq(1), Seq(0)) - val windowDuration = Seconds(2) - val slideDuration = Seconds(1) - val numBatches = expectedOutput.size * (slideDuration / batchDuration).toInt - val operation = (s: DStream[Int]) => { - s.countByWindow(windowDuration, slideDuration).map(_.toInt) - } - testOperation(input, operation, expectedOutput, numBatches, true) - } - - test("countByValueAndWindow") { - val input = Seq(Seq("a"), Seq("b", "b"), Seq("a", "b")) - val expectedOutput = Seq( Seq(("a", 1)), Seq(("a", 1), ("b", 2)), Seq(("a", 1), ("b", 3))) - val windowDuration = Seconds(2) - val slideDuration = Seconds(1) - val numBatches = expectedOutput.size * (slideDuration / batchDuration).toInt - val operation = (s: DStream[String]) => { - s.countByValueAndWindow(windowDuration, slideDuration).map(x => (x._1, x._2.toInt)) - } - testOperation(input, operation, expectedOutput, numBatches, true) - } - - - // Helper functions - - def testWindow( - name: String, - input: Seq[Seq[Int]], - expectedOutput: Seq[Seq[Int]], - windowDuration: Duration = Seconds(2), - slideDuration: Duration = Seconds(1) - ) { - test("window - " + name) { - val numBatches = expectedOutput.size * (slideDuration / batchDuration).toInt - val operation = (s: DStream[Int]) => s.window(windowDuration, slideDuration) - testOperation(input, operation, expectedOutput, numBatches, true) - } - } - - def testReduceByKeyAndWindow( - name: String, - input: Seq[Seq[(String, Int)]], - expectedOutput: Seq[Seq[(String, Int)]], - windowDuration: Duration = Seconds(2), - slideDuration: Duration = Seconds(1) - ) { - test("reduceByKeyAndWindow - " + name) { - logInfo("reduceByKeyAndWindow - " + name) - val numBatches = expectedOutput.size * (slideDuration / batchDuration).toInt - val operation = (s: DStream[(String, Int)]) => { - s.reduceByKeyAndWindow((x: Int, y: Int) => x + y, windowDuration, slideDuration) - } - testOperation(input, operation, expectedOutput, numBatches, true) - } - } - - def testReduceByKeyAndWindowWithInverse( - name: String, - input: Seq[Seq[(String, Int)]], - expectedOutput: Seq[Seq[(String, Int)]], - windowDuration: Duration = Seconds(2), - slideDuration: Duration = Seconds(1) - ) { - test("reduceByKeyAndWindow with inverse function - " + name) { - logInfo("reduceByKeyAndWindow with inverse function - " + name) - val numBatches = expectedOutput.size * (slideDuration / batchDuration).toInt - val operation = (s: DStream[(String, Int)]) => { - s.reduceByKeyAndWindow(_ + _, _ - _, windowDuration, slideDuration) - .checkpoint(Seconds(100)) // Large value to avoid effect of RDD checkpointing - } - testOperation(input, operation, expectedOutput, numBatches, true) - } - } - - def testReduceByKeyAndWindowWithFilteredInverse( - name: String, - input: Seq[Seq[(String, Int)]], - expectedOutput: Seq[Seq[(String, Int)]], - windowDuration: Duration = Seconds(2), - slideDuration: Duration = Seconds(1) - ) { - test("reduceByKeyAndWindow with inverse and filter functions - " + name) { - logInfo("reduceByKeyAndWindow with inverse and filter functions - " + name) - val numBatches = expectedOutput.size * (slideDuration / batchDuration).toInt - val filterFunc = (p: (String, Int)) => p._2 != 0 - val operation = (s: DStream[(String, Int)]) => { - s.reduceByKeyAndWindow(_ + _, _ - _, windowDuration, slideDuration, filterFunc = filterFunc) - .persist() - .checkpoint(Seconds(100)) // Large value to avoid effect of RDD checkpointing - } - testOperation(input, operation, expectedOutput, numBatches, true) - } - } -} -- cgit v1.2.3 From 0a8cc309211c62f8824d76618705c817edcf2424 Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Sun, 1 Sep 2013 00:32:28 -0700 Subject: Move some classes to more appropriate packages: * RDD, *RDDFunctions -> org.apache.spark.rdd * Utils, ClosureCleaner, SizeEstimator -> org.apache.spark.util * JavaSerializer, KryoSerializer -> org.apache.spark.serializer --- .../main/scala/org/apache/spark/bagel/Bagel.scala | 2 +- .../main/scala/org/apache/spark/Accumulators.scala | 1 + .../main/scala/org/apache/spark/CacheManager.scala | 1 + .../scala/org/apache/spark/ClosureCleaner.scala | 231 ----- .../main/scala/org/apache/spark/Dependency.scala | 2 + .../org/apache/spark/DoubleRDDFunctions.scala | 78 -- .../scala/org/apache/spark/HttpFileServer.scala | 1 + .../main/scala/org/apache/spark/HttpServer.scala | 1 + .../scala/org/apache/spark/JavaSerializer.scala | 83 -- .../scala/org/apache/spark/KryoSerializer.scala | 156 ---- .../scala/org/apache/spark/MapOutputTracker.scala | 2 +- .../scala/org/apache/spark/PairRDDFunctions.scala | 703 --------------- .../main/scala/org/apache/spark/Partitioner.scala | 3 + core/src/main/scala/org/apache/spark/RDD.scala | 957 --------------------- .../scala/org/apache/spark/RDDCheckpointData.scala | 130 --- .../apache/spark/SequenceFileRDDFunctions.scala | 107 --- .../scala/org/apache/spark/SizeEstimator.scala | 283 ------ .../main/scala/org/apache/spark/SparkContext.scala | 8 +- .../src/main/scala/org/apache/spark/SparkEnv.scala | 6 +- core/src/main/scala/org/apache/spark/Utils.scala | 780 ----------------- .../org/apache/spark/api/java/JavaDoubleRDD.scala | 2 +- .../org/apache/spark/api/java/JavaPairRDD.scala | 2 +- .../scala/org/apache/spark/api/java/JavaRDD.scala | 1 + .../org/apache/spark/api/java/JavaRDDLike.scala | 6 +- .../apache/spark/api/java/JavaSparkContext.scala | 6 +- .../spark/api/python/PythonPartitioner.scala | 2 +- .../org/apache/spark/api/python/PythonRDD.scala | 2 + .../spark/broadcast/BitTorrentBroadcast.scala | 1 + .../org/apache/spark/broadcast/HttpBroadcast.scala | 4 +- .../org/apache/spark/broadcast/MultiTracker.scala | 1 + .../org/apache/spark/broadcast/TreeBroadcast.scala | 1 + .../org/apache/spark/deploy/DeployMessage.scala | 2 +- .../apache/spark/deploy/LocalSparkCluster.scala | 4 +- .../apache/spark/deploy/client/TestClient.scala | 4 +- .../org/apache/spark/deploy/master/Master.scala | 4 +- .../spark/deploy/master/MasterArguments.scala | 3 +- .../apache/spark/deploy/master/WorkerInfo.scala | 2 +- .../spark/deploy/master/ui/ApplicationPage.scala | 2 +- .../apache/spark/deploy/master/ui/IndexPage.scala | 2 +- .../spark/deploy/master/ui/MasterWebUI.scala | 3 +- .../spark/deploy/worker/ExecutorRunner.scala | 3 +- .../org/apache/spark/deploy/worker/Worker.scala | 4 +- .../spark/deploy/worker/WorkerArguments.scala | 4 +- .../apache/spark/deploy/worker/ui/IndexPage.scala | 2 +- .../spark/deploy/worker/ui/WorkerWebUI.scala | 3 +- .../scala/org/apache/spark/executor/Executor.scala | 1 + .../spark/executor/MesosExecutorBackend.scala | 3 +- .../spark/executor/StandaloneExecutorBackend.scala | 4 +- .../apache/spark/network/ConnectionManager.scala | 1 + .../apache/spark/network/ConnectionManagerId.scala | 2 +- core/src/main/scala/org/apache/spark/package.scala | 21 +- .../spark/partial/ApproximateActionListener.scala | 1 + .../main/scala/org/apache/spark/rdd/BlockRDD.scala | 2 +- .../scala/org/apache/spark/rdd/CoGroupedRDD.scala | 2 +- .../org/apache/spark/rdd/DoubleRDDFunctions.scala | 79 ++ .../main/scala/org/apache/spark/rdd/EmptyRDD.scala | 2 +- .../scala/org/apache/spark/rdd/FilteredRDD.scala | 2 +- .../scala/org/apache/spark/rdd/FlatMappedRDD.scala | 2 +- .../org/apache/spark/rdd/FlatMappedValuesRDD.scala | 2 +- .../scala/org/apache/spark/rdd/GlommedRDD.scala | 2 +- .../scala/org/apache/spark/rdd/HadoopRDD.scala | 8 +- .../main/scala/org/apache/spark/rdd/JdbcRDD.scala | 2 +- .../org/apache/spark/rdd/MapPartitionsRDD.scala | 2 +- .../spark/rdd/MapPartitionsWithIndexRDD.scala | 2 +- .../scala/org/apache/spark/rdd/MappedRDD.scala | 2 +- .../org/apache/spark/rdd/MappedValuesRDD.scala | 2 +- .../scala/org/apache/spark/rdd/NewHadoopRDD.scala | 2 +- .../org/apache/spark/rdd/OrderedRDDFunctions.scala | 7 +- .../org/apache/spark/rdd/PairRDDFunctions.scala | 702 +++++++++++++++ .../apache/spark/rdd/ParallelCollectionRDD.scala | 2 + .../org/apache/spark/rdd/PartitionPruningRDD.scala | 2 +- .../main/scala/org/apache/spark/rdd/PipedRDD.scala | 2 +- core/src/main/scala/org/apache/spark/rdd/RDD.scala | 942 ++++++++++++++++++++ .../org/apache/spark/rdd/RDDCheckpointData.scala | 131 +++ .../scala/org/apache/spark/rdd/SampledRDD.scala | 2 +- .../spark/rdd/SequenceFileRDDFunctions.scala | 89 ++ .../scala/org/apache/spark/rdd/ShuffledRDD.scala | 2 +- .../scala/org/apache/spark/rdd/SubtractedRDD.scala | 1 - .../main/scala/org/apache/spark/rdd/UnionRDD.scala | 2 +- .../org/apache/spark/rdd/ZippedPartitionsRDD.scala | 2 +- .../scala/org/apache/spark/rdd/ZippedRDD.scala | 2 +- .../org/apache/spark/scheduler/DAGScheduler.scala | 1 + .../apache/spark/scheduler/DAGSchedulerEvent.scala | 1 + .../org/apache/spark/scheduler/JobLogger.scala | 1 + .../org/apache/spark/scheduler/ResultTask.scala | 7 +- .../apache/spark/scheduler/ShuffleMapTask.scala | 2 + .../org/apache/spark/scheduler/SparkListener.scala | 4 +- .../scala/org/apache/spark/scheduler/Stage.scala | 3 +- .../org/apache/spark/scheduler/TaskResult.scala | 3 +- .../scheduler/cluster/ClusterTaskSetManager.scala | 2 +- .../spark/scheduler/cluster/SchedulerBackend.scala | 2 +- .../cluster/SparkDeploySchedulerBackend.scala | 3 +- .../cluster/StandaloneClusterMessage.scala | 3 +- .../cluster/StandaloneSchedulerBackend.scala | 3 +- .../apache/spark/scheduler/cluster/TaskInfo.scala | 2 +- .../spark/scheduler/local/LocalScheduler.scala | 1 + .../mesos/CoarseMesosSchedulerBackend.scala | 2 +- .../scheduler/mesos/MesosSchedulerBackend.scala | 3 +- .../apache/spark/serializer/JavaSerializer.scala | 82 ++ .../apache/spark/serializer/KryoSerializer.scala | 159 ++++ .../spark/storage/BlockFetcherIterator.scala | 2 +- .../org/apache/spark/storage/BlockManager.scala | 4 +- .../org/apache/spark/storage/BlockManagerId.scala | 2 +- .../spark/storage/BlockManagerMasterActor.scala | 3 +- .../apache/spark/storage/BlockManagerWorker.scala | 3 +- .../scala/org/apache/spark/storage/DiskStore.scala | 2 +- .../org/apache/spark/storage/MemoryStore.scala | 2 +- .../org/apache/spark/storage/StorageUtils.scala | 3 +- .../org/apache/spark/storage/ThreadingTest.scala | 2 +- .../main/scala/org/apache/spark/ui/SparkUI.scala | 3 +- .../org/apache/spark/ui/exec/ExecutorsUI.scala | 3 +- .../apache/spark/ui/jobs/JobProgressListener.scala | 2 +- .../org/apache/spark/ui/jobs/JobProgressUI.scala | 3 +- .../scala/org/apache/spark/ui/jobs/StagePage.scala | 4 +- .../org/apache/spark/ui/jobs/StageTable.scala | 2 +- .../org/apache/spark/ui/storage/IndexPage.scala | 2 +- .../org/apache/spark/ui/storage/RDDPage.scala | 2 +- .../org/apache/spark/util/ClosureCleaner.scala | 232 +++++ .../scala/org/apache/spark/util/MemoryParam.scala | 2 - .../org/apache/spark/util/SizeEstimator.scala | 284 ++++++ .../main/scala/org/apache/spark/util/Utils.scala | 781 +++++++++++++++++ .../scala/org/apache/spark/CheckpointSuite.scala | 1 + .../org/apache/spark/ClosureCleanerSuite.scala | 146 ---- .../test/scala/org/apache/spark/DriverSuite.scala | 1 + .../test/scala/org/apache/spark/FailureSuite.scala | 1 + .../org/apache/spark/KryoSerializerSuite.scala | 208 ----- .../org/apache/spark/PairRDDFunctionsSuite.scala | 299 ------- .../apache/spark/PartitionPruningRDDSuite.scala | 2 +- .../scala/org/apache/spark/PartitioningSuite.scala | 9 +- .../src/test/scala/org/apache/spark/RDDSuite.scala | 389 --------- .../test/scala/org/apache/spark/ShuffleSuite.scala | 3 +- .../org/apache/spark/SizeEstimatorSuite.scala | 164 ---- .../test/scala/org/apache/spark/SortingSuite.scala | 123 --- .../test/scala/org/apache/spark/UtilsSuite.scala | 139 --- .../apache/spark/rdd/PairRDDFunctionsSuite.scala | 300 +++++++ .../test/scala/org/apache/spark/rdd/RDDSuite.scala | 391 +++++++++ .../scala/org/apache/spark/rdd/SortingSuite.scala | 125 +++ .../apache/spark/scheduler/DAGSchedulerSuite.scala | 2 +- .../apache/spark/scheduler/JobLoggerSuite.scala | 6 +- .../apache/spark/scheduler/TaskContextSuite.scala | 2 +- .../cluster/ClusterTaskSetManagerSuite.scala | 2 +- .../spark/serializer/KryoSerializerSuite.scala | 208 +++++ .../apache/spark/storage/BlockManagerSuite.scala | 8 +- .../apache/spark/util/ClosureCleanerSuite.scala | 147 ++++ .../org/apache/spark/util/SizeEstimatorSuite.scala | 164 ++++ .../scala/org/apache/spark/util/UtilsSuite.scala | 139 +++ docs/configuration.md | 10 +- docs/quick-start.md | 2 +- docs/scala-programming-guide.md | 2 +- docs/tuning.md | 10 +- .../spark/examples/bagel/PageRankUtils.scala | 1 + .../spark/examples/bagel/WikipediaPageRank.scala | 2 +- .../bagel/WikipediaPageRankStandalone.scala | 17 +- .../spark/streaming/examples/QueueStream.scala | 2 +- .../spark/streaming/examples/RawNetworkGrep.scala | 2 +- .../mllib/classification/ClassificationModel.scala | 2 +- .../mllib/classification/LogisticRegression.scala | 3 +- .../apache/spark/mllib/classification/SVM.scala | 3 +- .../org/apache/spark/mllib/clustering/KMeans.scala | 3 +- .../spark/mllib/clustering/KMeansModel.scala | 2 +- .../spark/mllib/optimization/GradientDescent.scala | 8 +- .../spark/mllib/optimization/Optimizer.scala | 2 +- .../apache/spark/mllib/recommendation/ALS.scala | 7 +- .../recommendation/MatrixFactorizationModel.scala | 2 +- .../regression/GeneralizedLinearAlgorithm.scala | 3 +- .../org/apache/spark/mllib/regression/Lasso.scala | 3 +- .../spark/mllib/regression/LinearRegression.scala | 3 +- .../spark/mllib/regression/RegressionModel.scala | 2 +- .../spark/mllib/regression/RidgeRegression.scala | 3 +- .../apache/spark/mllib/util/DataValidators.scala | 3 +- .../spark/mllib/util/KMeansDataGenerator.scala | 3 +- .../spark/mllib/util/LinearDataGenerator.scala | 3 +- .../util/LogisticRegressionDataGenerator.scala | 3 +- .../apache/spark/mllib/util/MFDataGenerator.scala | 3 +- .../org/apache/spark/mllib/util/MLUtils.scala | 3 +- .../apache/spark/mllib/util/SVMDataGenerator.scala | 5 +- python/pyspark/context.py | 4 +- .../scala/org/apache/spark/repl/SparkIMain.scala | 2 +- .../scala/org/apache/spark/streaming/DStream.scala | 3 +- .../org/apache/spark/streaming/Duration.scala | 2 +- .../spark/streaming/PairDStreamFunctions.scala | 9 +- .../apache/spark/streaming/StreamingContext.scala | 1 + .../spark/streaming/api/java/JavaDStream.scala | 2 +- .../spark/streaming/api/java/JavaDStreamLike.scala | 2 +- .../spark/streaming/api/java/JavaPairDStream.scala | 7 +- .../streaming/api/java/JavaStreamingContext.scala | 25 +- .../spark/streaming/dstream/CoGroupedDStream.scala | 3 +- .../streaming/dstream/ConstantInputDStream.scala | 2 +- .../spark/streaming/dstream/FileInputDStream.scala | 2 +- .../spark/streaming/dstream/FilteredDStream.scala | 2 +- .../streaming/dstream/FlatMapValuedDStream.scala | 2 +- .../streaming/dstream/FlatMappedDStream.scala | 2 +- .../streaming/dstream/FlumeInputDStream.scala | 15 +- .../spark/streaming/dstream/ForEachDStream.scala | 2 +- .../spark/streaming/dstream/GlommedDStream.scala | 2 +- .../streaming/dstream/MapPartitionedDStream.scala | 2 +- .../spark/streaming/dstream/MapValuedDStream.scala | 2 +- .../spark/streaming/dstream/MappedDStream.scala | 2 +- .../streaming/dstream/NetworkInputDStream.scala | 15 +- .../streaming/dstream/QueueInputDStream.scala | 2 +- .../streaming/dstream/ReducedWindowedDStream.scala | 2 +- .../spark/streaming/dstream/ShuffledDStream.scala | 3 +- .../spark/streaming/dstream/StateDStream.scala | 2 +- .../streaming/dstream/TransformedDStream.scala | 2 +- .../spark/streaming/dstream/UnionDStream.scala | 2 +- .../spark/streaming/dstream/WindowedDStream.scala | 2 +- .../spark/streaming/util/MasterFailureTest.scala | 3 +- .../spark/streaming/util/RawTextSender.scala | 3 +- .../org/apache/spark/streaming/TestSuiteBase.scala | 5 +- .../spark/tools/JavaAPICompletenessChecker.scala | 38 +- 210 files changed, 5290 insertions(+), 5246 deletions(-) delete mode 100644 core/src/main/scala/org/apache/spark/ClosureCleaner.scala delete mode 100644 core/src/main/scala/org/apache/spark/DoubleRDDFunctions.scala delete mode 100644 core/src/main/scala/org/apache/spark/JavaSerializer.scala delete mode 100644 core/src/main/scala/org/apache/spark/KryoSerializer.scala delete mode 100644 core/src/main/scala/org/apache/spark/PairRDDFunctions.scala delete mode 100644 core/src/main/scala/org/apache/spark/RDD.scala delete mode 100644 core/src/main/scala/org/apache/spark/RDDCheckpointData.scala delete mode 100644 core/src/main/scala/org/apache/spark/SequenceFileRDDFunctions.scala delete mode 100644 core/src/main/scala/org/apache/spark/SizeEstimator.scala delete mode 100644 core/src/main/scala/org/apache/spark/Utils.scala create mode 100644 core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala create mode 100644 core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala create mode 100644 core/src/main/scala/org/apache/spark/rdd/RDD.scala create mode 100644 core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala create mode 100644 core/src/main/scala/org/apache/spark/rdd/SequenceFileRDDFunctions.scala create mode 100644 core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala create mode 100644 core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala create mode 100644 core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala create mode 100644 core/src/main/scala/org/apache/spark/util/SizeEstimator.scala create mode 100644 core/src/main/scala/org/apache/spark/util/Utils.scala delete mode 100644 core/src/test/scala/org/apache/spark/ClosureCleanerSuite.scala delete mode 100644 core/src/test/scala/org/apache/spark/KryoSerializerSuite.scala delete mode 100644 core/src/test/scala/org/apache/spark/PairRDDFunctionsSuite.scala delete mode 100644 core/src/test/scala/org/apache/spark/RDDSuite.scala delete mode 100644 core/src/test/scala/org/apache/spark/SizeEstimatorSuite.scala delete mode 100644 core/src/test/scala/org/apache/spark/SortingSuite.scala delete mode 100644 core/src/test/scala/org/apache/spark/UtilsSuite.scala create mode 100644 core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala create mode 100644 core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala create mode 100644 core/src/test/scala/org/apache/spark/rdd/SortingSuite.scala create mode 100644 core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala create mode 100644 core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala create mode 100644 core/src/test/scala/org/apache/spark/util/SizeEstimatorSuite.scala create mode 100644 core/src/test/scala/org/apache/spark/util/UtilsSuite.scala (limited to 'streaming/src') diff --git a/bagel/src/main/scala/org/apache/spark/bagel/Bagel.scala b/bagel/src/main/scala/org/apache/spark/bagel/Bagel.scala index fec8737fcd..44e26bbb9e 100644 --- a/bagel/src/main/scala/org/apache/spark/bagel/Bagel.scala +++ b/bagel/src/main/scala/org/apache/spark/bagel/Bagel.scala @@ -19,7 +19,7 @@ package org.apache.spark.bagel import org.apache.spark._ import org.apache.spark.SparkContext._ - +import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel object Bagel extends Logging { diff --git a/core/src/main/scala/org/apache/spark/Accumulators.scala b/core/src/main/scala/org/apache/spark/Accumulators.scala index 5177ee58fa..6e922a612a 100644 --- a/core/src/main/scala/org/apache/spark/Accumulators.scala +++ b/core/src/main/scala/org/apache/spark/Accumulators.scala @@ -21,6 +21,7 @@ import java.io._ import scala.collection.mutable.Map import scala.collection.generic.Growable +import org.apache.spark.serializer.JavaSerializer /** * A datatype that can be accumulated, i.e. has an commutative and associative "add" operation, diff --git a/core/src/main/scala/org/apache/spark/CacheManager.scala b/core/src/main/scala/org/apache/spark/CacheManager.scala index 42e465b9d8..e299a106ee 100644 --- a/core/src/main/scala/org/apache/spark/CacheManager.scala +++ b/core/src/main/scala/org/apache/spark/CacheManager.scala @@ -19,6 +19,7 @@ package org.apache.spark import scala.collection.mutable.{ArrayBuffer, HashSet} import org.apache.spark.storage.{BlockManager, StorageLevel} +import org.apache.spark.rdd.RDD /** Spark class responsible for passing RDDs split contents to the BlockManager and making diff --git a/core/src/main/scala/org/apache/spark/ClosureCleaner.scala b/core/src/main/scala/org/apache/spark/ClosureCleaner.scala deleted file mode 100644 index 71d9e62d4f..0000000000 --- a/core/src/main/scala/org/apache/spark/ClosureCleaner.scala +++ /dev/null @@ -1,231 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark - -import java.lang.reflect.Field - -import scala.collection.mutable.Map -import scala.collection.mutable.Set - -import org.objectweb.asm.{ClassReader, ClassVisitor, MethodVisitor, Type} -import org.objectweb.asm.Opcodes._ -import java.io.{InputStream, IOException, ByteArrayOutputStream, ByteArrayInputStream, BufferedInputStream} - -private[spark] object ClosureCleaner extends Logging { - // Get an ASM class reader for a given class from the JAR that loaded it - private def getClassReader(cls: Class[_]): ClassReader = { - // Copy data over, before delegating to ClassReader - else we can run out of open file handles. - val className = cls.getName.replaceFirst("^.*\\.", "") + ".class" - val resourceStream = cls.getResourceAsStream(className) - // todo: Fixme - continuing with earlier behavior ... - if (resourceStream == null) return new ClassReader(resourceStream) - - val baos = new ByteArrayOutputStream(128) - Utils.copyStream(resourceStream, baos, true) - new ClassReader(new ByteArrayInputStream(baos.toByteArray)) - } - - // Check whether a class represents a Scala closure - private def isClosure(cls: Class[_]): Boolean = { - cls.getName.contains("$anonfun$") - } - - // Get a list of the classes of the outer objects of a given closure object, obj; - // the outer objects are defined as any closures that obj is nested within, plus - // possibly the class that the outermost closure is in, if any. We stop searching - // for outer objects beyond that because cloning the user's object is probably - // not a good idea (whereas we can clone closure objects just fine since we - // understand how all their fields are used). - private def getOuterClasses(obj: AnyRef): List[Class[_]] = { - for (f <- obj.getClass.getDeclaredFields if f.getName == "$outer") { - f.setAccessible(true) - if (isClosure(f.getType)) { - return f.getType :: getOuterClasses(f.get(obj)) - } else { - return f.getType :: Nil // Stop at the first $outer that is not a closure - } - } - return Nil - } - - // Get a list of the outer objects for a given closure object. - private def getOuterObjects(obj: AnyRef): List[AnyRef] = { - for (f <- obj.getClass.getDeclaredFields if f.getName == "$outer") { - f.setAccessible(true) - if (isClosure(f.getType)) { - return f.get(obj) :: getOuterObjects(f.get(obj)) - } else { - return f.get(obj) :: Nil // Stop at the first $outer that is not a closure - } - } - return Nil - } - - private def getInnerClasses(obj: AnyRef): List[Class[_]] = { - val seen = Set[Class[_]](obj.getClass) - var stack = List[Class[_]](obj.getClass) - while (!stack.isEmpty) { - val cr = getClassReader(stack.head) - stack = stack.tail - val set = Set[Class[_]]() - cr.accept(new InnerClosureFinder(set), 0) - for (cls <- set -- seen) { - seen += cls - stack = cls :: stack - } - } - return (seen - obj.getClass).toList - } - - private def createNullValue(cls: Class[_]): AnyRef = { - if (cls.isPrimitive) { - new java.lang.Byte(0: Byte) // Should be convertible to any primitive type - } else { - null - } - } - - def clean(func: AnyRef) { - // TODO: cache outerClasses / innerClasses / accessedFields - val outerClasses = getOuterClasses(func) - val innerClasses = getInnerClasses(func) - val outerObjects = getOuterObjects(func) - - val accessedFields = Map[Class[_], Set[String]]() - for (cls <- outerClasses) - accessedFields(cls) = Set[String]() - for (cls <- func.getClass :: innerClasses) - getClassReader(cls).accept(new FieldAccessFinder(accessedFields), 0) - //logInfo("accessedFields: " + accessedFields) - - val inInterpreter = { - try { - val interpClass = Class.forName("spark.repl.Main") - interpClass.getMethod("interp").invoke(null) != null - } catch { - case _: ClassNotFoundException => true - } - } - - var outerPairs: List[(Class[_], AnyRef)] = (outerClasses zip outerObjects).reverse - var outer: AnyRef = null - if (outerPairs.size > 0 && !isClosure(outerPairs.head._1)) { - // The closure is ultimately nested inside a class; keep the object of that - // class without cloning it since we don't want to clone the user's objects. - outer = outerPairs.head._2 - outerPairs = outerPairs.tail - } - // Clone the closure objects themselves, nulling out any fields that are not - // used in the closure we're working on or any of its inner closures. - for ((cls, obj) <- outerPairs) { - outer = instantiateClass(cls, outer, inInterpreter) - for (fieldName <- accessedFields(cls)) { - val field = cls.getDeclaredField(fieldName) - field.setAccessible(true) - val value = field.get(obj) - //logInfo("1: Setting " + fieldName + " on " + cls + " to " + value); - field.set(outer, value) - } - } - - if (outer != null) { - //logInfo("2: Setting $outer on " + func.getClass + " to " + outer); - val field = func.getClass.getDeclaredField("$outer") - field.setAccessible(true) - field.set(func, outer) - } - } - - private def instantiateClass(cls: Class[_], outer: AnyRef, inInterpreter: Boolean): AnyRef = { - //logInfo("Creating a " + cls + " with outer = " + outer) - if (!inInterpreter) { - // This is a bona fide closure class, whose constructor has no effects - // other than to set its fields, so use its constructor - val cons = cls.getConstructors()(0) - val params = cons.getParameterTypes.map(createNullValue).toArray - if (outer != null) - params(0) = outer // First param is always outer object - return cons.newInstance(params: _*).asInstanceOf[AnyRef] - } else { - // Use reflection to instantiate object without calling constructor - val rf = sun.reflect.ReflectionFactory.getReflectionFactory() - val parentCtor = classOf[java.lang.Object].getDeclaredConstructor() - val newCtor = rf.newConstructorForSerialization(cls, parentCtor) - val obj = newCtor.newInstance().asInstanceOf[AnyRef] - if (outer != null) { - //logInfo("3: Setting $outer on " + cls + " to " + outer); - val field = cls.getDeclaredField("$outer") - field.setAccessible(true) - field.set(obj, outer) - } - return obj - } - } -} - -private[spark] class FieldAccessFinder(output: Map[Class[_], Set[String]]) extends ClassVisitor(ASM4) { - override def visitMethod(access: Int, name: String, desc: String, - sig: String, exceptions: Array[String]): MethodVisitor = { - return new MethodVisitor(ASM4) { - override def visitFieldInsn(op: Int, owner: String, name: String, desc: String) { - if (op == GETFIELD) { - for (cl <- output.keys if cl.getName == owner.replace('/', '.')) { - output(cl) += name - } - } - } - - override def visitMethodInsn(op: Int, owner: String, name: String, - desc: String) { - // Check for calls a getter method for a variable in an interpreter wrapper object. - // This means that the corresponding field will be accessed, so we should save it. - if (op == INVOKEVIRTUAL && owner.endsWith("$iwC") && !name.endsWith("$outer")) { - for (cl <- output.keys if cl.getName == owner.replace('/', '.')) { - output(cl) += name - } - } - } - } - } -} - -private[spark] class InnerClosureFinder(output: Set[Class[_]]) extends ClassVisitor(ASM4) { - var myName: String = null - - override def visit(version: Int, access: Int, name: String, sig: String, - superName: String, interfaces: Array[String]) { - myName = name - } - - override def visitMethod(access: Int, name: String, desc: String, - sig: String, exceptions: Array[String]): MethodVisitor = { - return new MethodVisitor(ASM4) { - override def visitMethodInsn(op: Int, owner: String, name: String, - desc: String) { - val argTypes = Type.getArgumentTypes(desc) - if (op == INVOKESPECIAL && name == "" && argTypes.length > 0 - && argTypes(0).toString.startsWith("L") // is it an object? - && argTypes(0).getInternalName == myName) - output += Class.forName( - owner.replace('/', '.'), - false, - Thread.currentThread.getContextClassLoader) - } - } - } -} diff --git a/core/src/main/scala/org/apache/spark/Dependency.scala b/core/src/main/scala/org/apache/spark/Dependency.scala index cc3c2474a6..cc30105940 100644 --- a/core/src/main/scala/org/apache/spark/Dependency.scala +++ b/core/src/main/scala/org/apache/spark/Dependency.scala @@ -17,6 +17,8 @@ package org.apache.spark +import org.apache.spark.rdd.RDD + /** * Base class for dependencies. */ diff --git a/core/src/main/scala/org/apache/spark/DoubleRDDFunctions.scala b/core/src/main/scala/org/apache/spark/DoubleRDDFunctions.scala deleted file mode 100644 index dd344491b8..0000000000 --- a/core/src/main/scala/org/apache/spark/DoubleRDDFunctions.scala +++ /dev/null @@ -1,78 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark - -import org.apache.spark.partial.BoundedDouble -import org.apache.spark.partial.MeanEvaluator -import org.apache.spark.partial.PartialResult -import org.apache.spark.partial.SumEvaluator -import org.apache.spark.util.StatCounter - -/** - * Extra functions available on RDDs of Doubles through an implicit conversion. - * Import `spark.SparkContext._` at the top of your program to use these functions. - */ -class DoubleRDDFunctions(self: RDD[Double]) extends Logging with Serializable { - /** Add up the elements in this RDD. */ - def sum(): Double = { - self.reduce(_ + _) - } - - /** - * Return a [[org.apache.spark.util.StatCounter]] object that captures the mean, variance and count - * of the RDD's elements in one operation. - */ - def stats(): StatCounter = { - self.mapPartitions(nums => Iterator(StatCounter(nums))).reduce((a, b) => a.merge(b)) - } - - /** Compute the mean of this RDD's elements. */ - def mean(): Double = stats().mean - - /** Compute the variance of this RDD's elements. */ - def variance(): Double = stats().variance - - /** Compute the standard deviation of this RDD's elements. */ - def stdev(): Double = stats().stdev - - /** - * Compute the sample standard deviation of this RDD's elements (which corrects for bias in - * estimating the standard deviation by dividing by N-1 instead of N). - */ - def sampleStdev(): Double = stats().sampleStdev - - /** - * Compute the sample variance of this RDD's elements (which corrects for bias in - * estimating the variance by dividing by N-1 instead of N). - */ - def sampleVariance(): Double = stats().sampleVariance - - /** (Experimental) Approximate operation to return the mean within a timeout. */ - def meanApprox(timeout: Long, confidence: Double = 0.95): PartialResult[BoundedDouble] = { - val processPartition = (ctx: TaskContext, ns: Iterator[Double]) => StatCounter(ns) - val evaluator = new MeanEvaluator(self.partitions.size, confidence) - self.context.runApproximateJob(self, processPartition, evaluator, timeout) - } - - /** (Experimental) Approximate operation to return the sum within a timeout. */ - def sumApprox(timeout: Long, confidence: Double = 0.95): PartialResult[BoundedDouble] = { - val processPartition = (ctx: TaskContext, ns: Iterator[Double]) => StatCounter(ns) - val evaluator = new SumEvaluator(self.partitions.size, confidence) - self.context.runApproximateJob(self, processPartition, evaluator, timeout) - } -} diff --git a/core/src/main/scala/org/apache/spark/HttpFileServer.scala b/core/src/main/scala/org/apache/spark/HttpFileServer.scala index 9b3a896648..ad1ee20045 100644 --- a/core/src/main/scala/org/apache/spark/HttpFileServer.scala +++ b/core/src/main/scala/org/apache/spark/HttpFileServer.scala @@ -19,6 +19,7 @@ package org.apache.spark import java.io.{File} import com.google.common.io.Files +import org.apache.spark.util.Utils private[spark] class HttpFileServer extends Logging { diff --git a/core/src/main/scala/org/apache/spark/HttpServer.scala b/core/src/main/scala/org/apache/spark/HttpServer.scala index db36c7c9dd..cdfc9dd54e 100644 --- a/core/src/main/scala/org/apache/spark/HttpServer.scala +++ b/core/src/main/scala/org/apache/spark/HttpServer.scala @@ -26,6 +26,7 @@ import org.eclipse.jetty.server.handler.DefaultHandler import org.eclipse.jetty.server.handler.HandlerList import org.eclipse.jetty.server.handler.ResourceHandler import org.eclipse.jetty.util.thread.QueuedThreadPool +import org.apache.spark.util.Utils /** * Exception type thrown by HttpServer when it is in the wrong state for an operation. diff --git a/core/src/main/scala/org/apache/spark/JavaSerializer.scala b/core/src/main/scala/org/apache/spark/JavaSerializer.scala deleted file mode 100644 index f43396cb6b..0000000000 --- a/core/src/main/scala/org/apache/spark/JavaSerializer.scala +++ /dev/null @@ -1,83 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark - -import java.io._ -import java.nio.ByteBuffer - -import serializer.{Serializer, SerializerInstance, DeserializationStream, SerializationStream} -import org.apache.spark.util.ByteBufferInputStream - -private[spark] class JavaSerializationStream(out: OutputStream) extends SerializationStream { - val objOut = new ObjectOutputStream(out) - def writeObject[T](t: T): SerializationStream = { objOut.writeObject(t); this } - def flush() { objOut.flush() } - def close() { objOut.close() } -} - -private[spark] class JavaDeserializationStream(in: InputStream, loader: ClassLoader) -extends DeserializationStream { - val objIn = new ObjectInputStream(in) { - override def resolveClass(desc: ObjectStreamClass) = - Class.forName(desc.getName, false, loader) - } - - def readObject[T](): T = objIn.readObject().asInstanceOf[T] - def close() { objIn.close() } -} - -private[spark] class JavaSerializerInstance extends SerializerInstance { - def serialize[T](t: T): ByteBuffer = { - val bos = new ByteArrayOutputStream() - val out = serializeStream(bos) - out.writeObject(t) - out.close() - ByteBuffer.wrap(bos.toByteArray) - } - - def deserialize[T](bytes: ByteBuffer): T = { - val bis = new ByteBufferInputStream(bytes) - val in = deserializeStream(bis) - in.readObject().asInstanceOf[T] - } - - def deserialize[T](bytes: ByteBuffer, loader: ClassLoader): T = { - val bis = new ByteBufferInputStream(bytes) - val in = deserializeStream(bis, loader) - in.readObject().asInstanceOf[T] - } - - def serializeStream(s: OutputStream): SerializationStream = { - new JavaSerializationStream(s) - } - - def deserializeStream(s: InputStream): DeserializationStream = { - new JavaDeserializationStream(s, Thread.currentThread.getContextClassLoader) - } - - def deserializeStream(s: InputStream, loader: ClassLoader): DeserializationStream = { - new JavaDeserializationStream(s, loader) - } -} - -/** - * A Spark serializer that uses Java's built-in serialization. - */ -class JavaSerializer extends Serializer { - def newInstance(): SerializerInstance = new JavaSerializerInstance -} diff --git a/core/src/main/scala/org/apache/spark/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/KryoSerializer.scala deleted file mode 100644 index db86e6db43..0000000000 --- a/core/src/main/scala/org/apache/spark/KryoSerializer.scala +++ /dev/null @@ -1,156 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark - -import java.io._ -import java.nio.ByteBuffer -import com.esotericsoftware.kryo.{Kryo, KryoException} -import com.esotericsoftware.kryo.io.{Input => KryoInput, Output => KryoOutput} -import com.esotericsoftware.kryo.serializers.{JavaSerializer => KryoJavaSerializer} -import com.twitter.chill.ScalaKryoInstantiator -import serializer.{SerializerInstance, DeserializationStream, SerializationStream} -import org.apache.spark.broadcast._ -import org.apache.spark.storage._ - -private[spark] -class KryoSerializationStream(kryo: Kryo, outStream: OutputStream) extends SerializationStream { - val output = new KryoOutput(outStream) - - def writeObject[T](t: T): SerializationStream = { - kryo.writeClassAndObject(output, t) - this - } - - def flush() { output.flush() } - def close() { output.close() } -} - -private[spark] -class KryoDeserializationStream(kryo: Kryo, inStream: InputStream) extends DeserializationStream { - val input = new KryoInput(inStream) - - def readObject[T](): T = { - try { - kryo.readClassAndObject(input).asInstanceOf[T] - } catch { - // DeserializationStream uses the EOF exception to indicate stopping condition. - case _: KryoException => throw new EOFException - } - } - - def close() { - // Kryo's Input automatically closes the input stream it is using. - input.close() - } -} - -private[spark] class KryoSerializerInstance(ks: KryoSerializer) extends SerializerInstance { - val kryo = ks.newKryo() - val output = ks.newKryoOutput() - val input = ks.newKryoInput() - - def serialize[T](t: T): ByteBuffer = { - output.clear() - kryo.writeClassAndObject(output, t) - ByteBuffer.wrap(output.toBytes) - } - - def deserialize[T](bytes: ByteBuffer): T = { - input.setBuffer(bytes.array) - kryo.readClassAndObject(input).asInstanceOf[T] - } - - def deserialize[T](bytes: ByteBuffer, loader: ClassLoader): T = { - val oldClassLoader = kryo.getClassLoader - kryo.setClassLoader(loader) - input.setBuffer(bytes.array) - val obj = kryo.readClassAndObject(input).asInstanceOf[T] - kryo.setClassLoader(oldClassLoader) - obj - } - - def serializeStream(s: OutputStream): SerializationStream = { - new KryoSerializationStream(kryo, s) - } - - def deserializeStream(s: InputStream): DeserializationStream = { - new KryoDeserializationStream(kryo, s) - } -} - -/** - * Interface implemented by clients to register their classes with Kryo when using Kryo - * serialization. - */ -trait KryoRegistrator { - def registerClasses(kryo: Kryo) -} - -/** - * A Spark serializer that uses the [[http://code.google.com/p/kryo/wiki/V1Documentation Kryo 1.x library]]. - */ -class KryoSerializer extends org.apache.spark.serializer.Serializer with Logging { - private val bufferSize = System.getProperty("spark.kryoserializer.buffer.mb", "2").toInt * 1024 * 1024 - - def newKryoOutput() = new KryoOutput(bufferSize) - - def newKryoInput() = new KryoInput(bufferSize) - - def newKryo(): Kryo = { - val instantiator = new ScalaKryoInstantiator - val kryo = instantiator.newKryo() - val classLoader = Thread.currentThread.getContextClassLoader - - // Register some commonly used classes - val toRegister: Seq[AnyRef] = Seq( - ByteBuffer.allocate(1), - StorageLevel.MEMORY_ONLY, - PutBlock("1", ByteBuffer.allocate(1), StorageLevel.MEMORY_ONLY), - GotBlock("1", ByteBuffer.allocate(1)), - GetBlock("1") - ) - - for (obj <- toRegister) kryo.register(obj.getClass) - - // Allow sending SerializableWritable - kryo.register(classOf[SerializableWritable[_]], new KryoJavaSerializer()) - kryo.register(classOf[HttpBroadcast[_]], new KryoJavaSerializer()) - - // Allow the user to register their own classes by setting spark.kryo.registrator - try { - Option(System.getProperty("spark.kryo.registrator")).foreach { regCls => - logDebug("Running user registrator: " + regCls) - val reg = Class.forName(regCls, true, classLoader).newInstance().asInstanceOf[KryoRegistrator] - reg.registerClasses(kryo) - } - } catch { - case _: Exception => println("Failed to register spark.kryo.registrator") - } - - kryo.setClassLoader(classLoader) - - // Allow disabling Kryo reference tracking if user knows their object graphs don't have loops - kryo.setReferences(System.getProperty("spark.kryo.referenceTracking", "true").toBoolean) - - kryo - } - - def newInstance(): SerializerInstance = { - new KryoSerializerInstance(this) - } -} diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index 0f422d910a..ae7cf2a893 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -32,7 +32,7 @@ import akka.util.Duration import org.apache.spark.scheduler.MapStatus import org.apache.spark.storage.BlockManagerId -import org.apache.spark.util.{MetadataCleaner, TimeStampedHashMap} +import org.apache.spark.util.{Utils, MetadataCleaner, TimeStampedHashMap} private[spark] sealed trait MapOutputTrackerMessage diff --git a/core/src/main/scala/org/apache/spark/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/PairRDDFunctions.scala deleted file mode 100644 index d046e7c1a4..0000000000 --- a/core/src/main/scala/org/apache/spark/PairRDDFunctions.scala +++ /dev/null @@ -1,703 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark - -import java.nio.ByteBuffer -import java.util.{Date, HashMap => JHashMap} -import java.text.SimpleDateFormat - -import scala.collection.{mutable, Map} -import scala.collection.mutable.ArrayBuffer -import scala.collection.JavaConversions._ - -import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.Path -import org.apache.hadoop.io.compress.CompressionCodec -import org.apache.hadoop.io.SequenceFile.CompressionType -import org.apache.hadoop.mapred.FileOutputCommitter -import org.apache.hadoop.mapred.FileOutputFormat -import org.apache.hadoop.mapred.SparkHadoopWriter -import org.apache.hadoop.mapred.JobConf -import org.apache.hadoop.mapred.OutputFormat - -import org.apache.hadoop.mapreduce.lib.output.{FileOutputFormat => NewFileOutputFormat} -import org.apache.hadoop.mapreduce.{OutputFormat => NewOutputFormat, - RecordWriter => NewRecordWriter, Job => NewAPIHadoopJob, SparkHadoopMapReduceUtil} -import org.apache.hadoop.security.UserGroupInformation - -import org.apache.spark.partial.BoundedDouble -import org.apache.spark.partial.PartialResult -import org.apache.spark.rdd._ -import org.apache.spark.SparkContext._ -import org.apache.spark.Partitioner._ - -/** - * Extra functions available on RDDs of (key, value) pairs through an implicit conversion. - * Import `spark.SparkContext._` at the top of your program to use these functions. - */ -class PairRDDFunctions[K: ClassManifest, V: ClassManifest](self: RDD[(K, V)]) - extends Logging - with SparkHadoopMapReduceUtil - with Serializable { - - /** - * Generic function to combine the elements for each key using a custom set of aggregation - * functions. Turns an RDD[(K, V)] into a result of type RDD[(K, C)], for a "combined type" C - * Note that V and C can be different -- for example, one might group an RDD of type - * (Int, Int) into an RDD of type (Int, Seq[Int]). Users provide three functions: - * - * - `createCombiner`, which turns a V into a C (e.g., creates a one-element list) - * - `mergeValue`, to merge a V into a C (e.g., adds it to the end of a list) - * - `mergeCombiners`, to combine two C's into a single one. - * - * In addition, users can control the partitioning of the output RDD, and whether to perform - * map-side aggregation (if a mapper can produce multiple items with the same key). - */ - def combineByKey[C](createCombiner: V => C, - mergeValue: (C, V) => C, - mergeCombiners: (C, C) => C, - partitioner: Partitioner, - mapSideCombine: Boolean = true, - serializerClass: String = null): RDD[(K, C)] = { - if (getKeyClass().isArray) { - if (mapSideCombine) { - throw new SparkException("Cannot use map-side combining with array keys.") - } - if (partitioner.isInstanceOf[HashPartitioner]) { - throw new SparkException("Default partitioner cannot partition array keys.") - } - } - val aggregator = new Aggregator[K, V, C](createCombiner, mergeValue, mergeCombiners) - if (self.partitioner == Some(partitioner)) { - self.mapPartitions(aggregator.combineValuesByKey, preservesPartitioning = true) - } else if (mapSideCombine) { - val combined = self.mapPartitions(aggregator.combineValuesByKey, preservesPartitioning = true) - val partitioned = new ShuffledRDD[K, C, (K, C)](combined, partitioner) - .setSerializer(serializerClass) - partitioned.mapPartitions(aggregator.combineCombinersByKey, preservesPartitioning = true) - } else { - // Don't apply map-side combiner. - // A sanity check to make sure mergeCombiners is not defined. - assert(mergeCombiners == null) - val values = new ShuffledRDD[K, V, (K, V)](self, partitioner).setSerializer(serializerClass) - values.mapPartitions(aggregator.combineValuesByKey, preservesPartitioning = true) - } - } - - /** - * Simplified version of combineByKey that hash-partitions the output RDD. - */ - def combineByKey[C](createCombiner: V => C, - mergeValue: (C, V) => C, - mergeCombiners: (C, C) => C, - numPartitions: Int): RDD[(K, C)] = { - combineByKey(createCombiner, mergeValue, mergeCombiners, new HashPartitioner(numPartitions)) - } - - /** - * Merge the values for each key using an associative function and a neutral "zero value" which may - * be added to the result an arbitrary number of times, and must not change the result (e.g., Nil for - * list concatenation, 0 for addition, or 1 for multiplication.). - */ - def foldByKey(zeroValue: V, partitioner: Partitioner)(func: (V, V) => V): RDD[(K, V)] = { - // Serialize the zero value to a byte array so that we can get a new clone of it on each key - val zeroBuffer = SparkEnv.get.closureSerializer.newInstance().serialize(zeroValue) - val zeroArray = new Array[Byte](zeroBuffer.limit) - zeroBuffer.get(zeroArray) - - // When deserializing, use a lazy val to create just one instance of the serializer per task - lazy val cachedSerializer = SparkEnv.get.closureSerializer.newInstance() - def createZero() = cachedSerializer.deserialize[V](ByteBuffer.wrap(zeroArray)) - - combineByKey[V]((v: V) => func(createZero(), v), func, func, partitioner) - } - - /** - * Merge the values for each key using an associative function and a neutral "zero value" which may - * be added to the result an arbitrary number of times, and must not change the result (e.g., Nil for - * list concatenation, 0 for addition, or 1 for multiplication.). - */ - def foldByKey(zeroValue: V, numPartitions: Int)(func: (V, V) => V): RDD[(K, V)] = { - foldByKey(zeroValue, new HashPartitioner(numPartitions))(func) - } - - /** - * Merge the values for each key using an associative function and a neutral "zero value" which may - * be added to the result an arbitrary number of times, and must not change the result (e.g., Nil for - * list concatenation, 0 for addition, or 1 for multiplication.). - */ - def foldByKey(zeroValue: V)(func: (V, V) => V): RDD[(K, V)] = { - foldByKey(zeroValue, defaultPartitioner(self))(func) - } - - /** - * Merge the values for each key using an associative reduce function. This will also perform - * the merging locally on each mapper before sending results to a reducer, similarly to a - * "combiner" in MapReduce. - */ - def reduceByKey(partitioner: Partitioner, func: (V, V) => V): RDD[(K, V)] = { - combineByKey[V]((v: V) => v, func, func, partitioner) - } - - /** - * Merge the values for each key using an associative reduce function, but return the results - * immediately to the master as a Map. This will also perform the merging locally on each mapper - * before sending results to a reducer, similarly to a "combiner" in MapReduce. - */ - def reduceByKeyLocally(func: (V, V) => V): Map[K, V] = { - - if (getKeyClass().isArray) { - throw new SparkException("reduceByKeyLocally() does not support array keys") - } - - def reducePartition(iter: Iterator[(K, V)]): Iterator[JHashMap[K, V]] = { - val map = new JHashMap[K, V] - iter.foreach { case (k, v) => - val old = map.get(k) - map.put(k, if (old == null) v else func(old, v)) - } - Iterator(map) - } - - def mergeMaps(m1: JHashMap[K, V], m2: JHashMap[K, V]): JHashMap[K, V] = { - m2.foreach { case (k, v) => - val old = m1.get(k) - m1.put(k, if (old == null) v else func(old, v)) - } - m1 - } - - self.mapPartitions(reducePartition).reduce(mergeMaps) - } - - /** Alias for reduceByKeyLocally */ - def reduceByKeyToDriver(func: (V, V) => V): Map[K, V] = reduceByKeyLocally(func) - - /** Count the number of elements for each key, and return the result to the master as a Map. */ - def countByKey(): Map[K, Long] = self.map(_._1).countByValue() - - /** - * (Experimental) Approximate version of countByKey that can return a partial result if it does - * not finish within a timeout. - */ - def countByKeyApprox(timeout: Long, confidence: Double = 0.95) - : PartialResult[Map[K, BoundedDouble]] = { - self.map(_._1).countByValueApprox(timeout, confidence) - } - - /** - * Merge the values for each key using an associative reduce function. This will also perform - * the merging locally on each mapper before sending results to a reducer, similarly to a - * "combiner" in MapReduce. Output will be hash-partitioned with numPartitions partitions. - */ - def reduceByKey(func: (V, V) => V, numPartitions: Int): RDD[(K, V)] = { - reduceByKey(new HashPartitioner(numPartitions), func) - } - - /** - * Group the values for each key in the RDD into a single sequence. Allows controlling the - * partitioning of the resulting key-value pair RDD by passing a Partitioner. - */ - def groupByKey(partitioner: Partitioner): RDD[(K, Seq[V])] = { - // groupByKey shouldn't use map side combine because map side combine does not - // reduce the amount of data shuffled and requires all map side data be inserted - // into a hash table, leading to more objects in the old gen. - def createCombiner(v: V) = ArrayBuffer(v) - def mergeValue(buf: ArrayBuffer[V], v: V) = buf += v - val bufs = combineByKey[ArrayBuffer[V]]( - createCombiner _, mergeValue _, null, partitioner, mapSideCombine=false) - bufs.asInstanceOf[RDD[(K, Seq[V])]] - } - - /** - * Group the values for each key in the RDD into a single sequence. Hash-partitions the - * resulting RDD with into `numPartitions` partitions. - */ - def groupByKey(numPartitions: Int): RDD[(K, Seq[V])] = { - groupByKey(new HashPartitioner(numPartitions)) - } - - /** - * Return a copy of the RDD partitioned using the specified partitioner. - */ - def partitionBy(partitioner: Partitioner): RDD[(K, V)] = { - if (getKeyClass().isArray && partitioner.isInstanceOf[HashPartitioner]) { - throw new SparkException("Default partitioner cannot partition array keys.") - } - new ShuffledRDD[K, V, (K, V)](self, partitioner) - } - - /** - * Return an RDD containing all pairs of elements with matching keys in `this` and `other`. Each - * pair of elements will be returned as a (k, (v1, v2)) tuple, where (k, v1) is in `this` and - * (k, v2) is in `other`. Uses the given Partitioner to partition the output RDD. - */ - def join[W](other: RDD[(K, W)], partitioner: Partitioner): RDD[(K, (V, W))] = { - this.cogroup(other, partitioner).flatMapValues { case (vs, ws) => - for (v <- vs.iterator; w <- ws.iterator) yield (v, w) - } - } - - /** - * Perform a left outer join of `this` and `other`. For each element (k, v) in `this`, the - * resulting RDD will either contain all pairs (k, (v, Some(w))) for w in `other`, or the - * pair (k, (v, None)) if no elements in `other` have key k. Uses the given Partitioner to - * partition the output RDD. - */ - def leftOuterJoin[W](other: RDD[(K, W)], partitioner: Partitioner): RDD[(K, (V, Option[W]))] = { - this.cogroup(other, partitioner).flatMapValues { case (vs, ws) => - if (ws.isEmpty) { - vs.iterator.map(v => (v, None)) - } else { - for (v <- vs.iterator; w <- ws.iterator) yield (v, Some(w)) - } - } - } - - /** - * Perform a right outer join of `this` and `other`. For each element (k, w) in `other`, the - * resulting RDD will either contain all pairs (k, (Some(v), w)) for v in `this`, or the - * pair (k, (None, w)) if no elements in `this` have key k. Uses the given Partitioner to - * partition the output RDD. - */ - def rightOuterJoin[W](other: RDD[(K, W)], partitioner: Partitioner) - : RDD[(K, (Option[V], W))] = { - this.cogroup(other, partitioner).flatMapValues { case (vs, ws) => - if (vs.isEmpty) { - ws.iterator.map(w => (None, w)) - } else { - for (v <- vs.iterator; w <- ws.iterator) yield (Some(v), w) - } - } - } - - /** - * Simplified version of combineByKey that hash-partitions the resulting RDD using the - * existing partitioner/parallelism level. - */ - def combineByKey[C](createCombiner: V => C, mergeValue: (C, V) => C, mergeCombiners: (C, C) => C) - : RDD[(K, C)] = { - combineByKey(createCombiner, mergeValue, mergeCombiners, defaultPartitioner(self)) - } - - /** - * Merge the values for each key using an associative reduce function. This will also perform - * the merging locally on each mapper before sending results to a reducer, similarly to a - * "combiner" in MapReduce. Output will be hash-partitioned with the existing partitioner/ - * parallelism level. - */ - def reduceByKey(func: (V, V) => V): RDD[(K, V)] = { - reduceByKey(defaultPartitioner(self), func) - } - - /** - * Group the values for each key in the RDD into a single sequence. Hash-partitions the - * resulting RDD with the existing partitioner/parallelism level. - */ - def groupByKey(): RDD[(K, Seq[V])] = { - groupByKey(defaultPartitioner(self)) - } - - /** - * Return an RDD containing all pairs of elements with matching keys in `this` and `other`. Each - * pair of elements will be returned as a (k, (v1, v2)) tuple, where (k, v1) is in `this` and - * (k, v2) is in `other`. Performs a hash join across the cluster. - */ - def join[W](other: RDD[(K, W)]): RDD[(K, (V, W))] = { - join(other, defaultPartitioner(self, other)) - } - - /** - * Return an RDD containing all pairs of elements with matching keys in `this` and `other`. Each - * pair of elements will be returned as a (k, (v1, v2)) tuple, where (k, v1) is in `this` and - * (k, v2) is in `other`. Performs a hash join across the cluster. - */ - def join[W](other: RDD[(K, W)], numPartitions: Int): RDD[(K, (V, W))] = { - join(other, new HashPartitioner(numPartitions)) - } - - /** - * Perform a left outer join of `this` and `other`. For each element (k, v) in `this`, the - * resulting RDD will either contain all pairs (k, (v, Some(w))) for w in `other`, or the - * pair (k, (v, None)) if no elements in `other` have key k. Hash-partitions the output - * using the existing partitioner/parallelism level. - */ - def leftOuterJoin[W](other: RDD[(K, W)]): RDD[(K, (V, Option[W]))] = { - leftOuterJoin(other, defaultPartitioner(self, other)) - } - - /** - * Perform a left outer join of `this` and `other`. For each element (k, v) in `this`, the - * resulting RDD will either contain all pairs (k, (v, Some(w))) for w in `other`, or the - * pair (k, (v, None)) if no elements in `other` have key k. Hash-partitions the output - * into `numPartitions` partitions. - */ - def leftOuterJoin[W](other: RDD[(K, W)], numPartitions: Int): RDD[(K, (V, Option[W]))] = { - leftOuterJoin(other, new HashPartitioner(numPartitions)) - } - - /** - * Perform a right outer join of `this` and `other`. For each element (k, w) in `other`, the - * resulting RDD will either contain all pairs (k, (Some(v), w)) for v in `this`, or the - * pair (k, (None, w)) if no elements in `this` have key k. Hash-partitions the resulting - * RDD using the existing partitioner/parallelism level. - */ - def rightOuterJoin[W](other: RDD[(K, W)]): RDD[(K, (Option[V], W))] = { - rightOuterJoin(other, defaultPartitioner(self, other)) - } - - /** - * Perform a right outer join of `this` and `other`. For each element (k, w) in `other`, the - * resulting RDD will either contain all pairs (k, (Some(v), w)) for v in `this`, or the - * pair (k, (None, w)) if no elements in `this` have key k. Hash-partitions the resulting - * RDD into the given number of partitions. - */ - def rightOuterJoin[W](other: RDD[(K, W)], numPartitions: Int): RDD[(K, (Option[V], W))] = { - rightOuterJoin(other, new HashPartitioner(numPartitions)) - } - - /** - * Return the key-value pairs in this RDD to the master as a Map. - */ - def collectAsMap(): Map[K, V] = { - val data = self.toArray() - val map = new mutable.HashMap[K, V] - map.sizeHint(data.length) - data.foreach { case (k, v) => map.put(k, v) } - map - } - - /** - * Pass each value in the key-value pair RDD through a map function without changing the keys; - * this also retains the original RDD's partitioning. - */ - def mapValues[U](f: V => U): RDD[(K, U)] = { - val cleanF = self.context.clean(f) - new MappedValuesRDD(self, cleanF) - } - - /** - * Pass each value in the key-value pair RDD through a flatMap function without changing the - * keys; this also retains the original RDD's partitioning. - */ - def flatMapValues[U](f: V => TraversableOnce[U]): RDD[(K, U)] = { - val cleanF = self.context.clean(f) - new FlatMappedValuesRDD(self, cleanF) - } - - /** - * For each key k in `this` or `other`, return a resulting RDD that contains a tuple with the - * list of values for that key in `this` as well as `other`. - */ - def cogroup[W](other: RDD[(K, W)], partitioner: Partitioner): RDD[(K, (Seq[V], Seq[W]))] = { - if (partitioner.isInstanceOf[HashPartitioner] && getKeyClass().isArray) { - throw new SparkException("Default partitioner cannot partition array keys.") - } - val cg = new CoGroupedRDD[K](Seq(self, other), partitioner) - val prfs = new PairRDDFunctions[K, Seq[Seq[_]]](cg)(classManifest[K], Manifests.seqSeqManifest) - prfs.mapValues { case Seq(vs, ws) => - (vs.asInstanceOf[Seq[V]], ws.asInstanceOf[Seq[W]]) - } - } - - /** - * For each key k in `this` or `other1` or `other2`, return a resulting RDD that contains a - * tuple with the list of values for that key in `this`, `other1` and `other2`. - */ - def cogroup[W1, W2](other1: RDD[(K, W1)], other2: RDD[(K, W2)], partitioner: Partitioner) - : RDD[(K, (Seq[V], Seq[W1], Seq[W2]))] = { - if (partitioner.isInstanceOf[HashPartitioner] && getKeyClass().isArray) { - throw new SparkException("Default partitioner cannot partition array keys.") - } - val cg = new CoGroupedRDD[K](Seq(self, other1, other2), partitioner) - val prfs = new PairRDDFunctions[K, Seq[Seq[_]]](cg)(classManifest[K], Manifests.seqSeqManifest) - prfs.mapValues { case Seq(vs, w1s, w2s) => - (vs.asInstanceOf[Seq[V]], w1s.asInstanceOf[Seq[W1]], w2s.asInstanceOf[Seq[W2]]) - } - } - - /** - * For each key k in `this` or `other`, return a resulting RDD that contains a tuple with the - * list of values for that key in `this` as well as `other`. - */ - def cogroup[W](other: RDD[(K, W)]): RDD[(K, (Seq[V], Seq[W]))] = { - cogroup(other, defaultPartitioner(self, other)) - } - - /** - * For each key k in `this` or `other1` or `other2`, return a resulting RDD that contains a - * tuple with the list of values for that key in `this`, `other1` and `other2`. - */ - def cogroup[W1, W2](other1: RDD[(K, W1)], other2: RDD[(K, W2)]) - : RDD[(K, (Seq[V], Seq[W1], Seq[W2]))] = { - cogroup(other1, other2, defaultPartitioner(self, other1, other2)) - } - - /** - * For each key k in `this` or `other`, return a resulting RDD that contains a tuple with the - * list of values for that key in `this` as well as `other`. - */ - def cogroup[W](other: RDD[(K, W)], numPartitions: Int): RDD[(K, (Seq[V], Seq[W]))] = { - cogroup(other, new HashPartitioner(numPartitions)) - } - - /** - * For each key k in `this` or `other1` or `other2`, return a resulting RDD that contains a - * tuple with the list of values for that key in `this`, `other1` and `other2`. - */ - def cogroup[W1, W2](other1: RDD[(K, W1)], other2: RDD[(K, W2)], numPartitions: Int) - : RDD[(K, (Seq[V], Seq[W1], Seq[W2]))] = { - cogroup(other1, other2, new HashPartitioner(numPartitions)) - } - - /** Alias for cogroup. */ - def groupWith[W](other: RDD[(K, W)]): RDD[(K, (Seq[V], Seq[W]))] = { - cogroup(other, defaultPartitioner(self, other)) - } - - /** Alias for cogroup. */ - def groupWith[W1, W2](other1: RDD[(K, W1)], other2: RDD[(K, W2)]) - : RDD[(K, (Seq[V], Seq[W1], Seq[W2]))] = { - cogroup(other1, other2, defaultPartitioner(self, other1, other2)) - } - - /** - * Return an RDD with the pairs from `this` whose keys are not in `other`. - * - * Uses `this` partitioner/partition size, because even if `other` is huge, the resulting - * RDD will be <= us. - */ - def subtractByKey[W: ClassManifest](other: RDD[(K, W)]): RDD[(K, V)] = - subtractByKey(other, self.partitioner.getOrElse(new HashPartitioner(self.partitions.size))) - - /** Return an RDD with the pairs from `this` whose keys are not in `other`. */ - def subtractByKey[W: ClassManifest](other: RDD[(K, W)], numPartitions: Int): RDD[(K, V)] = - subtractByKey(other, new HashPartitioner(numPartitions)) - - /** Return an RDD with the pairs from `this` whose keys are not in `other`. */ - def subtractByKey[W: ClassManifest](other: RDD[(K, W)], p: Partitioner): RDD[(K, V)] = - new SubtractedRDD[K, V, W](self, other, p) - - /** - * Return the list of values in the RDD for key `key`. This operation is done efficiently if the - * RDD has a known partitioner by only searching the partition that the key maps to. - */ - def lookup(key: K): Seq[V] = { - self.partitioner match { - case Some(p) => - val index = p.getPartition(key) - def process(it: Iterator[(K, V)]): Seq[V] = { - val buf = new ArrayBuffer[V] - for ((k, v) <- it if k == key) { - buf += v - } - buf - } - val res = self.context.runJob(self, process _, Array(index), false) - res(0) - case None => - self.filter(_._1 == key).map(_._2).collect() - } - } - - /** - * Output the RDD to any Hadoop-supported file system, using a Hadoop `OutputFormat` class - * supporting the key and value types K and V in this RDD. - */ - def saveAsHadoopFile[F <: OutputFormat[K, V]](path: String)(implicit fm: ClassManifest[F]) { - saveAsHadoopFile(path, getKeyClass, getValueClass, fm.erasure.asInstanceOf[Class[F]]) - } - - /** - * Output the RDD to any Hadoop-supported file system, using a Hadoop `OutputFormat` class - * supporting the key and value types K and V in this RDD. Compress the result with the - * supplied codec. - */ - def saveAsHadoopFile[F <: OutputFormat[K, V]]( - path: String, codec: Class[_ <: CompressionCodec]) (implicit fm: ClassManifest[F]) { - saveAsHadoopFile(path, getKeyClass, getValueClass, fm.erasure.asInstanceOf[Class[F]], codec) - } - - /** - * Output the RDD to any Hadoop-supported file system, using a new Hadoop API `OutputFormat` - * (mapreduce.OutputFormat) object supporting the key and value types K and V in this RDD. - */ - def saveAsNewAPIHadoopFile[F <: NewOutputFormat[K, V]](path: String)(implicit fm: ClassManifest[F]) { - saveAsNewAPIHadoopFile(path, getKeyClass, getValueClass, fm.erasure.asInstanceOf[Class[F]]) - } - - /** - * Output the RDD to any Hadoop-supported file system, using a new Hadoop API `OutputFormat` - * (mapreduce.OutputFormat) object supporting the key and value types K and V in this RDD. - */ - def saveAsNewAPIHadoopFile( - path: String, - keyClass: Class[_], - valueClass: Class[_], - outputFormatClass: Class[_ <: NewOutputFormat[_, _]], - conf: Configuration = self.context.hadoopConfiguration) { - val job = new NewAPIHadoopJob(conf) - job.setOutputKeyClass(keyClass) - job.setOutputValueClass(valueClass) - val wrappedConf = new SerializableWritable(job.getConfiguration) - NewFileOutputFormat.setOutputPath(job, new Path(path)) - val formatter = new SimpleDateFormat("yyyyMMddHHmm") - val jobtrackerID = formatter.format(new Date()) - val stageId = self.id - def writeShard(context: TaskContext, iter: Iterator[(K,V)]): Int = { - // Hadoop wants a 32-bit task attempt ID, so if ours is bigger than Int.MaxValue, roll it - // around by taking a mod. We expect that no task will be attempted 2 billion times. - val attemptNumber = (context.attemptId % Int.MaxValue).toInt - /* "reduce task" */ - val attemptId = newTaskAttemptID(jobtrackerID, stageId, false, context.splitId, attemptNumber) - val hadoopContext = newTaskAttemptContext(wrappedConf.value, attemptId) - val format = outputFormatClass.newInstance - val committer = format.getOutputCommitter(hadoopContext) - committer.setupTask(hadoopContext) - val writer = format.getRecordWriter(hadoopContext).asInstanceOf[NewRecordWriter[K,V]] - while (iter.hasNext) { - val (k, v) = iter.next() - writer.write(k, v) - } - writer.close(hadoopContext) - committer.commitTask(hadoopContext) - return 1 - } - val jobFormat = outputFormatClass.newInstance - /* apparently we need a TaskAttemptID to construct an OutputCommitter; - * however we're only going to use this local OutputCommitter for - * setupJob/commitJob, so we just use a dummy "map" task. - */ - val jobAttemptId = newTaskAttemptID(jobtrackerID, stageId, true, 0, 0) - val jobTaskContext = newTaskAttemptContext(wrappedConf.value, jobAttemptId) - val jobCommitter = jobFormat.getOutputCommitter(jobTaskContext) - jobCommitter.setupJob(jobTaskContext) - val count = self.context.runJob(self, writeShard _).sum - jobCommitter.commitJob(jobTaskContext) - jobCommitter.cleanupJob(jobTaskContext) - } - - /** - * Output the RDD to any Hadoop-supported file system, using a Hadoop `OutputFormat` class - * supporting the key and value types K and V in this RDD. Compress with the supplied codec. - */ - def saveAsHadoopFile( - path: String, - keyClass: Class[_], - valueClass: Class[_], - outputFormatClass: Class[_ <: OutputFormat[_, _]], - codec: Class[_ <: CompressionCodec]) { - saveAsHadoopFile(path, keyClass, valueClass, outputFormatClass, - new JobConf(self.context.hadoopConfiguration), Some(codec)) - } - - /** - * Output the RDD to any Hadoop-supported file system, using a Hadoop `OutputFormat` class - * supporting the key and value types K and V in this RDD. - */ - def saveAsHadoopFile( - path: String, - keyClass: Class[_], - valueClass: Class[_], - outputFormatClass: Class[_ <: OutputFormat[_, _]], - conf: JobConf = new JobConf(self.context.hadoopConfiguration), - codec: Option[Class[_ <: CompressionCodec]] = None) { - conf.setOutputKeyClass(keyClass) - conf.setOutputValueClass(valueClass) - // conf.setOutputFormat(outputFormatClass) // Doesn't work in Scala 2.9 due to what may be a generics bug - conf.set("mapred.output.format.class", outputFormatClass.getName) - for (c <- codec) { - conf.setCompressMapOutput(true) - conf.set("mapred.output.compress", "true") - conf.setMapOutputCompressorClass(c) - conf.set("mapred.output.compression.codec", c.getCanonicalName) - conf.set("mapred.output.compression.type", CompressionType.BLOCK.toString) - } - conf.setOutputCommitter(classOf[FileOutputCommitter]) - FileOutputFormat.setOutputPath(conf, SparkHadoopWriter.createPathFromString(path, conf)) - saveAsHadoopDataset(conf) - } - - /** - * Output the RDD to any Hadoop-supported storage system, using a Hadoop JobConf object for - * that storage system. The JobConf should set an OutputFormat and any output paths required - * (e.g. a table name to write to) in the same way as it would be configured for a Hadoop - * MapReduce job. - */ - def saveAsHadoopDataset(conf: JobConf) { - val outputFormatClass = conf.getOutputFormat - val keyClass = conf.getOutputKeyClass - val valueClass = conf.getOutputValueClass - if (outputFormatClass == null) { - throw new SparkException("Output format class not set") - } - if (keyClass == null) { - throw new SparkException("Output key class not set") - } - if (valueClass == null) { - throw new SparkException("Output value class not set") - } - - logInfo("Saving as hadoop file of type (" + keyClass.getSimpleName+ ", " + valueClass.getSimpleName+ ")") - - val writer = new SparkHadoopWriter(conf) - writer.preSetup() - - def writeToFile(context: TaskContext, iter: Iterator[(K, V)]) { - // Hadoop wants a 32-bit task attempt ID, so if ours is bigger than Int.MaxValue, roll it - // around by taking a mod. We expect that no task will be attempted 2 billion times. - val attemptNumber = (context.attemptId % Int.MaxValue).toInt - - writer.setup(context.stageId, context.splitId, attemptNumber) - writer.open() - - var count = 0 - while(iter.hasNext) { - val record = iter.next() - count += 1 - writer.write(record._1.asInstanceOf[AnyRef], record._2.asInstanceOf[AnyRef]) - } - - writer.close() - writer.commit() - } - - self.context.runJob(self, writeToFile _) - writer.commitJob() - writer.cleanup() - } - - /** - * Return an RDD with the keys of each tuple. - */ - def keys: RDD[K] = self.map(_._1) - - /** - * Return an RDD with the values of each tuple. - */ - def values: RDD[V] = self.map(_._2) - - private[spark] def getKeyClass() = implicitly[ClassManifest[K]].erasure - - private[spark] def getValueClass() = implicitly[ClassManifest[V]].erasure -} - - -private[spark] object Manifests { - val seqSeqManifest = classManifest[Seq[Seq[_]]] -} diff --git a/core/src/main/scala/org/apache/spark/Partitioner.scala b/core/src/main/scala/org/apache/spark/Partitioner.scala index 4dce2607b0..0e2c987a59 100644 --- a/core/src/main/scala/org/apache/spark/Partitioner.scala +++ b/core/src/main/scala/org/apache/spark/Partitioner.scala @@ -17,6 +17,9 @@ package org.apache.spark +import org.apache.spark.util.Utils +import org.apache.spark.rdd.RDD + /** * An object that defines how the elements in a key-value pair RDD are partitioned by key. * Maps each key to a partition ID, from 0 to `numPartitions - 1`. diff --git a/core/src/main/scala/org/apache/spark/RDD.scala b/core/src/main/scala/org/apache/spark/RDD.scala deleted file mode 100644 index 0d1f07f76c..0000000000 --- a/core/src/main/scala/org/apache/spark/RDD.scala +++ /dev/null @@ -1,957 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark - -import java.util.Random - -import scala.collection.Map -import scala.collection.JavaConversions.mapAsScalaMap -import scala.collection.mutable.ArrayBuffer - -import org.apache.hadoop.io.BytesWritable -import org.apache.hadoop.io.compress.CompressionCodec -import org.apache.hadoop.io.NullWritable -import org.apache.hadoop.io.Text -import org.apache.hadoop.mapred.TextOutputFormat - -import it.unimi.dsi.fastutil.objects.{Object2LongOpenHashMap => OLMap} - -import org.apache.spark.Partitioner._ -import org.apache.spark.api.java.JavaRDD -import org.apache.spark.partial.BoundedDouble -import org.apache.spark.partial.CountEvaluator -import org.apache.spark.partial.GroupedCountEvaluator -import org.apache.spark.partial.PartialResult -import org.apache.spark.rdd.CoalescedRDD -import org.apache.spark.rdd.CartesianRDD -import org.apache.spark.rdd.FilteredRDD -import org.apache.spark.rdd.FlatMappedRDD -import org.apache.spark.rdd.GlommedRDD -import org.apache.spark.rdd.MappedRDD -import org.apache.spark.rdd.MapPartitionsRDD -import org.apache.spark.rdd.MapPartitionsWithIndexRDD -import org.apache.spark.rdd.PipedRDD -import org.apache.spark.rdd.SampledRDD -import org.apache.spark.rdd.ShuffledRDD -import org.apache.spark.rdd.UnionRDD -import org.apache.spark.rdd.ZippedRDD -import org.apache.spark.rdd.ZippedPartitionsRDD2 -import org.apache.spark.rdd.ZippedPartitionsRDD3 -import org.apache.spark.rdd.ZippedPartitionsRDD4 -import org.apache.spark.storage.StorageLevel -import org.apache.spark.util.BoundedPriorityQueue - -import SparkContext._ - -/** - * A Resilient Distributed Dataset (RDD), the basic abstraction in Spark. Represents an immutable, - * partitioned collection of elements that can be operated on in parallel. This class contains the - * basic operations available on all RDDs, such as `map`, `filter`, and `persist`. In addition, - * [[org.apache.spark.PairRDDFunctions]] contains operations available only on RDDs of key-value pairs, such - * as `groupByKey` and `join`; [[org.apache.spark.DoubleRDDFunctions]] contains operations available only on - * RDDs of Doubles; and [[org.apache.spark.SequenceFileRDDFunctions]] contains operations available on RDDs - * that can be saved as SequenceFiles. These operations are automatically available on any RDD of - * the right type (e.g. RDD[(Int, Int)] through implicit conversions when you - * `import org.apache.spark.SparkContext._`. - * - * Internally, each RDD is characterized by five main properties: - * - * - A list of partitions - * - A function for computing each split - * - A list of dependencies on other RDDs - * - Optionally, a Partitioner for key-value RDDs (e.g. to say that the RDD is hash-partitioned) - * - Optionally, a list of preferred locations to compute each split on (e.g. block locations for - * an HDFS file) - * - * All of the scheduling and execution in Spark is done based on these methods, allowing each RDD - * to implement its own way of computing itself. Indeed, users can implement custom RDDs (e.g. for - * reading data from a new storage system) by overriding these functions. Please refer to the - * [[http://www.cs.berkeley.edu/~matei/papers/2012/nsdi_spark.pdf Spark paper]] for more details - * on RDD internals. - */ -abstract class RDD[T: ClassManifest]( - @transient private var sc: SparkContext, - @transient private var deps: Seq[Dependency[_]] - ) extends Serializable with Logging { - - /** Construct an RDD with just a one-to-one dependency on one parent */ - def this(@transient oneParent: RDD[_]) = - this(oneParent.context , List(new OneToOneDependency(oneParent))) - - // ======================================================================= - // Methods that should be implemented by subclasses of RDD - // ======================================================================= - - /** Implemented by subclasses to compute a given partition. */ - def compute(split: Partition, context: TaskContext): Iterator[T] - - /** - * Implemented by subclasses to return the set of partitions in this RDD. This method will only - * be called once, so it is safe to implement a time-consuming computation in it. - */ - protected def getPartitions: Array[Partition] - - /** - * Implemented by subclasses to return how this RDD depends on parent RDDs. This method will only - * be called once, so it is safe to implement a time-consuming computation in it. - */ - protected def getDependencies: Seq[Dependency[_]] = deps - - /** Optionally overridden by subclasses to specify placement preferences. */ - protected def getPreferredLocations(split: Partition): Seq[String] = Nil - - /** Optionally overridden by subclasses to specify how they are partitioned. */ - val partitioner: Option[Partitioner] = None - - // ======================================================================= - // Methods and fields available on all RDDs - // ======================================================================= - - /** The SparkContext that created this RDD. */ - def sparkContext: SparkContext = sc - - /** A unique ID for this RDD (within its SparkContext). */ - val id: Int = sc.newRddId() - - /** A friendly name for this RDD */ - var name: String = null - - /** Assign a name to this RDD */ - def setName(_name: String) = { - name = _name - this - } - - /** User-defined generator of this RDD*/ - var generator = Utils.getCallSiteInfo.firstUserClass - - /** Reset generator*/ - def setGenerator(_generator: String) = { - generator = _generator - } - - /** - * Set this RDD's storage level to persist its values across operations after the first time - * it is computed. This can only be used to assign a new storage level if the RDD does not - * have a storage level set yet.. - */ - def persist(newLevel: StorageLevel): RDD[T] = { - // TODO: Handle changes of StorageLevel - if (storageLevel != StorageLevel.NONE && newLevel != storageLevel) { - throw new UnsupportedOperationException( - "Cannot change storage level of an RDD after it was already assigned a level") - } - storageLevel = newLevel - // Register the RDD with the SparkContext - sc.persistentRdds(id) = this - this - } - - /** Persist this RDD with the default storage level (`MEMORY_ONLY`). */ - def persist(): RDD[T] = persist(StorageLevel.MEMORY_ONLY) - - /** Persist this RDD with the default storage level (`MEMORY_ONLY`). */ - def cache(): RDD[T] = persist() - - /** - * Mark the RDD as non-persistent, and remove all blocks for it from memory and disk. - * - * @param blocking Whether to block until all blocks are deleted. - * @return This RDD. - */ - def unpersist(blocking: Boolean = true): RDD[T] = { - logInfo("Removing RDD " + id + " from persistence list") - sc.env.blockManager.master.removeRdd(id, blocking) - sc.persistentRdds.remove(id) - storageLevel = StorageLevel.NONE - this - } - - /** Get the RDD's current storage level, or StorageLevel.NONE if none is set. */ - def getStorageLevel = storageLevel - - // Our dependencies and partitions will be gotten by calling subclass's methods below, and will - // be overwritten when we're checkpointed - private var dependencies_ : Seq[Dependency[_]] = null - @transient private var partitions_ : Array[Partition] = null - - /** An Option holding our checkpoint RDD, if we are checkpointed */ - private def checkpointRDD: Option[RDD[T]] = checkpointData.flatMap(_.checkpointRDD) - - /** - * Get the list of dependencies of this RDD, taking into account whether the - * RDD is checkpointed or not. - */ - final def dependencies: Seq[Dependency[_]] = { - checkpointRDD.map(r => List(new OneToOneDependency(r))).getOrElse { - if (dependencies_ == null) { - dependencies_ = getDependencies - } - dependencies_ - } - } - - /** - * Get the array of partitions of this RDD, taking into account whether the - * RDD is checkpointed or not. - */ - final def partitions: Array[Partition] = { - checkpointRDD.map(_.partitions).getOrElse { - if (partitions_ == null) { - partitions_ = getPartitions - } - partitions_ - } - } - - /** - * Get the preferred locations of a partition (as hostnames), taking into account whether the - * RDD is checkpointed. - */ - final def preferredLocations(split: Partition): Seq[String] = { - checkpointRDD.map(_.getPreferredLocations(split)).getOrElse { - getPreferredLocations(split) - } - } - - /** - * Internal method to this RDD; will read from cache if applicable, or otherwise compute it. - * This should ''not'' be called by users directly, but is available for implementors of custom - * subclasses of RDD. - */ - final def iterator(split: Partition, context: TaskContext): Iterator[T] = { - if (storageLevel != StorageLevel.NONE) { - SparkEnv.get.cacheManager.getOrCompute(this, split, context, storageLevel) - } else { - computeOrReadCheckpoint(split, context) - } - } - - /** - * Compute an RDD partition or read it from a checkpoint if the RDD is checkpointing. - */ - private[spark] def computeOrReadCheckpoint(split: Partition, context: TaskContext): Iterator[T] = { - if (isCheckpointed) { - firstParent[T].iterator(split, context) - } else { - compute(split, context) - } - } - - // Transformations (return a new RDD) - - /** - * Return a new RDD by applying a function to all elements of this RDD. - */ - def map[U: ClassManifest](f: T => U): RDD[U] = new MappedRDD(this, sc.clean(f)) - - /** - * Return a new RDD by first applying a function to all elements of this - * RDD, and then flattening the results. - */ - def flatMap[U: ClassManifest](f: T => TraversableOnce[U]): RDD[U] = - new FlatMappedRDD(this, sc.clean(f)) - - /** - * Return a new RDD containing only the elements that satisfy a predicate. - */ - def filter(f: T => Boolean): RDD[T] = new FilteredRDD(this, sc.clean(f)) - - /** - * Return a new RDD containing the distinct elements in this RDD. - */ - def distinct(numPartitions: Int): RDD[T] = - map(x => (x, null)).reduceByKey((x, y) => x, numPartitions).map(_._1) - - def distinct(): RDD[T] = distinct(partitions.size) - - /** - * Return a new RDD that is reduced into `numPartitions` partitions. - */ - def coalesce(numPartitions: Int, shuffle: Boolean = false): RDD[T] = { - if (shuffle) { - // include a shuffle step so that our upstream tasks are still distributed - new CoalescedRDD( - new ShuffledRDD[T, Null, (T, Null)](map(x => (x, null)), - new HashPartitioner(numPartitions)), - numPartitions).keys - } else { - new CoalescedRDD(this, numPartitions) - } - } - - /** - * Return a sampled subset of this RDD. - */ - def sample(withReplacement: Boolean, fraction: Double, seed: Int): RDD[T] = - new SampledRDD(this, withReplacement, fraction, seed) - - def takeSample(withReplacement: Boolean, num: Int, seed: Int): Array[T] = { - var fraction = 0.0 - var total = 0 - val multiplier = 3.0 - val initialCount = this.count() - var maxSelected = 0 - - if (num < 0) { - throw new IllegalArgumentException("Negative number of elements requested") - } - - if (initialCount > Integer.MAX_VALUE - 1) { - maxSelected = Integer.MAX_VALUE - 1 - } else { - maxSelected = initialCount.toInt - } - - if (num > initialCount && !withReplacement) { - total = maxSelected - fraction = multiplier * (maxSelected + 1) / initialCount - } else { - fraction = multiplier * (num + 1) / initialCount - total = num - } - - val rand = new Random(seed) - var samples = this.sample(withReplacement, fraction, rand.nextInt()).collect() - - // If the first sample didn't turn out large enough, keep trying to take samples; - // this shouldn't happen often because we use a big multiplier for thei initial size - while (samples.length < total) { - samples = this.sample(withReplacement, fraction, rand.nextInt()).collect() - } - - Utils.randomizeInPlace(samples, rand).take(total) - } - - /** - * Return the union of this RDD and another one. Any identical elements will appear multiple - * times (use `.distinct()` to eliminate them). - */ - def union(other: RDD[T]): RDD[T] = new UnionRDD(sc, Array(this, other)) - - /** - * Return the union of this RDD and another one. Any identical elements will appear multiple - * times (use `.distinct()` to eliminate them). - */ - def ++(other: RDD[T]): RDD[T] = this.union(other) - - /** - * Return an RDD created by coalescing all elements within each partition into an array. - */ - def glom(): RDD[Array[T]] = new GlommedRDD(this) - - /** - * Return the Cartesian product of this RDD and another one, that is, the RDD of all pairs of - * elements (a, b) where a is in `this` and b is in `other`. - */ - def cartesian[U: ClassManifest](other: RDD[U]): RDD[(T, U)] = new CartesianRDD(sc, this, other) - - /** - * Return an RDD of grouped items. - */ - def groupBy[K: ClassManifest](f: T => K): RDD[(K, Seq[T])] = - groupBy[K](f, defaultPartitioner(this)) - - /** - * Return an RDD of grouped elements. Each group consists of a key and a sequence of elements - * mapping to that key. - */ - def groupBy[K: ClassManifest](f: T => K, numPartitions: Int): RDD[(K, Seq[T])] = - groupBy(f, new HashPartitioner(numPartitions)) - - /** - * Return an RDD of grouped items. - */ - def groupBy[K: ClassManifest](f: T => K, p: Partitioner): RDD[(K, Seq[T])] = { - val cleanF = sc.clean(f) - this.map(t => (cleanF(t), t)).groupByKey(p) - } - - /** - * Return an RDD created by piping elements to a forked external process. - */ - def pipe(command: String): RDD[String] = new PipedRDD(this, command) - - /** - * Return an RDD created by piping elements to a forked external process. - */ - def pipe(command: String, env: Map[String, String]): RDD[String] = - new PipedRDD(this, command, env) - - - /** - * Return an RDD created by piping elements to a forked external process. - * The print behavior can be customized by providing two functions. - * - * @param command command to run in forked process. - * @param env environment variables to set. - * @param printPipeContext Before piping elements, this function is called as an oppotunity - * to pipe context data. Print line function (like out.println) will be - * passed as printPipeContext's parameter. - * @param printRDDElement Use this function to customize how to pipe elements. This function - * will be called with each RDD element as the 1st parameter, and the - * print line function (like out.println()) as the 2nd parameter. - * An example of pipe the RDD data of groupBy() in a streaming way, - * instead of constructing a huge String to concat all the elements: - * def printRDDElement(record:(String, Seq[String]), f:String=>Unit) = - * for (e <- record._2){f(e)} - * @return the result RDD - */ - def pipe( - command: Seq[String], - env: Map[String, String] = Map(), - printPipeContext: (String => Unit) => Unit = null, - printRDDElement: (T, String => Unit) => Unit = null): RDD[String] = - new PipedRDD(this, command, env, - if (printPipeContext ne null) sc.clean(printPipeContext) else null, - if (printRDDElement ne null) sc.clean(printRDDElement) else null) - - /** - * Return a new RDD by applying a function to each partition of this RDD. - */ - def mapPartitions[U: ClassManifest](f: Iterator[T] => Iterator[U], - preservesPartitioning: Boolean = false): RDD[U] = - new MapPartitionsRDD(this, sc.clean(f), preservesPartitioning) - - /** - * Return a new RDD by applying a function to each partition of this RDD, while tracking the index - * of the original partition. - */ - def mapPartitionsWithIndex[U: ClassManifest]( - f: (Int, Iterator[T]) => Iterator[U], - preservesPartitioning: Boolean = false): RDD[U] = - new MapPartitionsWithIndexRDD(this, sc.clean(f), preservesPartitioning) - - /** - * Return a new RDD by applying a function to each partition of this RDD, while tracking the index - * of the original partition. - */ - @deprecated("use mapPartitionsWithIndex", "0.7.0") - def mapPartitionsWithSplit[U: ClassManifest]( - f: (Int, Iterator[T]) => Iterator[U], - preservesPartitioning: Boolean = false): RDD[U] = - new MapPartitionsWithIndexRDD(this, sc.clean(f), preservesPartitioning) - - /** - * Maps f over this RDD, where f takes an additional parameter of type A. This - * additional parameter is produced by constructA, which is called in each - * partition with the index of that partition. - */ - def mapWith[A: ClassManifest, U: ClassManifest](constructA: Int => A, preservesPartitioning: Boolean = false) - (f:(T, A) => U): RDD[U] = { - def iterF(index: Int, iter: Iterator[T]): Iterator[U] = { - val a = constructA(index) - iter.map(t => f(t, a)) - } - new MapPartitionsWithIndexRDD(this, sc.clean(iterF _), preservesPartitioning) - } - - /** - * FlatMaps f over this RDD, where f takes an additional parameter of type A. This - * additional parameter is produced by constructA, which is called in each - * partition with the index of that partition. - */ - def flatMapWith[A: ClassManifest, U: ClassManifest](constructA: Int => A, preservesPartitioning: Boolean = false) - (f:(T, A) => Seq[U]): RDD[U] = { - def iterF(index: Int, iter: Iterator[T]): Iterator[U] = { - val a = constructA(index) - iter.flatMap(t => f(t, a)) - } - new MapPartitionsWithIndexRDD(this, sc.clean(iterF _), preservesPartitioning) - } - - /** - * Applies f to each element of this RDD, where f takes an additional parameter of type A. - * This additional parameter is produced by constructA, which is called in each - * partition with the index of that partition. - */ - def foreachWith[A: ClassManifest](constructA: Int => A) - (f:(T, A) => Unit) { - def iterF(index: Int, iter: Iterator[T]): Iterator[T] = { - val a = constructA(index) - iter.map(t => {f(t, a); t}) - } - (new MapPartitionsWithIndexRDD(this, sc.clean(iterF _), true)).foreach(_ => {}) - } - - /** - * Filters this RDD with p, where p takes an additional parameter of type A. This - * additional parameter is produced by constructA, which is called in each - * partition with the index of that partition. - */ - def filterWith[A: ClassManifest](constructA: Int => A) - (p:(T, A) => Boolean): RDD[T] = { - def iterF(index: Int, iter: Iterator[T]): Iterator[T] = { - val a = constructA(index) - iter.filter(t => p(t, a)) - } - new MapPartitionsWithIndexRDD(this, sc.clean(iterF _), true) - } - - /** - * Zips this RDD with another one, returning key-value pairs with the first element in each RDD, - * second element in each RDD, etc. Assumes that the two RDDs have the *same number of - * partitions* and the *same number of elements in each partition* (e.g. one was made through - * a map on the other). - */ - def zip[U: ClassManifest](other: RDD[U]): RDD[(T, U)] = new ZippedRDD(sc, this, other) - - /** - * Zip this RDD's partitions with one (or more) RDD(s) and return a new RDD by - * applying a function to the zipped partitions. Assumes that all the RDDs have the - * *same number of partitions*, but does *not* require them to have the same number - * of elements in each partition. - */ - def zipPartitions[B: ClassManifest, V: ClassManifest] - (rdd2: RDD[B]) - (f: (Iterator[T], Iterator[B]) => Iterator[V]): RDD[V] = - new ZippedPartitionsRDD2(sc, sc.clean(f), this, rdd2) - - def zipPartitions[B: ClassManifest, C: ClassManifest, V: ClassManifest] - (rdd2: RDD[B], rdd3: RDD[C]) - (f: (Iterator[T], Iterator[B], Iterator[C]) => Iterator[V]): RDD[V] = - new ZippedPartitionsRDD3(sc, sc.clean(f), this, rdd2, rdd3) - - def zipPartitions[B: ClassManifest, C: ClassManifest, D: ClassManifest, V: ClassManifest] - (rdd2: RDD[B], rdd3: RDD[C], rdd4: RDD[D]) - (f: (Iterator[T], Iterator[B], Iterator[C], Iterator[D]) => Iterator[V]): RDD[V] = - new ZippedPartitionsRDD4(sc, sc.clean(f), this, rdd2, rdd3, rdd4) - - - // Actions (launch a job to return a value to the user program) - - /** - * Applies a function f to all elements of this RDD. - */ - def foreach(f: T => Unit) { - val cleanF = sc.clean(f) - sc.runJob(this, (iter: Iterator[T]) => iter.foreach(cleanF)) - } - - /** - * Applies a function f to each partition of this RDD. - */ - def foreachPartition(f: Iterator[T] => Unit) { - val cleanF = sc.clean(f) - sc.runJob(this, (iter: Iterator[T]) => cleanF(iter)) - } - - /** - * Return an array that contains all of the elements in this RDD. - */ - def collect(): Array[T] = { - val results = sc.runJob(this, (iter: Iterator[T]) => iter.toArray) - Array.concat(results: _*) - } - - /** - * Return an array that contains all of the elements in this RDD. - */ - def toArray(): Array[T] = collect() - - /** - * Return an RDD that contains all matching values by applying `f`. - */ - def collect[U: ClassManifest](f: PartialFunction[T, U]): RDD[U] = { - filter(f.isDefinedAt).map(f) - } - - /** - * Return an RDD with the elements from `this` that are not in `other`. - * - * Uses `this` partitioner/partition size, because even if `other` is huge, the resulting - * RDD will be <= us. - */ - def subtract(other: RDD[T]): RDD[T] = - subtract(other, partitioner.getOrElse(new HashPartitioner(partitions.size))) - - /** - * Return an RDD with the elements from `this` that are not in `other`. - */ - def subtract(other: RDD[T], numPartitions: Int): RDD[T] = - subtract(other, new HashPartitioner(numPartitions)) - - /** - * Return an RDD with the elements from `this` that are not in `other`. - */ - def subtract(other: RDD[T], p: Partitioner): RDD[T] = { - if (partitioner == Some(p)) { - // Our partitioner knows how to handle T (which, since we have a partitioner, is - // really (K, V)) so make a new Partitioner that will de-tuple our fake tuples - val p2 = new Partitioner() { - override def numPartitions = p.numPartitions - override def getPartition(k: Any) = p.getPartition(k.asInstanceOf[(Any, _)]._1) - } - // Unfortunately, since we're making a new p2, we'll get ShuffleDependencies - // anyway, and when calling .keys, will not have a partitioner set, even though - // the SubtractedRDD will, thanks to p2's de-tupled partitioning, already be - // partitioned by the right/real keys (e.g. p). - this.map(x => (x, null)).subtractByKey(other.map((_, null)), p2).keys - } else { - this.map(x => (x, null)).subtractByKey(other.map((_, null)), p).keys - } - } - - /** - * Reduces the elements of this RDD using the specified commutative and associative binary operator. - */ - def reduce(f: (T, T) => T): T = { - val cleanF = sc.clean(f) - val reducePartition: Iterator[T] => Option[T] = iter => { - if (iter.hasNext) { - Some(iter.reduceLeft(cleanF)) - } else { - None - } - } - var jobResult: Option[T] = None - val mergeResult = (index: Int, taskResult: Option[T]) => { - if (taskResult != None) { - jobResult = jobResult match { - case Some(value) => Some(f(value, taskResult.get)) - case None => taskResult - } - } - } - sc.runJob(this, reducePartition, mergeResult) - // Get the final result out of our Option, or throw an exception if the RDD was empty - jobResult.getOrElse(throw new UnsupportedOperationException("empty collection")) - } - - /** - * Aggregate the elements of each partition, and then the results for all the partitions, using a - * given associative function and a neutral "zero value". The function op(t1, t2) is allowed to - * modify t1 and return it as its result value to avoid object allocation; however, it should not - * modify t2. - */ - def fold(zeroValue: T)(op: (T, T) => T): T = { - // Clone the zero value since we will also be serializing it as part of tasks - var jobResult = Utils.clone(zeroValue, sc.env.closureSerializer.newInstance()) - val cleanOp = sc.clean(op) - val foldPartition = (iter: Iterator[T]) => iter.fold(zeroValue)(cleanOp) - val mergeResult = (index: Int, taskResult: T) => jobResult = op(jobResult, taskResult) - sc.runJob(this, foldPartition, mergeResult) - jobResult - } - - /** - * Aggregate the elements of each partition, and then the results for all the partitions, using - * given combine functions and a neutral "zero value". This function can return a different result - * type, U, than the type of this RDD, T. Thus, we need one operation for merging a T into an U - * and one operation for merging two U's, as in scala.TraversableOnce. Both of these functions are - * allowed to modify and return their first argument instead of creating a new U to avoid memory - * allocation. - */ - def aggregate[U: ClassManifest](zeroValue: U)(seqOp: (U, T) => U, combOp: (U, U) => U): U = { - // Clone the zero value since we will also be serializing it as part of tasks - var jobResult = Utils.clone(zeroValue, sc.env.closureSerializer.newInstance()) - val cleanSeqOp = sc.clean(seqOp) - val cleanCombOp = sc.clean(combOp) - val aggregatePartition = (it: Iterator[T]) => it.aggregate(zeroValue)(cleanSeqOp, cleanCombOp) - val mergeResult = (index: Int, taskResult: U) => jobResult = combOp(jobResult, taskResult) - sc.runJob(this, aggregatePartition, mergeResult) - jobResult - } - - /** - * Return the number of elements in the RDD. - */ - def count(): Long = { - sc.runJob(this, (iter: Iterator[T]) => { - var result = 0L - while (iter.hasNext) { - result += 1L - iter.next() - } - result - }).sum - } - - /** - * (Experimental) Approximate version of count() that returns a potentially incomplete result - * within a timeout, even if not all tasks have finished. - */ - def countApprox(timeout: Long, confidence: Double = 0.95): PartialResult[BoundedDouble] = { - val countElements: (TaskContext, Iterator[T]) => Long = { (ctx, iter) => - var result = 0L - while (iter.hasNext) { - result += 1L - iter.next() - } - result - } - val evaluator = new CountEvaluator(partitions.size, confidence) - sc.runApproximateJob(this, countElements, evaluator, timeout) - } - - /** - * Return the count of each unique value in this RDD as a map of (value, count) pairs. The final - * combine step happens locally on the master, equivalent to running a single reduce task. - */ - def countByValue(): Map[T, Long] = { - if (elementClassManifest.erasure.isArray) { - throw new SparkException("countByValue() does not support arrays") - } - // TODO: This should perhaps be distributed by default. - def countPartition(iter: Iterator[T]): Iterator[OLMap[T]] = { - val map = new OLMap[T] - while (iter.hasNext) { - val v = iter.next() - map.put(v, map.getLong(v) + 1L) - } - Iterator(map) - } - def mergeMaps(m1: OLMap[T], m2: OLMap[T]): OLMap[T] = { - val iter = m2.object2LongEntrySet.fastIterator() - while (iter.hasNext) { - val entry = iter.next() - m1.put(entry.getKey, m1.getLong(entry.getKey) + entry.getLongValue) - } - return m1 - } - val myResult = mapPartitions(countPartition).reduce(mergeMaps) - myResult.asInstanceOf[java.util.Map[T, Long]] // Will be wrapped as a Scala mutable Map - } - - /** - * (Experimental) Approximate version of countByValue(). - */ - def countByValueApprox( - timeout: Long, - confidence: Double = 0.95 - ): PartialResult[Map[T, BoundedDouble]] = { - if (elementClassManifest.erasure.isArray) { - throw new SparkException("countByValueApprox() does not support arrays") - } - val countPartition: (TaskContext, Iterator[T]) => OLMap[T] = { (ctx, iter) => - val map = new OLMap[T] - while (iter.hasNext) { - val v = iter.next() - map.put(v, map.getLong(v) + 1L) - } - map - } - val evaluator = new GroupedCountEvaluator[T](partitions.size, confidence) - sc.runApproximateJob(this, countPartition, evaluator, timeout) - } - - /** - * Take the first num elements of the RDD. This currently scans the partitions *one by one*, so - * it will be slow if a lot of partitions are required. In that case, use collect() to get the - * whole RDD instead. - */ - def take(num: Int): Array[T] = { - if (num == 0) { - return new Array[T](0) - } - val buf = new ArrayBuffer[T] - var p = 0 - while (buf.size < num && p < partitions.size) { - val left = num - buf.size - val res = sc.runJob(this, (it: Iterator[T]) => it.take(left).toArray, Array(p), true) - buf ++= res(0) - if (buf.size == num) - return buf.toArray - p += 1 - } - return buf.toArray - } - - /** - * Return the first element in this RDD. - */ - def first(): T = take(1) match { - case Array(t) => t - case _ => throw new UnsupportedOperationException("empty collection") - } - - /** - * Returns the top K elements from this RDD as defined by - * the specified implicit Ordering[T]. - * @param num the number of top elements to return - * @param ord the implicit ordering for T - * @return an array of top elements - */ - def top(num: Int)(implicit ord: Ordering[T]): Array[T] = { - mapPartitions { items => - val queue = new BoundedPriorityQueue[T](num) - queue ++= items - Iterator.single(queue) - }.reduce { (queue1, queue2) => - queue1 ++= queue2 - queue1 - }.toArray.sorted(ord.reverse) - } - - /** - * Returns the first K elements from this RDD as defined by - * the specified implicit Ordering[T] and maintains the - * ordering. - * @param num the number of top elements to return - * @param ord the implicit ordering for T - * @return an array of top elements - */ - def takeOrdered(num: Int)(implicit ord: Ordering[T]): Array[T] = top(num)(ord.reverse) - - /** - * Save this RDD as a text file, using string representations of elements. - */ - def saveAsTextFile(path: String) { - this.map(x => (NullWritable.get(), new Text(x.toString))) - .saveAsHadoopFile[TextOutputFormat[NullWritable, Text]](path) - } - - /** - * Save this RDD as a compressed text file, using string representations of elements. - */ - def saveAsTextFile(path: String, codec: Class[_ <: CompressionCodec]) { - this.map(x => (NullWritable.get(), new Text(x.toString))) - .saveAsHadoopFile[TextOutputFormat[NullWritable, Text]](path, codec) - } - - /** - * Save this RDD as a SequenceFile of serialized objects. - */ - def saveAsObjectFile(path: String) { - this.mapPartitions(iter => iter.grouped(10).map(_.toArray)) - .map(x => (NullWritable.get(), new BytesWritable(Utils.serialize(x)))) - .saveAsSequenceFile(path) - } - - /** - * Creates tuples of the elements in this RDD by applying `f`. - */ - def keyBy[K](f: T => K): RDD[(K, T)] = { - map(x => (f(x), x)) - } - - /** A private method for tests, to look at the contents of each partition */ - private[spark] def collectPartitions(): Array[Array[T]] = { - sc.runJob(this, (iter: Iterator[T]) => iter.toArray) - } - - /** - * Mark this RDD for checkpointing. It will be saved to a file inside the checkpoint - * directory set with SparkContext.setCheckpointDir() and all references to its parent - * RDDs will be removed. This function must be called before any job has been - * executed on this RDD. It is strongly recommended that this RDD is persisted in - * memory, otherwise saving it on a file will require recomputation. - */ - def checkpoint() { - if (context.checkpointDir.isEmpty) { - throw new Exception("Checkpoint directory has not been set in the SparkContext") - } else if (checkpointData.isEmpty) { - checkpointData = Some(new RDDCheckpointData(this)) - checkpointData.get.markForCheckpoint() - } - } - - /** - * Return whether this RDD has been checkpointed or not - */ - def isCheckpointed: Boolean = { - checkpointData.map(_.isCheckpointed).getOrElse(false) - } - - /** - * Gets the name of the file to which this RDD was checkpointed - */ - def getCheckpointFile: Option[String] = { - checkpointData.flatMap(_.getCheckpointFile) - } - - // ======================================================================= - // Other internal methods and fields - // ======================================================================= - - private var storageLevel: StorageLevel = StorageLevel.NONE - - /** Record user function generating this RDD. */ - private[spark] val origin = Utils.formatSparkCallSite - - private[spark] def elementClassManifest: ClassManifest[T] = classManifest[T] - - private[spark] var checkpointData: Option[RDDCheckpointData[T]] = None - - /** Returns the first parent RDD */ - protected[spark] def firstParent[U: ClassManifest] = { - dependencies.head.rdd.asInstanceOf[RDD[U]] - } - - /** The [[org.apache.spark.SparkContext]] that this RDD was created on. */ - def context = sc - - // Avoid handling doCheckpoint multiple times to prevent excessive recursion - private var doCheckpointCalled = false - - /** - * Performs the checkpointing of this RDD by saving this. It is called by the DAGScheduler - * after a job using this RDD has completed (therefore the RDD has been materialized and - * potentially stored in memory). doCheckpoint() is called recursively on the parent RDDs. - */ - private[spark] def doCheckpoint() { - if (!doCheckpointCalled) { - doCheckpointCalled = true - if (checkpointData.isDefined) { - checkpointData.get.doCheckpoint() - } else { - dependencies.foreach(_.rdd.doCheckpoint()) - } - } - } - - /** - * Changes the dependencies of this RDD from its original parents to a new RDD (`newRDD`) - * created from the checkpoint file, and forget its old dependencies and partitions. - */ - private[spark] def markCheckpointed(checkpointRDD: RDD[_]) { - clearDependencies() - partitions_ = null - deps = null // Forget the constructor argument for dependencies too - } - - /** - * Clears the dependencies of this RDD. This method must ensure that all references - * to the original parent RDDs is removed to enable the parent RDDs to be garbage - * collected. Subclasses of RDD may override this method for implementing their own cleaning - * logic. See [[org.apache.spark.rdd.UnionRDD]] for an example. - */ - protected def clearDependencies() { - dependencies_ = null - } - - /** A description of this RDD and its recursive dependencies for debugging. */ - def toDebugString: String = { - def debugString(rdd: RDD[_], prefix: String = ""): Seq[String] = { - Seq(prefix + rdd + " (" + rdd.partitions.size + " partitions)") ++ - rdd.dependencies.flatMap(d => debugString(d.rdd, prefix + " ")) - } - debugString(this).mkString("\n") - } - - override def toString: String = "%s%s[%d] at %s".format( - Option(name).map(_ + " ").getOrElse(""), - getClass.getSimpleName, - id, - origin) - - def toJavaRDD() : JavaRDD[T] = { - new JavaRDD(this)(elementClassManifest) - } - -} diff --git a/core/src/main/scala/org/apache/spark/RDDCheckpointData.scala b/core/src/main/scala/org/apache/spark/RDDCheckpointData.scala deleted file mode 100644 index 0334de6924..0000000000 --- a/core/src/main/scala/org/apache/spark/RDDCheckpointData.scala +++ /dev/null @@ -1,130 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark - -import org.apache.hadoop.fs.Path -import org.apache.hadoop.conf.Configuration -import rdd.{CheckpointRDD, CoalescedRDD} -import scheduler.{ResultTask, ShuffleMapTask} - -/** - * Enumeration to manage state transitions of an RDD through checkpointing - * [ Initialized --> marked for checkpointing --> checkpointing in progress --> checkpointed ] - */ -private[spark] object CheckpointState extends Enumeration { - type CheckpointState = Value - val Initialized, MarkedForCheckpoint, CheckpointingInProgress, Checkpointed = Value -} - -/** - * This class contains all the information related to RDD checkpointing. Each instance of this class - * is associated with a RDD. It manages process of checkpointing of the associated RDD, as well as, - * manages the post-checkpoint state by providing the updated partitions, iterator and preferred locations - * of the checkpointed RDD. - */ -private[spark] class RDDCheckpointData[T: ClassManifest](rdd: RDD[T]) - extends Logging with Serializable { - - import CheckpointState._ - - // The checkpoint state of the associated RDD. - var cpState = Initialized - - // The file to which the associated RDD has been checkpointed to - @transient var cpFile: Option[String] = None - - // The CheckpointRDD created from the checkpoint file, that is, the new parent the associated RDD. - var cpRDD: Option[RDD[T]] = None - - // Mark the RDD for checkpointing - def markForCheckpoint() { - RDDCheckpointData.synchronized { - if (cpState == Initialized) cpState = MarkedForCheckpoint - } - } - - // Is the RDD already checkpointed - def isCheckpointed: Boolean = { - RDDCheckpointData.synchronized { cpState == Checkpointed } - } - - // Get the file to which this RDD was checkpointed to as an Option - def getCheckpointFile: Option[String] = { - RDDCheckpointData.synchronized { cpFile } - } - - // Do the checkpointing of the RDD. Called after the first job using that RDD is over. - def doCheckpoint() { - // If it is marked for checkpointing AND checkpointing is not already in progress, - // then set it to be in progress, else return - RDDCheckpointData.synchronized { - if (cpState == MarkedForCheckpoint) { - cpState = CheckpointingInProgress - } else { - return - } - } - - // Create the output path for the checkpoint - val path = new Path(rdd.context.checkpointDir.get, "rdd-" + rdd.id) - val fs = path.getFileSystem(new Configuration()) - if (!fs.mkdirs(path)) { - throw new SparkException("Failed to create checkpoint path " + path) - } - - // Save to file, and reload it as an RDD - rdd.context.runJob(rdd, CheckpointRDD.writeToFile(path.toString) _) - val newRDD = new CheckpointRDD[T](rdd.context, path.toString) - - // Change the dependencies and partitions of the RDD - RDDCheckpointData.synchronized { - cpFile = Some(path.toString) - cpRDD = Some(newRDD) - rdd.markCheckpointed(newRDD) // Update the RDD's dependencies and partitions - cpState = Checkpointed - RDDCheckpointData.clearTaskCaches() - logInfo("Done checkpointing RDD " + rdd.id + ", new parent is RDD " + newRDD.id) - } - } - - // Get preferred location of a split after checkpointing - def getPreferredLocations(split: Partition): Seq[String] = { - RDDCheckpointData.synchronized { - cpRDD.get.preferredLocations(split) - } - } - - def getPartitions: Array[Partition] = { - RDDCheckpointData.synchronized { - cpRDD.get.partitions - } - } - - def checkpointRDD: Option[RDD[T]] = { - RDDCheckpointData.synchronized { - cpRDD - } - } -} - -private[spark] object RDDCheckpointData { - def clearTaskCaches() { - ShuffleMapTask.clearCache() - ResultTask.clearCache() - } -} diff --git a/core/src/main/scala/org/apache/spark/SequenceFileRDDFunctions.scala b/core/src/main/scala/org/apache/spark/SequenceFileRDDFunctions.scala deleted file mode 100644 index d58fb4e4bc..0000000000 --- a/core/src/main/scala/org/apache/spark/SequenceFileRDDFunctions.scala +++ /dev/null @@ -1,107 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark - -import java.io.EOFException -import java.net.URL -import java.io.ObjectInputStream -import java.util.concurrent.atomic.AtomicLong -import java.util.HashSet -import java.util.Random -import java.util.Date - -import scala.collection.mutable.ArrayBuffer -import scala.collection.mutable.Map -import scala.collection.mutable.HashMap - -import org.apache.hadoop.mapred.JobConf -import org.apache.hadoop.mapred.OutputFormat -import org.apache.hadoop.mapred.TextOutputFormat -import org.apache.hadoop.mapred.SequenceFileOutputFormat -import org.apache.hadoop.mapred.OutputCommitter -import org.apache.hadoop.mapred.FileOutputCommitter -import org.apache.hadoop.io.compress.CompressionCodec -import org.apache.hadoop.io.Writable -import org.apache.hadoop.io.NullWritable -import org.apache.hadoop.io.BytesWritable -import org.apache.hadoop.io.Text - -import org.apache.spark.SparkContext._ - -/** - * Extra functions available on RDDs of (key, value) pairs to create a Hadoop SequenceFile, - * through an implicit conversion. Note that this can't be part of PairRDDFunctions because - * we need more implicit parameters to convert our keys and values to Writable. - * - * Users should import `spark.SparkContext._` at the top of their program to use these functions. - */ -class SequenceFileRDDFunctions[K <% Writable: ClassManifest, V <% Writable : ClassManifest]( - self: RDD[(K, V)]) - extends Logging - with Serializable { - - private def getWritableClass[T <% Writable: ClassManifest](): Class[_ <: Writable] = { - val c = { - if (classOf[Writable].isAssignableFrom(classManifest[T].erasure)) { - classManifest[T].erasure - } else { - // We get the type of the Writable class by looking at the apply method which converts - // from T to Writable. Since we have two apply methods we filter out the one which - // is not of the form "java.lang.Object apply(java.lang.Object)" - implicitly[T => Writable].getClass.getDeclaredMethods().filter( - m => m.getReturnType().toString != "class java.lang.Object" && - m.getName() == "apply")(0).getReturnType - - } - // TODO: use something like WritableConverter to avoid reflection - } - c.asInstanceOf[Class[_ <: Writable]] - } - - /** - * Output the RDD as a Hadoop SequenceFile using the Writable types we infer from the RDD's key - * and value types. If the key or value are Writable, then we use their classes directly; - * otherwise we map primitive types such as Int and Double to IntWritable, DoubleWritable, etc, - * byte arrays to BytesWritable, and Strings to Text. The `path` can be on any Hadoop-supported - * file system. - */ - def saveAsSequenceFile(path: String, codec: Option[Class[_ <: CompressionCodec]] = None) { - def anyToWritable[U <% Writable](u: U): Writable = u - - val keyClass = getWritableClass[K] - val valueClass = getWritableClass[V] - val convertKey = !classOf[Writable].isAssignableFrom(self.getKeyClass) - val convertValue = !classOf[Writable].isAssignableFrom(self.getValueClass) - - logInfo("Saving as sequence file of type (" + keyClass.getSimpleName + "," + valueClass.getSimpleName + ")" ) - val format = classOf[SequenceFileOutputFormat[Writable, Writable]] - val jobConf = new JobConf(self.context.hadoopConfiguration) - if (!convertKey && !convertValue) { - self.saveAsHadoopFile(path, keyClass, valueClass, format, jobConf, codec) - } else if (!convertKey && convertValue) { - self.map(x => (x._1,anyToWritable(x._2))).saveAsHadoopFile( - path, keyClass, valueClass, format, jobConf, codec) - } else if (convertKey && !convertValue) { - self.map(x => (anyToWritable(x._1),x._2)).saveAsHadoopFile( - path, keyClass, valueClass, format, jobConf, codec) - } else if (convertKey && convertValue) { - self.map(x => (anyToWritable(x._1),anyToWritable(x._2))).saveAsHadoopFile( - path, keyClass, valueClass, format, jobConf, codec) - } - } -} diff --git a/core/src/main/scala/org/apache/spark/SizeEstimator.scala b/core/src/main/scala/org/apache/spark/SizeEstimator.scala deleted file mode 100644 index 4bfc837710..0000000000 --- a/core/src/main/scala/org/apache/spark/SizeEstimator.scala +++ /dev/null @@ -1,283 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark - -import java.lang.reflect.Field -import java.lang.reflect.Modifier -import java.lang.reflect.{Array => JArray} -import java.util.IdentityHashMap -import java.util.concurrent.ConcurrentHashMap -import java.util.Random - -import javax.management.MBeanServer -import java.lang.management.ManagementFactory - -import scala.collection.mutable.ArrayBuffer - -import it.unimi.dsi.fastutil.ints.IntOpenHashSet - -/** - * Estimates the sizes of Java objects (number of bytes of memory they occupy), for use in - * memory-aware caches. - * - * Based on the following JavaWorld article: - * http://www.javaworld.com/javaworld/javaqa/2003-12/02-qa-1226-sizeof.html - */ -private[spark] object SizeEstimator extends Logging { - - // Sizes of primitive types - private val BYTE_SIZE = 1 - private val BOOLEAN_SIZE = 1 - private val CHAR_SIZE = 2 - private val SHORT_SIZE = 2 - private val INT_SIZE = 4 - private val LONG_SIZE = 8 - private val FLOAT_SIZE = 4 - private val DOUBLE_SIZE = 8 - - // Alignment boundary for objects - // TODO: Is this arch dependent ? - private val ALIGN_SIZE = 8 - - // A cache of ClassInfo objects for each class - private val classInfos = new ConcurrentHashMap[Class[_], ClassInfo] - - // Object and pointer sizes are arch dependent - private var is64bit = false - - // Size of an object reference - // Based on https://wikis.oracle.com/display/HotSpotInternals/CompressedOops - private var isCompressedOops = false - private var pointerSize = 4 - - // Minimum size of a java.lang.Object - private var objectSize = 8 - - initialize() - - // Sets object size, pointer size based on architecture and CompressedOops settings - // from the JVM. - private def initialize() { - is64bit = System.getProperty("os.arch").contains("64") - isCompressedOops = getIsCompressedOops - - objectSize = if (!is64bit) 8 else { - if(!isCompressedOops) { - 16 - } else { - 12 - } - } - pointerSize = if (is64bit && !isCompressedOops) 8 else 4 - classInfos.clear() - classInfos.put(classOf[Object], new ClassInfo(objectSize, Nil)) - } - - private def getIsCompressedOops : Boolean = { - if (System.getProperty("spark.test.useCompressedOops") != null) { - return System.getProperty("spark.test.useCompressedOops").toBoolean - } - - try { - val hotSpotMBeanName = "com.sun.management:type=HotSpotDiagnostic" - val server = ManagementFactory.getPlatformMBeanServer() - - // NOTE: This should throw an exception in non-Sun JVMs - val hotSpotMBeanClass = Class.forName("com.sun.management.HotSpotDiagnosticMXBean") - val getVMMethod = hotSpotMBeanClass.getDeclaredMethod("getVMOption", - Class.forName("java.lang.String")) - - val bean = ManagementFactory.newPlatformMXBeanProxy(server, - hotSpotMBeanName, hotSpotMBeanClass) - // TODO: We could use reflection on the VMOption returned ? - return getVMMethod.invoke(bean, "UseCompressedOops").toString.contains("true") - } catch { - case e: Exception => { - // Guess whether they've enabled UseCompressedOops based on whether maxMemory < 32 GB - val guess = Runtime.getRuntime.maxMemory < (32L*1024*1024*1024) - val guessInWords = if (guess) "yes" else "not" - logWarning("Failed to check whether UseCompressedOops is set; assuming " + guessInWords) - return guess - } - } - } - - /** - * The state of an ongoing size estimation. Contains a stack of objects to visit as well as an - * IdentityHashMap of visited objects, and provides utility methods for enqueueing new objects - * to visit. - */ - private class SearchState(val visited: IdentityHashMap[AnyRef, AnyRef]) { - val stack = new ArrayBuffer[AnyRef] - var size = 0L - - def enqueue(obj: AnyRef) { - if (obj != null && !visited.containsKey(obj)) { - visited.put(obj, null) - stack += obj - } - } - - def isFinished(): Boolean = stack.isEmpty - - def dequeue(): AnyRef = { - val elem = stack.last - stack.trimEnd(1) - return elem - } - } - - /** - * Cached information about each class. We remember two things: the "shell size" of the class - * (size of all non-static fields plus the java.lang.Object size), and any fields that are - * pointers to objects. - */ - private class ClassInfo( - val shellSize: Long, - val pointerFields: List[Field]) {} - - def estimate(obj: AnyRef): Long = estimate(obj, new IdentityHashMap[AnyRef, AnyRef]) - - private def estimate(obj: AnyRef, visited: IdentityHashMap[AnyRef, AnyRef]): Long = { - val state = new SearchState(visited) - state.enqueue(obj) - while (!state.isFinished) { - visitSingleObject(state.dequeue(), state) - } - return state.size - } - - private def visitSingleObject(obj: AnyRef, state: SearchState) { - val cls = obj.getClass - if (cls.isArray) { - visitArray(obj, cls, state) - } else if (obj.isInstanceOf[ClassLoader] || obj.isInstanceOf[Class[_]]) { - // Hadoop JobConfs created in the interpreter have a ClassLoader, which greatly confuses - // the size estimator since it references the whole REPL. Do nothing in this case. In - // general all ClassLoaders and Classes will be shared between objects anyway. - } else { - val classInfo = getClassInfo(cls) - state.size += classInfo.shellSize - for (field <- classInfo.pointerFields) { - state.enqueue(field.get(obj)) - } - } - } - - // Estimat the size of arrays larger than ARRAY_SIZE_FOR_SAMPLING by sampling. - private val ARRAY_SIZE_FOR_SAMPLING = 200 - private val ARRAY_SAMPLE_SIZE = 100 // should be lower than ARRAY_SIZE_FOR_SAMPLING - - private def visitArray(array: AnyRef, cls: Class[_], state: SearchState) { - val length = JArray.getLength(array) - val elementClass = cls.getComponentType - - // Arrays have object header and length field which is an integer - var arrSize: Long = alignSize(objectSize + INT_SIZE) - - if (elementClass.isPrimitive) { - arrSize += alignSize(length * primitiveSize(elementClass)) - state.size += arrSize - } else { - arrSize += alignSize(length * pointerSize) - state.size += arrSize - - if (length <= ARRAY_SIZE_FOR_SAMPLING) { - for (i <- 0 until length) { - state.enqueue(JArray.get(array, i)) - } - } else { - // Estimate the size of a large array by sampling elements without replacement. - var size = 0.0 - val rand = new Random(42) - val drawn = new IntOpenHashSet(ARRAY_SAMPLE_SIZE) - for (i <- 0 until ARRAY_SAMPLE_SIZE) { - var index = 0 - do { - index = rand.nextInt(length) - } while (drawn.contains(index)) - drawn.add(index) - val elem = JArray.get(array, index) - size += SizeEstimator.estimate(elem, state.visited) - } - state.size += ((length / (ARRAY_SAMPLE_SIZE * 1.0)) * size).toLong - } - } - } - - private def primitiveSize(cls: Class[_]): Long = { - if (cls == classOf[Byte]) - BYTE_SIZE - else if (cls == classOf[Boolean]) - BOOLEAN_SIZE - else if (cls == classOf[Char]) - CHAR_SIZE - else if (cls == classOf[Short]) - SHORT_SIZE - else if (cls == classOf[Int]) - INT_SIZE - else if (cls == classOf[Long]) - LONG_SIZE - else if (cls == classOf[Float]) - FLOAT_SIZE - else if (cls == classOf[Double]) - DOUBLE_SIZE - else throw new IllegalArgumentException( - "Non-primitive class " + cls + " passed to primitiveSize()") - } - - /** - * Get or compute the ClassInfo for a given class. - */ - private def getClassInfo(cls: Class[_]): ClassInfo = { - // Check whether we've already cached a ClassInfo for this class - val info = classInfos.get(cls) - if (info != null) { - return info - } - - val parent = getClassInfo(cls.getSuperclass) - var shellSize = parent.shellSize - var pointerFields = parent.pointerFields - - for (field <- cls.getDeclaredFields) { - if (!Modifier.isStatic(field.getModifiers)) { - val fieldClass = field.getType - if (fieldClass.isPrimitive) { - shellSize += primitiveSize(fieldClass) - } else { - field.setAccessible(true) // Enable future get()'s on this field - shellSize += pointerSize - pointerFields = field :: pointerFields - } - } - } - - shellSize = alignSize(shellSize) - - // Create and cache a new ClassInfo - val newInfo = new ClassInfo(shellSize, pointerFields) - classInfos.put(cls, newInfo) - return newInfo - } - - private def alignSize(size: Long): Long = { - val rem = size % ALIGN_SIZE - return if (rem == 0) size else (size + ALIGN_SIZE - rem) - } -} diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 1207b242bc..faf0c2362a 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -54,17 +54,15 @@ import org.apache.mesos.MesosNativeLibrary import org.apache.spark.deploy.LocalSparkCluster import org.apache.spark.partial.{ApproximateEvaluator, PartialResult} -import org.apache.spark.rdd.{CheckpointRDD, HadoopRDD, NewHadoopRDD, UnionRDD, ParallelCollectionRDD, - OrderedRDDFunctions} +import org.apache.spark.rdd._ import org.apache.spark.scheduler._ import org.apache.spark.scheduler.cluster.{StandaloneSchedulerBackend, SparkDeploySchedulerBackend, ClusterScheduler, Schedulable, SchedulingMode} import org.apache.spark.scheduler.local.LocalScheduler import org.apache.spark.scheduler.mesos.{CoarseMesosSchedulerBackend, MesosSchedulerBackend} -import org.apache.spark.storage.{StorageStatus, StorageUtils, RDDInfo, BlockManagerSource} +import org.apache.spark.storage.{StorageUtils, BlockManagerSource} import org.apache.spark.ui.SparkUI -import org.apache.spark.util.{MetadataCleaner, TimeStampedHashMap} -import scala.Some +import org.apache.spark.util.{ClosureCleaner, Utils, MetadataCleaner, TimeStampedHashMap} import org.apache.spark.scheduler.StageInfo import org.apache.spark.storage.RDDInfo import org.apache.spark.storage.StorageStatus diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 6e6fe5df6b..478e5a0aaf 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -29,7 +29,7 @@ import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.storage.{BlockManagerMasterActor, BlockManager, BlockManagerMaster} import org.apache.spark.network.ConnectionManager import org.apache.spark.serializer.{Serializer, SerializerManager} -import org.apache.spark.util.AkkaUtils +import org.apache.spark.util.{Utils, AkkaUtils} import org.apache.spark.api.python.PythonWorkerFactory @@ -155,10 +155,10 @@ object SparkEnv extends Logging { val serializerManager = new SerializerManager val serializer = serializerManager.setDefault( - System.getProperty("spark.serializer", "org.apache.spark.JavaSerializer")) + System.getProperty("spark.serializer", "org.apache.spark.serializer.JavaSerializer")) val closureSerializer = serializerManager.get( - System.getProperty("spark.closure.serializer", "org.apache.spark.JavaSerializer")) + System.getProperty("spark.closure.serializer", "org.apache.spark.serializer.JavaSerializer")) def registerOrLookup(name: String, newActor: => Actor): ActorRef = { if (isDriver) { diff --git a/core/src/main/scala/org/apache/spark/Utils.scala b/core/src/main/scala/org/apache/spark/Utils.scala deleted file mode 100644 index 1e17deb010..0000000000 --- a/core/src/main/scala/org/apache/spark/Utils.scala +++ /dev/null @@ -1,780 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark - -import java.io._ -import java.net.{InetAddress, URL, URI, NetworkInterface, Inet4Address, ServerSocket} -import java.util.{Locale, Random, UUID} -import java.util.concurrent.{ConcurrentHashMap, Executors, ThreadFactory, ThreadPoolExecutor} -import java.util.regex.Pattern - -import scala.collection.Map -import scala.collection.mutable.{ArrayBuffer, HashMap} -import scala.collection.JavaConversions._ -import scala.io.Source - -import com.google.common.io.Files -import com.google.common.util.concurrent.ThreadFactoryBuilder - -import org.apache.hadoop.fs.{Path, FileSystem, FileUtil} - -import org.apache.spark.serializer.{DeserializationStream, SerializationStream, SerializerInstance} -import org.apache.spark.deploy.SparkHadoopUtil -import java.nio.ByteBuffer - - -/** - * Various utility methods used by Spark. - */ -private object Utils extends Logging { - - /** Serialize an object using Java serialization */ - def serialize[T](o: T): Array[Byte] = { - val bos = new ByteArrayOutputStream() - val oos = new ObjectOutputStream(bos) - oos.writeObject(o) - oos.close() - return bos.toByteArray - } - - /** Deserialize an object using Java serialization */ - def deserialize[T](bytes: Array[Byte]): T = { - val bis = new ByteArrayInputStream(bytes) - val ois = new ObjectInputStream(bis) - return ois.readObject.asInstanceOf[T] - } - - /** Deserialize an object using Java serialization and the given ClassLoader */ - def deserialize[T](bytes: Array[Byte], loader: ClassLoader): T = { - val bis = new ByteArrayInputStream(bytes) - val ois = new ObjectInputStream(bis) { - override def resolveClass(desc: ObjectStreamClass) = - Class.forName(desc.getName, false, loader) - } - return ois.readObject.asInstanceOf[T] - } - - /** Serialize via nested stream using specific serializer */ - def serializeViaNestedStream(os: OutputStream, ser: SerializerInstance)(f: SerializationStream => Unit) = { - val osWrapper = ser.serializeStream(new OutputStream { - def write(b: Int) = os.write(b) - - override def write(b: Array[Byte], off: Int, len: Int) = os.write(b, off, len) - }) - try { - f(osWrapper) - } finally { - osWrapper.close() - } - } - - /** Deserialize via nested stream using specific serializer */ - def deserializeViaNestedStream(is: InputStream, ser: SerializerInstance)(f: DeserializationStream => Unit) = { - val isWrapper = ser.deserializeStream(new InputStream { - def read(): Int = is.read() - - override def read(b: Array[Byte], off: Int, len: Int): Int = is.read(b, off, len) - }) - try { - f(isWrapper) - } finally { - isWrapper.close() - } - } - - /** - * Primitive often used when writing {@link java.nio.ByteBuffer} to {@link java.io.DataOutput}. - */ - def writeByteBuffer(bb: ByteBuffer, out: ObjectOutput) = { - if (bb.hasArray) { - out.write(bb.array(), bb.arrayOffset() + bb.position(), bb.remaining()) - } else { - val bbval = new Array[Byte](bb.remaining()) - bb.get(bbval) - out.write(bbval) - } - } - - def isAlpha(c: Char): Boolean = { - (c >= 'A' && c <= 'Z') || (c >= 'a' && c <= 'z') - } - - /** Split a string into words at non-alphabetic characters */ - def splitWords(s: String): Seq[String] = { - val buf = new ArrayBuffer[String] - var i = 0 - while (i < s.length) { - var j = i - while (j < s.length && isAlpha(s.charAt(j))) { - j += 1 - } - if (j > i) { - buf += s.substring(i, j) - } - i = j - while (i < s.length && !isAlpha(s.charAt(i))) { - i += 1 - } - } - return buf - } - - private val shutdownDeletePaths = new collection.mutable.HashSet[String]() - - // Register the path to be deleted via shutdown hook - def registerShutdownDeleteDir(file: File) { - val absolutePath = file.getAbsolutePath() - shutdownDeletePaths.synchronized { - shutdownDeletePaths += absolutePath - } - } - - // Is the path already registered to be deleted via a shutdown hook ? - def hasShutdownDeleteDir(file: File): Boolean = { - val absolutePath = file.getAbsolutePath() - shutdownDeletePaths.synchronized { - shutdownDeletePaths.contains(absolutePath) - } - } - - // Note: if file is child of some registered path, while not equal to it, then return true; - // else false. This is to ensure that two shutdown hooks do not try to delete each others - // paths - resulting in IOException and incomplete cleanup. - def hasRootAsShutdownDeleteDir(file: File): Boolean = { - val absolutePath = file.getAbsolutePath() - val retval = shutdownDeletePaths.synchronized { - shutdownDeletePaths.find { path => - !absolutePath.equals(path) && absolutePath.startsWith(path) - }.isDefined - } - if (retval) { - logInfo("path = " + file + ", already present as root for deletion.") - } - retval - } - - /** Create a temporary directory inside the given parent directory */ - def createTempDir(root: String = System.getProperty("java.io.tmpdir")): File = { - var attempts = 0 - val maxAttempts = 10 - var dir: File = null - while (dir == null) { - attempts += 1 - if (attempts > maxAttempts) { - throw new IOException("Failed to create a temp directory (under " + root + ") after " + - maxAttempts + " attempts!") - } - try { - dir = new File(root, "spark-" + UUID.randomUUID.toString) - if (dir.exists() || !dir.mkdirs()) { - dir = null - } - } catch { case e: IOException => ; } - } - - registerShutdownDeleteDir(dir) - - // Add a shutdown hook to delete the temp dir when the JVM exits - Runtime.getRuntime.addShutdownHook(new Thread("delete Spark temp dir " + dir) { - override def run() { - // Attempt to delete if some patch which is parent of this is not already registered. - if (! hasRootAsShutdownDeleteDir(dir)) Utils.deleteRecursively(dir) - } - }) - dir - } - - /** Copy all data from an InputStream to an OutputStream */ - def copyStream(in: InputStream, - out: OutputStream, - closeStreams: Boolean = false) - { - val buf = new Array[Byte](8192) - var n = 0 - while (n != -1) { - n = in.read(buf) - if (n != -1) { - out.write(buf, 0, n) - } - } - if (closeStreams) { - in.close() - out.close() - } - } - - /** - * Download a file requested by the executor. Supports fetching the file in a variety of ways, - * including HTTP, HDFS and files on a standard filesystem, based on the URL parameter. - * - * Throws SparkException if the target file already exists and has different contents than - * the requested file. - */ - def fetchFile(url: String, targetDir: File) { - val filename = url.split("/").last - val tempDir = getLocalDir - val tempFile = File.createTempFile("fetchFileTemp", null, new File(tempDir)) - val targetFile = new File(targetDir, filename) - val uri = new URI(url) - uri.getScheme match { - case "http" | "https" | "ftp" => - logInfo("Fetching " + url + " to " + tempFile) - val in = new URL(url).openStream() - val out = new FileOutputStream(tempFile) - Utils.copyStream(in, out, true) - if (targetFile.exists && !Files.equal(tempFile, targetFile)) { - tempFile.delete() - throw new SparkException( - "File " + targetFile + " exists and does not match contents of" + " " + url) - } else { - Files.move(tempFile, targetFile) - } - case "file" | null => - // In the case of a local file, copy the local file to the target directory. - // Note the difference between uri vs url. - val sourceFile = if (uri.isAbsolute) new File(uri) else new File(url) - if (targetFile.exists) { - // If the target file already exists, warn the user if - if (!Files.equal(sourceFile, targetFile)) { - throw new SparkException( - "File " + targetFile + " exists and does not match contents of" + " " + url) - } else { - // Do nothing if the file contents are the same, i.e. this file has been copied - // previously. - logInfo(sourceFile.getAbsolutePath + " has been previously copied to " - + targetFile.getAbsolutePath) - } - } else { - // The file does not exist in the target directory. Copy it there. - logInfo("Copying " + sourceFile.getAbsolutePath + " to " + targetFile.getAbsolutePath) - Files.copy(sourceFile, targetFile) - } - case _ => - // Use the Hadoop filesystem library, which supports file://, hdfs://, s3://, and others - val env = SparkEnv.get - val uri = new URI(url) - val conf = env.hadoop.newConfiguration() - val fs = FileSystem.get(uri, conf) - val in = fs.open(new Path(uri)) - val out = new FileOutputStream(tempFile) - Utils.copyStream(in, out, true) - if (targetFile.exists && !Files.equal(tempFile, targetFile)) { - tempFile.delete() - throw new SparkException("File " + targetFile + " exists and does not match contents of" + - " " + url) - } else { - Files.move(tempFile, targetFile) - } - } - // Decompress the file if it's a .tar or .tar.gz - if (filename.endsWith(".tar.gz") || filename.endsWith(".tgz")) { - logInfo("Untarring " + filename) - Utils.execute(Seq("tar", "-xzf", filename), targetDir) - } else if (filename.endsWith(".tar")) { - logInfo("Untarring " + filename) - Utils.execute(Seq("tar", "-xf", filename), targetDir) - } - // Make the file executable - That's necessary for scripts - FileUtil.chmod(targetFile.getAbsolutePath, "a+x") - } - - /** - * Get a temporary directory using Spark's spark.local.dir property, if set. This will always - * return a single directory, even though the spark.local.dir property might be a list of - * multiple paths. - */ - def getLocalDir: String = { - System.getProperty("spark.local.dir", System.getProperty("java.io.tmpdir")).split(',')(0) - } - - /** - * Shuffle the elements of a collection into a random order, returning the - * result in a new collection. Unlike scala.util.Random.shuffle, this method - * uses a local random number generator, avoiding inter-thread contention. - */ - def randomize[T: ClassManifest](seq: TraversableOnce[T]): Seq[T] = { - randomizeInPlace(seq.toArray) - } - - /** - * Shuffle the elements of an array into a random order, modifying the - * original array. Returns the original array. - */ - def randomizeInPlace[T](arr: Array[T], rand: Random = new Random): Array[T] = { - for (i <- (arr.length - 1) to 1 by -1) { - val j = rand.nextInt(i) - val tmp = arr(j) - arr(j) = arr(i) - arr(i) = tmp - } - arr - } - - /** - * Get the local host's IP address in dotted-quad format (e.g. 1.2.3.4). - * Note, this is typically not used from within core spark. - */ - lazy val localIpAddress: String = findLocalIpAddress() - lazy val localIpAddressHostname: String = getAddressHostName(localIpAddress) - - private def findLocalIpAddress(): String = { - val defaultIpOverride = System.getenv("SPARK_LOCAL_IP") - if (defaultIpOverride != null) { - defaultIpOverride - } else { - val address = InetAddress.getLocalHost - if (address.isLoopbackAddress) { - // Address resolves to something like 127.0.1.1, which happens on Debian; try to find - // a better address using the local network interfaces - for (ni <- NetworkInterface.getNetworkInterfaces) { - for (addr <- ni.getInetAddresses if !addr.isLinkLocalAddress && - !addr.isLoopbackAddress && addr.isInstanceOf[Inet4Address]) { - // We've found an address that looks reasonable! - logWarning("Your hostname, " + InetAddress.getLocalHost.getHostName + " resolves to" + - " a loopback address: " + address.getHostAddress + "; using " + addr.getHostAddress + - " instead (on interface " + ni.getName + ")") - logWarning("Set SPARK_LOCAL_IP if you need to bind to another address") - return addr.getHostAddress - } - } - logWarning("Your hostname, " + InetAddress.getLocalHost.getHostName + " resolves to" + - " a loopback address: " + address.getHostAddress + ", but we couldn't find any" + - " external IP address!") - logWarning("Set SPARK_LOCAL_IP if you need to bind to another address") - } - address.getHostAddress - } - } - - private var customHostname: Option[String] = None - - /** - * Allow setting a custom host name because when we run on Mesos we need to use the same - * hostname it reports to the master. - */ - def setCustomHostname(hostname: String) { - // DEBUG code - Utils.checkHost(hostname) - customHostname = Some(hostname) - } - - /** - * Get the local machine's hostname. - */ - def localHostName(): String = { - customHostname.getOrElse(localIpAddressHostname) - } - - def getAddressHostName(address: String): String = { - InetAddress.getByName(address).getHostName - } - - def localHostPort(): String = { - val retval = System.getProperty("spark.hostPort", null) - if (retval == null) { - logErrorWithStack("spark.hostPort not set but invoking localHostPort") - return localHostName() - } - - retval - } - - def checkHost(host: String, message: String = "") { - assert(host.indexOf(':') == -1, message) - } - - def checkHostPort(hostPort: String, message: String = "") { - assert(hostPort.indexOf(':') != -1, message) - } - - // Used by DEBUG code : remove when all testing done - def logErrorWithStack(msg: String) { - try { throw new Exception } catch { case ex: Exception => { logError(msg, ex) } } - } - - // Typically, this will be of order of number of nodes in cluster - // If not, we should change it to LRUCache or something. - private val hostPortParseResults = new ConcurrentHashMap[String, (String, Int)]() - - def parseHostPort(hostPort: String): (String, Int) = { - { - // Check cache first. - var cached = hostPortParseResults.get(hostPort) - if (cached != null) return cached - } - - val indx: Int = hostPort.lastIndexOf(':') - // This is potentially broken - when dealing with ipv6 addresses for example, sigh ... - // but then hadoop does not support ipv6 right now. - // For now, we assume that if port exists, then it is valid - not check if it is an int > 0 - if (-1 == indx) { - val retval = (hostPort, 0) - hostPortParseResults.put(hostPort, retval) - return retval - } - - val retval = (hostPort.substring(0, indx).trim(), hostPort.substring(indx + 1).trim().toInt) - hostPortParseResults.putIfAbsent(hostPort, retval) - hostPortParseResults.get(hostPort) - } - - private[spark] val daemonThreadFactory: ThreadFactory = - new ThreadFactoryBuilder().setDaemon(true).build() - - /** - * Wrapper over newCachedThreadPool. - */ - def newDaemonCachedThreadPool(): ThreadPoolExecutor = - Executors.newCachedThreadPool(daemonThreadFactory).asInstanceOf[ThreadPoolExecutor] - - /** - * Return the string to tell how long has passed in seconds. The passing parameter should be in - * millisecond. - */ - def getUsedTimeMs(startTimeMs: Long): String = { - return " " + (System.currentTimeMillis - startTimeMs) + " ms" - } - - /** - * Wrapper over newFixedThreadPool. - */ - def newDaemonFixedThreadPool(nThreads: Int): ThreadPoolExecutor = - Executors.newFixedThreadPool(nThreads, daemonThreadFactory).asInstanceOf[ThreadPoolExecutor] - - /** - * Delete a file or directory and its contents recursively. - */ - def deleteRecursively(file: File) { - if (file.isDirectory) { - for (child <- file.listFiles()) { - deleteRecursively(child) - } - } - if (!file.delete()) { - throw new IOException("Failed to delete: " + file) - } - } - - /** - * Convert a Java memory parameter passed to -Xmx (such as 300m or 1g) to a number of megabytes. - * This is used to figure out how much memory to claim from Mesos based on the SPARK_MEM - * environment variable. - */ - def memoryStringToMb(str: String): Int = { - val lower = str.toLowerCase - if (lower.endsWith("k")) { - (lower.substring(0, lower.length-1).toLong / 1024).toInt - } else if (lower.endsWith("m")) { - lower.substring(0, lower.length-1).toInt - } else if (lower.endsWith("g")) { - lower.substring(0, lower.length-1).toInt * 1024 - } else if (lower.endsWith("t")) { - lower.substring(0, lower.length-1).toInt * 1024 * 1024 - } else {// no suffix, so it's just a number in bytes - (lower.toLong / 1024 / 1024).toInt - } - } - - /** - * Convert a quantity in bytes to a human-readable string such as "4.0 MB". - */ - def bytesToString(size: Long): String = { - val TB = 1L << 40 - val GB = 1L << 30 - val MB = 1L << 20 - val KB = 1L << 10 - - val (value, unit) = { - if (size >= 2*TB) { - (size.asInstanceOf[Double] / TB, "TB") - } else if (size >= 2*GB) { - (size.asInstanceOf[Double] / GB, "GB") - } else if (size >= 2*MB) { - (size.asInstanceOf[Double] / MB, "MB") - } else if (size >= 2*KB) { - (size.asInstanceOf[Double] / KB, "KB") - } else { - (size.asInstanceOf[Double], "B") - } - } - "%.1f %s".formatLocal(Locale.US, value, unit) - } - - /** - * Returns a human-readable string representing a duration such as "35ms" - */ - def msDurationToString(ms: Long): String = { - val second = 1000 - val minute = 60 * second - val hour = 60 * minute - - ms match { - case t if t < second => - "%d ms".format(t) - case t if t < minute => - "%.1f s".format(t.toFloat / second) - case t if t < hour => - "%.1f m".format(t.toFloat / minute) - case t => - "%.2f h".format(t.toFloat / hour) - } - } - - /** - * Convert a quantity in megabytes to a human-readable string such as "4.0 MB". - */ - def megabytesToString(megabytes: Long): String = { - bytesToString(megabytes * 1024L * 1024L) - } - - /** - * Execute a command in the given working directory, throwing an exception if it completes - * with an exit code other than 0. - */ - def execute(command: Seq[String], workingDir: File) { - val process = new ProcessBuilder(command: _*) - .directory(workingDir) - .redirectErrorStream(true) - .start() - new Thread("read stdout for " + command(0)) { - override def run() { - for (line <- Source.fromInputStream(process.getInputStream).getLines) { - System.err.println(line) - } - } - }.start() - val exitCode = process.waitFor() - if (exitCode != 0) { - throw new SparkException("Process " + command + " exited with code " + exitCode) - } - } - - /** - * Execute a command in the current working directory, throwing an exception if it completes - * with an exit code other than 0. - */ - def execute(command: Seq[String]) { - execute(command, new File(".")) - } - - /** - * Execute a command and get its output, throwing an exception if it yields a code other than 0. - */ - def executeAndGetOutput(command: Seq[String], workingDir: File = new File("."), - extraEnvironment: Map[String, String] = Map.empty): String = { - val builder = new ProcessBuilder(command: _*) - .directory(workingDir) - val environment = builder.environment() - for ((key, value) <- extraEnvironment) { - environment.put(key, value) - } - val process = builder.start() - new Thread("read stderr for " + command(0)) { - override def run() { - for (line <- Source.fromInputStream(process.getErrorStream).getLines) { - System.err.println(line) - } - } - }.start() - val output = new StringBuffer - val stdoutThread = new Thread("read stdout for " + command(0)) { - override def run() { - for (line <- Source.fromInputStream(process.getInputStream).getLines) { - output.append(line) - } - } - } - stdoutThread.start() - val exitCode = process.waitFor() - stdoutThread.join() // Wait for it to finish reading output - if (exitCode != 0) { - throw new SparkException("Process " + command + " exited with code " + exitCode) - } - output.toString - } - - /** - * A regular expression to match classes of the "core" Spark API that we want to skip when - * finding the call site of a method. - */ - private val SPARK_CLASS_REGEX = """^spark(\.api\.java)?(\.rdd)?\.[A-Z]""".r - - private[spark] class CallSiteInfo(val lastSparkMethod: String, val firstUserFile: String, - val firstUserLine: Int, val firstUserClass: String) - - /** - * When called inside a class in the spark package, returns the name of the user code class - * (outside the spark package) that called into Spark, as well as which Spark method they called. - * This is used, for example, to tell users where in their code each RDD got created. - */ - def getCallSiteInfo: CallSiteInfo = { - val trace = Thread.currentThread.getStackTrace().filter( el => - (!el.getMethodName.contains("getStackTrace"))) - - // Keep crawling up the stack trace until we find the first function not inside of the spark - // package. We track the last (shallowest) contiguous Spark method. This might be an RDD - // transformation, a SparkContext function (such as parallelize), or anything else that leads - // to instantiation of an RDD. We also track the first (deepest) user method, file, and line. - var lastSparkMethod = "" - var firstUserFile = "" - var firstUserLine = 0 - var finished = false - var firstUserClass = "" - - for (el <- trace) { - if (!finished) { - if (SPARK_CLASS_REGEX.findFirstIn(el.getClassName) != None) { - lastSparkMethod = if (el.getMethodName == "") { - // Spark method is a constructor; get its class name - el.getClassName.substring(el.getClassName.lastIndexOf('.') + 1) - } else { - el.getMethodName - } - } - else { - firstUserLine = el.getLineNumber - firstUserFile = el.getFileName - firstUserClass = el.getClassName - finished = true - } - } - } - new CallSiteInfo(lastSparkMethod, firstUserFile, firstUserLine, firstUserClass) - } - - def formatSparkCallSite = { - val callSiteInfo = getCallSiteInfo - "%s at %s:%s".format(callSiteInfo.lastSparkMethod, callSiteInfo.firstUserFile, - callSiteInfo.firstUserLine) - } - - /** Return a string containing part of a file from byte 'start' to 'end'. */ - def offsetBytes(path: String, start: Long, end: Long): String = { - val file = new File(path) - val length = file.length() - val effectiveEnd = math.min(length, end) - val effectiveStart = math.max(0, start) - val buff = new Array[Byte]((effectiveEnd-effectiveStart).toInt) - val stream = new FileInputStream(file) - - stream.skip(effectiveStart) - stream.read(buff) - stream.close() - Source.fromBytes(buff).mkString - } - - /** - * Clone an object using a Spark serializer. - */ - def clone[T](value: T, serializer: SerializerInstance): T = { - serializer.deserialize[T](serializer.serialize(value)) - } - - /** - * Detect whether this thread might be executing a shutdown hook. Will always return true if - * the current thread is a running a shutdown hook but may spuriously return true otherwise (e.g. - * if System.exit was just called by a concurrent thread). - * - * Currently, this detects whether the JVM is shutting down by Runtime#addShutdownHook throwing - * an IllegalStateException. - */ - def inShutdown(): Boolean = { - try { - val hook = new Thread { - override def run() {} - } - Runtime.getRuntime.addShutdownHook(hook) - Runtime.getRuntime.removeShutdownHook(hook) - } catch { - case ise: IllegalStateException => return true - } - return false - } - - def isSpace(c: Char): Boolean = { - " \t\r\n".indexOf(c) != -1 - } - - /** - * Split a string of potentially quoted arguments from the command line the way that a shell - * would do it to determine arguments to a command. For example, if the string is 'a "b c" d', - * then it would be parsed as three arguments: 'a', 'b c' and 'd'. - */ - def splitCommandString(s: String): Seq[String] = { - val buf = new ArrayBuffer[String] - var inWord = false - var inSingleQuote = false - var inDoubleQuote = false - var curWord = new StringBuilder - def endWord() { - buf += curWord.toString - curWord.clear() - } - var i = 0 - while (i < s.length) { - var nextChar = s.charAt(i) - if (inDoubleQuote) { - if (nextChar == '"') { - inDoubleQuote = false - } else if (nextChar == '\\') { - if (i < s.length - 1) { - // Append the next character directly, because only " and \ may be escaped in - // double quotes after the shell's own expansion - curWord.append(s.charAt(i + 1)) - i += 1 - } - } else { - curWord.append(nextChar) - } - } else if (inSingleQuote) { - if (nextChar == '\'') { - inSingleQuote = false - } else { - curWord.append(nextChar) - } - // Backslashes are not treated specially in single quotes - } else if (nextChar == '"') { - inWord = true - inDoubleQuote = true - } else if (nextChar == '\'') { - inWord = true - inSingleQuote = true - } else if (!isSpace(nextChar)) { - curWord.append(nextChar) - inWord = true - } else if (inWord && isSpace(nextChar)) { - endWord() - inWord = false - } - i += 1 - } - if (inWord || inDoubleQuote || inSingleQuote) { - endWord() - } - return buf - } - - /* Calculates 'x' modulo 'mod', takes to consideration sign of x, - * i.e. if 'x' is negative, than 'x' % 'mod' is negative too - * so function return (x % mod) + mod in that case. - */ - def nonNegativeMod(x: Int, mod: Int): Int = { - val rawMod = x % mod - rawMod + (if (rawMod < 0) mod else 0) - } -} diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala index cb25ff728e..5fd1fab580 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala @@ -17,7 +17,7 @@ package org.apache.spark.api.java -import org.apache.spark.RDD +import org.apache.spark.rdd.RDD import org.apache.spark.SparkContext.doubleRDDToDoubleRDDFunctions import org.apache.spark.api.java.function.{Function => JFunction} import org.apache.spark.util.StatCounter diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala index 09da35aee6..a6518abf45 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala @@ -33,12 +33,12 @@ import org.apache.hadoop.conf.Configuration import org.apache.spark.HashPartitioner import org.apache.spark.Partitioner import org.apache.spark.Partitioner._ -import org.apache.spark.RDD import org.apache.spark.SparkContext.rddToPairRDDFunctions import org.apache.spark.api.java.function.{Function2 => JFunction2} import org.apache.spark.api.java.function.{Function => JFunction} import org.apache.spark.partial.BoundedDouble import org.apache.spark.partial.PartialResult +import org.apache.spark.rdd.RDD import org.apache.spark.rdd.OrderedRDDFunctions import org.apache.spark.storage.StorageLevel diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala index 68cfcf5999..eec58abdd6 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala @@ -18,6 +18,7 @@ package org.apache.spark.api.java import org.apache.spark._ +import org.apache.spark.rdd.RDD import org.apache.spark.api.java.function.{Function => JFunction} import org.apache.spark.storage.StorageLevel diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala index 1ad8514980..7e6e691f11 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala @@ -21,13 +21,15 @@ import java.util.{List => JList, Comparator} import scala.Tuple2 import scala.collection.JavaConversions._ +import com.google.common.base.Optional import org.apache.hadoop.io.compress.CompressionCodec -import org.apache.spark.{SparkContext, Partition, RDD, TaskContext} + +import org.apache.spark.{SparkContext, Partition, TaskContext} +import org.apache.spark.rdd.RDD import org.apache.spark.api.java.JavaPairRDD._ import org.apache.spark.api.java.function.{Function2 => JFunction2, Function => JFunction, _} import org.apache.spark.partial.{PartialResult, BoundedDouble} import org.apache.spark.storage.StorageLevel -import com.google.common.base.Optional trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala index 618a7b3bf7..8869e072bf 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala @@ -26,13 +26,13 @@ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.mapred.InputFormat import org.apache.hadoop.mapred.JobConf import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat} +import com.google.common.base.Optional -import org.apache.spark.{Accumulable, AccumulableParam, Accumulator, AccumulatorParam, RDD, SparkContext} +import org.apache.spark.{Accumulable, AccumulableParam, Accumulator, AccumulatorParam, SparkContext} import org.apache.spark.SparkContext.IntAccumulatorParam import org.apache.spark.SparkContext.DoubleAccumulatorParam import org.apache.spark.broadcast.Broadcast - -import com.google.common.base.Optional +import org.apache.spark.rdd.RDD /** * A Java-friendly version of [[org.apache.spark.SparkContext]] that returns [[org.apache.spark.api.java.JavaRDD]]s and diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonPartitioner.scala b/core/src/main/scala/org/apache/spark/api/python/PythonPartitioner.scala index eea63d5a4e..b090c6edf3 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonPartitioner.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonPartitioner.scala @@ -18,8 +18,8 @@ package org.apache.spark.api.python import org.apache.spark.Partitioner -import org.apache.spark.Utils import java.util.Arrays +import org.apache.spark.util.Utils /** * A [[org.apache.spark.Partitioner]] that performs handling of byte arrays, for use by the Python API. diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 621f0fe8ee..ccd3833964 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -26,7 +26,9 @@ import scala.collection.JavaConversions._ import org.apache.spark.api.java.{JavaSparkContext, JavaPairRDD, JavaRDD} import org.apache.spark.broadcast.Broadcast import org.apache.spark._ +import org.apache.spark.rdd.RDD import org.apache.spark.rdd.PipedRDD +import org.apache.spark.util.Utils private[spark] class PythonRDD[T: ClassManifest]( diff --git a/core/src/main/scala/org/apache/spark/broadcast/BitTorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/BitTorrentBroadcast.scala index 99e86237fc..93e7815ab5 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/BitTorrentBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/BitTorrentBroadcast.scala @@ -27,6 +27,7 @@ import scala.math import org.apache.spark._ import org.apache.spark.storage.StorageLevel +import org.apache.spark.util.Utils private[spark] class BitTorrentBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long) extends Broadcast[T](id) diff --git a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala index 7a52ff0769..9db26ae6de 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala @@ -23,10 +23,10 @@ import java.net.URL import it.unimi.dsi.fastutil.io.FastBufferedInputStream import it.unimi.dsi.fastutil.io.FastBufferedOutputStream -import org.apache.spark.{HttpServer, Logging, SparkEnv, Utils} +import org.apache.spark.{HttpServer, Logging, SparkEnv} import org.apache.spark.io.CompressionCodec import org.apache.spark.storage.StorageLevel -import org.apache.spark.util.{MetadataCleaner, TimeStampedHashSet} +import org.apache.spark.util.{Utils, MetadataCleaner, TimeStampedHashSet} private[spark] class HttpBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long) diff --git a/core/src/main/scala/org/apache/spark/broadcast/MultiTracker.scala b/core/src/main/scala/org/apache/spark/broadcast/MultiTracker.scala index 10b910df87..21ec94659e 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/MultiTracker.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/MultiTracker.scala @@ -24,6 +24,7 @@ import java.util.Random import scala.collection.mutable.Map import org.apache.spark._ +import org.apache.spark.util.Utils private object MultiTracker extends Logging { diff --git a/core/src/main/scala/org/apache/spark/broadcast/TreeBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TreeBroadcast.scala index b5a4ccc0ee..80c97ca073 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/TreeBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/TreeBroadcast.scala @@ -26,6 +26,7 @@ import scala.math import org.apache.spark._ import org.apache.spark.storage.StorageLevel +import org.apache.spark.util.Utils private[spark] class TreeBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long) extends Broadcast[T](id) with Logging with Serializable { diff --git a/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala b/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala index 4dc6ada2d1..c31619db27 100644 --- a/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala @@ -19,10 +19,10 @@ package org.apache.spark.deploy import scala.collection.immutable.List -import org.apache.spark.Utils import org.apache.spark.deploy.ExecutorState.ExecutorState import org.apache.spark.deploy.master.{WorkerInfo, ApplicationInfo} import org.apache.spark.deploy.worker.ExecutorRunner +import org.apache.spark.util.Utils private[deploy] sealed trait DeployMessage extends Serializable diff --git a/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala b/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala index af5a4110b0..78e3747ad8 100644 --- a/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala +++ b/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala @@ -21,8 +21,8 @@ import akka.actor.{ActorRef, Props, Actor, ActorSystem, Terminated} import org.apache.spark.deploy.worker.Worker import org.apache.spark.deploy.master.Master -import org.apache.spark.util.AkkaUtils -import org.apache.spark.{Logging, Utils} +import org.apache.spark.util.{Utils, AkkaUtils} +import org.apache.spark.{Logging} import scala.collection.mutable.ArrayBuffer diff --git a/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala b/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala index 0322029fbd..d5e9a0e095 100644 --- a/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala @@ -17,8 +17,8 @@ package org.apache.spark.deploy.client -import org.apache.spark.util.AkkaUtils -import org.apache.spark.{Logging, Utils} +import org.apache.spark.util.{Utils, AkkaUtils} +import org.apache.spark.{Logging} import org.apache.spark.deploy.{Command, ApplicationDescription} private[spark] object TestClient { diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index 869b2b2646..7cf0a7754f 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -27,12 +27,12 @@ import akka.actor.Terminated import akka.remote.{RemoteClientLifeCycleEvent, RemoteClientDisconnected, RemoteClientShutdown} import akka.util.duration._ -import org.apache.spark.{Logging, SparkException, Utils} +import org.apache.spark.{Logging, SparkException} import org.apache.spark.deploy.{ApplicationDescription, ExecutorState} import org.apache.spark.deploy.DeployMessages._ import org.apache.spark.deploy.master.ui.MasterWebUI import org.apache.spark.metrics.MetricsSystem -import org.apache.spark.util.AkkaUtils +import org.apache.spark.util.{Utils, AkkaUtils} private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Actor with Logging { diff --git a/core/src/main/scala/org/apache/spark/deploy/master/MasterArguments.scala b/core/src/main/scala/org/apache/spark/deploy/master/MasterArguments.scala index c86cca278d..9d89b455fb 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/MasterArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/MasterArguments.scala @@ -17,8 +17,7 @@ package org.apache.spark.deploy.master -import org.apache.spark.util.IntParam -import org.apache.spark.Utils +import org.apache.spark.util.{Utils, IntParam} /** * Command-line parser for the master. diff --git a/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala b/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala index 285e07a823..6219f11f2a 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala @@ -19,7 +19,7 @@ package org.apache.spark.deploy.master import akka.actor.ActorRef import scala.collection.mutable -import org.apache.spark.Utils +import org.apache.spark.util.Utils private[spark] class WorkerInfo( val id: String, diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala index 6435c7f917..f4e574d15d 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala @@ -31,7 +31,7 @@ import org.apache.spark.deploy.DeployMessages.{MasterStateResponse, RequestMaste import org.apache.spark.deploy.JsonProtocol import org.apache.spark.deploy.master.ExecutorInfo import org.apache.spark.ui.UIUtils -import org.apache.spark.Utils +import org.apache.spark.util.Utils private[spark] class ApplicationPage(parent: MasterWebUI) { val master = parent.masterActorRef diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/IndexPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/IndexPage.scala index 58d3863009..d7a57229b0 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/IndexPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/IndexPage.scala @@ -27,12 +27,12 @@ import akka.util.duration._ import net.liftweb.json.JsonAST.JValue -import org.apache.spark.Utils import org.apache.spark.deploy.DeployWebUI import org.apache.spark.deploy.DeployMessages.{MasterStateResponse, RequestMasterState} import org.apache.spark.deploy.JsonProtocol import org.apache.spark.deploy.master.{ApplicationInfo, WorkerInfo} import org.apache.spark.ui.UIUtils +import org.apache.spark.util.Utils private[spark] class IndexPage(parent: MasterWebUI) { val master = parent.masterActorRef diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala index 47b1e521f5..f4df729e87 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala @@ -23,10 +23,11 @@ import javax.servlet.http.HttpServletRequest import org.eclipse.jetty.server.{Handler, Server} -import org.apache.spark.{Logging, Utils} +import org.apache.spark.{Logging} import org.apache.spark.deploy.master.Master import org.apache.spark.ui.JettyUtils import org.apache.spark.ui.JettyUtils._ +import org.apache.spark.util.Utils /** * Web UI server for the standalone master. diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala index 01ce4a6dea..e3dc30eefc 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala @@ -25,9 +25,10 @@ import akka.actor.ActorRef import com.google.common.base.Charsets import com.google.common.io.Files -import org.apache.spark.{Utils, Logging} +import org.apache.spark.{Logging} import org.apache.spark.deploy.{ExecutorState, ApplicationDescription} import org.apache.spark.deploy.DeployMessages.ExecutorStateChanged +import org.apache.spark.util.Utils /** * Manages the execution of one executor process. diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala index 86e8e7543b..09530beb3b 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala @@ -27,13 +27,13 @@ import akka.actor.{ActorRef, Props, Actor, ActorSystem, Terminated} import akka.remote.{RemoteClientLifeCycleEvent, RemoteClientShutdown, RemoteClientDisconnected} import akka.util.duration._ -import org.apache.spark.{Logging, Utils} +import org.apache.spark.{Logging} import org.apache.spark.deploy.ExecutorState import org.apache.spark.deploy.DeployMessages._ import org.apache.spark.deploy.master.Master import org.apache.spark.deploy.worker.ui.WorkerWebUI import org.apache.spark.metrics.MetricsSystem -import org.apache.spark.util.AkkaUtils +import org.apache.spark.util.{Utils, AkkaUtils} private[spark] class Worker( diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala index 6d91223413..0ae89a864f 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala @@ -17,9 +17,7 @@ package org.apache.spark.deploy.worker -import org.apache.spark.util.IntParam -import org.apache.spark.util.MemoryParam -import org.apache.spark.Utils +import org.apache.spark.util.{Utils, IntParam, MemoryParam} import java.lang.management.ManagementFactory /** diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/IndexPage.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/IndexPage.scala index 6192c2324b..d2d3617498 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ui/IndexPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/IndexPage.scala @@ -27,11 +27,11 @@ import akka.util.duration._ import net.liftweb.json.JsonAST.JValue -import org.apache.spark.Utils import org.apache.spark.deploy.JsonProtocol import org.apache.spark.deploy.DeployMessages.{RequestWorkerState, WorkerStateResponse} import org.apache.spark.deploy.worker.ExecutorRunner import org.apache.spark.ui.UIUtils +import org.apache.spark.util.Utils private[spark] class IndexPage(parent: WorkerWebUI) { diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala index bb8165ac09..95d6007f3b 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala @@ -26,10 +26,11 @@ import javax.servlet.http.HttpServletRequest import org.eclipse.jetty.server.{Handler, Server} import org.apache.spark.deploy.worker.Worker -import org.apache.spark.{Utils, Logging} +import org.apache.spark.{Logging} import org.apache.spark.ui.JettyUtils import org.apache.spark.ui.JettyUtils._ import org.apache.spark.ui.UIUtils +import org.apache.spark.util.Utils /** * Web UI server for the standalone worker. diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 5446a3fca9..d365804994 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -27,6 +27,7 @@ import scala.collection.mutable.HashMap import org.apache.spark.scheduler._ import org.apache.spark._ +import org.apache.spark.util.Utils /** diff --git a/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala index 410a94df6b..da62091980 100644 --- a/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala @@ -22,8 +22,9 @@ import org.apache.mesos.{Executor => MesosExecutor, MesosExecutorDriver, MesosNa import org.apache.mesos.Protos.{TaskState => MesosTaskState, TaskStatus => MesosTaskStatus, _} import org.apache.spark.TaskState.TaskState import com.google.protobuf.ByteString -import org.apache.spark.{Utils, Logging} +import org.apache.spark.{Logging} import org.apache.spark.TaskState +import org.apache.spark.util.Utils private[spark] class MesosExecutorBackend extends MesosExecutor diff --git a/core/src/main/scala/org/apache/spark/executor/StandaloneExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/StandaloneExecutorBackend.scala index 65801f75b7..7839023868 100644 --- a/core/src/main/scala/org/apache/spark/executor/StandaloneExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/StandaloneExecutorBackend.scala @@ -22,10 +22,10 @@ import java.nio.ByteBuffer import akka.actor.{ActorRef, Actor, Props, Terminated} import akka.remote.{RemoteClientLifeCycleEvent, RemoteClientShutdown, RemoteClientDisconnected} -import org.apache.spark.{Logging, Utils, SparkEnv} +import org.apache.spark.{Logging, SparkEnv} import org.apache.spark.TaskState.TaskState import org.apache.spark.scheduler.cluster.StandaloneClusterMessages._ -import org.apache.spark.util.AkkaUtils +import org.apache.spark.util.{Utils, AkkaUtils} private[spark] class StandaloneExecutorBackend( diff --git a/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala b/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala index 9e2233c07b..e15a839c4e 100644 --- a/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala +++ b/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala @@ -34,6 +34,7 @@ import scala.collection.mutable.ArrayBuffer import akka.dispatch.{Await, Promise, ExecutionContext, Future} import akka.util.Duration import akka.util.duration._ +import org.apache.spark.util.Utils private[spark] class ConnectionManager(port: Int) extends Logging { diff --git a/core/src/main/scala/org/apache/spark/network/ConnectionManagerId.scala b/core/src/main/scala/org/apache/spark/network/ConnectionManagerId.scala index 0839c011b8..50dd9bc2d1 100644 --- a/core/src/main/scala/org/apache/spark/network/ConnectionManagerId.scala +++ b/core/src/main/scala/org/apache/spark/network/ConnectionManagerId.scala @@ -19,7 +19,7 @@ package org.apache.spark.network import java.net.InetSocketAddress -import org.apache.spark.Utils +import org.apache.spark.util.Utils private[spark] case class ConnectionManagerId(host: String, port: Int) { diff --git a/core/src/main/scala/org/apache/spark/package.scala b/core/src/main/scala/org/apache/spark/package.scala index 1126480689..c0ec527339 100644 --- a/core/src/main/scala/org/apache/spark/package.scala +++ b/core/src/main/scala/org/apache/spark/package.scala @@ -1,3 +1,5 @@ +import org.apache.spark.rdd.{SequenceFileRDDFunctions, DoubleRDDFunctions, PairRDDFunctions} + /* * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with @@ -16,16 +18,17 @@ */ /** - * Core Spark functionality. [[org.apache.spark.SparkContext]] serves as the main entry point to Spark, while - * [[org.apache.spark.RDD]] is the data type representing a distributed collection, and provides most - * parallel operations. + * Core Spark functionality. [[org.apache.spark.SparkContext]] serves as the main entry point to + * Spark, while [[org.apache.spark.rdd.RDD]] is the data type representing a distributed collection, + * and provides most parallel operations. * - * In addition, [[org.apache.spark.PairRDDFunctions]] contains operations available only on RDDs of key-value - * pairs, such as `groupByKey` and `join`; [[org.apache.spark.DoubleRDDFunctions]] contains operations - * available only on RDDs of Doubles; and [[org.apache.spark.SequenceFileRDDFunctions]] contains operations - * available on RDDs that can be saved as SequenceFiles. These operations are automatically - * available on any RDD of the right type (e.g. RDD[(Int, Int)] through implicit conversions when - * you `import org.apache.spark.SparkContext._`. + * In addition, [[org.apache.spark.rdd.PairRDDFunctions]] contains operations available only on RDDs + * of key-value pairs, such as `groupByKey` and `join`; [[org.apache.spark.rdd.DoubleRDDFunctions]] + * contains operations available only on RDDs of Doubles; and + * [[org.apache.spark.rdd.SequenceFileRDDFunctions]] contains operations available on RDDs that can + * be saved as SequenceFiles. These operations are automatically available on any RDD of the right + * type (e.g. RDD[(Int, Int)] through implicit conversions when you + * `import org.apache.spark.SparkContext._`. */ package object spark { // For package docs only diff --git a/core/src/main/scala/org/apache/spark/partial/ApproximateActionListener.scala b/core/src/main/scala/org/apache/spark/partial/ApproximateActionListener.scala index c5d51bee50..d71069444a 100644 --- a/core/src/main/scala/org/apache/spark/partial/ApproximateActionListener.scala +++ b/core/src/main/scala/org/apache/spark/partial/ApproximateActionListener.scala @@ -19,6 +19,7 @@ package org.apache.spark.partial import org.apache.spark._ import org.apache.spark.scheduler.JobListener +import org.apache.spark.rdd.RDD /** * A JobListener for an approximate single-result action, such as count() or non-parallel reduce(). diff --git a/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala b/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala index 4bb01efa86..bca6956a18 100644 --- a/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala @@ -17,7 +17,7 @@ package org.apache.spark.rdd -import org.apache.spark.{RDD, SparkContext, SparkEnv, Partition, TaskContext} +import org.apache.spark.{SparkContext, SparkEnv, Partition, TaskContext} import org.apache.spark.storage.BlockManager private[spark] class BlockRDDPartition(val blockId: String, idx: Int) extends Partition { diff --git a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala index dcc35e8d0e..0187256a8e 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala @@ -23,7 +23,7 @@ import java.util.{HashMap => JHashMap} import scala.collection.JavaConversions import scala.collection.mutable.ArrayBuffer -import org.apache.spark.{Partition, Partitioner, RDD, SparkEnv, TaskContext} +import org.apache.spark.{Partition, Partitioner, SparkEnv, TaskContext} import org.apache.spark.{Dependency, OneToOneDependency, ShuffleDependency} diff --git a/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala new file mode 100644 index 0000000000..a4bec41752 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rdd + +import org.apache.spark.partial.BoundedDouble +import org.apache.spark.partial.MeanEvaluator +import org.apache.spark.partial.PartialResult +import org.apache.spark.partial.SumEvaluator +import org.apache.spark.util.StatCounter +import org.apache.spark.{TaskContext, Logging} + +/** + * Extra functions available on RDDs of Doubles through an implicit conversion. + * Import `org.apache.spark.SparkContext._` at the top of your program to use these functions. + */ +class DoubleRDDFunctions(self: RDD[Double]) extends Logging with Serializable { + /** Add up the elements in this RDD. */ + def sum(): Double = { + self.reduce(_ + _) + } + + /** + * Return a [[org.apache.spark.util.StatCounter]] object that captures the mean, variance and count + * of the RDD's elements in one operation. + */ + def stats(): StatCounter = { + self.mapPartitions(nums => Iterator(StatCounter(nums))).reduce((a, b) => a.merge(b)) + } + + /** Compute the mean of this RDD's elements. */ + def mean(): Double = stats().mean + + /** Compute the variance of this RDD's elements. */ + def variance(): Double = stats().variance + + /** Compute the standard deviation of this RDD's elements. */ + def stdev(): Double = stats().stdev + + /** + * Compute the sample standard deviation of this RDD's elements (which corrects for bias in + * estimating the standard deviation by dividing by N-1 instead of N). + */ + def sampleStdev(): Double = stats().sampleStdev + + /** + * Compute the sample variance of this RDD's elements (which corrects for bias in + * estimating the variance by dividing by N-1 instead of N). + */ + def sampleVariance(): Double = stats().sampleVariance + + /** (Experimental) Approximate operation to return the mean within a timeout. */ + def meanApprox(timeout: Long, confidence: Double = 0.95): PartialResult[BoundedDouble] = { + val processPartition = (ctx: TaskContext, ns: Iterator[Double]) => StatCounter(ns) + val evaluator = new MeanEvaluator(self.partitions.size, confidence) + self.context.runApproximateJob(self, processPartition, evaluator, timeout) + } + + /** (Experimental) Approximate operation to return the sum within a timeout. */ + def sumApprox(timeout: Long, confidence: Double = 0.95): PartialResult[BoundedDouble] = { + val processPartition = (ctx: TaskContext, ns: Iterator[Double]) => StatCounter(ns) + val evaluator = new SumEvaluator(self.partitions.size, confidence) + self.context.runApproximateJob(self, processPartition, evaluator, timeout) + } +} diff --git a/core/src/main/scala/org/apache/spark/rdd/EmptyRDD.scala b/core/src/main/scala/org/apache/spark/rdd/EmptyRDD.scala index 24ce4abbc4..c8900d1a93 100644 --- a/core/src/main/scala/org/apache/spark/rdd/EmptyRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/EmptyRDD.scala @@ -17,7 +17,7 @@ package org.apache.spark.rdd -import org.apache.spark.{RDD, SparkContext, SparkEnv, Partition, TaskContext} +import org.apache.spark.{SparkContext, SparkEnv, Partition, TaskContext} /** diff --git a/core/src/main/scala/org/apache/spark/rdd/FilteredRDD.scala b/core/src/main/scala/org/apache/spark/rdd/FilteredRDD.scala index 4df8ceb58b..5312dc0b59 100644 --- a/core/src/main/scala/org/apache/spark/rdd/FilteredRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/FilteredRDD.scala @@ -17,7 +17,7 @@ package org.apache.spark.rdd -import org.apache.spark.{OneToOneDependency, RDD, Partition, TaskContext} +import org.apache.spark.{OneToOneDependency, Partition, TaskContext} private[spark] class FilteredRDD[T: ClassManifest]( prev: RDD[T], diff --git a/core/src/main/scala/org/apache/spark/rdd/FlatMappedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/FlatMappedRDD.scala index 2bf7653af1..cbdf6d84c0 100644 --- a/core/src/main/scala/org/apache/spark/rdd/FlatMappedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/FlatMappedRDD.scala @@ -17,7 +17,7 @@ package org.apache.spark.rdd -import org.apache.spark.{RDD, Partition, TaskContext} +import org.apache.spark.{Partition, TaskContext} private[spark] diff --git a/core/src/main/scala/org/apache/spark/rdd/FlatMappedValuesRDD.scala b/core/src/main/scala/org/apache/spark/rdd/FlatMappedValuesRDD.scala index e544720b05..82000bac09 100644 --- a/core/src/main/scala/org/apache/spark/rdd/FlatMappedValuesRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/FlatMappedValuesRDD.scala @@ -17,7 +17,7 @@ package org.apache.spark.rdd -import org.apache.spark.{TaskContext, Partition, RDD} +import org.apache.spark.{TaskContext, Partition} private[spark] diff --git a/core/src/main/scala/org/apache/spark/rdd/GlommedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/GlommedRDD.scala index 2ce94199f2..829545d7b0 100644 --- a/core/src/main/scala/org/apache/spark/rdd/GlommedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/GlommedRDD.scala @@ -17,7 +17,7 @@ package org.apache.spark.rdd -import org.apache.spark.{RDD, Partition, TaskContext} +import org.apache.spark.{Partition, TaskContext} private[spark] class GlommedRDD[T: ClassManifest](prev: RDD[T]) extends RDD[Array[T]](prev) { diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala index 08e6154bb9..2cb6734e41 100644 --- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala @@ -18,21 +18,15 @@ package org.apache.spark.rdd import java.io.EOFException -import java.util.NoSuchElementException -import org.apache.hadoop.io.LongWritable -import org.apache.hadoop.io.NullWritable -import org.apache.hadoop.io.Text -import org.apache.hadoop.mapred.FileInputFormat import org.apache.hadoop.mapred.InputFormat import org.apache.hadoop.mapred.InputSplit import org.apache.hadoop.mapred.JobConf -import org.apache.hadoop.mapred.TextInputFormat import org.apache.hadoop.mapred.RecordReader import org.apache.hadoop.mapred.Reporter import org.apache.hadoop.util.ReflectionUtils -import org.apache.spark.{Dependency, Logging, Partition, RDD, SerializableWritable, SparkContext, SparkEnv, TaskContext} +import org.apache.spark.{Logging, Partition, SerializableWritable, SparkContext, SparkEnv, TaskContext} import org.apache.spark.util.NextIterator import org.apache.hadoop.conf.{Configuration, Configurable} diff --git a/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala b/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala index 3db460b3ce..aca0146884 100644 --- a/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala @@ -19,7 +19,7 @@ package org.apache.spark.rdd import java.sql.{Connection, ResultSet} -import org.apache.spark.{Logging, Partition, RDD, SparkContext, TaskContext} +import org.apache.spark.{Logging, Partition, SparkContext, TaskContext} import org.apache.spark.util.NextIterator private[spark] class JdbcPartition(idx: Int, val lower: Long, val upper: Long) extends Partition { diff --git a/core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala b/core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala index 13009d3e17..203179c4ea 100644 --- a/core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala @@ -17,7 +17,7 @@ package org.apache.spark.rdd -import org.apache.spark.{RDD, Partition, TaskContext} +import org.apache.spark.{Partition, TaskContext} private[spark] diff --git a/core/src/main/scala/org/apache/spark/rdd/MapPartitionsWithIndexRDD.scala b/core/src/main/scala/org/apache/spark/rdd/MapPartitionsWithIndexRDD.scala index 1683050b86..3ed8339010 100644 --- a/core/src/main/scala/org/apache/spark/rdd/MapPartitionsWithIndexRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/MapPartitionsWithIndexRDD.scala @@ -17,7 +17,7 @@ package org.apache.spark.rdd -import org.apache.spark.{RDD, Partition, TaskContext} +import org.apache.spark.{Partition, TaskContext} /** diff --git a/core/src/main/scala/org/apache/spark/rdd/MappedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/MappedRDD.scala index 26d4806edb..e8be1c4816 100644 --- a/core/src/main/scala/org/apache/spark/rdd/MappedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/MappedRDD.scala @@ -17,7 +17,7 @@ package org.apache.spark.rdd -import org.apache.spark.{RDD, Partition, TaskContext} +import org.apache.spark.{Partition, TaskContext} private[spark] class MappedRDD[U: ClassManifest, T: ClassManifest](prev: RDD[T], f: T => U) diff --git a/core/src/main/scala/org/apache/spark/rdd/MappedValuesRDD.scala b/core/src/main/scala/org/apache/spark/rdd/MappedValuesRDD.scala index a405e9acdd..d33c1af581 100644 --- a/core/src/main/scala/org/apache/spark/rdd/MappedValuesRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/MappedValuesRDD.scala @@ -18,7 +18,7 @@ package org.apache.spark.rdd -import org.apache.spark.{TaskContext, Partition, RDD} +import org.apache.spark.{TaskContext, Partition} private[spark] class MappedValuesRDD[K, V, U](prev: RDD[_ <: Product2[K, V]], f: V => U) diff --git a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala index 114b504486..7b3a89f7e0 100644 --- a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala @@ -24,7 +24,7 @@ import org.apache.hadoop.conf.{Configurable, Configuration} import org.apache.hadoop.io.Writable import org.apache.hadoop.mapreduce._ -import org.apache.spark.{Dependency, Logging, Partition, RDD, SerializableWritable, SparkContext, TaskContext} +import org.apache.spark.{Dependency, Logging, Partition, SerializableWritable, SparkContext, TaskContext} private[spark] diff --git a/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala index 4c3df0eaf4..697be8b997 100644 --- a/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala @@ -17,12 +17,13 @@ package org.apache.spark.rdd -import org.apache.spark.{RangePartitioner, Logging, RDD} +import org.apache.spark.{RangePartitioner, Logging} /** * Extra functions available on RDDs of (key, value) pairs where the key is sortable through - * an implicit conversion. Import `spark.SparkContext._` at the top of your program to use these - * functions. They will work with any key type that has a `scala.math.Ordered` implementation. + * an implicit conversion. Import `org.apache.spark.SparkContext._` at the top of your program to + * use these functions. They will work with any key type that has a `scala.math.Ordered` + * implementation. */ class OrderedRDDFunctions[K <% Ordered[K]: ClassManifest, V: ClassManifest, diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala new file mode 100644 index 0000000000..a47c512275 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala @@ -0,0 +1,702 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rdd + +import java.nio.ByteBuffer +import java.util.Date +import java.text.SimpleDateFormat +import java.util.{HashMap => JHashMap} + +import scala.collection.{mutable, Map} +import scala.collection.mutable.ArrayBuffer +import scala.collection.JavaConversions._ + +import org.apache.hadoop.mapred._ +import org.apache.hadoop.io.compress.CompressionCodec +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.Path +import org.apache.hadoop.io.SequenceFile.CompressionType +import org.apache.hadoop.mapred.FileOutputFormat +import org.apache.hadoop.mapred.OutputFormat +import org.apache.hadoop.mapreduce.{OutputFormat => NewOutputFormat} +import org.apache.hadoop.mapreduce.lib.output.{FileOutputFormat => NewFileOutputFormat} +import org.apache.hadoop.mapreduce.SparkHadoopMapReduceUtil +import org.apache.hadoop.mapreduce.{Job => NewAPIHadoopJob} +import org.apache.hadoop.mapreduce.{RecordWriter => NewRecordWriter} + +import org.apache.spark._ +import org.apache.spark.SparkContext._ +import org.apache.spark.partial.{BoundedDouble, PartialResult} +import org.apache.spark.Aggregator +import org.apache.spark.Partitioner +import org.apache.spark.Partitioner.defaultPartitioner + +/** + * Extra functions available on RDDs of (key, value) pairs through an implicit conversion. + * Import `org.apache.spark.SparkContext._` at the top of your program to use these functions. + */ +class PairRDDFunctions[K: ClassManifest, V: ClassManifest](self: RDD[(K, V)]) + extends Logging + with SparkHadoopMapReduceUtil + with Serializable { + + /** + * Generic function to combine the elements for each key using a custom set of aggregation + * functions. Turns an RDD[(K, V)] into a result of type RDD[(K, C)], for a "combined type" C + * Note that V and C can be different -- for example, one might group an RDD of type + * (Int, Int) into an RDD of type (Int, Seq[Int]). Users provide three functions: + * + * - `createCombiner`, which turns a V into a C (e.g., creates a one-element list) + * - `mergeValue`, to merge a V into a C (e.g., adds it to the end of a list) + * - `mergeCombiners`, to combine two C's into a single one. + * + * In addition, users can control the partitioning of the output RDD, and whether to perform + * map-side aggregation (if a mapper can produce multiple items with the same key). + */ + def combineByKey[C](createCombiner: V => C, + mergeValue: (C, V) => C, + mergeCombiners: (C, C) => C, + partitioner: Partitioner, + mapSideCombine: Boolean = true, + serializerClass: String = null): RDD[(K, C)] = { + if (getKeyClass().isArray) { + if (mapSideCombine) { + throw new SparkException("Cannot use map-side combining with array keys.") + } + if (partitioner.isInstanceOf[HashPartitioner]) { + throw new SparkException("Default partitioner cannot partition array keys.") + } + } + val aggregator = new Aggregator[K, V, C](createCombiner, mergeValue, mergeCombiners) + if (self.partitioner == Some(partitioner)) { + self.mapPartitions(aggregator.combineValuesByKey, preservesPartitioning = true) + } else if (mapSideCombine) { + val combined = self.mapPartitions(aggregator.combineValuesByKey, preservesPartitioning = true) + val partitioned = new ShuffledRDD[K, C, (K, C)](combined, partitioner) + .setSerializer(serializerClass) + partitioned.mapPartitions(aggregator.combineCombinersByKey, preservesPartitioning = true) + } else { + // Don't apply map-side combiner. + // A sanity check to make sure mergeCombiners is not defined. + assert(mergeCombiners == null) + val values = new ShuffledRDD[K, V, (K, V)](self, partitioner).setSerializer(serializerClass) + values.mapPartitions(aggregator.combineValuesByKey, preservesPartitioning = true) + } + } + + /** + * Simplified version of combineByKey that hash-partitions the output RDD. + */ + def combineByKey[C](createCombiner: V => C, + mergeValue: (C, V) => C, + mergeCombiners: (C, C) => C, + numPartitions: Int): RDD[(K, C)] = { + combineByKey(createCombiner, mergeValue, mergeCombiners, new HashPartitioner(numPartitions)) + } + + /** + * Merge the values for each key using an associative function and a neutral "zero value" which may + * be added to the result an arbitrary number of times, and must not change the result (e.g., Nil for + * list concatenation, 0 for addition, or 1 for multiplication.). + */ + def foldByKey(zeroValue: V, partitioner: Partitioner)(func: (V, V) => V): RDD[(K, V)] = { + // Serialize the zero value to a byte array so that we can get a new clone of it on each key + val zeroBuffer = SparkEnv.get.closureSerializer.newInstance().serialize(zeroValue) + val zeroArray = new Array[Byte](zeroBuffer.limit) + zeroBuffer.get(zeroArray) + + // When deserializing, use a lazy val to create just one instance of the serializer per task + lazy val cachedSerializer = SparkEnv.get.closureSerializer.newInstance() + def createZero() = cachedSerializer.deserialize[V](ByteBuffer.wrap(zeroArray)) + + combineByKey[V]((v: V) => func(createZero(), v), func, func, partitioner) + } + + /** + * Merge the values for each key using an associative function and a neutral "zero value" which may + * be added to the result an arbitrary number of times, and must not change the result (e.g., Nil for + * list concatenation, 0 for addition, or 1 for multiplication.). + */ + def foldByKey(zeroValue: V, numPartitions: Int)(func: (V, V) => V): RDD[(K, V)] = { + foldByKey(zeroValue, new HashPartitioner(numPartitions))(func) + } + + /** + * Merge the values for each key using an associative function and a neutral "zero value" which may + * be added to the result an arbitrary number of times, and must not change the result (e.g., Nil for + * list concatenation, 0 for addition, or 1 for multiplication.). + */ + def foldByKey(zeroValue: V)(func: (V, V) => V): RDD[(K, V)] = { + foldByKey(zeroValue, defaultPartitioner(self))(func) + } + + /** + * Merge the values for each key using an associative reduce function. This will also perform + * the merging locally on each mapper before sending results to a reducer, similarly to a + * "combiner" in MapReduce. + */ + def reduceByKey(partitioner: Partitioner, func: (V, V) => V): RDD[(K, V)] = { + combineByKey[V]((v: V) => v, func, func, partitioner) + } + + /** + * Merge the values for each key using an associative reduce function, but return the results + * immediately to the master as a Map. This will also perform the merging locally on each mapper + * before sending results to a reducer, similarly to a "combiner" in MapReduce. + */ + def reduceByKeyLocally(func: (V, V) => V): Map[K, V] = { + + if (getKeyClass().isArray) { + throw new SparkException("reduceByKeyLocally() does not support array keys") + } + + def reducePartition(iter: Iterator[(K, V)]): Iterator[JHashMap[K, V]] = { + val map = new JHashMap[K, V] + iter.foreach { case (k, v) => + val old = map.get(k) + map.put(k, if (old == null) v else func(old, v)) + } + Iterator(map) + } + + def mergeMaps(m1: JHashMap[K, V], m2: JHashMap[K, V]): JHashMap[K, V] = { + m2.foreach { case (k, v) => + val old = m1.get(k) + m1.put(k, if (old == null) v else func(old, v)) + } + m1 + } + + self.mapPartitions(reducePartition).reduce(mergeMaps) + } + + /** Alias for reduceByKeyLocally */ + def reduceByKeyToDriver(func: (V, V) => V): Map[K, V] = reduceByKeyLocally(func) + + /** Count the number of elements for each key, and return the result to the master as a Map. */ + def countByKey(): Map[K, Long] = self.map(_._1).countByValue() + + /** + * (Experimental) Approximate version of countByKey that can return a partial result if it does + * not finish within a timeout. + */ + def countByKeyApprox(timeout: Long, confidence: Double = 0.95) + : PartialResult[Map[K, BoundedDouble]] = { + self.map(_._1).countByValueApprox(timeout, confidence) + } + + /** + * Merge the values for each key using an associative reduce function. This will also perform + * the merging locally on each mapper before sending results to a reducer, similarly to a + * "combiner" in MapReduce. Output will be hash-partitioned with numPartitions partitions. + */ + def reduceByKey(func: (V, V) => V, numPartitions: Int): RDD[(K, V)] = { + reduceByKey(new HashPartitioner(numPartitions), func) + } + + /** + * Group the values for each key in the RDD into a single sequence. Allows controlling the + * partitioning of the resulting key-value pair RDD by passing a Partitioner. + */ + def groupByKey(partitioner: Partitioner): RDD[(K, Seq[V])] = { + // groupByKey shouldn't use map side combine because map side combine does not + // reduce the amount of data shuffled and requires all map side data be inserted + // into a hash table, leading to more objects in the old gen. + def createCombiner(v: V) = ArrayBuffer(v) + def mergeValue(buf: ArrayBuffer[V], v: V) = buf += v + val bufs = combineByKey[ArrayBuffer[V]]( + createCombiner _, mergeValue _, null, partitioner, mapSideCombine=false) + bufs.asInstanceOf[RDD[(K, Seq[V])]] + } + + /** + * Group the values for each key in the RDD into a single sequence. Hash-partitions the + * resulting RDD with into `numPartitions` partitions. + */ + def groupByKey(numPartitions: Int): RDD[(K, Seq[V])] = { + groupByKey(new HashPartitioner(numPartitions)) + } + + /** + * Return a copy of the RDD partitioned using the specified partitioner. + */ + def partitionBy(partitioner: Partitioner): RDD[(K, V)] = { + if (getKeyClass().isArray && partitioner.isInstanceOf[HashPartitioner]) { + throw new SparkException("Default partitioner cannot partition array keys.") + } + new ShuffledRDD[K, V, (K, V)](self, partitioner) + } + + /** + * Return an RDD containing all pairs of elements with matching keys in `this` and `other`. Each + * pair of elements will be returned as a (k, (v1, v2)) tuple, where (k, v1) is in `this` and + * (k, v2) is in `other`. Uses the given Partitioner to partition the output RDD. + */ + def join[W](other: RDD[(K, W)], partitioner: Partitioner): RDD[(K, (V, W))] = { + this.cogroup(other, partitioner).flatMapValues { case (vs, ws) => + for (v <- vs.iterator; w <- ws.iterator) yield (v, w) + } + } + + /** + * Perform a left outer join of `this` and `other`. For each element (k, v) in `this`, the + * resulting RDD will either contain all pairs (k, (v, Some(w))) for w in `other`, or the + * pair (k, (v, None)) if no elements in `other` have key k. Uses the given Partitioner to + * partition the output RDD. + */ + def leftOuterJoin[W](other: RDD[(K, W)], partitioner: Partitioner): RDD[(K, (V, Option[W]))] = { + this.cogroup(other, partitioner).flatMapValues { case (vs, ws) => + if (ws.isEmpty) { + vs.iterator.map(v => (v, None)) + } else { + for (v <- vs.iterator; w <- ws.iterator) yield (v, Some(w)) + } + } + } + + /** + * Perform a right outer join of `this` and `other`. For each element (k, w) in `other`, the + * resulting RDD will either contain all pairs (k, (Some(v), w)) for v in `this`, or the + * pair (k, (None, w)) if no elements in `this` have key k. Uses the given Partitioner to + * partition the output RDD. + */ + def rightOuterJoin[W](other: RDD[(K, W)], partitioner: Partitioner) + : RDD[(K, (Option[V], W))] = { + this.cogroup(other, partitioner).flatMapValues { case (vs, ws) => + if (vs.isEmpty) { + ws.iterator.map(w => (None, w)) + } else { + for (v <- vs.iterator; w <- ws.iterator) yield (Some(v), w) + } + } + } + + /** + * Simplified version of combineByKey that hash-partitions the resulting RDD using the + * existing partitioner/parallelism level. + */ + def combineByKey[C](createCombiner: V => C, mergeValue: (C, V) => C, mergeCombiners: (C, C) => C) + : RDD[(K, C)] = { + combineByKey(createCombiner, mergeValue, mergeCombiners, defaultPartitioner(self)) + } + + /** + * Merge the values for each key using an associative reduce function. This will also perform + * the merging locally on each mapper before sending results to a reducer, similarly to a + * "combiner" in MapReduce. Output will be hash-partitioned with the existing partitioner/ + * parallelism level. + */ + def reduceByKey(func: (V, V) => V): RDD[(K, V)] = { + reduceByKey(defaultPartitioner(self), func) + } + + /** + * Group the values for each key in the RDD into a single sequence. Hash-partitions the + * resulting RDD with the existing partitioner/parallelism level. + */ + def groupByKey(): RDD[(K, Seq[V])] = { + groupByKey(defaultPartitioner(self)) + } + + /** + * Return an RDD containing all pairs of elements with matching keys in `this` and `other`. Each + * pair of elements will be returned as a (k, (v1, v2)) tuple, where (k, v1) is in `this` and + * (k, v2) is in `other`. Performs a hash join across the cluster. + */ + def join[W](other: RDD[(K, W)]): RDD[(K, (V, W))] = { + join(other, defaultPartitioner(self, other)) + } + + /** + * Return an RDD containing all pairs of elements with matching keys in `this` and `other`. Each + * pair of elements will be returned as a (k, (v1, v2)) tuple, where (k, v1) is in `this` and + * (k, v2) is in `other`. Performs a hash join across the cluster. + */ + def join[W](other: RDD[(K, W)], numPartitions: Int): RDD[(K, (V, W))] = { + join(other, new HashPartitioner(numPartitions)) + } + + /** + * Perform a left outer join of `this` and `other`. For each element (k, v) in `this`, the + * resulting RDD will either contain all pairs (k, (v, Some(w))) for w in `other`, or the + * pair (k, (v, None)) if no elements in `other` have key k. Hash-partitions the output + * using the existing partitioner/parallelism level. + */ + def leftOuterJoin[W](other: RDD[(K, W)]): RDD[(K, (V, Option[W]))] = { + leftOuterJoin(other, defaultPartitioner(self, other)) + } + + /** + * Perform a left outer join of `this` and `other`. For each element (k, v) in `this`, the + * resulting RDD will either contain all pairs (k, (v, Some(w))) for w in `other`, or the + * pair (k, (v, None)) if no elements in `other` have key k. Hash-partitions the output + * into `numPartitions` partitions. + */ + def leftOuterJoin[W](other: RDD[(K, W)], numPartitions: Int): RDD[(K, (V, Option[W]))] = { + leftOuterJoin(other, new HashPartitioner(numPartitions)) + } + + /** + * Perform a right outer join of `this` and `other`. For each element (k, w) in `other`, the + * resulting RDD will either contain all pairs (k, (Some(v), w)) for v in `this`, or the + * pair (k, (None, w)) if no elements in `this` have key k. Hash-partitions the resulting + * RDD using the existing partitioner/parallelism level. + */ + def rightOuterJoin[W](other: RDD[(K, W)]): RDD[(K, (Option[V], W))] = { + rightOuterJoin(other, defaultPartitioner(self, other)) + } + + /** + * Perform a right outer join of `this` and `other`. For each element (k, w) in `other`, the + * resulting RDD will either contain all pairs (k, (Some(v), w)) for v in `this`, or the + * pair (k, (None, w)) if no elements in `this` have key k. Hash-partitions the resulting + * RDD into the given number of partitions. + */ + def rightOuterJoin[W](other: RDD[(K, W)], numPartitions: Int): RDD[(K, (Option[V], W))] = { + rightOuterJoin(other, new HashPartitioner(numPartitions)) + } + + /** + * Return the key-value pairs in this RDD to the master as a Map. + */ + def collectAsMap(): Map[K, V] = { + val data = self.toArray() + val map = new mutable.HashMap[K, V] + map.sizeHint(data.length) + data.foreach { case (k, v) => map.put(k, v) } + map + } + + /** + * Pass each value in the key-value pair RDD through a map function without changing the keys; + * this also retains the original RDD's partitioning. + */ + def mapValues[U](f: V => U): RDD[(K, U)] = { + val cleanF = self.context.clean(f) + new MappedValuesRDD(self, cleanF) + } + + /** + * Pass each value in the key-value pair RDD through a flatMap function without changing the + * keys; this also retains the original RDD's partitioning. + */ + def flatMapValues[U](f: V => TraversableOnce[U]): RDD[(K, U)] = { + val cleanF = self.context.clean(f) + new FlatMappedValuesRDD(self, cleanF) + } + + /** + * For each key k in `this` or `other`, return a resulting RDD that contains a tuple with the + * list of values for that key in `this` as well as `other`. + */ + def cogroup[W](other: RDD[(K, W)], partitioner: Partitioner): RDD[(K, (Seq[V], Seq[W]))] = { + if (partitioner.isInstanceOf[HashPartitioner] && getKeyClass().isArray) { + throw new SparkException("Default partitioner cannot partition array keys.") + } + val cg = new CoGroupedRDD[K](Seq(self, other), partitioner) + val prfs = new PairRDDFunctions[K, Seq[Seq[_]]](cg)(classManifest[K], Manifests.seqSeqManifest) + prfs.mapValues { case Seq(vs, ws) => + (vs.asInstanceOf[Seq[V]], ws.asInstanceOf[Seq[W]]) + } + } + + /** + * For each key k in `this` or `other1` or `other2`, return a resulting RDD that contains a + * tuple with the list of values for that key in `this`, `other1` and `other2`. + */ + def cogroup[W1, W2](other1: RDD[(K, W1)], other2: RDD[(K, W2)], partitioner: Partitioner) + : RDD[(K, (Seq[V], Seq[W1], Seq[W2]))] = { + if (partitioner.isInstanceOf[HashPartitioner] && getKeyClass().isArray) { + throw new SparkException("Default partitioner cannot partition array keys.") + } + val cg = new CoGroupedRDD[K](Seq(self, other1, other2), partitioner) + val prfs = new PairRDDFunctions[K, Seq[Seq[_]]](cg)(classManifest[K], Manifests.seqSeqManifest) + prfs.mapValues { case Seq(vs, w1s, w2s) => + (vs.asInstanceOf[Seq[V]], w1s.asInstanceOf[Seq[W1]], w2s.asInstanceOf[Seq[W2]]) + } + } + + /** + * For each key k in `this` or `other`, return a resulting RDD that contains a tuple with the + * list of values for that key in `this` as well as `other`. + */ + def cogroup[W](other: RDD[(K, W)]): RDD[(K, (Seq[V], Seq[W]))] = { + cogroup(other, defaultPartitioner(self, other)) + } + + /** + * For each key k in `this` or `other1` or `other2`, return a resulting RDD that contains a + * tuple with the list of values for that key in `this`, `other1` and `other2`. + */ + def cogroup[W1, W2](other1: RDD[(K, W1)], other2: RDD[(K, W2)]) + : RDD[(K, (Seq[V], Seq[W1], Seq[W2]))] = { + cogroup(other1, other2, defaultPartitioner(self, other1, other2)) + } + + /** + * For each key k in `this` or `other`, return a resulting RDD that contains a tuple with the + * list of values for that key in `this` as well as `other`. + */ + def cogroup[W](other: RDD[(K, W)], numPartitions: Int): RDD[(K, (Seq[V], Seq[W]))] = { + cogroup(other, new HashPartitioner(numPartitions)) + } + + /** + * For each key k in `this` or `other1` or `other2`, return a resulting RDD that contains a + * tuple with the list of values for that key in `this`, `other1` and `other2`. + */ + def cogroup[W1, W2](other1: RDD[(K, W1)], other2: RDD[(K, W2)], numPartitions: Int) + : RDD[(K, (Seq[V], Seq[W1], Seq[W2]))] = { + cogroup(other1, other2, new HashPartitioner(numPartitions)) + } + + /** Alias for cogroup. */ + def groupWith[W](other: RDD[(K, W)]): RDD[(K, (Seq[V], Seq[W]))] = { + cogroup(other, defaultPartitioner(self, other)) + } + + /** Alias for cogroup. */ + def groupWith[W1, W2](other1: RDD[(K, W1)], other2: RDD[(K, W2)]) + : RDD[(K, (Seq[V], Seq[W1], Seq[W2]))] = { + cogroup(other1, other2, defaultPartitioner(self, other1, other2)) + } + + /** + * Return an RDD with the pairs from `this` whose keys are not in `other`. + * + * Uses `this` partitioner/partition size, because even if `other` is huge, the resulting + * RDD will be <= us. + */ + def subtractByKey[W: ClassManifest](other: RDD[(K, W)]): RDD[(K, V)] = + subtractByKey(other, self.partitioner.getOrElse(new HashPartitioner(self.partitions.size))) + + /** Return an RDD with the pairs from `this` whose keys are not in `other`. */ + def subtractByKey[W: ClassManifest](other: RDD[(K, W)], numPartitions: Int): RDD[(K, V)] = + subtractByKey(other, new HashPartitioner(numPartitions)) + + /** Return an RDD with the pairs from `this` whose keys are not in `other`. */ + def subtractByKey[W: ClassManifest](other: RDD[(K, W)], p: Partitioner): RDD[(K, V)] = + new SubtractedRDD[K, V, W](self, other, p) + + /** + * Return the list of values in the RDD for key `key`. This operation is done efficiently if the + * RDD has a known partitioner by only searching the partition that the key maps to. + */ + def lookup(key: K): Seq[V] = { + self.partitioner match { + case Some(p) => + val index = p.getPartition(key) + def process(it: Iterator[(K, V)]): Seq[V] = { + val buf = new ArrayBuffer[V] + for ((k, v) <- it if k == key) { + buf += v + } + buf + } + val res = self.context.runJob(self, process _, Array(index), false) + res(0) + case None => + self.filter(_._1 == key).map(_._2).collect() + } + } + + /** + * Output the RDD to any Hadoop-supported file system, using a Hadoop `OutputFormat` class + * supporting the key and value types K and V in this RDD. + */ + def saveAsHadoopFile[F <: OutputFormat[K, V]](path: String)(implicit fm: ClassManifest[F]) { + saveAsHadoopFile(path, getKeyClass, getValueClass, fm.erasure.asInstanceOf[Class[F]]) + } + + /** + * Output the RDD to any Hadoop-supported file system, using a Hadoop `OutputFormat` class + * supporting the key and value types K and V in this RDD. Compress the result with the + * supplied codec. + */ + def saveAsHadoopFile[F <: OutputFormat[K, V]]( + path: String, codec: Class[_ <: CompressionCodec]) (implicit fm: ClassManifest[F]) { + saveAsHadoopFile(path, getKeyClass, getValueClass, fm.erasure.asInstanceOf[Class[F]], codec) + } + + /** + * Output the RDD to any Hadoop-supported file system, using a new Hadoop API `OutputFormat` + * (mapreduce.OutputFormat) object supporting the key and value types K and V in this RDD. + */ + def saveAsNewAPIHadoopFile[F <: NewOutputFormat[K, V]](path: String)(implicit fm: ClassManifest[F]) { + saveAsNewAPIHadoopFile(path, getKeyClass, getValueClass, fm.erasure.asInstanceOf[Class[F]]) + } + + /** + * Output the RDD to any Hadoop-supported file system, using a new Hadoop API `OutputFormat` + * (mapreduce.OutputFormat) object supporting the key and value types K and V in this RDD. + */ + def saveAsNewAPIHadoopFile( + path: String, + keyClass: Class[_], + valueClass: Class[_], + outputFormatClass: Class[_ <: NewOutputFormat[_, _]], + conf: Configuration = self.context.hadoopConfiguration) { + val job = new NewAPIHadoopJob(conf) + job.setOutputKeyClass(keyClass) + job.setOutputValueClass(valueClass) + val wrappedConf = new SerializableWritable(job.getConfiguration) + NewFileOutputFormat.setOutputPath(job, new Path(path)) + val formatter = new SimpleDateFormat("yyyyMMddHHmm") + val jobtrackerID = formatter.format(new Date()) + val stageId = self.id + def writeShard(context: TaskContext, iter: Iterator[(K,V)]): Int = { + // Hadoop wants a 32-bit task attempt ID, so if ours is bigger than Int.MaxValue, roll it + // around by taking a mod. We expect that no task will be attempted 2 billion times. + val attemptNumber = (context.attemptId % Int.MaxValue).toInt + /* "reduce task" */ + val attemptId = newTaskAttemptID(jobtrackerID, stageId, false, context.splitId, attemptNumber) + val hadoopContext = newTaskAttemptContext(wrappedConf.value, attemptId) + val format = outputFormatClass.newInstance + val committer = format.getOutputCommitter(hadoopContext) + committer.setupTask(hadoopContext) + val writer = format.getRecordWriter(hadoopContext).asInstanceOf[NewRecordWriter[K,V]] + while (iter.hasNext) { + val (k, v) = iter.next() + writer.write(k, v) + } + writer.close(hadoopContext) + committer.commitTask(hadoopContext) + return 1 + } + val jobFormat = outputFormatClass.newInstance + /* apparently we need a TaskAttemptID to construct an OutputCommitter; + * however we're only going to use this local OutputCommitter for + * setupJob/commitJob, so we just use a dummy "map" task. + */ + val jobAttemptId = newTaskAttemptID(jobtrackerID, stageId, true, 0, 0) + val jobTaskContext = newTaskAttemptContext(wrappedConf.value, jobAttemptId) + val jobCommitter = jobFormat.getOutputCommitter(jobTaskContext) + jobCommitter.setupJob(jobTaskContext) + val count = self.context.runJob(self, writeShard _).sum + jobCommitter.commitJob(jobTaskContext) + jobCommitter.cleanupJob(jobTaskContext) + } + + /** + * Output the RDD to any Hadoop-supported file system, using a Hadoop `OutputFormat` class + * supporting the key and value types K and V in this RDD. Compress with the supplied codec. + */ + def saveAsHadoopFile( + path: String, + keyClass: Class[_], + valueClass: Class[_], + outputFormatClass: Class[_ <: OutputFormat[_, _]], + codec: Class[_ <: CompressionCodec]) { + saveAsHadoopFile(path, keyClass, valueClass, outputFormatClass, + new JobConf(self.context.hadoopConfiguration), Some(codec)) + } + + /** + * Output the RDD to any Hadoop-supported file system, using a Hadoop `OutputFormat` class + * supporting the key and value types K and V in this RDD. + */ + def saveAsHadoopFile( + path: String, + keyClass: Class[_], + valueClass: Class[_], + outputFormatClass: Class[_ <: OutputFormat[_, _]], + conf: JobConf = new JobConf(self.context.hadoopConfiguration), + codec: Option[Class[_ <: CompressionCodec]] = None) { + conf.setOutputKeyClass(keyClass) + conf.setOutputValueClass(valueClass) + // conf.setOutputFormat(outputFormatClass) // Doesn't work in Scala 2.9 due to what may be a generics bug + conf.set("mapred.output.format.class", outputFormatClass.getName) + for (c <- codec) { + conf.setCompressMapOutput(true) + conf.set("mapred.output.compress", "true") + conf.setMapOutputCompressorClass(c) + conf.set("mapred.output.compression.codec", c.getCanonicalName) + conf.set("mapred.output.compression.type", CompressionType.BLOCK.toString) + } + conf.setOutputCommitter(classOf[FileOutputCommitter]) + FileOutputFormat.setOutputPath(conf, SparkHadoopWriter.createPathFromString(path, conf)) + saveAsHadoopDataset(conf) + } + + /** + * Output the RDD to any Hadoop-supported storage system, using a Hadoop JobConf object for + * that storage system. The JobConf should set an OutputFormat and any output paths required + * (e.g. a table name to write to) in the same way as it would be configured for a Hadoop + * MapReduce job. + */ + def saveAsHadoopDataset(conf: JobConf) { + val outputFormatClass = conf.getOutputFormat + val keyClass = conf.getOutputKeyClass + val valueClass = conf.getOutputValueClass + if (outputFormatClass == null) { + throw new SparkException("Output format class not set") + } + if (keyClass == null) { + throw new SparkException("Output key class not set") + } + if (valueClass == null) { + throw new SparkException("Output value class not set") + } + + logInfo("Saving as hadoop file of type (" + keyClass.getSimpleName+ ", " + valueClass.getSimpleName+ ")") + + val writer = new SparkHadoopWriter(conf) + writer.preSetup() + + def writeToFile(context: TaskContext, iter: Iterator[(K, V)]) { + // Hadoop wants a 32-bit task attempt ID, so if ours is bigger than Int.MaxValue, roll it + // around by taking a mod. We expect that no task will be attempted 2 billion times. + val attemptNumber = (context.attemptId % Int.MaxValue).toInt + + writer.setup(context.stageId, context.splitId, attemptNumber) + writer.open() + + var count = 0 + while(iter.hasNext) { + val record = iter.next() + count += 1 + writer.write(record._1.asInstanceOf[AnyRef], record._2.asInstanceOf[AnyRef]) + } + + writer.close() + writer.commit() + } + + self.context.runJob(self, writeToFile _) + writer.commitJob() + writer.cleanup() + } + + /** + * Return an RDD with the keys of each tuple. + */ + def keys: RDD[K] = self.map(_._1) + + /** + * Return an RDD with the values of each tuple. + */ + def values: RDD[V] = self.map(_._2) + + private[spark] def getKeyClass() = implicitly[ClassManifest[K]].erasure + + private[spark] def getValueClass() = implicitly[ClassManifest[V]].erasure +} + +private[spark] object Manifests { + val seqSeqManifest = classManifest[Seq[Seq[_]]] +} diff --git a/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala index 8db3611054..6dbd4309aa 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala @@ -23,6 +23,8 @@ import scala.collection.Map import org.apache.spark._ import java.io._ import scala.Serializable +import org.apache.spark.serializer.JavaSerializer +import org.apache.spark.util.Utils private[spark] class ParallelCollectionPartition[T: ClassManifest]( var rddId: Long, diff --git a/core/src/main/scala/org/apache/spark/rdd/PartitionPruningRDD.scala b/core/src/main/scala/org/apache/spark/rdd/PartitionPruningRDD.scala index 8e79a5c874..165cd412fc 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PartitionPruningRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PartitionPruningRDD.scala @@ -17,7 +17,7 @@ package org.apache.spark.rdd -import org.apache.spark.{NarrowDependency, RDD, SparkEnv, Partition, TaskContext} +import org.apache.spark.{NarrowDependency, SparkEnv, Partition, TaskContext} class PartitionPruningRDDPartition(idx: Int, val parentSplit: Partition) extends Partition { diff --git a/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala index 98498d5ddf..d5304ab0ae 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala @@ -25,7 +25,7 @@ import scala.collection.JavaConversions._ import scala.collection.mutable.ArrayBuffer import scala.io.Source -import org.apache.spark.{RDD, SparkEnv, Partition, TaskContext} +import org.apache.spark.{SparkEnv, Partition, TaskContext} import org.apache.spark.broadcast.Broadcast diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala new file mode 100644 index 0000000000..e143ecd096 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -0,0 +1,942 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rdd + +import java.util.Random + +import scala.collection.Map +import scala.collection.JavaConversions.mapAsScalaMap +import scala.collection.mutable.ArrayBuffer + +import org.apache.hadoop.io.BytesWritable +import org.apache.hadoop.io.compress.CompressionCodec +import org.apache.hadoop.io.NullWritable +import org.apache.hadoop.io.Text +import org.apache.hadoop.mapred.TextOutputFormat + +import it.unimi.dsi.fastutil.objects.{Object2LongOpenHashMap => OLMap} + +import org.apache.spark.Partitioner._ +import org.apache.spark.api.java.JavaRDD +import org.apache.spark.partial.BoundedDouble +import org.apache.spark.partial.CountEvaluator +import org.apache.spark.partial.GroupedCountEvaluator +import org.apache.spark.partial.PartialResult +import org.apache.spark.storage.StorageLevel +import org.apache.spark.util.{Utils, BoundedPriorityQueue} + +import org.apache.spark.SparkContext._ +import org.apache.spark._ + +/** + * A Resilient Distributed Dataset (RDD), the basic abstraction in Spark. Represents an immutable, + * partitioned collection of elements that can be operated on in parallel. This class contains the + * basic operations available on all RDDs, such as `map`, `filter`, and `persist`. In addition, + * [[org.apache.spark.rdd.PairRDDFunctions]] contains operations available only on RDDs of key-value + * pairs, such as `groupByKey` and `join`; [[org.apache.spark.rdd.DoubleRDDFunctions]] contains + * operations available only on RDDs of Doubles; and [[org.apache.spark.rdd.SequenceFileRDDFunctions]] + * contains operations available on RDDs that can be saved as SequenceFiles. These operations are + * automatically available on any RDD of the right type (e.g. RDD[(Int, Int)] through implicit + * conversions when you `import org.apache.spark.SparkContext._`. + * + * Internally, each RDD is characterized by five main properties: + * + * - A list of partitions + * - A function for computing each split + * - A list of dependencies on other RDDs + * - Optionally, a Partitioner for key-value RDDs (e.g. to say that the RDD is hash-partitioned) + * - Optionally, a list of preferred locations to compute each split on (e.g. block locations for + * an HDFS file) + * + * All of the scheduling and execution in Spark is done based on these methods, allowing each RDD + * to implement its own way of computing itself. Indeed, users can implement custom RDDs (e.g. for + * reading data from a new storage system) by overriding these functions. Please refer to the + * [[http://www.cs.berkeley.edu/~matei/papers/2012/nsdi_spark.pdf Spark paper]] for more details + * on RDD internals. + */ +abstract class RDD[T: ClassManifest]( + @transient private var sc: SparkContext, + @transient private var deps: Seq[Dependency[_]] + ) extends Serializable with Logging { + + /** Construct an RDD with just a one-to-one dependency on one parent */ + def this(@transient oneParent: RDD[_]) = + this(oneParent.context , List(new OneToOneDependency(oneParent))) + + // ======================================================================= + // Methods that should be implemented by subclasses of RDD + // ======================================================================= + + /** Implemented by subclasses to compute a given partition. */ + def compute(split: Partition, context: TaskContext): Iterator[T] + + /** + * Implemented by subclasses to return the set of partitions in this RDD. This method will only + * be called once, so it is safe to implement a time-consuming computation in it. + */ + protected def getPartitions: Array[Partition] + + /** + * Implemented by subclasses to return how this RDD depends on parent RDDs. This method will only + * be called once, so it is safe to implement a time-consuming computation in it. + */ + protected def getDependencies: Seq[Dependency[_]] = deps + + /** Optionally overridden by subclasses to specify placement preferences. */ + protected def getPreferredLocations(split: Partition): Seq[String] = Nil + + /** Optionally overridden by subclasses to specify how they are partitioned. */ + val partitioner: Option[Partitioner] = None + + // ======================================================================= + // Methods and fields available on all RDDs + // ======================================================================= + + /** The SparkContext that created this RDD. */ + def sparkContext: SparkContext = sc + + /** A unique ID for this RDD (within its SparkContext). */ + val id: Int = sc.newRddId() + + /** A friendly name for this RDD */ + var name: String = null + + /** Assign a name to this RDD */ + def setName(_name: String) = { + name = _name + this + } + + /** User-defined generator of this RDD*/ + var generator = Utils.getCallSiteInfo.firstUserClass + + /** Reset generator*/ + def setGenerator(_generator: String) = { + generator = _generator + } + + /** + * Set this RDD's storage level to persist its values across operations after the first time + * it is computed. This can only be used to assign a new storage level if the RDD does not + * have a storage level set yet.. + */ + def persist(newLevel: StorageLevel): RDD[T] = { + // TODO: Handle changes of StorageLevel + if (storageLevel != StorageLevel.NONE && newLevel != storageLevel) { + throw new UnsupportedOperationException( + "Cannot change storage level of an RDD after it was already assigned a level") + } + storageLevel = newLevel + // Register the RDD with the SparkContext + sc.persistentRdds(id) = this + this + } + + /** Persist this RDD with the default storage level (`MEMORY_ONLY`). */ + def persist(): RDD[T] = persist(StorageLevel.MEMORY_ONLY) + + /** Persist this RDD with the default storage level (`MEMORY_ONLY`). */ + def cache(): RDD[T] = persist() + + /** + * Mark the RDD as non-persistent, and remove all blocks for it from memory and disk. + * + * @param blocking Whether to block until all blocks are deleted. + * @return This RDD. + */ + def unpersist(blocking: Boolean = true): RDD[T] = { + logInfo("Removing RDD " + id + " from persistence list") + sc.env.blockManager.master.removeRdd(id, blocking) + sc.persistentRdds.remove(id) + storageLevel = StorageLevel.NONE + this + } + + /** Get the RDD's current storage level, or StorageLevel.NONE if none is set. */ + def getStorageLevel = storageLevel + + // Our dependencies and partitions will be gotten by calling subclass's methods below, and will + // be overwritten when we're checkpointed + private var dependencies_ : Seq[Dependency[_]] = null + @transient private var partitions_ : Array[Partition] = null + + /** An Option holding our checkpoint RDD, if we are checkpointed */ + private def checkpointRDD: Option[RDD[T]] = checkpointData.flatMap(_.checkpointRDD) + + /** + * Get the list of dependencies of this RDD, taking into account whether the + * RDD is checkpointed or not. + */ + final def dependencies: Seq[Dependency[_]] = { + checkpointRDD.map(r => List(new OneToOneDependency(r))).getOrElse { + if (dependencies_ == null) { + dependencies_ = getDependencies + } + dependencies_ + } + } + + /** + * Get the array of partitions of this RDD, taking into account whether the + * RDD is checkpointed or not. + */ + final def partitions: Array[Partition] = { + checkpointRDD.map(_.partitions).getOrElse { + if (partitions_ == null) { + partitions_ = getPartitions + } + partitions_ + } + } + + /** + * Get the preferred locations of a partition (as hostnames), taking into account whether the + * RDD is checkpointed. + */ + final def preferredLocations(split: Partition): Seq[String] = { + checkpointRDD.map(_.getPreferredLocations(split)).getOrElse { + getPreferredLocations(split) + } + } + + /** + * Internal method to this RDD; will read from cache if applicable, or otherwise compute it. + * This should ''not'' be called by users directly, but is available for implementors of custom + * subclasses of RDD. + */ + final def iterator(split: Partition, context: TaskContext): Iterator[T] = { + if (storageLevel != StorageLevel.NONE) { + SparkEnv.get.cacheManager.getOrCompute(this, split, context, storageLevel) + } else { + computeOrReadCheckpoint(split, context) + } + } + + /** + * Compute an RDD partition or read it from a checkpoint if the RDD is checkpointing. + */ + private[spark] def computeOrReadCheckpoint(split: Partition, context: TaskContext): Iterator[T] = { + if (isCheckpointed) { + firstParent[T].iterator(split, context) + } else { + compute(split, context) + } + } + + // Transformations (return a new RDD) + + /** + * Return a new RDD by applying a function to all elements of this RDD. + */ + def map[U: ClassManifest](f: T => U): RDD[U] = new MappedRDD(this, sc.clean(f)) + + /** + * Return a new RDD by first applying a function to all elements of this + * RDD, and then flattening the results. + */ + def flatMap[U: ClassManifest](f: T => TraversableOnce[U]): RDD[U] = + new FlatMappedRDD(this, sc.clean(f)) + + /** + * Return a new RDD containing only the elements that satisfy a predicate. + */ + def filter(f: T => Boolean): RDD[T] = new FilteredRDD(this, sc.clean(f)) + + /** + * Return a new RDD containing the distinct elements in this RDD. + */ + def distinct(numPartitions: Int): RDD[T] = + map(x => (x, null)).reduceByKey((x, y) => x, numPartitions).map(_._1) + + def distinct(): RDD[T] = distinct(partitions.size) + + /** + * Return a new RDD that is reduced into `numPartitions` partitions. + */ + def coalesce(numPartitions: Int, shuffle: Boolean = false): RDD[T] = { + if (shuffle) { + // include a shuffle step so that our upstream tasks are still distributed + new CoalescedRDD( + new ShuffledRDD[T, Null, (T, Null)](map(x => (x, null)), + new HashPartitioner(numPartitions)), + numPartitions).keys + } else { + new CoalescedRDD(this, numPartitions) + } + } + + /** + * Return a sampled subset of this RDD. + */ + def sample(withReplacement: Boolean, fraction: Double, seed: Int): RDD[T] = + new SampledRDD(this, withReplacement, fraction, seed) + + def takeSample(withReplacement: Boolean, num: Int, seed: Int): Array[T] = { + var fraction = 0.0 + var total = 0 + val multiplier = 3.0 + val initialCount = this.count() + var maxSelected = 0 + + if (num < 0) { + throw new IllegalArgumentException("Negative number of elements requested") + } + + if (initialCount > Integer.MAX_VALUE - 1) { + maxSelected = Integer.MAX_VALUE - 1 + } else { + maxSelected = initialCount.toInt + } + + if (num > initialCount && !withReplacement) { + total = maxSelected + fraction = multiplier * (maxSelected + 1) / initialCount + } else { + fraction = multiplier * (num + 1) / initialCount + total = num + } + + val rand = new Random(seed) + var samples = this.sample(withReplacement, fraction, rand.nextInt()).collect() + + // If the first sample didn't turn out large enough, keep trying to take samples; + // this shouldn't happen often because we use a big multiplier for thei initial size + while (samples.length < total) { + samples = this.sample(withReplacement, fraction, rand.nextInt()).collect() + } + + Utils.randomizeInPlace(samples, rand).take(total) + } + + /** + * Return the union of this RDD and another one. Any identical elements will appear multiple + * times (use `.distinct()` to eliminate them). + */ + def union(other: RDD[T]): RDD[T] = new UnionRDD(sc, Array(this, other)) + + /** + * Return the union of this RDD and another one. Any identical elements will appear multiple + * times (use `.distinct()` to eliminate them). + */ + def ++(other: RDD[T]): RDD[T] = this.union(other) + + /** + * Return an RDD created by coalescing all elements within each partition into an array. + */ + def glom(): RDD[Array[T]] = new GlommedRDD(this) + + /** + * Return the Cartesian product of this RDD and another one, that is, the RDD of all pairs of + * elements (a, b) where a is in `this` and b is in `other`. + */ + def cartesian[U: ClassManifest](other: RDD[U]): RDD[(T, U)] = new CartesianRDD(sc, this, other) + + /** + * Return an RDD of grouped items. + */ + def groupBy[K: ClassManifest](f: T => K): RDD[(K, Seq[T])] = + groupBy[K](f, defaultPartitioner(this)) + + /** + * Return an RDD of grouped elements. Each group consists of a key and a sequence of elements + * mapping to that key. + */ + def groupBy[K: ClassManifest](f: T => K, numPartitions: Int): RDD[(K, Seq[T])] = + groupBy(f, new HashPartitioner(numPartitions)) + + /** + * Return an RDD of grouped items. + */ + def groupBy[K: ClassManifest](f: T => K, p: Partitioner): RDD[(K, Seq[T])] = { + val cleanF = sc.clean(f) + this.map(t => (cleanF(t), t)).groupByKey(p) + } + + /** + * Return an RDD created by piping elements to a forked external process. + */ + def pipe(command: String): RDD[String] = new PipedRDD(this, command) + + /** + * Return an RDD created by piping elements to a forked external process. + */ + def pipe(command: String, env: Map[String, String]): RDD[String] = + new PipedRDD(this, command, env) + + + /** + * Return an RDD created by piping elements to a forked external process. + * The print behavior can be customized by providing two functions. + * + * @param command command to run in forked process. + * @param env environment variables to set. + * @param printPipeContext Before piping elements, this function is called as an oppotunity + * to pipe context data. Print line function (like out.println) will be + * passed as printPipeContext's parameter. + * @param printRDDElement Use this function to customize how to pipe elements. This function + * will be called with each RDD element as the 1st parameter, and the + * print line function (like out.println()) as the 2nd parameter. + * An example of pipe the RDD data of groupBy() in a streaming way, + * instead of constructing a huge String to concat all the elements: + * def printRDDElement(record:(String, Seq[String]), f:String=>Unit) = + * for (e <- record._2){f(e)} + * @return the result RDD + */ + def pipe( + command: Seq[String], + env: Map[String, String] = Map(), + printPipeContext: (String => Unit) => Unit = null, + printRDDElement: (T, String => Unit) => Unit = null): RDD[String] = + new PipedRDD(this, command, env, + if (printPipeContext ne null) sc.clean(printPipeContext) else null, + if (printRDDElement ne null) sc.clean(printRDDElement) else null) + + /** + * Return a new RDD by applying a function to each partition of this RDD. + */ + def mapPartitions[U: ClassManifest](f: Iterator[T] => Iterator[U], + preservesPartitioning: Boolean = false): RDD[U] = + new MapPartitionsRDD(this, sc.clean(f), preservesPartitioning) + + /** + * Return a new RDD by applying a function to each partition of this RDD, while tracking the index + * of the original partition. + */ + def mapPartitionsWithIndex[U: ClassManifest]( + f: (Int, Iterator[T]) => Iterator[U], + preservesPartitioning: Boolean = false): RDD[U] = + new MapPartitionsWithIndexRDD(this, sc.clean(f), preservesPartitioning) + + /** + * Return a new RDD by applying a function to each partition of this RDD, while tracking the index + * of the original partition. + */ + @deprecated("use mapPartitionsWithIndex", "0.7.0") + def mapPartitionsWithSplit[U: ClassManifest]( + f: (Int, Iterator[T]) => Iterator[U], + preservesPartitioning: Boolean = false): RDD[U] = + new MapPartitionsWithIndexRDD(this, sc.clean(f), preservesPartitioning) + + /** + * Maps f over this RDD, where f takes an additional parameter of type A. This + * additional parameter is produced by constructA, which is called in each + * partition with the index of that partition. + */ + def mapWith[A: ClassManifest, U: ClassManifest](constructA: Int => A, preservesPartitioning: Boolean = false) + (f:(T, A) => U): RDD[U] = { + def iterF(index: Int, iter: Iterator[T]): Iterator[U] = { + val a = constructA(index) + iter.map(t => f(t, a)) + } + new MapPartitionsWithIndexRDD(this, sc.clean(iterF _), preservesPartitioning) + } + + /** + * FlatMaps f over this RDD, where f takes an additional parameter of type A. This + * additional parameter is produced by constructA, which is called in each + * partition with the index of that partition. + */ + def flatMapWith[A: ClassManifest, U: ClassManifest](constructA: Int => A, preservesPartitioning: Boolean = false) + (f:(T, A) => Seq[U]): RDD[U] = { + def iterF(index: Int, iter: Iterator[T]): Iterator[U] = { + val a = constructA(index) + iter.flatMap(t => f(t, a)) + } + new MapPartitionsWithIndexRDD(this, sc.clean(iterF _), preservesPartitioning) + } + + /** + * Applies f to each element of this RDD, where f takes an additional parameter of type A. + * This additional parameter is produced by constructA, which is called in each + * partition with the index of that partition. + */ + def foreachWith[A: ClassManifest](constructA: Int => A) + (f:(T, A) => Unit) { + def iterF(index: Int, iter: Iterator[T]): Iterator[T] = { + val a = constructA(index) + iter.map(t => {f(t, a); t}) + } + (new MapPartitionsWithIndexRDD(this, sc.clean(iterF _), true)).foreach(_ => {}) + } + + /** + * Filters this RDD with p, where p takes an additional parameter of type A. This + * additional parameter is produced by constructA, which is called in each + * partition with the index of that partition. + */ + def filterWith[A: ClassManifest](constructA: Int => A) + (p:(T, A) => Boolean): RDD[T] = { + def iterF(index: Int, iter: Iterator[T]): Iterator[T] = { + val a = constructA(index) + iter.filter(t => p(t, a)) + } + new MapPartitionsWithIndexRDD(this, sc.clean(iterF _), true) + } + + /** + * Zips this RDD with another one, returning key-value pairs with the first element in each RDD, + * second element in each RDD, etc. Assumes that the two RDDs have the *same number of + * partitions* and the *same number of elements in each partition* (e.g. one was made through + * a map on the other). + */ + def zip[U: ClassManifest](other: RDD[U]): RDD[(T, U)] = new ZippedRDD(sc, this, other) + + /** + * Zip this RDD's partitions with one (or more) RDD(s) and return a new RDD by + * applying a function to the zipped partitions. Assumes that all the RDDs have the + * *same number of partitions*, but does *not* require them to have the same number + * of elements in each partition. + */ + def zipPartitions[B: ClassManifest, V: ClassManifest] + (rdd2: RDD[B]) + (f: (Iterator[T], Iterator[B]) => Iterator[V]): RDD[V] = + new ZippedPartitionsRDD2(sc, sc.clean(f), this, rdd2) + + def zipPartitions[B: ClassManifest, C: ClassManifest, V: ClassManifest] + (rdd2: RDD[B], rdd3: RDD[C]) + (f: (Iterator[T], Iterator[B], Iterator[C]) => Iterator[V]): RDD[V] = + new ZippedPartitionsRDD3(sc, sc.clean(f), this, rdd2, rdd3) + + def zipPartitions[B: ClassManifest, C: ClassManifest, D: ClassManifest, V: ClassManifest] + (rdd2: RDD[B], rdd3: RDD[C], rdd4: RDD[D]) + (f: (Iterator[T], Iterator[B], Iterator[C], Iterator[D]) => Iterator[V]): RDD[V] = + new ZippedPartitionsRDD4(sc, sc.clean(f), this, rdd2, rdd3, rdd4) + + + // Actions (launch a job to return a value to the user program) + + /** + * Applies a function f to all elements of this RDD. + */ + def foreach(f: T => Unit) { + val cleanF = sc.clean(f) + sc.runJob(this, (iter: Iterator[T]) => iter.foreach(cleanF)) + } + + /** + * Applies a function f to each partition of this RDD. + */ + def foreachPartition(f: Iterator[T] => Unit) { + val cleanF = sc.clean(f) + sc.runJob(this, (iter: Iterator[T]) => cleanF(iter)) + } + + /** + * Return an array that contains all of the elements in this RDD. + */ + def collect(): Array[T] = { + val results = sc.runJob(this, (iter: Iterator[T]) => iter.toArray) + Array.concat(results: _*) + } + + /** + * Return an array that contains all of the elements in this RDD. + */ + def toArray(): Array[T] = collect() + + /** + * Return an RDD that contains all matching values by applying `f`. + */ + def collect[U: ClassManifest](f: PartialFunction[T, U]): RDD[U] = { + filter(f.isDefinedAt).map(f) + } + + /** + * Return an RDD with the elements from `this` that are not in `other`. + * + * Uses `this` partitioner/partition size, because even if `other` is huge, the resulting + * RDD will be <= us. + */ + def subtract(other: RDD[T]): RDD[T] = + subtract(other, partitioner.getOrElse(new HashPartitioner(partitions.size))) + + /** + * Return an RDD with the elements from `this` that are not in `other`. + */ + def subtract(other: RDD[T], numPartitions: Int): RDD[T] = + subtract(other, new HashPartitioner(numPartitions)) + + /** + * Return an RDD with the elements from `this` that are not in `other`. + */ + def subtract(other: RDD[T], p: Partitioner): RDD[T] = { + if (partitioner == Some(p)) { + // Our partitioner knows how to handle T (which, since we have a partitioner, is + // really (K, V)) so make a new Partitioner that will de-tuple our fake tuples + val p2 = new Partitioner() { + override def numPartitions = p.numPartitions + override def getPartition(k: Any) = p.getPartition(k.asInstanceOf[(Any, _)]._1) + } + // Unfortunately, since we're making a new p2, we'll get ShuffleDependencies + // anyway, and when calling .keys, will not have a partitioner set, even though + // the SubtractedRDD will, thanks to p2's de-tupled partitioning, already be + // partitioned by the right/real keys (e.g. p). + this.map(x => (x, null)).subtractByKey(other.map((_, null)), p2).keys + } else { + this.map(x => (x, null)).subtractByKey(other.map((_, null)), p).keys + } + } + + /** + * Reduces the elements of this RDD using the specified commutative and associative binary operator. + */ + def reduce(f: (T, T) => T): T = { + val cleanF = sc.clean(f) + val reducePartition: Iterator[T] => Option[T] = iter => { + if (iter.hasNext) { + Some(iter.reduceLeft(cleanF)) + } else { + None + } + } + var jobResult: Option[T] = None + val mergeResult = (index: Int, taskResult: Option[T]) => { + if (taskResult != None) { + jobResult = jobResult match { + case Some(value) => Some(f(value, taskResult.get)) + case None => taskResult + } + } + } + sc.runJob(this, reducePartition, mergeResult) + // Get the final result out of our Option, or throw an exception if the RDD was empty + jobResult.getOrElse(throw new UnsupportedOperationException("empty collection")) + } + + /** + * Aggregate the elements of each partition, and then the results for all the partitions, using a + * given associative function and a neutral "zero value". The function op(t1, t2) is allowed to + * modify t1 and return it as its result value to avoid object allocation; however, it should not + * modify t2. + */ + def fold(zeroValue: T)(op: (T, T) => T): T = { + // Clone the zero value since we will also be serializing it as part of tasks + var jobResult = Utils.clone(zeroValue, sc.env.closureSerializer.newInstance()) + val cleanOp = sc.clean(op) + val foldPartition = (iter: Iterator[T]) => iter.fold(zeroValue)(cleanOp) + val mergeResult = (index: Int, taskResult: T) => jobResult = op(jobResult, taskResult) + sc.runJob(this, foldPartition, mergeResult) + jobResult + } + + /** + * Aggregate the elements of each partition, and then the results for all the partitions, using + * given combine functions and a neutral "zero value". This function can return a different result + * type, U, than the type of this RDD, T. Thus, we need one operation for merging a T into an U + * and one operation for merging two U's, as in scala.TraversableOnce. Both of these functions are + * allowed to modify and return their first argument instead of creating a new U to avoid memory + * allocation. + */ + def aggregate[U: ClassManifest](zeroValue: U)(seqOp: (U, T) => U, combOp: (U, U) => U): U = { + // Clone the zero value since we will also be serializing it as part of tasks + var jobResult = Utils.clone(zeroValue, sc.env.closureSerializer.newInstance()) + val cleanSeqOp = sc.clean(seqOp) + val cleanCombOp = sc.clean(combOp) + val aggregatePartition = (it: Iterator[T]) => it.aggregate(zeroValue)(cleanSeqOp, cleanCombOp) + val mergeResult = (index: Int, taskResult: U) => jobResult = combOp(jobResult, taskResult) + sc.runJob(this, aggregatePartition, mergeResult) + jobResult + } + + /** + * Return the number of elements in the RDD. + */ + def count(): Long = { + sc.runJob(this, (iter: Iterator[T]) => { + var result = 0L + while (iter.hasNext) { + result += 1L + iter.next() + } + result + }).sum + } + + /** + * (Experimental) Approximate version of count() that returns a potentially incomplete result + * within a timeout, even if not all tasks have finished. + */ + def countApprox(timeout: Long, confidence: Double = 0.95): PartialResult[BoundedDouble] = { + val countElements: (TaskContext, Iterator[T]) => Long = { (ctx, iter) => + var result = 0L + while (iter.hasNext) { + result += 1L + iter.next() + } + result + } + val evaluator = new CountEvaluator(partitions.size, confidence) + sc.runApproximateJob(this, countElements, evaluator, timeout) + } + + /** + * Return the count of each unique value in this RDD as a map of (value, count) pairs. The final + * combine step happens locally on the master, equivalent to running a single reduce task. + */ + def countByValue(): Map[T, Long] = { + if (elementClassManifest.erasure.isArray) { + throw new SparkException("countByValue() does not support arrays") + } + // TODO: This should perhaps be distributed by default. + def countPartition(iter: Iterator[T]): Iterator[OLMap[T]] = { + val map = new OLMap[T] + while (iter.hasNext) { + val v = iter.next() + map.put(v, map.getLong(v) + 1L) + } + Iterator(map) + } + def mergeMaps(m1: OLMap[T], m2: OLMap[T]): OLMap[T] = { + val iter = m2.object2LongEntrySet.fastIterator() + while (iter.hasNext) { + val entry = iter.next() + m1.put(entry.getKey, m1.getLong(entry.getKey) + entry.getLongValue) + } + return m1 + } + val myResult = mapPartitions(countPartition).reduce(mergeMaps) + myResult.asInstanceOf[java.util.Map[T, Long]] // Will be wrapped as a Scala mutable Map + } + + /** + * (Experimental) Approximate version of countByValue(). + */ + def countByValueApprox( + timeout: Long, + confidence: Double = 0.95 + ): PartialResult[Map[T, BoundedDouble]] = { + if (elementClassManifest.erasure.isArray) { + throw new SparkException("countByValueApprox() does not support arrays") + } + val countPartition: (TaskContext, Iterator[T]) => OLMap[T] = { (ctx, iter) => + val map = new OLMap[T] + while (iter.hasNext) { + val v = iter.next() + map.put(v, map.getLong(v) + 1L) + } + map + } + val evaluator = new GroupedCountEvaluator[T](partitions.size, confidence) + sc.runApproximateJob(this, countPartition, evaluator, timeout) + } + + /** + * Take the first num elements of the RDD. This currently scans the partitions *one by one*, so + * it will be slow if a lot of partitions are required. In that case, use collect() to get the + * whole RDD instead. + */ + def take(num: Int): Array[T] = { + if (num == 0) { + return new Array[T](0) + } + val buf = new ArrayBuffer[T] + var p = 0 + while (buf.size < num && p < partitions.size) { + val left = num - buf.size + val res = sc.runJob(this, (it: Iterator[T]) => it.take(left).toArray, Array(p), true) + buf ++= res(0) + if (buf.size == num) + return buf.toArray + p += 1 + } + return buf.toArray + } + + /** + * Return the first element in this RDD. + */ + def first(): T = take(1) match { + case Array(t) => t + case _ => throw new UnsupportedOperationException("empty collection") + } + + /** + * Returns the top K elements from this RDD as defined by + * the specified implicit Ordering[T]. + * @param num the number of top elements to return + * @param ord the implicit ordering for T + * @return an array of top elements + */ + def top(num: Int)(implicit ord: Ordering[T]): Array[T] = { + mapPartitions { items => + val queue = new BoundedPriorityQueue[T](num) + queue ++= items + Iterator.single(queue) + }.reduce { (queue1, queue2) => + queue1 ++= queue2 + queue1 + }.toArray.sorted(ord.reverse) + } + + /** + * Returns the first K elements from this RDD as defined by + * the specified implicit Ordering[T] and maintains the + * ordering. + * @param num the number of top elements to return + * @param ord the implicit ordering for T + * @return an array of top elements + */ + def takeOrdered(num: Int)(implicit ord: Ordering[T]): Array[T] = top(num)(ord.reverse) + + /** + * Save this RDD as a text file, using string representations of elements. + */ + def saveAsTextFile(path: String) { + this.map(x => (NullWritable.get(), new Text(x.toString))) + .saveAsHadoopFile[TextOutputFormat[NullWritable, Text]](path) + } + + /** + * Save this RDD as a compressed text file, using string representations of elements. + */ + def saveAsTextFile(path: String, codec: Class[_ <: CompressionCodec]) { + this.map(x => (NullWritable.get(), new Text(x.toString))) + .saveAsHadoopFile[TextOutputFormat[NullWritable, Text]](path, codec) + } + + /** + * Save this RDD as a SequenceFile of serialized objects. + */ + def saveAsObjectFile(path: String) { + this.mapPartitions(iter => iter.grouped(10).map(_.toArray)) + .map(x => (NullWritable.get(), new BytesWritable(Utils.serialize(x)))) + .saveAsSequenceFile(path) + } + + /** + * Creates tuples of the elements in this RDD by applying `f`. + */ + def keyBy[K](f: T => K): RDD[(K, T)] = { + map(x => (f(x), x)) + } + + /** A private method for tests, to look at the contents of each partition */ + private[spark] def collectPartitions(): Array[Array[T]] = { + sc.runJob(this, (iter: Iterator[T]) => iter.toArray) + } + + /** + * Mark this RDD for checkpointing. It will be saved to a file inside the checkpoint + * directory set with SparkContext.setCheckpointDir() and all references to its parent + * RDDs will be removed. This function must be called before any job has been + * executed on this RDD. It is strongly recommended that this RDD is persisted in + * memory, otherwise saving it on a file will require recomputation. + */ + def checkpoint() { + if (context.checkpointDir.isEmpty) { + throw new Exception("Checkpoint directory has not been set in the SparkContext") + } else if (checkpointData.isEmpty) { + checkpointData = Some(new RDDCheckpointData(this)) + checkpointData.get.markForCheckpoint() + } + } + + /** + * Return whether this RDD has been checkpointed or not + */ + def isCheckpointed: Boolean = { + checkpointData.map(_.isCheckpointed).getOrElse(false) + } + + /** + * Gets the name of the file to which this RDD was checkpointed + */ + def getCheckpointFile: Option[String] = { + checkpointData.flatMap(_.getCheckpointFile) + } + + // ======================================================================= + // Other internal methods and fields + // ======================================================================= + + private var storageLevel: StorageLevel = StorageLevel.NONE + + /** Record user function generating this RDD. */ + private[spark] val origin = Utils.formatSparkCallSite + + private[spark] def elementClassManifest: ClassManifest[T] = classManifest[T] + + private[spark] var checkpointData: Option[RDDCheckpointData[T]] = None + + /** Returns the first parent RDD */ + protected[spark] def firstParent[U: ClassManifest] = { + dependencies.head.rdd.asInstanceOf[RDD[U]] + } + + /** The [[org.apache.spark.SparkContext]] that this RDD was created on. */ + def context = sc + + // Avoid handling doCheckpoint multiple times to prevent excessive recursion + private var doCheckpointCalled = false + + /** + * Performs the checkpointing of this RDD by saving this. It is called by the DAGScheduler + * after a job using this RDD has completed (therefore the RDD has been materialized and + * potentially stored in memory). doCheckpoint() is called recursively on the parent RDDs. + */ + private[spark] def doCheckpoint() { + if (!doCheckpointCalled) { + doCheckpointCalled = true + if (checkpointData.isDefined) { + checkpointData.get.doCheckpoint() + } else { + dependencies.foreach(_.rdd.doCheckpoint()) + } + } + } + + /** + * Changes the dependencies of this RDD from its original parents to a new RDD (`newRDD`) + * created from the checkpoint file, and forget its old dependencies and partitions. + */ + private[spark] def markCheckpointed(checkpointRDD: RDD[_]) { + clearDependencies() + partitions_ = null + deps = null // Forget the constructor argument for dependencies too + } + + /** + * Clears the dependencies of this RDD. This method must ensure that all references + * to the original parent RDDs is removed to enable the parent RDDs to be garbage + * collected. Subclasses of RDD may override this method for implementing their own cleaning + * logic. See [[org.apache.spark.rdd.UnionRDD]] for an example. + */ + protected def clearDependencies() { + dependencies_ = null + } + + /** A description of this RDD and its recursive dependencies for debugging. */ + def toDebugString: String = { + def debugString(rdd: RDD[_], prefix: String = ""): Seq[String] = { + Seq(prefix + rdd + " (" + rdd.partitions.size + " partitions)") ++ + rdd.dependencies.flatMap(d => debugString(d.rdd, prefix + " ")) + } + debugString(this).mkString("\n") + } + + override def toString: String = "%s%s[%d] at %s".format( + Option(name).map(_ + " ").getOrElse(""), + getClass.getSimpleName, + id, + origin) + + def toJavaRDD() : JavaRDD[T] = { + new JavaRDD(this)(elementClassManifest) + } + +} diff --git a/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala b/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala new file mode 100644 index 0000000000..6009a41570 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala @@ -0,0 +1,131 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rdd + +import org.apache.hadoop.fs.Path +import org.apache.hadoop.conf.Configuration + +import org.apache.spark.{Partition, SparkException, Logging} +import org.apache.spark.scheduler.{ResultTask, ShuffleMapTask} + +/** + * Enumeration to manage state transitions of an RDD through checkpointing + * [ Initialized --> marked for checkpointing --> checkpointing in progress --> checkpointed ] + */ +private[spark] object CheckpointState extends Enumeration { + type CheckpointState = Value + val Initialized, MarkedForCheckpoint, CheckpointingInProgress, Checkpointed = Value +} + +/** + * This class contains all the information related to RDD checkpointing. Each instance of this class + * is associated with a RDD. It manages process of checkpointing of the associated RDD, as well as, + * manages the post-checkpoint state by providing the updated partitions, iterator and preferred locations + * of the checkpointed RDD. + */ +private[spark] class RDDCheckpointData[T: ClassManifest](rdd: RDD[T]) + extends Logging with Serializable { + + import CheckpointState._ + + // The checkpoint state of the associated RDD. + var cpState = Initialized + + // The file to which the associated RDD has been checkpointed to + @transient var cpFile: Option[String] = None + + // The CheckpointRDD created from the checkpoint file, that is, the new parent the associated RDD. + var cpRDD: Option[RDD[T]] = None + + // Mark the RDD for checkpointing + def markForCheckpoint() { + RDDCheckpointData.synchronized { + if (cpState == Initialized) cpState = MarkedForCheckpoint + } + } + + // Is the RDD already checkpointed + def isCheckpointed: Boolean = { + RDDCheckpointData.synchronized { cpState == Checkpointed } + } + + // Get the file to which this RDD was checkpointed to as an Option + def getCheckpointFile: Option[String] = { + RDDCheckpointData.synchronized { cpFile } + } + + // Do the checkpointing of the RDD. Called after the first job using that RDD is over. + def doCheckpoint() { + // If it is marked for checkpointing AND checkpointing is not already in progress, + // then set it to be in progress, else return + RDDCheckpointData.synchronized { + if (cpState == MarkedForCheckpoint) { + cpState = CheckpointingInProgress + } else { + return + } + } + + // Create the output path for the checkpoint + val path = new Path(rdd.context.checkpointDir.get, "rdd-" + rdd.id) + val fs = path.getFileSystem(new Configuration()) + if (!fs.mkdirs(path)) { + throw new SparkException("Failed to create checkpoint path " + path) + } + + // Save to file, and reload it as an RDD + rdd.context.runJob(rdd, CheckpointRDD.writeToFile(path.toString) _) + val newRDD = new CheckpointRDD[T](rdd.context, path.toString) + + // Change the dependencies and partitions of the RDD + RDDCheckpointData.synchronized { + cpFile = Some(path.toString) + cpRDD = Some(newRDD) + rdd.markCheckpointed(newRDD) // Update the RDD's dependencies and partitions + cpState = Checkpointed + RDDCheckpointData.clearTaskCaches() + logInfo("Done checkpointing RDD " + rdd.id + ", new parent is RDD " + newRDD.id) + } + } + + // Get preferred location of a split after checkpointing + def getPreferredLocations(split: Partition): Seq[String] = { + RDDCheckpointData.synchronized { + cpRDD.get.preferredLocations(split) + } + } + + def getPartitions: Array[Partition] = { + RDDCheckpointData.synchronized { + cpRDD.get.partitions + } + } + + def checkpointRDD: Option[RDD[T]] = { + RDDCheckpointData.synchronized { + cpRDD + } + } +} + +private[spark] object RDDCheckpointData { + def clearTaskCaches() { + ShuffleMapTask.clearCache() + ResultTask.clearCache() + } +} diff --git a/core/src/main/scala/org/apache/spark/rdd/SampledRDD.scala b/core/src/main/scala/org/apache/spark/rdd/SampledRDD.scala index 1e8d89e912..2c5253ae30 100644 --- a/core/src/main/scala/org/apache/spark/rdd/SampledRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/SampledRDD.scala @@ -22,7 +22,7 @@ import java.util.Random import cern.jet.random.Poisson import cern.jet.random.engine.DRand -import org.apache.spark.{RDD, Partition, TaskContext} +import org.apache.spark.{Partition, TaskContext} private[spark] class SampledRDDPartition(val prev: Partition, val seed: Int) extends Partition with Serializable { diff --git a/core/src/main/scala/org/apache/spark/rdd/SequenceFileRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/SequenceFileRDDFunctions.scala new file mode 100644 index 0000000000..5fe4676029 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rdd/SequenceFileRDDFunctions.scala @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rdd + +import org.apache.hadoop.mapred.JobConf +import org.apache.hadoop.mapred.SequenceFileOutputFormat +import org.apache.hadoop.io.compress.CompressionCodec +import org.apache.hadoop.io.Writable + +import org.apache.spark.SparkContext._ +import org.apache.spark.Logging + +/** + * Extra functions available on RDDs of (key, value) pairs to create a Hadoop SequenceFile, + * through an implicit conversion. Note that this can't be part of PairRDDFunctions because + * we need more implicit parameters to convert our keys and values to Writable. + * + * Import `org.apache.spark.SparkContext._` at the top of their program to use these functions. + */ +class SequenceFileRDDFunctions[K <% Writable: ClassManifest, V <% Writable : ClassManifest]( + self: RDD[(K, V)]) + extends Logging + with Serializable { + + private def getWritableClass[T <% Writable: ClassManifest](): Class[_ <: Writable] = { + val c = { + if (classOf[Writable].isAssignableFrom(classManifest[T].erasure)) { + classManifest[T].erasure + } else { + // We get the type of the Writable class by looking at the apply method which converts + // from T to Writable. Since we have two apply methods we filter out the one which + // is not of the form "java.lang.Object apply(java.lang.Object)" + implicitly[T => Writable].getClass.getDeclaredMethods().filter( + m => m.getReturnType().toString != "class java.lang.Object" && + m.getName() == "apply")(0).getReturnType + + } + // TODO: use something like WritableConverter to avoid reflection + } + c.asInstanceOf[Class[_ <: Writable]] + } + + /** + * Output the RDD as a Hadoop SequenceFile using the Writable types we infer from the RDD's key + * and value types. If the key or value are Writable, then we use their classes directly; + * otherwise we map primitive types such as Int and Double to IntWritable, DoubleWritable, etc, + * byte arrays to BytesWritable, and Strings to Text. The `path` can be on any Hadoop-supported + * file system. + */ + def saveAsSequenceFile(path: String, codec: Option[Class[_ <: CompressionCodec]] = None) { + def anyToWritable[U <% Writable](u: U): Writable = u + + val keyClass = getWritableClass[K] + val valueClass = getWritableClass[V] + val convertKey = !classOf[Writable].isAssignableFrom(self.getKeyClass) + val convertValue = !classOf[Writable].isAssignableFrom(self.getValueClass) + + logInfo("Saving as sequence file of type (" + keyClass.getSimpleName + "," + valueClass.getSimpleName + ")" ) + val format = classOf[SequenceFileOutputFormat[Writable, Writable]] + val jobConf = new JobConf(self.context.hadoopConfiguration) + if (!convertKey && !convertValue) { + self.saveAsHadoopFile(path, keyClass, valueClass, format, jobConf, codec) + } else if (!convertKey && convertValue) { + self.map(x => (x._1,anyToWritable(x._2))).saveAsHadoopFile( + path, keyClass, valueClass, format, jobConf, codec) + } else if (convertKey && !convertValue) { + self.map(x => (anyToWritable(x._1),x._2)).saveAsHadoopFile( + path, keyClass, valueClass, format, jobConf, codec) + } else if (convertKey && convertValue) { + self.map(x => (anyToWritable(x._1),anyToWritable(x._2))).saveAsHadoopFile( + path, keyClass, valueClass, format, jobConf, codec) + } + } +} diff --git a/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala index f0e9ab8b80..9537152335 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala @@ -17,7 +17,7 @@ package org.apache.spark.rdd -import org.apache.spark.{Dependency, Partitioner, RDD, SparkEnv, ShuffleDependency, Partition, TaskContext} +import org.apache.spark.{Dependency, Partitioner, SparkEnv, ShuffleDependency, Partition, TaskContext} private[spark] class ShuffledRDDPartition(val idx: Int) extends Partition { diff --git a/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala index 7369dfaa74..8c1a29dfff 100644 --- a/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala @@ -20,7 +20,6 @@ package org.apache.spark.rdd import java.util.{HashMap => JHashMap} import scala.collection.JavaConversions._ import scala.collection.mutable.ArrayBuffer -import org.apache.spark.RDD import org.apache.spark.Partitioner import org.apache.spark.Dependency import org.apache.spark.TaskContext diff --git a/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala b/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala index fd02476b62..ae8a9f36a6 100644 --- a/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala @@ -18,7 +18,7 @@ package org.apache.spark.rdd import scala.collection.mutable.ArrayBuffer -import org.apache.spark.{Dependency, RangeDependency, RDD, SparkContext, Partition, TaskContext} +import org.apache.spark.{Dependency, RangeDependency, SparkContext, Partition, TaskContext} import java.io.{ObjectOutputStream, IOException} private[spark] class UnionPartition[T: ClassManifest](idx: Int, rdd: RDD[T], splitIndex: Int) diff --git a/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala index 5ae1db3e67..31e6fd519d 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala @@ -17,7 +17,7 @@ package org.apache.spark.rdd -import org.apache.spark.{Utils, OneToOneDependency, RDD, SparkContext, Partition, TaskContext} +import org.apache.spark.{OneToOneDependency, SparkContext, Partition, TaskContext} import java.io.{ObjectOutputStream, IOException} private[spark] class ZippedPartitionsPartition( diff --git a/core/src/main/scala/org/apache/spark/rdd/ZippedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ZippedRDD.scala index 3bd00d291b..567b67dfee 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ZippedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ZippedRDD.scala @@ -17,7 +17,7 @@ package org.apache.spark.rdd -import org.apache.spark.{Utils, OneToOneDependency, RDD, SparkContext, Partition, TaskContext} +import org.apache.spark.{OneToOneDependency, SparkContext, Partition, TaskContext} import java.io.{ObjectOutputStream, IOException} diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 5ac700bbf4..92add5b073 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -25,6 +25,7 @@ import java.util.concurrent.atomic.AtomicInteger import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Map} import org.apache.spark._ +import org.apache.spark.rdd.RDD import org.apache.spark.executor.TaskMetrics import org.apache.spark.partial.{ApproximateActionListener, ApproximateEvaluator, PartialResult} import org.apache.spark.scheduler.cluster.TaskInfo diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala index 5b07933eed..0d99670648 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala @@ -23,6 +23,7 @@ import org.apache.spark.scheduler.cluster.TaskInfo import scala.collection.mutable.Map import org.apache.spark._ +import org.apache.spark.rdd.RDD import org.apache.spark.executor.TaskMetrics /** diff --git a/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala b/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala index 98ef4d1e63..c8b78bf00a 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala @@ -28,6 +28,7 @@ import scala.collection.mutable.{Map, HashMap, ListBuffer} import scala.io.Source import org.apache.spark._ +import org.apache.spark.rdd.RDD import org.apache.spark.executor.TaskMetrics import org.apache.spark.scheduler.cluster.TaskInfo diff --git a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala index 2f157ccdd2..2b007cbe82 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala @@ -17,11 +17,14 @@ package org.apache.spark.scheduler -import org.apache.spark._ import java.io._ -import util.{MetadataCleaner, TimeStampedHashMap} import java.util.zip.{GZIPInputStream, GZIPOutputStream} +import org.apache.spark._ +import org.apache.spark.rdd.RDD +import org.apache.spark.rdd.RDDCheckpointData +import org.apache.spark.util.{MetadataCleaner, TimeStampedHashMap} + private[spark] object ResultTask { // A simple map between the stage id to the serialized byte array of a task. diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala index ca716b44e8..764775fede 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala @@ -26,6 +26,8 @@ import org.apache.spark._ import org.apache.spark.executor.ShuffleWriteMetrics import org.apache.spark.storage._ import org.apache.spark.util.{TimeStampedHashMap, MetadataCleaner} +import org.apache.spark.rdd.RDD +import org.apache.spark.rdd.RDDCheckpointData private[spark] object ShuffleMapTask { diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala index 3504424fa9..c3cf4b8907 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala @@ -19,8 +19,8 @@ package org.apache.spark.scheduler import java.util.Properties import org.apache.spark.scheduler.cluster.TaskInfo -import org.apache.spark.util.Distribution -import org.apache.spark.{Logging, SparkContext, TaskEndReason, Utils} +import org.apache.spark.util.{Utils, Distribution} +import org.apache.spark.{Logging, SparkContext, TaskEndReason} import org.apache.spark.executor.TaskMetrics sealed trait SparkListenerEvents diff --git a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala index 87b1fe4e0c..aa293dc6b3 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala @@ -17,9 +17,8 @@ package org.apache.spark.scheduler -import java.net.URI - import org.apache.spark._ +import org.apache.spark.rdd.RDD import org.apache.spark.storage.BlockManagerId /** diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala index 776675d28c..5c7e5bb977 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala @@ -21,8 +21,9 @@ import java.io._ import scala.collection.mutable.Map import org.apache.spark.executor.TaskMetrics -import org.apache.spark.{Utils, SparkEnv} +import org.apache.spark.{SparkEnv} import java.nio.ByteBuffer +import org.apache.spark.util.Utils // Task result. Also contains updates to accumulator variables. // TODO: Use of distributed cache to return result is a hack to get around diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManager.scala index a33307b83a..1b31c8c57e 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManager.scala @@ -26,7 +26,7 @@ import scala.collection.mutable.HashSet import scala.math.max import scala.math.min -import org.apache.spark.{FetchFailed, Logging, Resubmitted, SparkEnv, Success, TaskEndReason, TaskState, Utils} +import org.apache.spark.{FetchFailed, Logging, Resubmitted, SparkEnv, Success, TaskEndReason, TaskState} import org.apache.spark.{ExceptionFailure, SparkException, TaskResultTooBigFailure} import org.apache.spark.TaskState.TaskState import org.apache.spark.scheduler._ diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SchedulerBackend.scala index bde2f73df4..d57eb3276f 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/SchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SchedulerBackend.scala @@ -17,7 +17,7 @@ package org.apache.spark.scheduler.cluster -import org.apache.spark.{SparkContext, Utils} +import org.apache.spark.{SparkContext} /** * A backend interface for cluster scheduling systems that allows plugging in different ones under diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala index ac6dc7d879..d003bf1bba 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala @@ -17,10 +17,11 @@ package org.apache.spark.scheduler.cluster -import org.apache.spark.{Utils, Logging, SparkContext} +import org.apache.spark.{Logging, SparkContext} import org.apache.spark.deploy.client.{Client, ClientListener} import org.apache.spark.deploy.{Command, ApplicationDescription} import scala.collection.mutable.HashMap +import org.apache.spark.util.Utils private[spark] class SparkDeploySchedulerBackend( scheduler: ClusterScheduler, diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneClusterMessage.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneClusterMessage.scala index 1cc5daf673..9c36d221f6 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneClusterMessage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneClusterMessage.scala @@ -20,8 +20,7 @@ package org.apache.spark.scheduler.cluster import java.nio.ByteBuffer import org.apache.spark.TaskState.TaskState -import org.apache.spark.Utils -import org.apache.spark.util.SerializableBuffer +import org.apache.spark.util.{Utils, SerializableBuffer} private[spark] sealed trait StandaloneClusterMessage extends Serializable diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala index 3677a827e0..b4ea0be415 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala @@ -28,8 +28,9 @@ import akka.remote.{RemoteClientShutdown, RemoteClientDisconnected, RemoteClient import akka.util.Duration import akka.util.duration._ -import org.apache.spark.{Utils, SparkException, Logging, TaskState} +import org.apache.spark.{SparkException, Logging, TaskState} import org.apache.spark.scheduler.cluster.StandaloneClusterMessages._ +import org.apache.spark.util.Utils /** * A standalone scheduler backend, which waits for standalone executors to connect to it through diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/TaskInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/TaskInfo.scala index 7ce14be7fb..9685fb1a67 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/TaskInfo.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/TaskInfo.scala @@ -17,7 +17,7 @@ package org.apache.spark.scheduler.cluster -import org.apache.spark.Utils +import org.apache.spark.util.Utils /** * Information about a running task attempt inside a TaskSet. diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalScheduler.scala index f0ebe66d82..e8fa5e2f17 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/local/LocalScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/local/LocalScheduler.scala @@ -34,6 +34,7 @@ import org.apache.spark.scheduler._ import org.apache.spark.scheduler.cluster._ import org.apache.spark.scheduler.cluster.SchedulingMode.SchedulingMode import akka.actor._ +import org.apache.spark.util.Utils /** * A FIFO or Fair TaskScheduler implementation that runs tasks locally in a thread pool. Optionally diff --git a/core/src/main/scala/org/apache/spark/scheduler/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/mesos/CoarseMesosSchedulerBackend.scala index f6a2feab28..3dbe61d706 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/mesos/CoarseMesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/mesos/CoarseMesosSchedulerBackend.scala @@ -23,7 +23,7 @@ import org.apache.mesos.{Scheduler => MScheduler} import org.apache.mesos._ import org.apache.mesos.Protos.{TaskInfo => MesosTaskInfo, TaskState => MesosTaskState, _} -import org.apache.spark.{SparkException, Utils, Logging, SparkContext} +import org.apache.spark.{SparkException, Logging, SparkContext} import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} import scala.collection.JavaConversions._ import java.io.File diff --git a/core/src/main/scala/org/apache/spark/scheduler/mesos/MesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/mesos/MesosSchedulerBackend.scala index e002af1742..541f86e338 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/mesos/MesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/mesos/MesosSchedulerBackend.scala @@ -23,7 +23,7 @@ import org.apache.mesos.{Scheduler => MScheduler} import org.apache.mesos._ import org.apache.mesos.Protos.{TaskInfo => MesosTaskInfo, TaskState => MesosTaskState, _} -import org.apache.spark.{SparkException, Utils, Logging, SparkContext} +import org.apache.spark.{SparkException, Logging, SparkContext} import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} import scala.collection.JavaConversions._ import java.io.File @@ -31,6 +31,7 @@ import org.apache.spark.scheduler.cluster._ import java.util.{ArrayList => JArrayList, List => JList} import java.util.Collections import org.apache.spark.TaskState +import org.apache.spark.util.Utils /** * A SchedulerBackend for running fine-grained tasks on Mesos. Each Spark task is mapped to a diff --git a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala new file mode 100644 index 0000000000..4de81617b1 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.serializer + +import java.io._ +import java.nio.ByteBuffer + +import org.apache.spark.util.ByteBufferInputStream + +private[spark] class JavaSerializationStream(out: OutputStream) extends SerializationStream { + val objOut = new ObjectOutputStream(out) + def writeObject[T](t: T): SerializationStream = { objOut.writeObject(t); this } + def flush() { objOut.flush() } + def close() { objOut.close() } +} + +private[spark] class JavaDeserializationStream(in: InputStream, loader: ClassLoader) +extends DeserializationStream { + val objIn = new ObjectInputStream(in) { + override def resolveClass(desc: ObjectStreamClass) = + Class.forName(desc.getName, false, loader) + } + + def readObject[T](): T = objIn.readObject().asInstanceOf[T] + def close() { objIn.close() } +} + +private[spark] class JavaSerializerInstance extends SerializerInstance { + def serialize[T](t: T): ByteBuffer = { + val bos = new ByteArrayOutputStream() + val out = serializeStream(bos) + out.writeObject(t) + out.close() + ByteBuffer.wrap(bos.toByteArray) + } + + def deserialize[T](bytes: ByteBuffer): T = { + val bis = new ByteBufferInputStream(bytes) + val in = deserializeStream(bis) + in.readObject().asInstanceOf[T] + } + + def deserialize[T](bytes: ByteBuffer, loader: ClassLoader): T = { + val bis = new ByteBufferInputStream(bytes) + val in = deserializeStream(bis, loader) + in.readObject().asInstanceOf[T] + } + + def serializeStream(s: OutputStream): SerializationStream = { + new JavaSerializationStream(s) + } + + def deserializeStream(s: InputStream): DeserializationStream = { + new JavaDeserializationStream(s, Thread.currentThread.getContextClassLoader) + } + + def deserializeStream(s: InputStream, loader: ClassLoader): DeserializationStream = { + new JavaDeserializationStream(s, loader) + } +} + +/** + * A Spark serializer that uses Java's built-in serialization. + */ +class JavaSerializer extends Serializer { + def newInstance(): SerializerInstance = new JavaSerializerInstance +} diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala new file mode 100644 index 0000000000..24ef204aa1 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala @@ -0,0 +1,159 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.serializer + +import java.nio.ByteBuffer +import java.io.{EOFException, InputStream, OutputStream} + +import com.esotericsoftware.kryo.serializers.{JavaSerializer => KryoJavaSerializer} +import com.esotericsoftware.kryo.{KryoException, Kryo} +import com.esotericsoftware.kryo.io.{Input => KryoInput, Output => KryoOutput} +import com.twitter.chill.ScalaKryoInstantiator + +import org.apache.spark.{SerializableWritable, Logging} +import org.apache.spark.storage.{GetBlock, GotBlock, PutBlock, StorageLevel} + +import org.apache.spark.broadcast.HttpBroadcast + +/** + * A Spark serializer that uses the [[http://code.google.com/p/kryo/wiki/V1Documentation Kryo 1.x library]]. + */ +class KryoSerializer extends org.apache.spark.serializer.Serializer with Logging { + private val bufferSize = System.getProperty("spark.kryoserializer.buffer.mb", "2").toInt * 1024 * 1024 + + def newKryoOutput() = new KryoOutput(bufferSize) + + def newKryoInput() = new KryoInput(bufferSize) + + def newKryo(): Kryo = { + val instantiator = new ScalaKryoInstantiator + val kryo = instantiator.newKryo() + val classLoader = Thread.currentThread.getContextClassLoader + + // Register some commonly used classes + val toRegister: Seq[AnyRef] = Seq( + ByteBuffer.allocate(1), + StorageLevel.MEMORY_ONLY, + PutBlock("1", ByteBuffer.allocate(1), StorageLevel.MEMORY_ONLY), + GotBlock("1", ByteBuffer.allocate(1)), + GetBlock("1") + ) + + for (obj <- toRegister) kryo.register(obj.getClass) + + // Allow sending SerializableWritable + kryo.register(classOf[SerializableWritable[_]], new KryoJavaSerializer()) + kryo.register(classOf[HttpBroadcast[_]], new KryoJavaSerializer()) + + // Allow the user to register their own classes by setting spark.kryo.registrator + try { + Option(System.getProperty("spark.kryo.registrator")).foreach { regCls => + logDebug("Running user registrator: " + regCls) + val reg = Class.forName(regCls, true, classLoader).newInstance().asInstanceOf[KryoRegistrator] + reg.registerClasses(kryo) + } + } catch { + case _: Exception => println("Failed to register spark.kryo.registrator") + } + + kryo.setClassLoader(classLoader) + + // Allow disabling Kryo reference tracking if user knows their object graphs don't have loops + kryo.setReferences(System.getProperty("spark.kryo.referenceTracking", "true").toBoolean) + + kryo + } + + def newInstance(): SerializerInstance = { + new KryoSerializerInstance(this) + } +} + +private[spark] +class KryoSerializationStream(kryo: Kryo, outStream: OutputStream) extends SerializationStream { + val output = new KryoOutput(outStream) + + def writeObject[T](t: T): SerializationStream = { + kryo.writeClassAndObject(output, t) + this + } + + def flush() { output.flush() } + def close() { output.close() } +} + +private[spark] +class KryoDeserializationStream(kryo: Kryo, inStream: InputStream) extends DeserializationStream { + val input = new KryoInput(inStream) + + def readObject[T](): T = { + try { + kryo.readClassAndObject(input).asInstanceOf[T] + } catch { + // DeserializationStream uses the EOF exception to indicate stopping condition. + case _: KryoException => throw new EOFException + } + } + + def close() { + // Kryo's Input automatically closes the input stream it is using. + input.close() + } +} + +private[spark] class KryoSerializerInstance(ks: KryoSerializer) extends SerializerInstance { + val kryo = ks.newKryo() + val output = ks.newKryoOutput() + val input = ks.newKryoInput() + + def serialize[T](t: T): ByteBuffer = { + output.clear() + kryo.writeClassAndObject(output, t) + ByteBuffer.wrap(output.toBytes) + } + + def deserialize[T](bytes: ByteBuffer): T = { + input.setBuffer(bytes.array) + kryo.readClassAndObject(input).asInstanceOf[T] + } + + def deserialize[T](bytes: ByteBuffer, loader: ClassLoader): T = { + val oldClassLoader = kryo.getClassLoader + kryo.setClassLoader(loader) + input.setBuffer(bytes.array) + val obj = kryo.readClassAndObject(input).asInstanceOf[T] + kryo.setClassLoader(oldClassLoader) + obj + } + + def serializeStream(s: OutputStream): SerializationStream = { + new KryoSerializationStream(kryo, s) + } + + def deserializeStream(s: InputStream): DeserializationStream = { + new KryoDeserializationStream(kryo, s) + } +} + +/** + * Interface implemented by clients to register their classes with Kryo when using Kryo + * serialization. + */ +trait KryoRegistrator { + def registerClasses(kryo: Kryo) +} diff --git a/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala index c91f0fc1ad..3aeda3879d 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala @@ -27,12 +27,12 @@ import scala.collection.mutable.Queue import io.netty.buffer.ByteBuf import org.apache.spark.Logging -import org.apache.spark.Utils import org.apache.spark.SparkException import org.apache.spark.network.BufferMessage import org.apache.spark.network.ConnectionManagerId import org.apache.spark.network.netty.ShuffleCopier import org.apache.spark.serializer.Serializer +import org.apache.spark.util.Utils /** diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 3299ac98d5..60fdc5f2ee 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -29,11 +29,11 @@ import akka.util.duration._ import it.unimi.dsi.fastutil.io.FastByteArrayOutputStream -import org.apache.spark.{Logging, SparkEnv, SparkException, Utils} +import org.apache.spark.{Logging, SparkEnv, SparkException} import org.apache.spark.io.CompressionCodec import org.apache.spark.network._ import org.apache.spark.serializer.Serializer -import org.apache.spark.util.{ByteBufferInputStream, IdGenerator, MetadataCleaner, TimeStampedHashMap} +import org.apache.spark.util._ import sun.nio.ch.DirectBuffer diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala index a22a80decc..74207f59af 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala @@ -19,7 +19,7 @@ package org.apache.spark.storage import java.io.{Externalizable, IOException, ObjectInput, ObjectOutput} import java.util.concurrent.ConcurrentHashMap -import org.apache.spark.Utils +import org.apache.spark.util.Utils /** * This class represent an unique identifier for a BlockManager. diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala index baa4a1da50..c7b23ab094 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala @@ -28,8 +28,9 @@ import akka.pattern.ask import akka.util.Duration import akka.util.duration._ -import org.apache.spark.{Logging, Utils, SparkException} +import org.apache.spark.{Logging, SparkException} import org.apache.spark.storage.BlockManagerMessages._ +import org.apache.spark.util.Utils /** diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerWorker.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerWorker.scala index f4856020e5..678c38203c 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerWorker.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerWorker.scala @@ -19,8 +19,9 @@ package org.apache.spark.storage import java.nio.ByteBuffer -import org.apache.spark.{Logging, Utils} +import org.apache.spark.{Logging} import org.apache.spark.network._ +import org.apache.spark.util.Utils /** * A network interface for BlockManager. Each slave should have one diff --git a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala index fd945e065c..fc25ef0fae 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala @@ -28,12 +28,12 @@ import scala.collection.mutable.ArrayBuffer import it.unimi.dsi.fastutil.io.FastBufferedOutputStream -import org.apache.spark.Utils import org.apache.spark.executor.ExecutorExitCode import org.apache.spark.serializer.{Serializer, SerializationStream} import org.apache.spark.Logging import org.apache.spark.network.netty.ShuffleSender import org.apache.spark.network.netty.PathResolver +import org.apache.spark.util.Utils /** diff --git a/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala b/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala index 828dc0f22d..3b3b2342fa 100644 --- a/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala @@ -19,9 +19,9 @@ package org.apache.spark.storage import java.util.LinkedHashMap import java.util.concurrent.ArrayBlockingQueue -import org.apache.spark.{SizeEstimator, Utils} import java.nio.ByteBuffer import collection.mutable.ArrayBuffer +import org.apache.spark.util.{SizeEstimator, Utils} /** * Stores blocks in memory, either as ArrayBuffers of deserialized Java objects or as diff --git a/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala b/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala index 0bba1dac54..2bb7715696 100644 --- a/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala +++ b/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala @@ -17,8 +17,9 @@ package org.apache.spark.storage -import org.apache.spark.{Utils, SparkContext} +import org.apache.spark.{SparkContext} import BlockManagerMasterActor.BlockStatus +import org.apache.spark.util.Utils private[spark] case class StorageStatus(blockManagerId: BlockManagerId, maxMem: Long, diff --git a/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala b/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala index 1d5afe9b08..f2ae8dd97d 100644 --- a/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala +++ b/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala @@ -19,9 +19,9 @@ package org.apache.spark.storage import akka.actor._ -import org.apache.spark.KryoSerializer import java.util.concurrent.ArrayBlockingQueue import util.Random +import org.apache.spark.serializer.KryoSerializer /** * This class tests the BlockManager and MemoryStore for thread safety and diff --git a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala index 4688effe0a..ad456ea565 100644 --- a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala @@ -21,12 +21,13 @@ import javax.servlet.http.HttpServletRequest import org.eclipse.jetty.server.{Handler, Server} -import org.apache.spark.{Logging, SparkContext, SparkEnv, Utils} +import org.apache.spark.{Logging, SparkContext, SparkEnv} import org.apache.spark.ui.env.EnvironmentUI import org.apache.spark.ui.exec.ExecutorsUI import org.apache.spark.ui.storage.BlockManagerUI import org.apache.spark.ui.jobs.JobProgressUI import org.apache.spark.ui.JettyUtils._ +import org.apache.spark.util.Utils /** Top level user interface for Spark */ private[spark] class SparkUI(sc: SparkContext) extends Logging { diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsUI.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsUI.scala index efe6b474e0..6e56c22d04 100644 --- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsUI.scala @@ -7,13 +7,14 @@ import scala.xml.Node import org.eclipse.jetty.server.Handler -import org.apache.spark.{ExceptionFailure, Logging, Utils, SparkContext} +import org.apache.spark.{ExceptionFailure, Logging, SparkContext} import org.apache.spark.executor.TaskMetrics import org.apache.spark.scheduler.cluster.TaskInfo import org.apache.spark.scheduler.{SparkListenerTaskStart, SparkListenerTaskEnd, SparkListener} import org.apache.spark.ui.JettyUtils._ import org.apache.spark.ui.Page.Executors import org.apache.spark.ui.UIUtils +import org.apache.spark.util.Utils private[spark] class ExecutorsUI(val sc: SparkContext) { diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala index ae02226300..86e0af0399 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala @@ -3,7 +3,7 @@ package org.apache.spark.ui.jobs import scala.Seq import scala.collection.mutable.{ListBuffer, HashMap, HashSet} -import org.apache.spark.{ExceptionFailure, SparkContext, Success, Utils} +import org.apache.spark.{ExceptionFailure, SparkContext, Success} import org.apache.spark.scheduler._ import org.apache.spark.scheduler.cluster.TaskInfo import org.apache.spark.executor.TaskMetrics diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressUI.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressUI.scala index 1bb7638bd9..6aecef5120 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressUI.scala @@ -29,11 +29,12 @@ import scala.Seq import scala.collection.mutable.{HashSet, ListBuffer, HashMap, ArrayBuffer} import org.apache.spark.ui.JettyUtils._ -import org.apache.spark.{ExceptionFailure, SparkContext, Success, Utils} +import org.apache.spark.{ExceptionFailure, SparkContext, Success} import org.apache.spark.scheduler._ import collection.mutable import org.apache.spark.scheduler.cluster.SchedulingMode import org.apache.spark.scheduler.cluster.SchedulingMode.SchedulingMode +import org.apache.spark.util.Utils /** Web UI showing progress status of all jobs in the given SparkContext. */ private[spark] class JobProgressUI(val sc: SparkContext) { diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index 2fe85bc0cf..a9969ab1c0 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -25,8 +25,8 @@ import scala.xml.Node import org.apache.spark.ui.UIUtils._ import org.apache.spark.ui.Page._ -import org.apache.spark.util.Distribution -import org.apache.spark.{ExceptionFailure, Utils} +import org.apache.spark.util.{Utils, Distribution} +import org.apache.spark.{ExceptionFailure} import org.apache.spark.scheduler.cluster.TaskInfo import org.apache.spark.executor.TaskMetrics diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala index beb0574548..71e58a977e 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala @@ -5,9 +5,9 @@ import java.util.Date import scala.xml.Node import scala.collection.mutable.HashSet -import org.apache.spark.Utils import org.apache.spark.scheduler.cluster.{SchedulingMode, TaskInfo} import org.apache.spark.scheduler.Stage +import org.apache.spark.util.Utils /** Page showing list of all ongoing and recently finished stages */ diff --git a/core/src/main/scala/org/apache/spark/ui/storage/IndexPage.scala b/core/src/main/scala/org/apache/spark/ui/storage/IndexPage.scala index 1eb4a7a85e..c3ec907370 100644 --- a/core/src/main/scala/org/apache/spark/ui/storage/IndexPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/storage/IndexPage.scala @@ -22,9 +22,9 @@ import javax.servlet.http.HttpServletRequest import scala.xml.Node import org.apache.spark.storage.{RDDInfo, StorageUtils} -import org.apache.spark.Utils import org.apache.spark.ui.UIUtils._ import org.apache.spark.ui.Page._ +import org.apache.spark.util.Utils /** Page showing list of RDD's currently stored in the cluster */ private[spark] class IndexPage(parent: BlockManagerUI) { diff --git a/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala b/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala index 37baf17f7a..43c1257677 100644 --- a/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala @@ -21,11 +21,11 @@ import javax.servlet.http.HttpServletRequest import scala.xml.Node -import org.apache.spark.Utils import org.apache.spark.storage.{StorageStatus, StorageUtils} import org.apache.spark.storage.BlockManagerMasterActor.BlockStatus import org.apache.spark.ui.UIUtils._ import org.apache.spark.ui.Page._ +import org.apache.spark.util.Utils /** Page showing storage details for a given RDD */ diff --git a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala new file mode 100644 index 0000000000..7108595e3e --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala @@ -0,0 +1,232 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +import java.lang.reflect.Field + +import scala.collection.mutable.Map +import scala.collection.mutable.Set + +import org.objectweb.asm.{ClassReader, ClassVisitor, MethodVisitor, Type} +import org.objectweb.asm.Opcodes._ +import java.io.{InputStream, IOException, ByteArrayOutputStream, ByteArrayInputStream, BufferedInputStream} +import org.apache.spark.Logging + +private[spark] object ClosureCleaner extends Logging { + // Get an ASM class reader for a given class from the JAR that loaded it + private def getClassReader(cls: Class[_]): ClassReader = { + // Copy data over, before delegating to ClassReader - else we can run out of open file handles. + val className = cls.getName.replaceFirst("^.*\\.", "") + ".class" + val resourceStream = cls.getResourceAsStream(className) + // todo: Fixme - continuing with earlier behavior ... + if (resourceStream == null) return new ClassReader(resourceStream) + + val baos = new ByteArrayOutputStream(128) + Utils.copyStream(resourceStream, baos, true) + new ClassReader(new ByteArrayInputStream(baos.toByteArray)) + } + + // Check whether a class represents a Scala closure + private def isClosure(cls: Class[_]): Boolean = { + cls.getName.contains("$anonfun$") + } + + // Get a list of the classes of the outer objects of a given closure object, obj; + // the outer objects are defined as any closures that obj is nested within, plus + // possibly the class that the outermost closure is in, if any. We stop searching + // for outer objects beyond that because cloning the user's object is probably + // not a good idea (whereas we can clone closure objects just fine since we + // understand how all their fields are used). + private def getOuterClasses(obj: AnyRef): List[Class[_]] = { + for (f <- obj.getClass.getDeclaredFields if f.getName == "$outer") { + f.setAccessible(true) + if (isClosure(f.getType)) { + return f.getType :: getOuterClasses(f.get(obj)) + } else { + return f.getType :: Nil // Stop at the first $outer that is not a closure + } + } + return Nil + } + + // Get a list of the outer objects for a given closure object. + private def getOuterObjects(obj: AnyRef): List[AnyRef] = { + for (f <- obj.getClass.getDeclaredFields if f.getName == "$outer") { + f.setAccessible(true) + if (isClosure(f.getType)) { + return f.get(obj) :: getOuterObjects(f.get(obj)) + } else { + return f.get(obj) :: Nil // Stop at the first $outer that is not a closure + } + } + return Nil + } + + private def getInnerClasses(obj: AnyRef): List[Class[_]] = { + val seen = Set[Class[_]](obj.getClass) + var stack = List[Class[_]](obj.getClass) + while (!stack.isEmpty) { + val cr = getClassReader(stack.head) + stack = stack.tail + val set = Set[Class[_]]() + cr.accept(new InnerClosureFinder(set), 0) + for (cls <- set -- seen) { + seen += cls + stack = cls :: stack + } + } + return (seen - obj.getClass).toList + } + + private def createNullValue(cls: Class[_]): AnyRef = { + if (cls.isPrimitive) { + new java.lang.Byte(0: Byte) // Should be convertible to any primitive type + } else { + null + } + } + + def clean(func: AnyRef) { + // TODO: cache outerClasses / innerClasses / accessedFields + val outerClasses = getOuterClasses(func) + val innerClasses = getInnerClasses(func) + val outerObjects = getOuterObjects(func) + + val accessedFields = Map[Class[_], Set[String]]() + for (cls <- outerClasses) + accessedFields(cls) = Set[String]() + for (cls <- func.getClass :: innerClasses) + getClassReader(cls).accept(new FieldAccessFinder(accessedFields), 0) + //logInfo("accessedFields: " + accessedFields) + + val inInterpreter = { + try { + val interpClass = Class.forName("spark.repl.Main") + interpClass.getMethod("interp").invoke(null) != null + } catch { + case _: ClassNotFoundException => true + } + } + + var outerPairs: List[(Class[_], AnyRef)] = (outerClasses zip outerObjects).reverse + var outer: AnyRef = null + if (outerPairs.size > 0 && !isClosure(outerPairs.head._1)) { + // The closure is ultimately nested inside a class; keep the object of that + // class without cloning it since we don't want to clone the user's objects. + outer = outerPairs.head._2 + outerPairs = outerPairs.tail + } + // Clone the closure objects themselves, nulling out any fields that are not + // used in the closure we're working on or any of its inner closures. + for ((cls, obj) <- outerPairs) { + outer = instantiateClass(cls, outer, inInterpreter) + for (fieldName <- accessedFields(cls)) { + val field = cls.getDeclaredField(fieldName) + field.setAccessible(true) + val value = field.get(obj) + //logInfo("1: Setting " + fieldName + " on " + cls + " to " + value); + field.set(outer, value) + } + } + + if (outer != null) { + //logInfo("2: Setting $outer on " + func.getClass + " to " + outer); + val field = func.getClass.getDeclaredField("$outer") + field.setAccessible(true) + field.set(func, outer) + } + } + + private def instantiateClass(cls: Class[_], outer: AnyRef, inInterpreter: Boolean): AnyRef = { + //logInfo("Creating a " + cls + " with outer = " + outer) + if (!inInterpreter) { + // This is a bona fide closure class, whose constructor has no effects + // other than to set its fields, so use its constructor + val cons = cls.getConstructors()(0) + val params = cons.getParameterTypes.map(createNullValue).toArray + if (outer != null) + params(0) = outer // First param is always outer object + return cons.newInstance(params: _*).asInstanceOf[AnyRef] + } else { + // Use reflection to instantiate object without calling constructor + val rf = sun.reflect.ReflectionFactory.getReflectionFactory() + val parentCtor = classOf[java.lang.Object].getDeclaredConstructor() + val newCtor = rf.newConstructorForSerialization(cls, parentCtor) + val obj = newCtor.newInstance().asInstanceOf[AnyRef] + if (outer != null) { + //logInfo("3: Setting $outer on " + cls + " to " + outer); + val field = cls.getDeclaredField("$outer") + field.setAccessible(true) + field.set(obj, outer) + } + return obj + } + } +} + +private[spark] class FieldAccessFinder(output: Map[Class[_], Set[String]]) extends ClassVisitor(ASM4) { + override def visitMethod(access: Int, name: String, desc: String, + sig: String, exceptions: Array[String]): MethodVisitor = { + return new MethodVisitor(ASM4) { + override def visitFieldInsn(op: Int, owner: String, name: String, desc: String) { + if (op == GETFIELD) { + for (cl <- output.keys if cl.getName == owner.replace('/', '.')) { + output(cl) += name + } + } + } + + override def visitMethodInsn(op: Int, owner: String, name: String, + desc: String) { + // Check for calls a getter method for a variable in an interpreter wrapper object. + // This means that the corresponding field will be accessed, so we should save it. + if (op == INVOKEVIRTUAL && owner.endsWith("$iwC") && !name.endsWith("$outer")) { + for (cl <- output.keys if cl.getName == owner.replace('/', '.')) { + output(cl) += name + } + } + } + } + } +} + +private[spark] class InnerClosureFinder(output: Set[Class[_]]) extends ClassVisitor(ASM4) { + var myName: String = null + + override def visit(version: Int, access: Int, name: String, sig: String, + superName: String, interfaces: Array[String]) { + myName = name + } + + override def visitMethod(access: Int, name: String, desc: String, + sig: String, exceptions: Array[String]): MethodVisitor = { + return new MethodVisitor(ASM4) { + override def visitMethodInsn(op: Int, owner: String, name: String, + desc: String) { + val argTypes = Type.getArgumentTypes(desc) + if (op == INVOKESPECIAL && name == "" && argTypes.length > 0 + && argTypes(0).toString.startsWith("L") // is it an object? + && argTypes(0).getInternalName == myName) + output += Class.forName( + owner.replace('/', '.'), + false, + Thread.currentThread.getContextClassLoader) + } + } + } +} diff --git a/core/src/main/scala/org/apache/spark/util/MemoryParam.scala b/core/src/main/scala/org/apache/spark/util/MemoryParam.scala index 0ee6707826..4869c9897a 100644 --- a/core/src/main/scala/org/apache/spark/util/MemoryParam.scala +++ b/core/src/main/scala/org/apache/spark/util/MemoryParam.scala @@ -17,8 +17,6 @@ package org.apache.spark.util -import org.apache.spark.Utils - /** * An extractor object for parsing JVM memory strings, such as "10g", into an Int representing * the number of megabytes. Supports the same formats as Utils.memoryStringToMb. diff --git a/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala b/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala new file mode 100644 index 0000000000..a25b37a2a9 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala @@ -0,0 +1,284 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +import java.lang.reflect.Field +import java.lang.reflect.Modifier +import java.lang.reflect.{Array => JArray} +import java.util.IdentityHashMap +import java.util.concurrent.ConcurrentHashMap +import java.util.Random + +import javax.management.MBeanServer +import java.lang.management.ManagementFactory + +import scala.collection.mutable.ArrayBuffer + +import it.unimi.dsi.fastutil.ints.IntOpenHashSet +import org.apache.spark.Logging + +/** + * Estimates the sizes of Java objects (number of bytes of memory they occupy), for use in + * memory-aware caches. + * + * Based on the following JavaWorld article: + * http://www.javaworld.com/javaworld/javaqa/2003-12/02-qa-1226-sizeof.html + */ +private[spark] object SizeEstimator extends Logging { + + // Sizes of primitive types + private val BYTE_SIZE = 1 + private val BOOLEAN_SIZE = 1 + private val CHAR_SIZE = 2 + private val SHORT_SIZE = 2 + private val INT_SIZE = 4 + private val LONG_SIZE = 8 + private val FLOAT_SIZE = 4 + private val DOUBLE_SIZE = 8 + + // Alignment boundary for objects + // TODO: Is this arch dependent ? + private val ALIGN_SIZE = 8 + + // A cache of ClassInfo objects for each class + private val classInfos = new ConcurrentHashMap[Class[_], ClassInfo] + + // Object and pointer sizes are arch dependent + private var is64bit = false + + // Size of an object reference + // Based on https://wikis.oracle.com/display/HotSpotInternals/CompressedOops + private var isCompressedOops = false + private var pointerSize = 4 + + // Minimum size of a java.lang.Object + private var objectSize = 8 + + initialize() + + // Sets object size, pointer size based on architecture and CompressedOops settings + // from the JVM. + private def initialize() { + is64bit = System.getProperty("os.arch").contains("64") + isCompressedOops = getIsCompressedOops + + objectSize = if (!is64bit) 8 else { + if(!isCompressedOops) { + 16 + } else { + 12 + } + } + pointerSize = if (is64bit && !isCompressedOops) 8 else 4 + classInfos.clear() + classInfos.put(classOf[Object], new ClassInfo(objectSize, Nil)) + } + + private def getIsCompressedOops : Boolean = { + if (System.getProperty("spark.test.useCompressedOops") != null) { + return System.getProperty("spark.test.useCompressedOops").toBoolean + } + + try { + val hotSpotMBeanName = "com.sun.management:type=HotSpotDiagnostic" + val server = ManagementFactory.getPlatformMBeanServer() + + // NOTE: This should throw an exception in non-Sun JVMs + val hotSpotMBeanClass = Class.forName("com.sun.management.HotSpotDiagnosticMXBean") + val getVMMethod = hotSpotMBeanClass.getDeclaredMethod("getVMOption", + Class.forName("java.lang.String")) + + val bean = ManagementFactory.newPlatformMXBeanProxy(server, + hotSpotMBeanName, hotSpotMBeanClass) + // TODO: We could use reflection on the VMOption returned ? + return getVMMethod.invoke(bean, "UseCompressedOops").toString.contains("true") + } catch { + case e: Exception => { + // Guess whether they've enabled UseCompressedOops based on whether maxMemory < 32 GB + val guess = Runtime.getRuntime.maxMemory < (32L*1024*1024*1024) + val guessInWords = if (guess) "yes" else "not" + logWarning("Failed to check whether UseCompressedOops is set; assuming " + guessInWords) + return guess + } + } + } + + /** + * The state of an ongoing size estimation. Contains a stack of objects to visit as well as an + * IdentityHashMap of visited objects, and provides utility methods for enqueueing new objects + * to visit. + */ + private class SearchState(val visited: IdentityHashMap[AnyRef, AnyRef]) { + val stack = new ArrayBuffer[AnyRef] + var size = 0L + + def enqueue(obj: AnyRef) { + if (obj != null && !visited.containsKey(obj)) { + visited.put(obj, null) + stack += obj + } + } + + def isFinished(): Boolean = stack.isEmpty + + def dequeue(): AnyRef = { + val elem = stack.last + stack.trimEnd(1) + return elem + } + } + + /** + * Cached information about each class. We remember two things: the "shell size" of the class + * (size of all non-static fields plus the java.lang.Object size), and any fields that are + * pointers to objects. + */ + private class ClassInfo( + val shellSize: Long, + val pointerFields: List[Field]) {} + + def estimate(obj: AnyRef): Long = estimate(obj, new IdentityHashMap[AnyRef, AnyRef]) + + private def estimate(obj: AnyRef, visited: IdentityHashMap[AnyRef, AnyRef]): Long = { + val state = new SearchState(visited) + state.enqueue(obj) + while (!state.isFinished) { + visitSingleObject(state.dequeue(), state) + } + return state.size + } + + private def visitSingleObject(obj: AnyRef, state: SearchState) { + val cls = obj.getClass + if (cls.isArray) { + visitArray(obj, cls, state) + } else if (obj.isInstanceOf[ClassLoader] || obj.isInstanceOf[Class[_]]) { + // Hadoop JobConfs created in the interpreter have a ClassLoader, which greatly confuses + // the size estimator since it references the whole REPL. Do nothing in this case. In + // general all ClassLoaders and Classes will be shared between objects anyway. + } else { + val classInfo = getClassInfo(cls) + state.size += classInfo.shellSize + for (field <- classInfo.pointerFields) { + state.enqueue(field.get(obj)) + } + } + } + + // Estimat the size of arrays larger than ARRAY_SIZE_FOR_SAMPLING by sampling. + private val ARRAY_SIZE_FOR_SAMPLING = 200 + private val ARRAY_SAMPLE_SIZE = 100 // should be lower than ARRAY_SIZE_FOR_SAMPLING + + private def visitArray(array: AnyRef, cls: Class[_], state: SearchState) { + val length = JArray.getLength(array) + val elementClass = cls.getComponentType + + // Arrays have object header and length field which is an integer + var arrSize: Long = alignSize(objectSize + INT_SIZE) + + if (elementClass.isPrimitive) { + arrSize += alignSize(length * primitiveSize(elementClass)) + state.size += arrSize + } else { + arrSize += alignSize(length * pointerSize) + state.size += arrSize + + if (length <= ARRAY_SIZE_FOR_SAMPLING) { + for (i <- 0 until length) { + state.enqueue(JArray.get(array, i)) + } + } else { + // Estimate the size of a large array by sampling elements without replacement. + var size = 0.0 + val rand = new Random(42) + val drawn = new IntOpenHashSet(ARRAY_SAMPLE_SIZE) + for (i <- 0 until ARRAY_SAMPLE_SIZE) { + var index = 0 + do { + index = rand.nextInt(length) + } while (drawn.contains(index)) + drawn.add(index) + val elem = JArray.get(array, index) + size += SizeEstimator.estimate(elem, state.visited) + } + state.size += ((length / (ARRAY_SAMPLE_SIZE * 1.0)) * size).toLong + } + } + } + + private def primitiveSize(cls: Class[_]): Long = { + if (cls == classOf[Byte]) + BYTE_SIZE + else if (cls == classOf[Boolean]) + BOOLEAN_SIZE + else if (cls == classOf[Char]) + CHAR_SIZE + else if (cls == classOf[Short]) + SHORT_SIZE + else if (cls == classOf[Int]) + INT_SIZE + else if (cls == classOf[Long]) + LONG_SIZE + else if (cls == classOf[Float]) + FLOAT_SIZE + else if (cls == classOf[Double]) + DOUBLE_SIZE + else throw new IllegalArgumentException( + "Non-primitive class " + cls + " passed to primitiveSize()") + } + + /** + * Get or compute the ClassInfo for a given class. + */ + private def getClassInfo(cls: Class[_]): ClassInfo = { + // Check whether we've already cached a ClassInfo for this class + val info = classInfos.get(cls) + if (info != null) { + return info + } + + val parent = getClassInfo(cls.getSuperclass) + var shellSize = parent.shellSize + var pointerFields = parent.pointerFields + + for (field <- cls.getDeclaredFields) { + if (!Modifier.isStatic(field.getModifiers)) { + val fieldClass = field.getType + if (fieldClass.isPrimitive) { + shellSize += primitiveSize(fieldClass) + } else { + field.setAccessible(true) // Enable future get()'s on this field + shellSize += pointerSize + pointerFields = field :: pointerFields + } + } + } + + shellSize = alignSize(shellSize) + + // Create and cache a new ClassInfo + val newInfo = new ClassInfo(shellSize, pointerFields) + classInfos.put(cls, newInfo) + return newInfo + } + + private def alignSize(size: Long): Long = { + val rem = size % ALIGN_SIZE + return if (rem == 0) size else (size + ALIGN_SIZE - rem) + } +} diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala new file mode 100644 index 0000000000..bb47fc0a2c --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -0,0 +1,781 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +import java.io._ +import java.net.{InetAddress, URL, URI, NetworkInterface, Inet4Address, ServerSocket} +import java.util.{Locale, Random, UUID} +import java.util.concurrent.{ConcurrentHashMap, Executors, ThreadFactory, ThreadPoolExecutor} +import java.util.regex.Pattern + +import scala.collection.Map +import scala.collection.mutable.{ArrayBuffer, HashMap} +import scala.collection.JavaConversions._ +import scala.io.Source + +import com.google.common.io.Files +import com.google.common.util.concurrent.ThreadFactoryBuilder + +import org.apache.hadoop.fs.{Path, FileSystem, FileUtil} + +import org.apache.spark.serializer.{DeserializationStream, SerializationStream, SerializerInstance} +import org.apache.spark.deploy.SparkHadoopUtil +import java.nio.ByteBuffer +import org.apache.spark.{SparkEnv, SparkException, Logging} + + +/** + * Various utility methods used by Spark. + */ +private[spark] object Utils extends Logging { + + /** Serialize an object using Java serialization */ + def serialize[T](o: T): Array[Byte] = { + val bos = new ByteArrayOutputStream() + val oos = new ObjectOutputStream(bos) + oos.writeObject(o) + oos.close() + return bos.toByteArray + } + + /** Deserialize an object using Java serialization */ + def deserialize[T](bytes: Array[Byte]): T = { + val bis = new ByteArrayInputStream(bytes) + val ois = new ObjectInputStream(bis) + return ois.readObject.asInstanceOf[T] + } + + /** Deserialize an object using Java serialization and the given ClassLoader */ + def deserialize[T](bytes: Array[Byte], loader: ClassLoader): T = { + val bis = new ByteArrayInputStream(bytes) + val ois = new ObjectInputStream(bis) { + override def resolveClass(desc: ObjectStreamClass) = + Class.forName(desc.getName, false, loader) + } + return ois.readObject.asInstanceOf[T] + } + + /** Serialize via nested stream using specific serializer */ + def serializeViaNestedStream(os: OutputStream, ser: SerializerInstance)(f: SerializationStream => Unit) = { + val osWrapper = ser.serializeStream(new OutputStream { + def write(b: Int) = os.write(b) + + override def write(b: Array[Byte], off: Int, len: Int) = os.write(b, off, len) + }) + try { + f(osWrapper) + } finally { + osWrapper.close() + } + } + + /** Deserialize via nested stream using specific serializer */ + def deserializeViaNestedStream(is: InputStream, ser: SerializerInstance)(f: DeserializationStream => Unit) = { + val isWrapper = ser.deserializeStream(new InputStream { + def read(): Int = is.read() + + override def read(b: Array[Byte], off: Int, len: Int): Int = is.read(b, off, len) + }) + try { + f(isWrapper) + } finally { + isWrapper.close() + } + } + + /** + * Primitive often used when writing {@link java.nio.ByteBuffer} to {@link java.io.DataOutput}. + */ + def writeByteBuffer(bb: ByteBuffer, out: ObjectOutput) = { + if (bb.hasArray) { + out.write(bb.array(), bb.arrayOffset() + bb.position(), bb.remaining()) + } else { + val bbval = new Array[Byte](bb.remaining()) + bb.get(bbval) + out.write(bbval) + } + } + + def isAlpha(c: Char): Boolean = { + (c >= 'A' && c <= 'Z') || (c >= 'a' && c <= 'z') + } + + /** Split a string into words at non-alphabetic characters */ + def splitWords(s: String): Seq[String] = { + val buf = new ArrayBuffer[String] + var i = 0 + while (i < s.length) { + var j = i + while (j < s.length && isAlpha(s.charAt(j))) { + j += 1 + } + if (j > i) { + buf += s.substring(i, j) + } + i = j + while (i < s.length && !isAlpha(s.charAt(i))) { + i += 1 + } + } + return buf + } + + private val shutdownDeletePaths = new collection.mutable.HashSet[String]() + + // Register the path to be deleted via shutdown hook + def registerShutdownDeleteDir(file: File) { + val absolutePath = file.getAbsolutePath() + shutdownDeletePaths.synchronized { + shutdownDeletePaths += absolutePath + } + } + + // Is the path already registered to be deleted via a shutdown hook ? + def hasShutdownDeleteDir(file: File): Boolean = { + val absolutePath = file.getAbsolutePath() + shutdownDeletePaths.synchronized { + shutdownDeletePaths.contains(absolutePath) + } + } + + // Note: if file is child of some registered path, while not equal to it, then return true; + // else false. This is to ensure that two shutdown hooks do not try to delete each others + // paths - resulting in IOException and incomplete cleanup. + def hasRootAsShutdownDeleteDir(file: File): Boolean = { + val absolutePath = file.getAbsolutePath() + val retval = shutdownDeletePaths.synchronized { + shutdownDeletePaths.find { path => + !absolutePath.equals(path) && absolutePath.startsWith(path) + }.isDefined + } + if (retval) { + logInfo("path = " + file + ", already present as root for deletion.") + } + retval + } + + /** Create a temporary directory inside the given parent directory */ + def createTempDir(root: String = System.getProperty("java.io.tmpdir")): File = { + var attempts = 0 + val maxAttempts = 10 + var dir: File = null + while (dir == null) { + attempts += 1 + if (attempts > maxAttempts) { + throw new IOException("Failed to create a temp directory (under " + root + ") after " + + maxAttempts + " attempts!") + } + try { + dir = new File(root, "spark-" + UUID.randomUUID.toString) + if (dir.exists() || !dir.mkdirs()) { + dir = null + } + } catch { case e: IOException => ; } + } + + registerShutdownDeleteDir(dir) + + // Add a shutdown hook to delete the temp dir when the JVM exits + Runtime.getRuntime.addShutdownHook(new Thread("delete Spark temp dir " + dir) { + override def run() { + // Attempt to delete if some patch which is parent of this is not already registered. + if (! hasRootAsShutdownDeleteDir(dir)) Utils.deleteRecursively(dir) + } + }) + dir + } + + /** Copy all data from an InputStream to an OutputStream */ + def copyStream(in: InputStream, + out: OutputStream, + closeStreams: Boolean = false) + { + val buf = new Array[Byte](8192) + var n = 0 + while (n != -1) { + n = in.read(buf) + if (n != -1) { + out.write(buf, 0, n) + } + } + if (closeStreams) { + in.close() + out.close() + } + } + + /** + * Download a file requested by the executor. Supports fetching the file in a variety of ways, + * including HTTP, HDFS and files on a standard filesystem, based on the URL parameter. + * + * Throws SparkException if the target file already exists and has different contents than + * the requested file. + */ + def fetchFile(url: String, targetDir: File) { + val filename = url.split("/").last + val tempDir = getLocalDir + val tempFile = File.createTempFile("fetchFileTemp", null, new File(tempDir)) + val targetFile = new File(targetDir, filename) + val uri = new URI(url) + uri.getScheme match { + case "http" | "https" | "ftp" => + logInfo("Fetching " + url + " to " + tempFile) + val in = new URL(url).openStream() + val out = new FileOutputStream(tempFile) + Utils.copyStream(in, out, true) + if (targetFile.exists && !Files.equal(tempFile, targetFile)) { + tempFile.delete() + throw new SparkException( + "File " + targetFile + " exists and does not match contents of" + " " + url) + } else { + Files.move(tempFile, targetFile) + } + case "file" | null => + // In the case of a local file, copy the local file to the target directory. + // Note the difference between uri vs url. + val sourceFile = if (uri.isAbsolute) new File(uri) else new File(url) + if (targetFile.exists) { + // If the target file already exists, warn the user if + if (!Files.equal(sourceFile, targetFile)) { + throw new SparkException( + "File " + targetFile + " exists and does not match contents of" + " " + url) + } else { + // Do nothing if the file contents are the same, i.e. this file has been copied + // previously. + logInfo(sourceFile.getAbsolutePath + " has been previously copied to " + + targetFile.getAbsolutePath) + } + } else { + // The file does not exist in the target directory. Copy it there. + logInfo("Copying " + sourceFile.getAbsolutePath + " to " + targetFile.getAbsolutePath) + Files.copy(sourceFile, targetFile) + } + case _ => + // Use the Hadoop filesystem library, which supports file://, hdfs://, s3://, and others + val env = SparkEnv.get + val uri = new URI(url) + val conf = env.hadoop.newConfiguration() + val fs = FileSystem.get(uri, conf) + val in = fs.open(new Path(uri)) + val out = new FileOutputStream(tempFile) + Utils.copyStream(in, out, true) + if (targetFile.exists && !Files.equal(tempFile, targetFile)) { + tempFile.delete() + throw new SparkException("File " + targetFile + " exists and does not match contents of" + + " " + url) + } else { + Files.move(tempFile, targetFile) + } + } + // Decompress the file if it's a .tar or .tar.gz + if (filename.endsWith(".tar.gz") || filename.endsWith(".tgz")) { + logInfo("Untarring " + filename) + Utils.execute(Seq("tar", "-xzf", filename), targetDir) + } else if (filename.endsWith(".tar")) { + logInfo("Untarring " + filename) + Utils.execute(Seq("tar", "-xf", filename), targetDir) + } + // Make the file executable - That's necessary for scripts + FileUtil.chmod(targetFile.getAbsolutePath, "a+x") + } + + /** + * Get a temporary directory using Spark's spark.local.dir property, if set. This will always + * return a single directory, even though the spark.local.dir property might be a list of + * multiple paths. + */ + def getLocalDir: String = { + System.getProperty("spark.local.dir", System.getProperty("java.io.tmpdir")).split(',')(0) + } + + /** + * Shuffle the elements of a collection into a random order, returning the + * result in a new collection. Unlike scala.util.Random.shuffle, this method + * uses a local random number generator, avoiding inter-thread contention. + */ + def randomize[T: ClassManifest](seq: TraversableOnce[T]): Seq[T] = { + randomizeInPlace(seq.toArray) + } + + /** + * Shuffle the elements of an array into a random order, modifying the + * original array. Returns the original array. + */ + def randomizeInPlace[T](arr: Array[T], rand: Random = new Random): Array[T] = { + for (i <- (arr.length - 1) to 1 by -1) { + val j = rand.nextInt(i) + val tmp = arr(j) + arr(j) = arr(i) + arr(i) = tmp + } + arr + } + + /** + * Get the local host's IP address in dotted-quad format (e.g. 1.2.3.4). + * Note, this is typically not used from within core spark. + */ + lazy val localIpAddress: String = findLocalIpAddress() + lazy val localIpAddressHostname: String = getAddressHostName(localIpAddress) + + private def findLocalIpAddress(): String = { + val defaultIpOverride = System.getenv("SPARK_LOCAL_IP") + if (defaultIpOverride != null) { + defaultIpOverride + } else { + val address = InetAddress.getLocalHost + if (address.isLoopbackAddress) { + // Address resolves to something like 127.0.1.1, which happens on Debian; try to find + // a better address using the local network interfaces + for (ni <- NetworkInterface.getNetworkInterfaces) { + for (addr <- ni.getInetAddresses if !addr.isLinkLocalAddress && + !addr.isLoopbackAddress && addr.isInstanceOf[Inet4Address]) { + // We've found an address that looks reasonable! + logWarning("Your hostname, " + InetAddress.getLocalHost.getHostName + " resolves to" + + " a loopback address: " + address.getHostAddress + "; using " + addr.getHostAddress + + " instead (on interface " + ni.getName + ")") + logWarning("Set SPARK_LOCAL_IP if you need to bind to another address") + return addr.getHostAddress + } + } + logWarning("Your hostname, " + InetAddress.getLocalHost.getHostName + " resolves to" + + " a loopback address: " + address.getHostAddress + ", but we couldn't find any" + + " external IP address!") + logWarning("Set SPARK_LOCAL_IP if you need to bind to another address") + } + address.getHostAddress + } + } + + private var customHostname: Option[String] = None + + /** + * Allow setting a custom host name because when we run on Mesos we need to use the same + * hostname it reports to the master. + */ + def setCustomHostname(hostname: String) { + // DEBUG code + Utils.checkHost(hostname) + customHostname = Some(hostname) + } + + /** + * Get the local machine's hostname. + */ + def localHostName(): String = { + customHostname.getOrElse(localIpAddressHostname) + } + + def getAddressHostName(address: String): String = { + InetAddress.getByName(address).getHostName + } + + def localHostPort(): String = { + val retval = System.getProperty("spark.hostPort", null) + if (retval == null) { + logErrorWithStack("spark.hostPort not set but invoking localHostPort") + return localHostName() + } + + retval + } + + def checkHost(host: String, message: String = "") { + assert(host.indexOf(':') == -1, message) + } + + def checkHostPort(hostPort: String, message: String = "") { + assert(hostPort.indexOf(':') != -1, message) + } + + // Used by DEBUG code : remove when all testing done + def logErrorWithStack(msg: String) { + try { throw new Exception } catch { case ex: Exception => { logError(msg, ex) } } + } + + // Typically, this will be of order of number of nodes in cluster + // If not, we should change it to LRUCache or something. + private val hostPortParseResults = new ConcurrentHashMap[String, (String, Int)]() + + def parseHostPort(hostPort: String): (String, Int) = { + { + // Check cache first. + var cached = hostPortParseResults.get(hostPort) + if (cached != null) return cached + } + + val indx: Int = hostPort.lastIndexOf(':') + // This is potentially broken - when dealing with ipv6 addresses for example, sigh ... + // but then hadoop does not support ipv6 right now. + // For now, we assume that if port exists, then it is valid - not check if it is an int > 0 + if (-1 == indx) { + val retval = (hostPort, 0) + hostPortParseResults.put(hostPort, retval) + return retval + } + + val retval = (hostPort.substring(0, indx).trim(), hostPort.substring(indx + 1).trim().toInt) + hostPortParseResults.putIfAbsent(hostPort, retval) + hostPortParseResults.get(hostPort) + } + + private[spark] val daemonThreadFactory: ThreadFactory = + new ThreadFactoryBuilder().setDaemon(true).build() + + /** + * Wrapper over newCachedThreadPool. + */ + def newDaemonCachedThreadPool(): ThreadPoolExecutor = + Executors.newCachedThreadPool(daemonThreadFactory).asInstanceOf[ThreadPoolExecutor] + + /** + * Return the string to tell how long has passed in seconds. The passing parameter should be in + * millisecond. + */ + def getUsedTimeMs(startTimeMs: Long): String = { + return " " + (System.currentTimeMillis - startTimeMs) + " ms" + } + + /** + * Wrapper over newFixedThreadPool. + */ + def newDaemonFixedThreadPool(nThreads: Int): ThreadPoolExecutor = + Executors.newFixedThreadPool(nThreads, daemonThreadFactory).asInstanceOf[ThreadPoolExecutor] + + /** + * Delete a file or directory and its contents recursively. + */ + def deleteRecursively(file: File) { + if (file.isDirectory) { + for (child <- file.listFiles()) { + deleteRecursively(child) + } + } + if (!file.delete()) { + throw new IOException("Failed to delete: " + file) + } + } + + /** + * Convert a Java memory parameter passed to -Xmx (such as 300m or 1g) to a number of megabytes. + * This is used to figure out how much memory to claim from Mesos based on the SPARK_MEM + * environment variable. + */ + def memoryStringToMb(str: String): Int = { + val lower = str.toLowerCase + if (lower.endsWith("k")) { + (lower.substring(0, lower.length-1).toLong / 1024).toInt + } else if (lower.endsWith("m")) { + lower.substring(0, lower.length-1).toInt + } else if (lower.endsWith("g")) { + lower.substring(0, lower.length-1).toInt * 1024 + } else if (lower.endsWith("t")) { + lower.substring(0, lower.length-1).toInt * 1024 * 1024 + } else {// no suffix, so it's just a number in bytes + (lower.toLong / 1024 / 1024).toInt + } + } + + /** + * Convert a quantity in bytes to a human-readable string such as "4.0 MB". + */ + def bytesToString(size: Long): String = { + val TB = 1L << 40 + val GB = 1L << 30 + val MB = 1L << 20 + val KB = 1L << 10 + + val (value, unit) = { + if (size >= 2*TB) { + (size.asInstanceOf[Double] / TB, "TB") + } else if (size >= 2*GB) { + (size.asInstanceOf[Double] / GB, "GB") + } else if (size >= 2*MB) { + (size.asInstanceOf[Double] / MB, "MB") + } else if (size >= 2*KB) { + (size.asInstanceOf[Double] / KB, "KB") + } else { + (size.asInstanceOf[Double], "B") + } + } + "%.1f %s".formatLocal(Locale.US, value, unit) + } + + /** + * Returns a human-readable string representing a duration such as "35ms" + */ + def msDurationToString(ms: Long): String = { + val second = 1000 + val minute = 60 * second + val hour = 60 * minute + + ms match { + case t if t < second => + "%d ms".format(t) + case t if t < minute => + "%.1f s".format(t.toFloat / second) + case t if t < hour => + "%.1f m".format(t.toFloat / minute) + case t => + "%.2f h".format(t.toFloat / hour) + } + } + + /** + * Convert a quantity in megabytes to a human-readable string such as "4.0 MB". + */ + def megabytesToString(megabytes: Long): String = { + bytesToString(megabytes * 1024L * 1024L) + } + + /** + * Execute a command in the given working directory, throwing an exception if it completes + * with an exit code other than 0. + */ + def execute(command: Seq[String], workingDir: File) { + val process = new ProcessBuilder(command: _*) + .directory(workingDir) + .redirectErrorStream(true) + .start() + new Thread("read stdout for " + command(0)) { + override def run() { + for (line <- Source.fromInputStream(process.getInputStream).getLines) { + System.err.println(line) + } + } + }.start() + val exitCode = process.waitFor() + if (exitCode != 0) { + throw new SparkException("Process " + command + " exited with code " + exitCode) + } + } + + /** + * Execute a command in the current working directory, throwing an exception if it completes + * with an exit code other than 0. + */ + def execute(command: Seq[String]) { + execute(command, new File(".")) + } + + /** + * Execute a command and get its output, throwing an exception if it yields a code other than 0. + */ + def executeAndGetOutput(command: Seq[String], workingDir: File = new File("."), + extraEnvironment: Map[String, String] = Map.empty): String = { + val builder = new ProcessBuilder(command: _*) + .directory(workingDir) + val environment = builder.environment() + for ((key, value) <- extraEnvironment) { + environment.put(key, value) + } + val process = builder.start() + new Thread("read stderr for " + command(0)) { + override def run() { + for (line <- Source.fromInputStream(process.getErrorStream).getLines) { + System.err.println(line) + } + } + }.start() + val output = new StringBuffer + val stdoutThread = new Thread("read stdout for " + command(0)) { + override def run() { + for (line <- Source.fromInputStream(process.getInputStream).getLines) { + output.append(line) + } + } + } + stdoutThread.start() + val exitCode = process.waitFor() + stdoutThread.join() // Wait for it to finish reading output + if (exitCode != 0) { + throw new SparkException("Process " + command + " exited with code " + exitCode) + } + output.toString + } + + /** + * A regular expression to match classes of the "core" Spark API that we want to skip when + * finding the call site of a method. + */ + private val SPARK_CLASS_REGEX = """^spark(\.api\.java)?(\.rdd)?\.[A-Z]""".r + + private[spark] class CallSiteInfo(val lastSparkMethod: String, val firstUserFile: String, + val firstUserLine: Int, val firstUserClass: String) + + /** + * When called inside a class in the spark package, returns the name of the user code class + * (outside the spark package) that called into Spark, as well as which Spark method they called. + * This is used, for example, to tell users where in their code each RDD got created. + */ + def getCallSiteInfo: CallSiteInfo = { + val trace = Thread.currentThread.getStackTrace().filter( el => + (!el.getMethodName.contains("getStackTrace"))) + + // Keep crawling up the stack trace until we find the first function not inside of the spark + // package. We track the last (shallowest) contiguous Spark method. This might be an RDD + // transformation, a SparkContext function (such as parallelize), or anything else that leads + // to instantiation of an RDD. We also track the first (deepest) user method, file, and line. + var lastSparkMethod = "" + var firstUserFile = "" + var firstUserLine = 0 + var finished = false + var firstUserClass = "" + + for (el <- trace) { + if (!finished) { + if (SPARK_CLASS_REGEX.findFirstIn(el.getClassName) != None) { + lastSparkMethod = if (el.getMethodName == "") { + // Spark method is a constructor; get its class name + el.getClassName.substring(el.getClassName.lastIndexOf('.') + 1) + } else { + el.getMethodName + } + } + else { + firstUserLine = el.getLineNumber + firstUserFile = el.getFileName + firstUserClass = el.getClassName + finished = true + } + } + } + new CallSiteInfo(lastSparkMethod, firstUserFile, firstUserLine, firstUserClass) + } + + def formatSparkCallSite = { + val callSiteInfo = getCallSiteInfo + "%s at %s:%s".format(callSiteInfo.lastSparkMethod, callSiteInfo.firstUserFile, + callSiteInfo.firstUserLine) + } + + /** Return a string containing part of a file from byte 'start' to 'end'. */ + def offsetBytes(path: String, start: Long, end: Long): String = { + val file = new File(path) + val length = file.length() + val effectiveEnd = math.min(length, end) + val effectiveStart = math.max(0, start) + val buff = new Array[Byte]((effectiveEnd-effectiveStart).toInt) + val stream = new FileInputStream(file) + + stream.skip(effectiveStart) + stream.read(buff) + stream.close() + Source.fromBytes(buff).mkString + } + + /** + * Clone an object using a Spark serializer. + */ + def clone[T](value: T, serializer: SerializerInstance): T = { + serializer.deserialize[T](serializer.serialize(value)) + } + + /** + * Detect whether this thread might be executing a shutdown hook. Will always return true if + * the current thread is a running a shutdown hook but may spuriously return true otherwise (e.g. + * if System.exit was just called by a concurrent thread). + * + * Currently, this detects whether the JVM is shutting down by Runtime#addShutdownHook throwing + * an IllegalStateException. + */ + def inShutdown(): Boolean = { + try { + val hook = new Thread { + override def run() {} + } + Runtime.getRuntime.addShutdownHook(hook) + Runtime.getRuntime.removeShutdownHook(hook) + } catch { + case ise: IllegalStateException => return true + } + return false + } + + def isSpace(c: Char): Boolean = { + " \t\r\n".indexOf(c) != -1 + } + + /** + * Split a string of potentially quoted arguments from the command line the way that a shell + * would do it to determine arguments to a command. For example, if the string is 'a "b c" d', + * then it would be parsed as three arguments: 'a', 'b c' and 'd'. + */ + def splitCommandString(s: String): Seq[String] = { + val buf = new ArrayBuffer[String] + var inWord = false + var inSingleQuote = false + var inDoubleQuote = false + var curWord = new StringBuilder + def endWord() { + buf += curWord.toString + curWord.clear() + } + var i = 0 + while (i < s.length) { + var nextChar = s.charAt(i) + if (inDoubleQuote) { + if (nextChar == '"') { + inDoubleQuote = false + } else if (nextChar == '\\') { + if (i < s.length - 1) { + // Append the next character directly, because only " and \ may be escaped in + // double quotes after the shell's own expansion + curWord.append(s.charAt(i + 1)) + i += 1 + } + } else { + curWord.append(nextChar) + } + } else if (inSingleQuote) { + if (nextChar == '\'') { + inSingleQuote = false + } else { + curWord.append(nextChar) + } + // Backslashes are not treated specially in single quotes + } else if (nextChar == '"') { + inWord = true + inDoubleQuote = true + } else if (nextChar == '\'') { + inWord = true + inSingleQuote = true + } else if (!isSpace(nextChar)) { + curWord.append(nextChar) + inWord = true + } else if (inWord && isSpace(nextChar)) { + endWord() + inWord = false + } + i += 1 + } + if (inWord || inDoubleQuote || inSingleQuote) { + endWord() + } + return buf + } + + /* Calculates 'x' modulo 'mod', takes to consideration sign of x, + * i.e. if 'x' is negative, than 'x' % 'mod' is negative too + * so function return (x % mod) + mod in that case. + */ + def nonNegativeMod(x: Int, mod: Int): Int = { + val rawMod = x % mod + rawMod + (if (rawMod < 0) mod else 0) + } +} diff --git a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala index 23b14f4245..d9103aebb7 100644 --- a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala +++ b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala @@ -22,6 +22,7 @@ import java.io.File import org.apache.spark.rdd._ import org.apache.spark.SparkContext._ import storage.StorageLevel +import org.apache.spark.util.Utils class CheckpointSuite extends FunSuite with LocalSparkContext with Logging { initLogging() diff --git a/core/src/test/scala/org/apache/spark/ClosureCleanerSuite.scala b/core/src/test/scala/org/apache/spark/ClosureCleanerSuite.scala deleted file mode 100644 index 8494899b98..0000000000 --- a/core/src/test/scala/org/apache/spark/ClosureCleanerSuite.scala +++ /dev/null @@ -1,146 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark - -import java.io.NotSerializableException - -import org.scalatest.FunSuite -import org.apache.spark.LocalSparkContext._ -import SparkContext._ - -class ClosureCleanerSuite extends FunSuite { - test("closures inside an object") { - assert(TestObject.run() === 30) // 6 + 7 + 8 + 9 - } - - test("closures inside a class") { - val obj = new TestClass - assert(obj.run() === 30) // 6 + 7 + 8 + 9 - } - - test("closures inside a class with no default constructor") { - val obj = new TestClassWithoutDefaultConstructor(5) - assert(obj.run() === 30) // 6 + 7 + 8 + 9 - } - - test("closures that don't use fields of the outer class") { - val obj = new TestClassWithoutFieldAccess - assert(obj.run() === 30) // 6 + 7 + 8 + 9 - } - - test("nested closures inside an object") { - assert(TestObjectWithNesting.run() === 96) // 4 * (1+2+3+4) + 4 * (1+2+3+4) + 16 * 1 - } - - test("nested closures inside a class") { - val obj = new TestClassWithNesting(1) - assert(obj.run() === 96) // 4 * (1+2+3+4) + 4 * (1+2+3+4) + 16 * 1 - } -} - -// A non-serializable class we create in closures to make sure that we aren't -// keeping references to unneeded variables from our outer closures. -class NonSerializable {} - -object TestObject { - def run(): Int = { - var nonSer = new NonSerializable - var x = 5 - return withSpark(new SparkContext("local", "test")) { sc => - val nums = sc.parallelize(Array(1, 2, 3, 4)) - nums.map(_ + x).reduce(_ + _) - } - } -} - -class TestClass extends Serializable { - var x = 5 - - def getX = x - - def run(): Int = { - var nonSer = new NonSerializable - return withSpark(new SparkContext("local", "test")) { sc => - val nums = sc.parallelize(Array(1, 2, 3, 4)) - nums.map(_ + getX).reduce(_ + _) - } - } -} - -class TestClassWithoutDefaultConstructor(x: Int) extends Serializable { - def getX = x - - def run(): Int = { - var nonSer = new NonSerializable - return withSpark(new SparkContext("local", "test")) { sc => - val nums = sc.parallelize(Array(1, 2, 3, 4)) - nums.map(_ + getX).reduce(_ + _) - } - } -} - -// This class is not serializable, but we aren't using any of its fields in our -// closures, so they won't have a $outer pointing to it and should still work. -class TestClassWithoutFieldAccess { - var nonSer = new NonSerializable - - def run(): Int = { - var nonSer2 = new NonSerializable - var x = 5 - return withSpark(new SparkContext("local", "test")) { sc => - val nums = sc.parallelize(Array(1, 2, 3, 4)) - nums.map(_ + x).reduce(_ + _) - } - } -} - - -object TestObjectWithNesting { - def run(): Int = { - var nonSer = new NonSerializable - var answer = 0 - return withSpark(new SparkContext("local", "test")) { sc => - val nums = sc.parallelize(Array(1, 2, 3, 4)) - var y = 1 - for (i <- 1 to 4) { - var nonSer2 = new NonSerializable - var x = i - answer += nums.map(_ + x + y).reduce(_ + _) - } - answer - } - } -} - -class TestClassWithNesting(val y: Int) extends Serializable { - def getY = y - - def run(): Int = { - var nonSer = new NonSerializable - var answer = 0 - return withSpark(new SparkContext("local", "test")) { sc => - val nums = sc.parallelize(Array(1, 2, 3, 4)) - for (i <- 1 to 4) { - var nonSer2 = new NonSerializable - var x = i - answer += nums.map(_ + x + getY).reduce(_ + _) - } - answer - } - } -} diff --git a/core/src/test/scala/org/apache/spark/DriverSuite.scala b/core/src/test/scala/org/apache/spark/DriverSuite.scala index b08aad1a6f..01a72d8401 100644 --- a/core/src/test/scala/org/apache/spark/DriverSuite.scala +++ b/core/src/test/scala/org/apache/spark/DriverSuite.scala @@ -26,6 +26,7 @@ import org.scalatest.FunSuite import org.scalatest.concurrent.Timeouts import org.scalatest.prop.TableDrivenPropertyChecks._ import org.scalatest.time.SpanSugar._ +import org.apache.spark.util.Utils class DriverSuite extends FunSuite with Timeouts { test("driver should exit after finishing") { diff --git a/core/src/test/scala/org/apache/spark/FailureSuite.scala b/core/src/test/scala/org/apache/spark/FailureSuite.scala index ee89a7a387..af448fcb37 100644 --- a/core/src/test/scala/org/apache/spark/FailureSuite.scala +++ b/core/src/test/scala/org/apache/spark/FailureSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark import org.scalatest.FunSuite import SparkContext._ +import org.apache.spark.util.NonSerializable // Common state shared by FailureSuite-launched tasks. We use a global object // for this because any local variables used in the task closures will rightfully diff --git a/core/src/test/scala/org/apache/spark/KryoSerializerSuite.scala b/core/src/test/scala/org/apache/spark/KryoSerializerSuite.scala deleted file mode 100644 index d7b23c93fe..0000000000 --- a/core/src/test/scala/org/apache/spark/KryoSerializerSuite.scala +++ /dev/null @@ -1,208 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark - -import scala.collection.mutable - -import org.scalatest.FunSuite -import com.esotericsoftware.kryo._ - -import KryoTest._ - -class KryoSerializerSuite extends FunSuite with SharedSparkContext { - test("basic types") { - val ser = (new KryoSerializer).newInstance() - def check[T](t: T) { - assert(ser.deserialize[T](ser.serialize(t)) === t) - } - check(1) - check(1L) - check(1.0f) - check(1.0) - check(1.toByte) - check(1.toShort) - check("") - check("hello") - check(Integer.MAX_VALUE) - check(Integer.MIN_VALUE) - check(java.lang.Long.MAX_VALUE) - check(java.lang.Long.MIN_VALUE) - check[String](null) - check(Array(1, 2, 3)) - check(Array(1L, 2L, 3L)) - check(Array(1.0, 2.0, 3.0)) - check(Array(1.0f, 2.9f, 3.9f)) - check(Array("aaa", "bbb", "ccc")) - check(Array("aaa", "bbb", null)) - check(Array(true, false, true)) - check(Array('a', 'b', 'c')) - check(Array[Int]()) - check(Array(Array("1", "2"), Array("1", "2", "3", "4"))) - } - - test("pairs") { - val ser = (new KryoSerializer).newInstance() - def check[T](t: T) { - assert(ser.deserialize[T](ser.serialize(t)) === t) - } - check((1, 1)) - check((1, 1L)) - check((1L, 1)) - check((1L, 1L)) - check((1.0, 1)) - check((1, 1.0)) - check((1.0, 1.0)) - check((1.0, 1L)) - check((1L, 1.0)) - check((1.0, 1L)) - check(("x", 1)) - check(("x", 1.0)) - check(("x", 1L)) - check((1, "x")) - check((1.0, "x")) - check((1L, "x")) - check(("x", "x")) - } - - test("Scala data structures") { - val ser = (new KryoSerializer).newInstance() - def check[T](t: T) { - assert(ser.deserialize[T](ser.serialize(t)) === t) - } - check(List[Int]()) - check(List[Int](1, 2, 3)) - check(List[String]()) - check(List[String]("x", "y", "z")) - check(None) - check(Some(1)) - check(Some("hi")) - check(mutable.ArrayBuffer(1, 2, 3)) - check(mutable.ArrayBuffer("1", "2", "3")) - check(mutable.Map()) - check(mutable.Map(1 -> "one", 2 -> "two")) - check(mutable.Map("one" -> 1, "two" -> 2)) - check(mutable.HashMap(1 -> "one", 2 -> "two")) - check(mutable.HashMap("one" -> 1, "two" -> 2)) - check(List(Some(mutable.HashMap(1->1, 2->2)), None, Some(mutable.HashMap(3->4)))) - check(List(mutable.HashMap("one" -> 1, "two" -> 2),mutable.HashMap(1->"one",2->"two",3->"three"))) - } - - test("custom registrator") { - import KryoTest._ - System.setProperty("spark.kryo.registrator", classOf[MyRegistrator].getName) - - val ser = (new KryoSerializer).newInstance() - def check[T](t: T) { - assert(ser.deserialize[T](ser.serialize(t)) === t) - } - - check(CaseClass(17, "hello")) - - val c1 = new ClassWithNoArgConstructor - c1.x = 32 - check(c1) - - val c2 = new ClassWithoutNoArgConstructor(47) - check(c2) - - val hashMap = new java.util.HashMap[String, String] - hashMap.put("foo", "bar") - check(hashMap) - - System.clearProperty("spark.kryo.registrator") - } - - test("kryo with collect") { - val control = 1 :: 2 :: Nil - val result = sc.parallelize(control, 2).map(new ClassWithoutNoArgConstructor(_)).collect().map(_.x) - assert(control === result.toSeq) - } - - test("kryo with parallelize") { - val control = 1 :: 2 :: Nil - val result = sc.parallelize(control.map(new ClassWithoutNoArgConstructor(_))).map(_.x).collect() - assert (control === result.toSeq) - } - - test("kryo with parallelize for specialized tuples") { - assert (sc.parallelize( Array((1, 11), (2, 22), (3, 33)) ).count === 3) - } - - test("kryo with parallelize for primitive arrays") { - assert (sc.parallelize( Array(1, 2, 3) ).count === 3) - } - - test("kryo with collect for specialized tuples") { - assert (sc.parallelize( Array((1, 11), (2, 22), (3, 33)) ).collect().head === (1, 11)) - } - - test("kryo with reduce") { - val control = 1 :: 2 :: Nil - val result = sc.parallelize(control, 2).map(new ClassWithoutNoArgConstructor(_)) - .reduce((t1, t2) => new ClassWithoutNoArgConstructor(t1.x + t2.x)).x - assert(control.sum === result) - } - - // TODO: this still doesn't work - ignore("kryo with fold") { - val control = 1 :: 2 :: Nil - val result = sc.parallelize(control, 2).map(new ClassWithoutNoArgConstructor(_)) - .fold(new ClassWithoutNoArgConstructor(10))((t1, t2) => new ClassWithoutNoArgConstructor(t1.x + t2.x)).x - assert(10 + control.sum === result) - } - - override def beforeAll() { - System.setProperty("spark.serializer", "org.apache.spark.KryoSerializer") - System.setProperty("spark.kryo.registrator", classOf[MyRegistrator].getName) - super.beforeAll() - } - - override def afterAll() { - super.afterAll() - System.clearProperty("spark.kryo.registrator") - System.clearProperty("spark.serializer") - } -} - -object KryoTest { - case class CaseClass(i: Int, s: String) {} - - class ClassWithNoArgConstructor { - var x: Int = 0 - override def equals(other: Any) = other match { - case c: ClassWithNoArgConstructor => x == c.x - case _ => false - } - } - - class ClassWithoutNoArgConstructor(val x: Int) { - override def equals(other: Any) = other match { - case c: ClassWithoutNoArgConstructor => x == c.x - case _ => false - } - } - - class MyRegistrator extends KryoRegistrator { - override def registerClasses(k: Kryo) { - k.register(classOf[CaseClass]) - k.register(classOf[ClassWithNoArgConstructor]) - k.register(classOf[ClassWithoutNoArgConstructor]) - k.register(classOf[java.util.HashMap[_, _]]) - } - } -} diff --git a/core/src/test/scala/org/apache/spark/PairRDDFunctionsSuite.scala b/core/src/test/scala/org/apache/spark/PairRDDFunctionsSuite.scala deleted file mode 100644 index f79752b34e..0000000000 --- a/core/src/test/scala/org/apache/spark/PairRDDFunctionsSuite.scala +++ /dev/null @@ -1,299 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark - -import scala.collection.mutable.ArrayBuffer -import scala.collection.mutable.HashSet - -import org.scalatest.FunSuite - -import com.google.common.io.Files -import org.apache.spark.SparkContext._ - - -class PairRDDFunctionsSuite extends FunSuite with SharedSparkContext { - test("groupByKey") { - val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (2, 1))) - val groups = pairs.groupByKey().collect() - assert(groups.size === 2) - val valuesFor1 = groups.find(_._1 == 1).get._2 - assert(valuesFor1.toList.sorted === List(1, 2, 3)) - val valuesFor2 = groups.find(_._1 == 2).get._2 - assert(valuesFor2.toList.sorted === List(1)) - } - - test("groupByKey with duplicates") { - val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1))) - val groups = pairs.groupByKey().collect() - assert(groups.size === 2) - val valuesFor1 = groups.find(_._1 == 1).get._2 - assert(valuesFor1.toList.sorted === List(1, 1, 2, 3)) - val valuesFor2 = groups.find(_._1 == 2).get._2 - assert(valuesFor2.toList.sorted === List(1)) - } - - test("groupByKey with negative key hash codes") { - val pairs = sc.parallelize(Array((-1, 1), (-1, 2), (-1, 3), (2, 1))) - val groups = pairs.groupByKey().collect() - assert(groups.size === 2) - val valuesForMinus1 = groups.find(_._1 == -1).get._2 - assert(valuesForMinus1.toList.sorted === List(1, 2, 3)) - val valuesFor2 = groups.find(_._1 == 2).get._2 - assert(valuesFor2.toList.sorted === List(1)) - } - - test("groupByKey with many output partitions") { - val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (2, 1))) - val groups = pairs.groupByKey(10).collect() - assert(groups.size === 2) - val valuesFor1 = groups.find(_._1 == 1).get._2 - assert(valuesFor1.toList.sorted === List(1, 2, 3)) - val valuesFor2 = groups.find(_._1 == 2).get._2 - assert(valuesFor2.toList.sorted === List(1)) - } - - test("reduceByKey") { - val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1))) - val sums = pairs.reduceByKey(_+_).collect() - assert(sums.toSet === Set((1, 7), (2, 1))) - } - - test("reduceByKey with collectAsMap") { - val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1))) - val sums = pairs.reduceByKey(_+_).collectAsMap() - assert(sums.size === 2) - assert(sums(1) === 7) - assert(sums(2) === 1) - } - - test("reduceByKey with many output partitons") { - val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1))) - val sums = pairs.reduceByKey(_+_, 10).collect() - assert(sums.toSet === Set((1, 7), (2, 1))) - } - - test("reduceByKey with partitioner") { - val p = new Partitioner() { - def numPartitions = 2 - def getPartition(key: Any) = key.asInstanceOf[Int] - } - val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 1), (0, 1))).partitionBy(p) - val sums = pairs.reduceByKey(_+_) - assert(sums.collect().toSet === Set((1, 4), (0, 1))) - assert(sums.partitioner === Some(p)) - // count the dependencies to make sure there is only 1 ShuffledRDD - val deps = new HashSet[RDD[_]]() - def visit(r: RDD[_]) { - for (dep <- r.dependencies) { - deps += dep.rdd - visit(dep.rdd) - } - } - visit(sums) - assert(deps.size === 2) // ShuffledRDD, ParallelCollection - } - - test("join") { - val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1))) - val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w'))) - val joined = rdd1.join(rdd2).collect() - assert(joined.size === 4) - assert(joined.toSet === Set( - (1, (1, 'x')), - (1, (2, 'x')), - (2, (1, 'y')), - (2, (1, 'z')) - )) - } - - test("join all-to-all") { - val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (1, 3))) - val rdd2 = sc.parallelize(Array((1, 'x'), (1, 'y'))) - val joined = rdd1.join(rdd2).collect() - assert(joined.size === 6) - assert(joined.toSet === Set( - (1, (1, 'x')), - (1, (1, 'y')), - (1, (2, 'x')), - (1, (2, 'y')), - (1, (3, 'x')), - (1, (3, 'y')) - )) - } - - test("leftOuterJoin") { - val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1))) - val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w'))) - val joined = rdd1.leftOuterJoin(rdd2).collect() - assert(joined.size === 5) - assert(joined.toSet === Set( - (1, (1, Some('x'))), - (1, (2, Some('x'))), - (2, (1, Some('y'))), - (2, (1, Some('z'))), - (3, (1, None)) - )) - } - - test("rightOuterJoin") { - val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1))) - val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w'))) - val joined = rdd1.rightOuterJoin(rdd2).collect() - assert(joined.size === 5) - assert(joined.toSet === Set( - (1, (Some(1), 'x')), - (1, (Some(2), 'x')), - (2, (Some(1), 'y')), - (2, (Some(1), 'z')), - (4, (None, 'w')) - )) - } - - test("join with no matches") { - val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1))) - val rdd2 = sc.parallelize(Array((4, 'x'), (5, 'y'), (5, 'z'), (6, 'w'))) - val joined = rdd1.join(rdd2).collect() - assert(joined.size === 0) - } - - test("join with many output partitions") { - val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1))) - val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w'))) - val joined = rdd1.join(rdd2, 10).collect() - assert(joined.size === 4) - assert(joined.toSet === Set( - (1, (1, 'x')), - (1, (2, 'x')), - (2, (1, 'y')), - (2, (1, 'z')) - )) - } - - test("groupWith") { - val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1))) - val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w'))) - val joined = rdd1.groupWith(rdd2).collect() - assert(joined.size === 4) - assert(joined.toSet === Set( - (1, (ArrayBuffer(1, 2), ArrayBuffer('x'))), - (2, (ArrayBuffer(1), ArrayBuffer('y', 'z'))), - (3, (ArrayBuffer(1), ArrayBuffer())), - (4, (ArrayBuffer(), ArrayBuffer('w'))) - )) - } - - test("zero-partition RDD") { - val emptyDir = Files.createTempDir() - val file = sc.textFile(emptyDir.getAbsolutePath) - assert(file.partitions.size == 0) - assert(file.collect().toList === Nil) - // Test that a shuffle on the file works, because this used to be a bug - assert(file.map(line => (line, 1)).reduceByKey(_ + _).collect().toList === Nil) - } - - test("keys and values") { - val rdd = sc.parallelize(Array((1, "a"), (2, "b"))) - assert(rdd.keys.collect().toList === List(1, 2)) - assert(rdd.values.collect().toList === List("a", "b")) - } - - test("default partitioner uses partition size") { - // specify 2000 partitions - val a = sc.makeRDD(Array(1, 2, 3, 4), 2000) - // do a map, which loses the partitioner - val b = a.map(a => (a, (a * 2).toString)) - // then a group by, and see we didn't revert to 2 partitions - val c = b.groupByKey() - assert(c.partitions.size === 2000) - } - - test("default partitioner uses largest partitioner") { - val a = sc.makeRDD(Array((1, "a"), (2, "b")), 2) - val b = sc.makeRDD(Array((1, "a"), (2, "b")), 2000) - val c = a.join(b) - assert(c.partitions.size === 2000) - } - - test("subtract") { - val a = sc.parallelize(Array(1, 2, 3), 2) - val b = sc.parallelize(Array(2, 3, 4), 4) - val c = a.subtract(b) - assert(c.collect().toSet === Set(1)) - assert(c.partitions.size === a.partitions.size) - } - - test("subtract with narrow dependency") { - // use a deterministic partitioner - val p = new Partitioner() { - def numPartitions = 5 - def getPartition(key: Any) = key.asInstanceOf[Int] - } - // partitionBy so we have a narrow dependency - val a = sc.parallelize(Array((1, "a"), (2, "b"), (3, "c"))).partitionBy(p) - // more partitions/no partitioner so a shuffle dependency - val b = sc.parallelize(Array((2, "b"), (3, "cc"), (4, "d")), 4) - val c = a.subtract(b) - assert(c.collect().toSet === Set((1, "a"), (3, "c"))) - // Ideally we could keep the original partitioner... - assert(c.partitioner === None) - } - - test("subtractByKey") { - val a = sc.parallelize(Array((1, "a"), (1, "a"), (2, "b"), (3, "c")), 2) - val b = sc.parallelize(Array((2, 20), (3, 30), (4, 40)), 4) - val c = a.subtractByKey(b) - assert(c.collect().toSet === Set((1, "a"), (1, "a"))) - assert(c.partitions.size === a.partitions.size) - } - - test("subtractByKey with narrow dependency") { - // use a deterministic partitioner - val p = new Partitioner() { - def numPartitions = 5 - def getPartition(key: Any) = key.asInstanceOf[Int] - } - // partitionBy so we have a narrow dependency - val a = sc.parallelize(Array((1, "a"), (1, "a"), (2, "b"), (3, "c"))).partitionBy(p) - // more partitions/no partitioner so a shuffle dependency - val b = sc.parallelize(Array((2, "b"), (3, "cc"), (4, "d")), 4) - val c = a.subtractByKey(b) - assert(c.collect().toSet === Set((1, "a"), (1, "a"))) - assert(c.partitioner.get === p) - } - - test("foldByKey") { - val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1))) - val sums = pairs.foldByKey(0)(_+_).collect() - assert(sums.toSet === Set((1, 7), (2, 1))) - } - - test("foldByKey with mutable result type") { - val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1))) - val bufs = pairs.mapValues(v => ArrayBuffer(v)).cache() - // Fold the values using in-place mutation - val sums = bufs.foldByKey(new ArrayBuffer[Int])(_ ++= _).collect() - assert(sums.toSet === Set((1, ArrayBuffer(1, 2, 3, 1)), (2, ArrayBuffer(1)))) - // Check that the mutable objects in the original RDD were not changed - assert(bufs.collect().toSet === Set( - (1, ArrayBuffer(1)), - (1, ArrayBuffer(2)), - (1, ArrayBuffer(3)), - (1, ArrayBuffer(1)), - (2, ArrayBuffer(1)))) - } -} diff --git a/core/src/test/scala/org/apache/spark/PartitionPruningRDDSuite.scala b/core/src/test/scala/org/apache/spark/PartitionPruningRDDSuite.scala index adbe805916..5a18dd13ff 100644 --- a/core/src/test/scala/org/apache/spark/PartitionPruningRDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/PartitionPruningRDDSuite.scala @@ -2,7 +2,7 @@ package org.apache.spark import org.scalatest.FunSuite import org.apache.spark.SparkContext._ -import org.apache.spark.rdd.PartitionPruningRDD +import org.apache.spark.rdd.{RDD, PartitionPruningRDD} class PartitionPruningRDDSuite extends FunSuite with SharedSparkContext { diff --git a/core/src/test/scala/org/apache/spark/PartitioningSuite.scala b/core/src/test/scala/org/apache/spark/PartitioningSuite.scala index 7669cf6fb1..7d938917f2 100644 --- a/core/src/test/scala/org/apache/spark/PartitioningSuite.scala +++ b/core/src/test/scala/org/apache/spark/PartitioningSuite.scala @@ -17,11 +17,14 @@ package org.apache.spark -import org.scalatest.FunSuite +import scala.math.abs import scala.collection.mutable.ArrayBuffer -import SparkContext._ + +import org.scalatest.FunSuite + +import org.apache.spark.SparkContext._ import org.apache.spark.util.StatCounter -import scala.math.abs +import org.apache.spark.rdd.RDD class PartitioningSuite extends FunSuite with SharedSparkContext { diff --git a/core/src/test/scala/org/apache/spark/RDDSuite.scala b/core/src/test/scala/org/apache/spark/RDDSuite.scala deleted file mode 100644 index 342ba8adb2..0000000000 --- a/core/src/test/scala/org/apache/spark/RDDSuite.scala +++ /dev/null @@ -1,389 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark - -import scala.collection.mutable.HashMap -import org.scalatest.FunSuite -import org.scalatest.concurrent.Timeouts._ -import org.scalatest.time.{Span, Millis} -import org.apache.spark.SparkContext._ -import org.apache.spark.rdd._ -import scala.collection.parallel.mutable - -class RDDSuite extends FunSuite with SharedSparkContext { - - test("basic operations") { - val nums = sc.makeRDD(Array(1, 2, 3, 4), 2) - assert(nums.collect().toList === List(1, 2, 3, 4)) - val dups = sc.makeRDD(Array(1, 1, 2, 2, 3, 3, 4, 4), 2) - assert(dups.distinct().count() === 4) - assert(dups.distinct.count === 4) // Can distinct and count be called without parentheses? - assert(dups.distinct.collect === dups.distinct().collect) - assert(dups.distinct(2).collect === dups.distinct().collect) - assert(nums.reduce(_ + _) === 10) - assert(nums.fold(0)(_ + _) === 10) - assert(nums.map(_.toString).collect().toList === List("1", "2", "3", "4")) - assert(nums.filter(_ > 2).collect().toList === List(3, 4)) - assert(nums.flatMap(x => 1 to x).collect().toList === List(1, 1, 2, 1, 2, 3, 1, 2, 3, 4)) - assert(nums.union(nums).collect().toList === List(1, 2, 3, 4, 1, 2, 3, 4)) - assert(nums.glom().map(_.toList).collect().toList === List(List(1, 2), List(3, 4))) - assert(nums.collect({ case i if i >= 3 => i.toString }).collect().toList === List("3", "4")) - assert(nums.keyBy(_.toString).collect().toList === List(("1", 1), ("2", 2), ("3", 3), ("4", 4))) - val partitionSums = nums.mapPartitions(iter => Iterator(iter.reduceLeft(_ + _))) - assert(partitionSums.collect().toList === List(3, 7)) - - val partitionSumsWithSplit = nums.mapPartitionsWithSplit { - case(split, iter) => Iterator((split, iter.reduceLeft(_ + _))) - } - assert(partitionSumsWithSplit.collect().toList === List((0, 3), (1, 7))) - - val partitionSumsWithIndex = nums.mapPartitionsWithIndex { - case(split, iter) => Iterator((split, iter.reduceLeft(_ + _))) - } - assert(partitionSumsWithIndex.collect().toList === List((0, 3), (1, 7))) - - intercept[UnsupportedOperationException] { - nums.filter(_ > 5).reduce(_ + _) - } - } - - test("SparkContext.union") { - val nums = sc.makeRDD(Array(1, 2, 3, 4), 2) - assert(sc.union(nums).collect().toList === List(1, 2, 3, 4)) - assert(sc.union(nums, nums).collect().toList === List(1, 2, 3, 4, 1, 2, 3, 4)) - assert(sc.union(Seq(nums)).collect().toList === List(1, 2, 3, 4)) - assert(sc.union(Seq(nums, nums)).collect().toList === List(1, 2, 3, 4, 1, 2, 3, 4)) - } - - test("aggregate") { - val pairs = sc.makeRDD(Array(("a", 1), ("b", 2), ("a", 2), ("c", 5), ("a", 3))) - type StringMap = HashMap[String, Int] - val emptyMap = new StringMap { - override def default(key: String): Int = 0 - } - val mergeElement: (StringMap, (String, Int)) => StringMap = (map, pair) => { - map(pair._1) += pair._2 - map - } - val mergeMaps: (StringMap, StringMap) => StringMap = (map1, map2) => { - for ((key, value) <- map2) { - map1(key) += value - } - map1 - } - val result = pairs.aggregate(emptyMap)(mergeElement, mergeMaps) - assert(result.toSet === Set(("a", 6), ("b", 2), ("c", 5))) - } - - test("basic caching") { - val rdd = sc.makeRDD(Array(1, 2, 3, 4), 2).cache() - assert(rdd.collect().toList === List(1, 2, 3, 4)) - assert(rdd.collect().toList === List(1, 2, 3, 4)) - assert(rdd.collect().toList === List(1, 2, 3, 4)) - } - - test("caching with failures") { - val onlySplit = new Partition { override def index: Int = 0 } - var shouldFail = true - val rdd = new RDD[Int](sc, Nil) { - override def getPartitions: Array[Partition] = Array(onlySplit) - override val getDependencies = List[Dependency[_]]() - override def compute(split: Partition, context: TaskContext): Iterator[Int] = { - if (shouldFail) { - throw new Exception("injected failure") - } else { - return Array(1, 2, 3, 4).iterator - } - } - }.cache() - val thrown = intercept[Exception]{ - rdd.collect() - } - assert(thrown.getMessage.contains("injected failure")) - shouldFail = false - assert(rdd.collect().toList === List(1, 2, 3, 4)) - } - - test("empty RDD") { - val empty = new EmptyRDD[Int](sc) - assert(empty.count === 0) - assert(empty.collect().size === 0) - - val thrown = intercept[UnsupportedOperationException]{ - empty.reduce(_+_) - } - assert(thrown.getMessage.contains("empty")) - - val emptyKv = new EmptyRDD[(Int, Int)](sc) - val rdd = sc.parallelize(1 to 2, 2).map(x => (x, x)) - assert(rdd.join(emptyKv).collect().size === 0) - assert(rdd.rightOuterJoin(emptyKv).collect().size === 0) - assert(rdd.leftOuterJoin(emptyKv).collect().size === 2) - assert(rdd.cogroup(emptyKv).collect().size === 2) - assert(rdd.union(emptyKv).collect().size === 2) - } - - test("cogrouped RDDs") { - val data = sc.parallelize(1 to 10, 10) - - val coalesced1 = data.coalesce(2) - assert(coalesced1.collect().toList === (1 to 10).toList) - assert(coalesced1.glom().collect().map(_.toList).toList === - List(List(1, 2, 3, 4, 5), List(6, 7, 8, 9, 10))) - - // Check that the narrow dependency is also specified correctly - assert(coalesced1.dependencies.head.asInstanceOf[NarrowDependency[_]].getParents(0).toList === - List(0, 1, 2, 3, 4)) - assert(coalesced1.dependencies.head.asInstanceOf[NarrowDependency[_]].getParents(1).toList === - List(5, 6, 7, 8, 9)) - - val coalesced2 = data.coalesce(3) - assert(coalesced2.collect().toList === (1 to 10).toList) - assert(coalesced2.glom().collect().map(_.toList).toList === - List(List(1, 2, 3), List(4, 5, 6), List(7, 8, 9, 10))) - - val coalesced3 = data.coalesce(10) - assert(coalesced3.collect().toList === (1 to 10).toList) - assert(coalesced3.glom().collect().map(_.toList).toList === - (1 to 10).map(x => List(x)).toList) - - // If we try to coalesce into more partitions than the original RDD, it should just - // keep the original number of partitions. - val coalesced4 = data.coalesce(20) - assert(coalesced4.collect().toList === (1 to 10).toList) - assert(coalesced4.glom().collect().map(_.toList).toList === - (1 to 10).map(x => List(x)).toList) - - // we can optionally shuffle to keep the upstream parallel - val coalesced5 = data.coalesce(1, shuffle = true) - assert(coalesced5.dependencies.head.rdd.dependencies.head.rdd.asInstanceOf[ShuffledRDD[_, _, _]] != - null) - } - test("cogrouped RDDs with locality") { - val data3 = sc.makeRDD(List((1,List("a","c")), (2,List("a","b","c")), (3,List("b")))) - val coal3 = data3.coalesce(3) - val list3 = coal3.partitions.map(p => p.asInstanceOf[CoalescedRDDPartition].preferredLocation) - assert(list3.sorted === Array("a","b","c"), "Locality preferences are dropped") - - // RDD with locality preferences spread (non-randomly) over 6 machines, m0 through m5 - val data = sc.makeRDD((1 to 9).map(i => (i, (i to (i+2)).map{ j => "m" + (j%6)}))) - val coalesced1 = data.coalesce(3) - assert(coalesced1.collect().toList.sorted === (1 to 9).toList, "Data got *lost* in coalescing") - - val splits = coalesced1.glom().collect().map(_.toList).toList - assert(splits.length === 3, "Supposed to coalesce to 3 but got " + splits.length) - - assert(splits.forall(_.length >= 1) === true, "Some partitions were empty") - - // If we try to coalesce into more partitions than the original RDD, it should just - // keep the original number of partitions. - val coalesced4 = data.coalesce(20) - val listOfLists = coalesced4.glom().collect().map(_.toList).toList - val sortedList = listOfLists.sortWith{ (x, y) => !x.isEmpty && (y.isEmpty || (x(0) < y(0))) } - assert( sortedList === (1 to 9). - map{x => List(x)}.toList, "Tried coalescing 9 partitions to 20 but didn't get 9 back") - } - - test("cogrouped RDDs with locality, large scale (10K partitions)") { - // large scale experiment - import collection.mutable - val rnd = scala.util.Random - val partitions = 10000 - val numMachines = 50 - val machines = mutable.ListBuffer[String]() - (1 to numMachines).foreach(machines += "m"+_) - - val blocks = (1 to partitions).map(i => - { (i, Array.fill(3)(machines(rnd.nextInt(machines.size))).toList) } ) - - val data2 = sc.makeRDD(blocks) - val coalesced2 = data2.coalesce(numMachines*2) - - // test that you get over 90% locality in each group - val minLocality = coalesced2.partitions - .map(part => part.asInstanceOf[CoalescedRDDPartition].localFraction) - .foldLeft(1.)((perc, loc) => math.min(perc,loc)) - assert(minLocality >= 0.90, "Expected 90% locality but got " + (minLocality*100.).toInt + "%") - - // test that the groups are load balanced with 100 +/- 20 elements in each - val maxImbalance = coalesced2.partitions - .map(part => part.asInstanceOf[CoalescedRDDPartition].parents.size) - .foldLeft(0)((dev, curr) => math.max(math.abs(100-curr),dev)) - assert(maxImbalance <= 20, "Expected 100 +/- 20 per partition, but got " + maxImbalance) - - val data3 = sc.makeRDD(blocks).map(i => i*2) // derived RDD to test *current* pref locs - val coalesced3 = data3.coalesce(numMachines*2) - val minLocality2 = coalesced3.partitions - .map(part => part.asInstanceOf[CoalescedRDDPartition].localFraction) - .foldLeft(1.)((perc, loc) => math.min(perc,loc)) - assert(minLocality2 >= 0.90, "Expected 90% locality for derived RDD but got " + - (minLocality2*100.).toInt + "%") - } - - test("zipped RDDs") { - val nums = sc.makeRDD(Array(1, 2, 3, 4), 2) - val zipped = nums.zip(nums.map(_ + 1.0)) - assert(zipped.glom().map(_.toList).collect().toList === - List(List((1, 2.0), (2, 3.0)), List((3, 4.0), (4, 5.0)))) - - intercept[IllegalArgumentException] { - nums.zip(sc.parallelize(1 to 4, 1)).collect() - } - } - - test("partition pruning") { - val data = sc.parallelize(1 to 10, 10) - // Note that split number starts from 0, so > 8 means only 10th partition left. - val prunedRdd = new PartitionPruningRDD(data, splitNum => splitNum > 8) - assert(prunedRdd.partitions.size === 1) - val prunedData = prunedRdd.collect() - assert(prunedData.size === 1) - assert(prunedData(0) === 10) - } - - test("mapWith") { - import java.util.Random - val ones = sc.makeRDD(Array(1, 1, 1, 1, 1, 1), 2) - val randoms = ones.mapWith( - (index: Int) => new Random(index + 42)) - {(t: Int, prng: Random) => prng.nextDouble * t}.collect() - val prn42_3 = { - val prng42 = new Random(42) - prng42.nextDouble(); prng42.nextDouble(); prng42.nextDouble() - } - val prn43_3 = { - val prng43 = new Random(43) - prng43.nextDouble(); prng43.nextDouble(); prng43.nextDouble() - } - assert(randoms(2) === prn42_3) - assert(randoms(5) === prn43_3) - } - - test("flatMapWith") { - import java.util.Random - val ones = sc.makeRDD(Array(1, 1, 1, 1, 1, 1), 2) - val randoms = ones.flatMapWith( - (index: Int) => new Random(index + 42)) - {(t: Int, prng: Random) => - val random = prng.nextDouble() - Seq(random * t, random * t * 10)}. - collect() - val prn42_3 = { - val prng42 = new Random(42) - prng42.nextDouble(); prng42.nextDouble(); prng42.nextDouble() - } - val prn43_3 = { - val prng43 = new Random(43) - prng43.nextDouble(); prng43.nextDouble(); prng43.nextDouble() - } - assert(randoms(5) === prn42_3 * 10) - assert(randoms(11) === prn43_3 * 10) - } - - test("filterWith") { - import java.util.Random - val ints = sc.makeRDD(Array(1, 2, 3, 4, 5, 6), 2) - val sample = ints.filterWith( - (index: Int) => new Random(index + 42)) - {(t: Int, prng: Random) => prng.nextInt(3) == 0}. - collect() - val checkSample = { - val prng42 = new Random(42) - val prng43 = new Random(43) - Array(1, 2, 3, 4, 5, 6).filter{i => - if (i < 4) 0 == prng42.nextInt(3) - else 0 == prng43.nextInt(3)} - } - assert(sample.size === checkSample.size) - for (i <- 0 until sample.size) assert(sample(i) === checkSample(i)) - } - - test("top with predefined ordering") { - val nums = Array.range(1, 100000) - val ints = sc.makeRDD(scala.util.Random.shuffle(nums), 2) - val topK = ints.top(5) - assert(topK.size === 5) - assert(topK === nums.reverse.take(5)) - } - - test("top with custom ordering") { - val words = Vector("a", "b", "c", "d") - implicit val ord = implicitly[Ordering[String]].reverse - val rdd = sc.makeRDD(words, 2) - val topK = rdd.top(2) - assert(topK.size === 2) - assert(topK.sorted === Array("b", "a")) - } - - test("takeOrdered with predefined ordering") { - val nums = Array(1, 2, 3, 4, 5, 6, 7, 8, 9, 10) - val rdd = sc.makeRDD(nums, 2) - val sortedLowerK = rdd.takeOrdered(5) - assert(sortedLowerK.size === 5) - assert(sortedLowerK === Array(1, 2, 3, 4, 5)) - } - - test("takeOrdered with custom ordering") { - val nums = Array(1, 2, 3, 4, 5, 6, 7, 8, 9, 10) - implicit val ord = implicitly[Ordering[Int]].reverse - val rdd = sc.makeRDD(nums, 2) - val sortedTopK = rdd.takeOrdered(5) - assert(sortedTopK.size === 5) - assert(sortedTopK === Array(10, 9, 8, 7, 6)) - assert(sortedTopK === nums.sorted(ord).take(5)) - } - - test("takeSample") { - val data = sc.parallelize(1 to 100, 2) - for (seed <- 1 to 5) { - val sample = data.takeSample(withReplacement=false, 20, seed) - assert(sample.size === 20) // Got exactly 20 elements - assert(sample.toSet.size === 20) // Elements are distinct - assert(sample.forall(x => 1 <= x && x <= 100), "elements not in [1, 100]") - } - for (seed <- 1 to 5) { - val sample = data.takeSample(withReplacement=false, 200, seed) - assert(sample.size === 100) // Got only 100 elements - assert(sample.toSet.size === 100) // Elements are distinct - assert(sample.forall(x => 1 <= x && x <= 100), "elements not in [1, 100]") - } - for (seed <- 1 to 5) { - val sample = data.takeSample(withReplacement=true, 20, seed) - assert(sample.size === 20) // Got exactly 20 elements - assert(sample.forall(x => 1 <= x && x <= 100), "elements not in [1, 100]") - } - for (seed <- 1 to 5) { - val sample = data.takeSample(withReplacement=true, 100, seed) - assert(sample.size === 100) // Got exactly 100 elements - // Chance of getting all distinct elements is astronomically low, so test we got < 100 - assert(sample.toSet.size < 100, "sampling with replacement returned all distinct elements") - } - for (seed <- 1 to 5) { - val sample = data.takeSample(withReplacement=true, 200, seed) - assert(sample.size === 200) // Got exactly 200 elements - // Chance of getting all distinct elements is still quite low, so test we got < 100 - assert(sample.toSet.size < 100, "sampling with replacement returned all distinct elements") - } - } - - test("runJob on an invalid partition") { - intercept[IllegalArgumentException] { - sc.runJob(sc.parallelize(1 to 10, 2), {iter: Iterator[Int] => iter.size}, Seq(0, 1, 2), false) - } - } -} diff --git a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala index 357175e89e..db717865db 100644 --- a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala +++ b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala @@ -22,8 +22,9 @@ import org.scalatest.matchers.ShouldMatchers import org.apache.spark.SparkContext._ import org.apache.spark.ShuffleSuite.NonJavaSerializableClass -import org.apache.spark.rdd.{SubtractedRDD, CoGroupedRDD, OrderedRDDFunctions, ShuffledRDD} +import org.apache.spark.rdd.{RDD, SubtractedRDD, CoGroupedRDD, OrderedRDDFunctions, ShuffledRDD} import org.apache.spark.util.MutablePair +import org.apache.spark.serializer.KryoSerializer class ShuffleSuite extends FunSuite with ShouldMatchers with LocalSparkContext { diff --git a/core/src/test/scala/org/apache/spark/SizeEstimatorSuite.scala b/core/src/test/scala/org/apache/spark/SizeEstimatorSuite.scala deleted file mode 100644 index 214ac74898..0000000000 --- a/core/src/test/scala/org/apache/spark/SizeEstimatorSuite.scala +++ /dev/null @@ -1,164 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark - -import org.scalatest.FunSuite -import org.scalatest.BeforeAndAfterAll -import org.scalatest.PrivateMethodTester - -class DummyClass1 {} - -class DummyClass2 { - val x: Int = 0 -} - -class DummyClass3 { - val x: Int = 0 - val y: Double = 0.0 -} - -class DummyClass4(val d: DummyClass3) { - val x: Int = 0 -} - -object DummyString { - def apply(str: String) : DummyString = new DummyString(str.toArray) -} -class DummyString(val arr: Array[Char]) { - override val hashCode: Int = 0 - // JDK-7 has an extra hash32 field http://hg.openjdk.java.net/jdk7u/jdk7u6/jdk/rev/11987e85555f - @transient val hash32: Int = 0 -} - -class SizeEstimatorSuite - extends FunSuite with BeforeAndAfterAll with PrivateMethodTester { - - var oldArch: String = _ - var oldOops: String = _ - - override def beforeAll() { - // Set the arch to 64-bit and compressedOops to true to get a deterministic test-case - oldArch = System.setProperty("os.arch", "amd64") - oldOops = System.setProperty("spark.test.useCompressedOops", "true") - } - - override def afterAll() { - resetOrClear("os.arch", oldArch) - resetOrClear("spark.test.useCompressedOops", oldOops) - } - - test("simple classes") { - assert(SizeEstimator.estimate(new DummyClass1) === 16) - assert(SizeEstimator.estimate(new DummyClass2) === 16) - assert(SizeEstimator.estimate(new DummyClass3) === 24) - assert(SizeEstimator.estimate(new DummyClass4(null)) === 24) - assert(SizeEstimator.estimate(new DummyClass4(new DummyClass3)) === 48) - } - - // NOTE: The String class definition varies across JDK versions (1.6 vs. 1.7) and vendors - // (Sun vs IBM). Use a DummyString class to make tests deterministic. - test("strings") { - assert(SizeEstimator.estimate(DummyString("")) === 40) - assert(SizeEstimator.estimate(DummyString("a")) === 48) - assert(SizeEstimator.estimate(DummyString("ab")) === 48) - assert(SizeEstimator.estimate(DummyString("abcdefgh")) === 56) - } - - test("primitive arrays") { - assert(SizeEstimator.estimate(new Array[Byte](10)) === 32) - assert(SizeEstimator.estimate(new Array[Char](10)) === 40) - assert(SizeEstimator.estimate(new Array[Short](10)) === 40) - assert(SizeEstimator.estimate(new Array[Int](10)) === 56) - assert(SizeEstimator.estimate(new Array[Long](10)) === 96) - assert(SizeEstimator.estimate(new Array[Float](10)) === 56) - assert(SizeEstimator.estimate(new Array[Double](10)) === 96) - assert(SizeEstimator.estimate(new Array[Int](1000)) === 4016) - assert(SizeEstimator.estimate(new Array[Long](1000)) === 8016) - } - - test("object arrays") { - // Arrays containing nulls should just have one pointer per element - assert(SizeEstimator.estimate(new Array[String](10)) === 56) - assert(SizeEstimator.estimate(new Array[AnyRef](10)) === 56) - - // For object arrays with non-null elements, each object should take one pointer plus - // however many bytes that class takes. (Note that Array.fill calls the code in its - // second parameter separately for each object, so we get distinct objects.) - assert(SizeEstimator.estimate(Array.fill(10)(new DummyClass1)) === 216) - assert(SizeEstimator.estimate(Array.fill(10)(new DummyClass2)) === 216) - assert(SizeEstimator.estimate(Array.fill(10)(new DummyClass3)) === 296) - assert(SizeEstimator.estimate(Array(new DummyClass1, new DummyClass2)) === 56) - - // Past size 100, our samples 100 elements, but we should still get the right size. - assert(SizeEstimator.estimate(Array.fill(1000)(new DummyClass3)) === 28016) - - // If an array contains the *same* element many times, we should only count it once. - val d1 = new DummyClass1 - assert(SizeEstimator.estimate(Array.fill(10)(d1)) === 72) // 10 pointers plus 8-byte object - assert(SizeEstimator.estimate(Array.fill(100)(d1)) === 432) // 100 pointers plus 8-byte object - - // Same thing with huge array containing the same element many times. Note that this won't - // return exactly 4032 because it can't tell that *all* the elements will equal the first - // one it samples, but it should be close to that. - - // TODO: If we sample 100 elements, this should always be 4176 ? - val estimatedSize = SizeEstimator.estimate(Array.fill(1000)(d1)) - assert(estimatedSize >= 4000, "Estimated size " + estimatedSize + " should be more than 4000") - assert(estimatedSize <= 4200, "Estimated size " + estimatedSize + " should be less than 4100") - } - - test("32-bit arch") { - val arch = System.setProperty("os.arch", "x86") - - val initialize = PrivateMethod[Unit]('initialize) - SizeEstimator invokePrivate initialize() - - assert(SizeEstimator.estimate(DummyString("")) === 40) - assert(SizeEstimator.estimate(DummyString("a")) === 48) - assert(SizeEstimator.estimate(DummyString("ab")) === 48) - assert(SizeEstimator.estimate(DummyString("abcdefgh")) === 56) - - resetOrClear("os.arch", arch) - } - - // NOTE: The String class definition varies across JDK versions (1.6 vs. 1.7) and vendors - // (Sun vs IBM). Use a DummyString class to make tests deterministic. - test("64-bit arch with no compressed oops") { - val arch = System.setProperty("os.arch", "amd64") - val oops = System.setProperty("spark.test.useCompressedOops", "false") - - val initialize = PrivateMethod[Unit]('initialize) - SizeEstimator invokePrivate initialize() - - assert(SizeEstimator.estimate(DummyString("")) === 56) - assert(SizeEstimator.estimate(DummyString("a")) === 64) - assert(SizeEstimator.estimate(DummyString("ab")) === 64) - assert(SizeEstimator.estimate(DummyString("abcdefgh")) === 72) - - resetOrClear("os.arch", arch) - resetOrClear("spark.test.useCompressedOops", oops) - } - - def resetOrClear(prop: String, oldValue: String) { - if (oldValue != null) { - System.setProperty(prop, oldValue) - } else { - System.clearProperty(prop) - } - } -} diff --git a/core/src/test/scala/org/apache/spark/SortingSuite.scala b/core/src/test/scala/org/apache/spark/SortingSuite.scala deleted file mode 100644 index f4fa9511dd..0000000000 --- a/core/src/test/scala/org/apache/spark/SortingSuite.scala +++ /dev/null @@ -1,123 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark - -import org.scalatest.FunSuite -import org.scalatest.BeforeAndAfter -import org.scalatest.matchers.ShouldMatchers -import SparkContext._ - -class SortingSuite extends FunSuite with SharedSparkContext with ShouldMatchers with Logging { - - test("sortByKey") { - val pairs = sc.parallelize(Array((1, 0), (2, 0), (0, 0), (3, 0)), 2) - assert(pairs.sortByKey().collect() === Array((0,0), (1,0), (2,0), (3,0))) - } - - test("large array") { - val rand = new scala.util.Random() - val pairArr = Array.fill(1000) { (rand.nextInt(), rand.nextInt()) } - val pairs = sc.parallelize(pairArr, 2) - val sorted = pairs.sortByKey() - assert(sorted.partitions.size === 2) - assert(sorted.collect() === pairArr.sortBy(_._1)) - } - - test("large array with one split") { - val rand = new scala.util.Random() - val pairArr = Array.fill(1000) { (rand.nextInt(), rand.nextInt()) } - val pairs = sc.parallelize(pairArr, 2) - val sorted = pairs.sortByKey(true, 1) - assert(sorted.partitions.size === 1) - assert(sorted.collect() === pairArr.sortBy(_._1)) - } - - test("large array with many partitions") { - val rand = new scala.util.Random() - val pairArr = Array.fill(1000) { (rand.nextInt(), rand.nextInt()) } - val pairs = sc.parallelize(pairArr, 2) - val sorted = pairs.sortByKey(true, 20) - assert(sorted.partitions.size === 20) - assert(sorted.collect() === pairArr.sortBy(_._1)) - } - - test("sort descending") { - val rand = new scala.util.Random() - val pairArr = Array.fill(1000) { (rand.nextInt(), rand.nextInt()) } - val pairs = sc.parallelize(pairArr, 2) - assert(pairs.sortByKey(false).collect() === pairArr.sortWith((x, y) => x._1 > y._1)) - } - - test("sort descending with one split") { - val rand = new scala.util.Random() - val pairArr = Array.fill(1000) { (rand.nextInt(), rand.nextInt()) } - val pairs = sc.parallelize(pairArr, 1) - assert(pairs.sortByKey(false, 1).collect() === pairArr.sortWith((x, y) => x._1 > y._1)) - } - - test("sort descending with many partitions") { - val rand = new scala.util.Random() - val pairArr = Array.fill(1000) { (rand.nextInt(), rand.nextInt()) } - val pairs = sc.parallelize(pairArr, 2) - assert(pairs.sortByKey(false, 20).collect() === pairArr.sortWith((x, y) => x._1 > y._1)) - } - - test("more partitions than elements") { - val rand = new scala.util.Random() - val pairArr = Array.fill(10) { (rand.nextInt(), rand.nextInt()) } - val pairs = sc.parallelize(pairArr, 30) - assert(pairs.sortByKey().collect() === pairArr.sortBy(_._1)) - } - - test("empty RDD") { - val pairArr = new Array[(Int, Int)](0) - val pairs = sc.parallelize(pairArr, 2) - assert(pairs.sortByKey().collect() === pairArr.sortBy(_._1)) - } - - test("partition balancing") { - val pairArr = (1 to 1000).map(x => (x, x)).toArray - val sorted = sc.parallelize(pairArr, 4).sortByKey() - assert(sorted.collect() === pairArr.sortBy(_._1)) - val partitions = sorted.collectPartitions() - logInfo("Partition lengths: " + partitions.map(_.length).mkString(", ")) - partitions(0).length should be > 180 - partitions(1).length should be > 180 - partitions(2).length should be > 180 - partitions(3).length should be > 180 - partitions(0).last should be < partitions(1).head - partitions(1).last should be < partitions(2).head - partitions(2).last should be < partitions(3).head - } - - test("partition balancing for descending sort") { - val pairArr = (1 to 1000).map(x => (x, x)).toArray - val sorted = sc.parallelize(pairArr, 4).sortByKey(false) - assert(sorted.collect() === pairArr.sortBy(_._1).reverse) - val partitions = sorted.collectPartitions() - logInfo("partition lengths: " + partitions.map(_.length).mkString(", ")) - partitions(0).length should be > 180 - partitions(1).length should be > 180 - partitions(2).length should be > 180 - partitions(3).length should be > 180 - partitions(0).last should be > partitions(1).head - partitions(1).last should be > partitions(2).head - partitions(2).last should be > partitions(3).head - } -} - diff --git a/core/src/test/scala/org/apache/spark/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/UtilsSuite.scala deleted file mode 100644 index 3a908720a8..0000000000 --- a/core/src/test/scala/org/apache/spark/UtilsSuite.scala +++ /dev/null @@ -1,139 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark - -import com.google.common.base.Charsets -import com.google.common.io.Files -import java.io.{ByteArrayOutputStream, ByteArrayInputStream, FileOutputStream, File} -import org.scalatest.FunSuite -import org.apache.commons.io.FileUtils -import scala.util.Random - -class UtilsSuite extends FunSuite { - - test("bytesToString") { - assert(Utils.bytesToString(10) === "10.0 B") - assert(Utils.bytesToString(1500) === "1500.0 B") - assert(Utils.bytesToString(2000000) === "1953.1 KB") - assert(Utils.bytesToString(2097152) === "2.0 MB") - assert(Utils.bytesToString(2306867) === "2.2 MB") - assert(Utils.bytesToString(5368709120L) === "5.0 GB") - assert(Utils.bytesToString(5L * 1024L * 1024L * 1024L * 1024L) === "5.0 TB") - } - - test("copyStream") { - //input array initialization - val bytes = Array.ofDim[Byte](9000) - Random.nextBytes(bytes) - - val os = new ByteArrayOutputStream() - Utils.copyStream(new ByteArrayInputStream(bytes), os) - - assert(os.toByteArray.toList.equals(bytes.toList)) - } - - test("memoryStringToMb") { - assert(Utils.memoryStringToMb("1") === 0) - assert(Utils.memoryStringToMb("1048575") === 0) - assert(Utils.memoryStringToMb("3145728") === 3) - - assert(Utils.memoryStringToMb("1024k") === 1) - assert(Utils.memoryStringToMb("5000k") === 4) - assert(Utils.memoryStringToMb("4024k") === Utils.memoryStringToMb("4024K")) - - assert(Utils.memoryStringToMb("1024m") === 1024) - assert(Utils.memoryStringToMb("5000m") === 5000) - assert(Utils.memoryStringToMb("4024m") === Utils.memoryStringToMb("4024M")) - - assert(Utils.memoryStringToMb("2g") === 2048) - assert(Utils.memoryStringToMb("3g") === Utils.memoryStringToMb("3G")) - - assert(Utils.memoryStringToMb("2t") === 2097152) - assert(Utils.memoryStringToMb("3t") === Utils.memoryStringToMb("3T")) - } - - test("splitCommandString") { - assert(Utils.splitCommandString("") === Seq()) - assert(Utils.splitCommandString("a") === Seq("a")) - assert(Utils.splitCommandString("aaa") === Seq("aaa")) - assert(Utils.splitCommandString("a b c") === Seq("a", "b", "c")) - assert(Utils.splitCommandString(" a b\t c ") === Seq("a", "b", "c")) - assert(Utils.splitCommandString("a 'b c'") === Seq("a", "b c")) - assert(Utils.splitCommandString("a 'b c' d") === Seq("a", "b c", "d")) - assert(Utils.splitCommandString("'b c'") === Seq("b c")) - assert(Utils.splitCommandString("a \"b c\"") === Seq("a", "b c")) - assert(Utils.splitCommandString("a \"b c\" d") === Seq("a", "b c", "d")) - assert(Utils.splitCommandString("\"b c\"") === Seq("b c")) - assert(Utils.splitCommandString("a 'b\" c' \"d' e\"") === Seq("a", "b\" c", "d' e")) - assert(Utils.splitCommandString("a\t'b\nc'\nd") === Seq("a", "b\nc", "d")) - assert(Utils.splitCommandString("a \"b\\\\c\"") === Seq("a", "b\\c")) - assert(Utils.splitCommandString("a \"b\\\"c\"") === Seq("a", "b\"c")) - assert(Utils.splitCommandString("a 'b\\\"c'") === Seq("a", "b\\\"c")) - assert(Utils.splitCommandString("'a'b") === Seq("ab")) - assert(Utils.splitCommandString("'a''b'") === Seq("ab")) - assert(Utils.splitCommandString("\"a\"b") === Seq("ab")) - assert(Utils.splitCommandString("\"a\"\"b\"") === Seq("ab")) - assert(Utils.splitCommandString("''") === Seq("")) - assert(Utils.splitCommandString("\"\"") === Seq("")) - } - - test("string formatting of time durations") { - val second = 1000 - val minute = second * 60 - val hour = minute * 60 - def str = Utils.msDurationToString(_) - - assert(str(123) === "123 ms") - assert(str(second) === "1.0 s") - assert(str(second + 462) === "1.5 s") - assert(str(hour) === "1.00 h") - assert(str(minute) === "1.0 m") - assert(str(minute + 4 * second + 34) === "1.1 m") - assert(str(10 * hour + minute + 4 * second) === "10.02 h") - assert(str(10 * hour + 59 * minute + 59 * second + 999) === "11.00 h") - } - - test("reading offset bytes of a file") { - val tmpDir2 = Files.createTempDir() - val f1Path = tmpDir2 + "/f1" - val f1 = new FileOutputStream(f1Path) - f1.write("1\n2\n3\n4\n5\n6\n7\n8\n9\n".getBytes(Charsets.UTF_8)) - f1.close() - - // Read first few bytes - assert(Utils.offsetBytes(f1Path, 0, 5) === "1\n2\n3") - - // Read some middle bytes - assert(Utils.offsetBytes(f1Path, 4, 11) === "3\n4\n5\n6") - - // Read last few bytes - assert(Utils.offsetBytes(f1Path, 12, 18) === "7\n8\n9\n") - - // Read some nonexistent bytes in the beginning - assert(Utils.offsetBytes(f1Path, -5, 5) === "1\n2\n3") - - // Read some nonexistent bytes at the end - assert(Utils.offsetBytes(f1Path, 12, 22) === "7\n8\n9\n") - - // Read some nonexistent bytes on both ends - assert(Utils.offsetBytes(f1Path, -3, 25) === "1\n2\n3\n4\n5\n6\n7\n8\n9\n") - - FileUtils.deleteDirectory(tmpDir2) - } -} - diff --git a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala new file mode 100644 index 0000000000..31f97fc139 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala @@ -0,0 +1,300 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rdd + +import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.HashSet + +import org.scalatest.FunSuite + +import com.google.common.io.Files +import org.apache.spark.SparkContext._ +import org.apache.spark.{Partitioner, SharedSparkContext} + + +class PairRDDFunctionsSuite extends FunSuite with SharedSparkContext { + test("groupByKey") { + val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (2, 1))) + val groups = pairs.groupByKey().collect() + assert(groups.size === 2) + val valuesFor1 = groups.find(_._1 == 1).get._2 + assert(valuesFor1.toList.sorted === List(1, 2, 3)) + val valuesFor2 = groups.find(_._1 == 2).get._2 + assert(valuesFor2.toList.sorted === List(1)) + } + + test("groupByKey with duplicates") { + val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1))) + val groups = pairs.groupByKey().collect() + assert(groups.size === 2) + val valuesFor1 = groups.find(_._1 == 1).get._2 + assert(valuesFor1.toList.sorted === List(1, 1, 2, 3)) + val valuesFor2 = groups.find(_._1 == 2).get._2 + assert(valuesFor2.toList.sorted === List(1)) + } + + test("groupByKey with negative key hash codes") { + val pairs = sc.parallelize(Array((-1, 1), (-1, 2), (-1, 3), (2, 1))) + val groups = pairs.groupByKey().collect() + assert(groups.size === 2) + val valuesForMinus1 = groups.find(_._1 == -1).get._2 + assert(valuesForMinus1.toList.sorted === List(1, 2, 3)) + val valuesFor2 = groups.find(_._1 == 2).get._2 + assert(valuesFor2.toList.sorted === List(1)) + } + + test("groupByKey with many output partitions") { + val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (2, 1))) + val groups = pairs.groupByKey(10).collect() + assert(groups.size === 2) + val valuesFor1 = groups.find(_._1 == 1).get._2 + assert(valuesFor1.toList.sorted === List(1, 2, 3)) + val valuesFor2 = groups.find(_._1 == 2).get._2 + assert(valuesFor2.toList.sorted === List(1)) + } + + test("reduceByKey") { + val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1))) + val sums = pairs.reduceByKey(_+_).collect() + assert(sums.toSet === Set((1, 7), (2, 1))) + } + + test("reduceByKey with collectAsMap") { + val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1))) + val sums = pairs.reduceByKey(_+_).collectAsMap() + assert(sums.size === 2) + assert(sums(1) === 7) + assert(sums(2) === 1) + } + + test("reduceByKey with many output partitons") { + val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1))) + val sums = pairs.reduceByKey(_+_, 10).collect() + assert(sums.toSet === Set((1, 7), (2, 1))) + } + + test("reduceByKey with partitioner") { + val p = new Partitioner() { + def numPartitions = 2 + def getPartition(key: Any) = key.asInstanceOf[Int] + } + val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 1), (0, 1))).partitionBy(p) + val sums = pairs.reduceByKey(_+_) + assert(sums.collect().toSet === Set((1, 4), (0, 1))) + assert(sums.partitioner === Some(p)) + // count the dependencies to make sure there is only 1 ShuffledRDD + val deps = new HashSet[RDD[_]]() + def visit(r: RDD[_]) { + for (dep <- r.dependencies) { + deps += dep.rdd + visit(dep.rdd) + } + } + visit(sums) + assert(deps.size === 2) // ShuffledRDD, ParallelCollection + } + + test("join") { + val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1))) + val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w'))) + val joined = rdd1.join(rdd2).collect() + assert(joined.size === 4) + assert(joined.toSet === Set( + (1, (1, 'x')), + (1, (2, 'x')), + (2, (1, 'y')), + (2, (1, 'z')) + )) + } + + test("join all-to-all") { + val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (1, 3))) + val rdd2 = sc.parallelize(Array((1, 'x'), (1, 'y'))) + val joined = rdd1.join(rdd2).collect() + assert(joined.size === 6) + assert(joined.toSet === Set( + (1, (1, 'x')), + (1, (1, 'y')), + (1, (2, 'x')), + (1, (2, 'y')), + (1, (3, 'x')), + (1, (3, 'y')) + )) + } + + test("leftOuterJoin") { + val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1))) + val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w'))) + val joined = rdd1.leftOuterJoin(rdd2).collect() + assert(joined.size === 5) + assert(joined.toSet === Set( + (1, (1, Some('x'))), + (1, (2, Some('x'))), + (2, (1, Some('y'))), + (2, (1, Some('z'))), + (3, (1, None)) + )) + } + + test("rightOuterJoin") { + val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1))) + val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w'))) + val joined = rdd1.rightOuterJoin(rdd2).collect() + assert(joined.size === 5) + assert(joined.toSet === Set( + (1, (Some(1), 'x')), + (1, (Some(2), 'x')), + (2, (Some(1), 'y')), + (2, (Some(1), 'z')), + (4, (None, 'w')) + )) + } + + test("join with no matches") { + val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1))) + val rdd2 = sc.parallelize(Array((4, 'x'), (5, 'y'), (5, 'z'), (6, 'w'))) + val joined = rdd1.join(rdd2).collect() + assert(joined.size === 0) + } + + test("join with many output partitions") { + val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1))) + val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w'))) + val joined = rdd1.join(rdd2, 10).collect() + assert(joined.size === 4) + assert(joined.toSet === Set( + (1, (1, 'x')), + (1, (2, 'x')), + (2, (1, 'y')), + (2, (1, 'z')) + )) + } + + test("groupWith") { + val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1))) + val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w'))) + val joined = rdd1.groupWith(rdd2).collect() + assert(joined.size === 4) + assert(joined.toSet === Set( + (1, (ArrayBuffer(1, 2), ArrayBuffer('x'))), + (2, (ArrayBuffer(1), ArrayBuffer('y', 'z'))), + (3, (ArrayBuffer(1), ArrayBuffer())), + (4, (ArrayBuffer(), ArrayBuffer('w'))) + )) + } + + test("zero-partition RDD") { + val emptyDir = Files.createTempDir() + val file = sc.textFile(emptyDir.getAbsolutePath) + assert(file.partitions.size == 0) + assert(file.collect().toList === Nil) + // Test that a shuffle on the file works, because this used to be a bug + assert(file.map(line => (line, 1)).reduceByKey(_ + _).collect().toList === Nil) + } + + test("keys and values") { + val rdd = sc.parallelize(Array((1, "a"), (2, "b"))) + assert(rdd.keys.collect().toList === List(1, 2)) + assert(rdd.values.collect().toList === List("a", "b")) + } + + test("default partitioner uses partition size") { + // specify 2000 partitions + val a = sc.makeRDD(Array(1, 2, 3, 4), 2000) + // do a map, which loses the partitioner + val b = a.map(a => (a, (a * 2).toString)) + // then a group by, and see we didn't revert to 2 partitions + val c = b.groupByKey() + assert(c.partitions.size === 2000) + } + + test("default partitioner uses largest partitioner") { + val a = sc.makeRDD(Array((1, "a"), (2, "b")), 2) + val b = sc.makeRDD(Array((1, "a"), (2, "b")), 2000) + val c = a.join(b) + assert(c.partitions.size === 2000) + } + + test("subtract") { + val a = sc.parallelize(Array(1, 2, 3), 2) + val b = sc.parallelize(Array(2, 3, 4), 4) + val c = a.subtract(b) + assert(c.collect().toSet === Set(1)) + assert(c.partitions.size === a.partitions.size) + } + + test("subtract with narrow dependency") { + // use a deterministic partitioner + val p = new Partitioner() { + def numPartitions = 5 + def getPartition(key: Any) = key.asInstanceOf[Int] + } + // partitionBy so we have a narrow dependency + val a = sc.parallelize(Array((1, "a"), (2, "b"), (3, "c"))).partitionBy(p) + // more partitions/no partitioner so a shuffle dependency + val b = sc.parallelize(Array((2, "b"), (3, "cc"), (4, "d")), 4) + val c = a.subtract(b) + assert(c.collect().toSet === Set((1, "a"), (3, "c"))) + // Ideally we could keep the original partitioner... + assert(c.partitioner === None) + } + + test("subtractByKey") { + val a = sc.parallelize(Array((1, "a"), (1, "a"), (2, "b"), (3, "c")), 2) + val b = sc.parallelize(Array((2, 20), (3, 30), (4, 40)), 4) + val c = a.subtractByKey(b) + assert(c.collect().toSet === Set((1, "a"), (1, "a"))) + assert(c.partitions.size === a.partitions.size) + } + + test("subtractByKey with narrow dependency") { + // use a deterministic partitioner + val p = new Partitioner() { + def numPartitions = 5 + def getPartition(key: Any) = key.asInstanceOf[Int] + } + // partitionBy so we have a narrow dependency + val a = sc.parallelize(Array((1, "a"), (1, "a"), (2, "b"), (3, "c"))).partitionBy(p) + // more partitions/no partitioner so a shuffle dependency + val b = sc.parallelize(Array((2, "b"), (3, "cc"), (4, "d")), 4) + val c = a.subtractByKey(b) + assert(c.collect().toSet === Set((1, "a"), (1, "a"))) + assert(c.partitioner.get === p) + } + + test("foldByKey") { + val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1))) + val sums = pairs.foldByKey(0)(_+_).collect() + assert(sums.toSet === Set((1, 7), (2, 1))) + } + + test("foldByKey with mutable result type") { + val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1))) + val bufs = pairs.mapValues(v => ArrayBuffer(v)).cache() + // Fold the values using in-place mutation + val sums = bufs.foldByKey(new ArrayBuffer[Int])(_ ++= _).collect() + assert(sums.toSet === Set((1, ArrayBuffer(1, 2, 3, 1)), (2, ArrayBuffer(1)))) + // Check that the mutable objects in the original RDD were not changed + assert(bufs.collect().toSet === Set( + (1, ArrayBuffer(1)), + (1, ArrayBuffer(2)), + (1, ArrayBuffer(3)), + (1, ArrayBuffer(1)), + (2, ArrayBuffer(1)))) + } +} diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala new file mode 100644 index 0000000000..adc971050e --- /dev/null +++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala @@ -0,0 +1,391 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rdd + +import scala.collection.mutable.HashMap +import org.scalatest.FunSuite +import org.scalatest.concurrent.Timeouts._ +import org.scalatest.time.{Span, Millis} +import org.apache.spark.SparkContext._ +import org.apache.spark.rdd._ +import scala.collection.parallel.mutable +import org.apache.spark._ +import org.apache.spark.rdd.CoalescedRDDPartition + +class RDDSuite extends FunSuite with SharedSparkContext { + + test("basic operations") { + val nums = sc.makeRDD(Array(1, 2, 3, 4), 2) + assert(nums.collect().toList === List(1, 2, 3, 4)) + val dups = sc.makeRDD(Array(1, 1, 2, 2, 3, 3, 4, 4), 2) + assert(dups.distinct().count() === 4) + assert(dups.distinct.count === 4) // Can distinct and count be called without parentheses? + assert(dups.distinct.collect === dups.distinct().collect) + assert(dups.distinct(2).collect === dups.distinct().collect) + assert(nums.reduce(_ + _) === 10) + assert(nums.fold(0)(_ + _) === 10) + assert(nums.map(_.toString).collect().toList === List("1", "2", "3", "4")) + assert(nums.filter(_ > 2).collect().toList === List(3, 4)) + assert(nums.flatMap(x => 1 to x).collect().toList === List(1, 1, 2, 1, 2, 3, 1, 2, 3, 4)) + assert(nums.union(nums).collect().toList === List(1, 2, 3, 4, 1, 2, 3, 4)) + assert(nums.glom().map(_.toList).collect().toList === List(List(1, 2), List(3, 4))) + assert(nums.collect({ case i if i >= 3 => i.toString }).collect().toList === List("3", "4")) + assert(nums.keyBy(_.toString).collect().toList === List(("1", 1), ("2", 2), ("3", 3), ("4", 4))) + val partitionSums = nums.mapPartitions(iter => Iterator(iter.reduceLeft(_ + _))) + assert(partitionSums.collect().toList === List(3, 7)) + + val partitionSumsWithSplit = nums.mapPartitionsWithSplit { + case(split, iter) => Iterator((split, iter.reduceLeft(_ + _))) + } + assert(partitionSumsWithSplit.collect().toList === List((0, 3), (1, 7))) + + val partitionSumsWithIndex = nums.mapPartitionsWithIndex { + case(split, iter) => Iterator((split, iter.reduceLeft(_ + _))) + } + assert(partitionSumsWithIndex.collect().toList === List((0, 3), (1, 7))) + + intercept[UnsupportedOperationException] { + nums.filter(_ > 5).reduce(_ + _) + } + } + + test("SparkContext.union") { + val nums = sc.makeRDD(Array(1, 2, 3, 4), 2) + assert(sc.union(nums).collect().toList === List(1, 2, 3, 4)) + assert(sc.union(nums, nums).collect().toList === List(1, 2, 3, 4, 1, 2, 3, 4)) + assert(sc.union(Seq(nums)).collect().toList === List(1, 2, 3, 4)) + assert(sc.union(Seq(nums, nums)).collect().toList === List(1, 2, 3, 4, 1, 2, 3, 4)) + } + + test("aggregate") { + val pairs = sc.makeRDD(Array(("a", 1), ("b", 2), ("a", 2), ("c", 5), ("a", 3))) + type StringMap = HashMap[String, Int] + val emptyMap = new StringMap { + override def default(key: String): Int = 0 + } + val mergeElement: (StringMap, (String, Int)) => StringMap = (map, pair) => { + map(pair._1) += pair._2 + map + } + val mergeMaps: (StringMap, StringMap) => StringMap = (map1, map2) => { + for ((key, value) <- map2) { + map1(key) += value + } + map1 + } + val result = pairs.aggregate(emptyMap)(mergeElement, mergeMaps) + assert(result.toSet === Set(("a", 6), ("b", 2), ("c", 5))) + } + + test("basic caching") { + val rdd = sc.makeRDD(Array(1, 2, 3, 4), 2).cache() + assert(rdd.collect().toList === List(1, 2, 3, 4)) + assert(rdd.collect().toList === List(1, 2, 3, 4)) + assert(rdd.collect().toList === List(1, 2, 3, 4)) + } + + test("caching with failures") { + val onlySplit = new Partition { override def index: Int = 0 } + var shouldFail = true + val rdd = new RDD[Int](sc, Nil) { + override def getPartitions: Array[Partition] = Array(onlySplit) + override val getDependencies = List[Dependency[_]]() + override def compute(split: Partition, context: TaskContext): Iterator[Int] = { + if (shouldFail) { + throw new Exception("injected failure") + } else { + return Array(1, 2, 3, 4).iterator + } + } + }.cache() + val thrown = intercept[Exception]{ + rdd.collect() + } + assert(thrown.getMessage.contains("injected failure")) + shouldFail = false + assert(rdd.collect().toList === List(1, 2, 3, 4)) + } + + test("empty RDD") { + val empty = new EmptyRDD[Int](sc) + assert(empty.count === 0) + assert(empty.collect().size === 0) + + val thrown = intercept[UnsupportedOperationException]{ + empty.reduce(_+_) + } + assert(thrown.getMessage.contains("empty")) + + val emptyKv = new EmptyRDD[(Int, Int)](sc) + val rdd = sc.parallelize(1 to 2, 2).map(x => (x, x)) + assert(rdd.join(emptyKv).collect().size === 0) + assert(rdd.rightOuterJoin(emptyKv).collect().size === 0) + assert(rdd.leftOuterJoin(emptyKv).collect().size === 2) + assert(rdd.cogroup(emptyKv).collect().size === 2) + assert(rdd.union(emptyKv).collect().size === 2) + } + + test("cogrouped RDDs") { + val data = sc.parallelize(1 to 10, 10) + + val coalesced1 = data.coalesce(2) + assert(coalesced1.collect().toList === (1 to 10).toList) + assert(coalesced1.glom().collect().map(_.toList).toList === + List(List(1, 2, 3, 4, 5), List(6, 7, 8, 9, 10))) + + // Check that the narrow dependency is also specified correctly + assert(coalesced1.dependencies.head.asInstanceOf[NarrowDependency[_]].getParents(0).toList === + List(0, 1, 2, 3, 4)) + assert(coalesced1.dependencies.head.asInstanceOf[NarrowDependency[_]].getParents(1).toList === + List(5, 6, 7, 8, 9)) + + val coalesced2 = data.coalesce(3) + assert(coalesced2.collect().toList === (1 to 10).toList) + assert(coalesced2.glom().collect().map(_.toList).toList === + List(List(1, 2, 3), List(4, 5, 6), List(7, 8, 9, 10))) + + val coalesced3 = data.coalesce(10) + assert(coalesced3.collect().toList === (1 to 10).toList) + assert(coalesced3.glom().collect().map(_.toList).toList === + (1 to 10).map(x => List(x)).toList) + + // If we try to coalesce into more partitions than the original RDD, it should just + // keep the original number of partitions. + val coalesced4 = data.coalesce(20) + assert(coalesced4.collect().toList === (1 to 10).toList) + assert(coalesced4.glom().collect().map(_.toList).toList === + (1 to 10).map(x => List(x)).toList) + + // we can optionally shuffle to keep the upstream parallel + val coalesced5 = data.coalesce(1, shuffle = true) + assert(coalesced5.dependencies.head.rdd.dependencies.head.rdd.asInstanceOf[ShuffledRDD[_, _, _]] != + null) + } + test("cogrouped RDDs with locality") { + val data3 = sc.makeRDD(List((1,List("a","c")), (2,List("a","b","c")), (3,List("b")))) + val coal3 = data3.coalesce(3) + val list3 = coal3.partitions.map(p => p.asInstanceOf[CoalescedRDDPartition].preferredLocation) + assert(list3.sorted === Array("a","b","c"), "Locality preferences are dropped") + + // RDD with locality preferences spread (non-randomly) over 6 machines, m0 through m5 + val data = sc.makeRDD((1 to 9).map(i => (i, (i to (i+2)).map{ j => "m" + (j%6)}))) + val coalesced1 = data.coalesce(3) + assert(coalesced1.collect().toList.sorted === (1 to 9).toList, "Data got *lost* in coalescing") + + val splits = coalesced1.glom().collect().map(_.toList).toList + assert(splits.length === 3, "Supposed to coalesce to 3 but got " + splits.length) + + assert(splits.forall(_.length >= 1) === true, "Some partitions were empty") + + // If we try to coalesce into more partitions than the original RDD, it should just + // keep the original number of partitions. + val coalesced4 = data.coalesce(20) + val listOfLists = coalesced4.glom().collect().map(_.toList).toList + val sortedList = listOfLists.sortWith{ (x, y) => !x.isEmpty && (y.isEmpty || (x(0) < y(0))) } + assert( sortedList === (1 to 9). + map{x => List(x)}.toList, "Tried coalescing 9 partitions to 20 but didn't get 9 back") + } + + test("cogrouped RDDs with locality, large scale (10K partitions)") { + // large scale experiment + import collection.mutable + val rnd = scala.util.Random + val partitions = 10000 + val numMachines = 50 + val machines = mutable.ListBuffer[String]() + (1 to numMachines).foreach(machines += "m"+_) + + val blocks = (1 to partitions).map(i => + { (i, Array.fill(3)(machines(rnd.nextInt(machines.size))).toList) } ) + + val data2 = sc.makeRDD(blocks) + val coalesced2 = data2.coalesce(numMachines*2) + + // test that you get over 90% locality in each group + val minLocality = coalesced2.partitions + .map(part => part.asInstanceOf[CoalescedRDDPartition].localFraction) + .foldLeft(1.)((perc, loc) => math.min(perc,loc)) + assert(minLocality >= 0.90, "Expected 90% locality but got " + (minLocality*100.).toInt + "%") + + // test that the groups are load balanced with 100 +/- 20 elements in each + val maxImbalance = coalesced2.partitions + .map(part => part.asInstanceOf[CoalescedRDDPartition].parents.size) + .foldLeft(0)((dev, curr) => math.max(math.abs(100-curr),dev)) + assert(maxImbalance <= 20, "Expected 100 +/- 20 per partition, but got " + maxImbalance) + + val data3 = sc.makeRDD(blocks).map(i => i*2) // derived RDD to test *current* pref locs + val coalesced3 = data3.coalesce(numMachines*2) + val minLocality2 = coalesced3.partitions + .map(part => part.asInstanceOf[CoalescedRDDPartition].localFraction) + .foldLeft(1.)((perc, loc) => math.min(perc,loc)) + assert(minLocality2 >= 0.90, "Expected 90% locality for derived RDD but got " + + (minLocality2*100.).toInt + "%") + } + + test("zipped RDDs") { + val nums = sc.makeRDD(Array(1, 2, 3, 4), 2) + val zipped = nums.zip(nums.map(_ + 1.0)) + assert(zipped.glom().map(_.toList).collect().toList === + List(List((1, 2.0), (2, 3.0)), List((3, 4.0), (4, 5.0)))) + + intercept[IllegalArgumentException] { + nums.zip(sc.parallelize(1 to 4, 1)).collect() + } + } + + test("partition pruning") { + val data = sc.parallelize(1 to 10, 10) + // Note that split number starts from 0, so > 8 means only 10th partition left. + val prunedRdd = new PartitionPruningRDD(data, splitNum => splitNum > 8) + assert(prunedRdd.partitions.size === 1) + val prunedData = prunedRdd.collect() + assert(prunedData.size === 1) + assert(prunedData(0) === 10) + } + + test("mapWith") { + import java.util.Random + val ones = sc.makeRDD(Array(1, 1, 1, 1, 1, 1), 2) + val randoms = ones.mapWith( + (index: Int) => new Random(index + 42)) + {(t: Int, prng: Random) => prng.nextDouble * t}.collect() + val prn42_3 = { + val prng42 = new Random(42) + prng42.nextDouble(); prng42.nextDouble(); prng42.nextDouble() + } + val prn43_3 = { + val prng43 = new Random(43) + prng43.nextDouble(); prng43.nextDouble(); prng43.nextDouble() + } + assert(randoms(2) === prn42_3) + assert(randoms(5) === prn43_3) + } + + test("flatMapWith") { + import java.util.Random + val ones = sc.makeRDD(Array(1, 1, 1, 1, 1, 1), 2) + val randoms = ones.flatMapWith( + (index: Int) => new Random(index + 42)) + {(t: Int, prng: Random) => + val random = prng.nextDouble() + Seq(random * t, random * t * 10)}. + collect() + val prn42_3 = { + val prng42 = new Random(42) + prng42.nextDouble(); prng42.nextDouble(); prng42.nextDouble() + } + val prn43_3 = { + val prng43 = new Random(43) + prng43.nextDouble(); prng43.nextDouble(); prng43.nextDouble() + } + assert(randoms(5) === prn42_3 * 10) + assert(randoms(11) === prn43_3 * 10) + } + + test("filterWith") { + import java.util.Random + val ints = sc.makeRDD(Array(1, 2, 3, 4, 5, 6), 2) + val sample = ints.filterWith( + (index: Int) => new Random(index + 42)) + {(t: Int, prng: Random) => prng.nextInt(3) == 0}. + collect() + val checkSample = { + val prng42 = new Random(42) + val prng43 = new Random(43) + Array(1, 2, 3, 4, 5, 6).filter{i => + if (i < 4) 0 == prng42.nextInt(3) + else 0 == prng43.nextInt(3)} + } + assert(sample.size === checkSample.size) + for (i <- 0 until sample.size) assert(sample(i) === checkSample(i)) + } + + test("top with predefined ordering") { + val nums = Array.range(1, 100000) + val ints = sc.makeRDD(scala.util.Random.shuffle(nums), 2) + val topK = ints.top(5) + assert(topK.size === 5) + assert(topK === nums.reverse.take(5)) + } + + test("top with custom ordering") { + val words = Vector("a", "b", "c", "d") + implicit val ord = implicitly[Ordering[String]].reverse + val rdd = sc.makeRDD(words, 2) + val topK = rdd.top(2) + assert(topK.size === 2) + assert(topK.sorted === Array("b", "a")) + } + + test("takeOrdered with predefined ordering") { + val nums = Array(1, 2, 3, 4, 5, 6, 7, 8, 9, 10) + val rdd = sc.makeRDD(nums, 2) + val sortedLowerK = rdd.takeOrdered(5) + assert(sortedLowerK.size === 5) + assert(sortedLowerK === Array(1, 2, 3, 4, 5)) + } + + test("takeOrdered with custom ordering") { + val nums = Array(1, 2, 3, 4, 5, 6, 7, 8, 9, 10) + implicit val ord = implicitly[Ordering[Int]].reverse + val rdd = sc.makeRDD(nums, 2) + val sortedTopK = rdd.takeOrdered(5) + assert(sortedTopK.size === 5) + assert(sortedTopK === Array(10, 9, 8, 7, 6)) + assert(sortedTopK === nums.sorted(ord).take(5)) + } + + test("takeSample") { + val data = sc.parallelize(1 to 100, 2) + for (seed <- 1 to 5) { + val sample = data.takeSample(withReplacement=false, 20, seed) + assert(sample.size === 20) // Got exactly 20 elements + assert(sample.toSet.size === 20) // Elements are distinct + assert(sample.forall(x => 1 <= x && x <= 100), "elements not in [1, 100]") + } + for (seed <- 1 to 5) { + val sample = data.takeSample(withReplacement=false, 200, seed) + assert(sample.size === 100) // Got only 100 elements + assert(sample.toSet.size === 100) // Elements are distinct + assert(sample.forall(x => 1 <= x && x <= 100), "elements not in [1, 100]") + } + for (seed <- 1 to 5) { + val sample = data.takeSample(withReplacement=true, 20, seed) + assert(sample.size === 20) // Got exactly 20 elements + assert(sample.forall(x => 1 <= x && x <= 100), "elements not in [1, 100]") + } + for (seed <- 1 to 5) { + val sample = data.takeSample(withReplacement=true, 100, seed) + assert(sample.size === 100) // Got exactly 100 elements + // Chance of getting all distinct elements is astronomically low, so test we got < 100 + assert(sample.toSet.size < 100, "sampling with replacement returned all distinct elements") + } + for (seed <- 1 to 5) { + val sample = data.takeSample(withReplacement=true, 200, seed) + assert(sample.size === 200) // Got exactly 200 elements + // Chance of getting all distinct elements is still quite low, so test we got < 100 + assert(sample.toSet.size < 100, "sampling with replacement returned all distinct elements") + } + } + + test("runJob on an invalid partition") { + intercept[IllegalArgumentException] { + sc.runJob(sc.parallelize(1 to 10, 2), {iter: Iterator[Int] => iter.size}, Seq(0, 1, 2), false) + } + } +} diff --git a/core/src/test/scala/org/apache/spark/rdd/SortingSuite.scala b/core/src/test/scala/org/apache/spark/rdd/SortingSuite.scala new file mode 100644 index 0000000000..2f7bd370fc --- /dev/null +++ b/core/src/test/scala/org/apache/spark/rdd/SortingSuite.scala @@ -0,0 +1,125 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rdd + +import org.scalatest.FunSuite +import org.scalatest.BeforeAndAfter +import org.scalatest.matchers.ShouldMatchers + +import org.apache.spark.{Logging, SharedSparkContext} +import org.apache.spark.SparkContext._ + +class SortingSuite extends FunSuite with SharedSparkContext with ShouldMatchers with Logging { + + test("sortByKey") { + val pairs = sc.parallelize(Array((1, 0), (2, 0), (0, 0), (3, 0)), 2) + assert(pairs.sortByKey().collect() === Array((0,0), (1,0), (2,0), (3,0))) + } + + test("large array") { + val rand = new scala.util.Random() + val pairArr = Array.fill(1000) { (rand.nextInt(), rand.nextInt()) } + val pairs = sc.parallelize(pairArr, 2) + val sorted = pairs.sortByKey() + assert(sorted.partitions.size === 2) + assert(sorted.collect() === pairArr.sortBy(_._1)) + } + + test("large array with one split") { + val rand = new scala.util.Random() + val pairArr = Array.fill(1000) { (rand.nextInt(), rand.nextInt()) } + val pairs = sc.parallelize(pairArr, 2) + val sorted = pairs.sortByKey(true, 1) + assert(sorted.partitions.size === 1) + assert(sorted.collect() === pairArr.sortBy(_._1)) + } + + test("large array with many partitions") { + val rand = new scala.util.Random() + val pairArr = Array.fill(1000) { (rand.nextInt(), rand.nextInt()) } + val pairs = sc.parallelize(pairArr, 2) + val sorted = pairs.sortByKey(true, 20) + assert(sorted.partitions.size === 20) + assert(sorted.collect() === pairArr.sortBy(_._1)) + } + + test("sort descending") { + val rand = new scala.util.Random() + val pairArr = Array.fill(1000) { (rand.nextInt(), rand.nextInt()) } + val pairs = sc.parallelize(pairArr, 2) + assert(pairs.sortByKey(false).collect() === pairArr.sortWith((x, y) => x._1 > y._1)) + } + + test("sort descending with one split") { + val rand = new scala.util.Random() + val pairArr = Array.fill(1000) { (rand.nextInt(), rand.nextInt()) } + val pairs = sc.parallelize(pairArr, 1) + assert(pairs.sortByKey(false, 1).collect() === pairArr.sortWith((x, y) => x._1 > y._1)) + } + + test("sort descending with many partitions") { + val rand = new scala.util.Random() + val pairArr = Array.fill(1000) { (rand.nextInt(), rand.nextInt()) } + val pairs = sc.parallelize(pairArr, 2) + assert(pairs.sortByKey(false, 20).collect() === pairArr.sortWith((x, y) => x._1 > y._1)) + } + + test("more partitions than elements") { + val rand = new scala.util.Random() + val pairArr = Array.fill(10) { (rand.nextInt(), rand.nextInt()) } + val pairs = sc.parallelize(pairArr, 30) + assert(pairs.sortByKey().collect() === pairArr.sortBy(_._1)) + } + + test("empty RDD") { + val pairArr = new Array[(Int, Int)](0) + val pairs = sc.parallelize(pairArr, 2) + assert(pairs.sortByKey().collect() === pairArr.sortBy(_._1)) + } + + test("partition balancing") { + val pairArr = (1 to 1000).map(x => (x, x)).toArray + val sorted = sc.parallelize(pairArr, 4).sortByKey() + assert(sorted.collect() === pairArr.sortBy(_._1)) + val partitions = sorted.collectPartitions() + logInfo("Partition lengths: " + partitions.map(_.length).mkString(", ")) + partitions(0).length should be > 180 + partitions(1).length should be > 180 + partitions(2).length should be > 180 + partitions(3).length should be > 180 + partitions(0).last should be < partitions(1).head + partitions(1).last should be < partitions(2).head + partitions(2).last should be < partitions(3).head + } + + test("partition balancing for descending sort") { + val pairArr = (1 to 1000).map(x => (x, x)).toArray + val sorted = sc.parallelize(pairArr, 4).sortByKey(false) + assert(sorted.collect() === pairArr.sortBy(_._1).reverse) + val partitions = sorted.collectPartitions() + logInfo("partition lengths: " + partitions.map(_.length).mkString(", ")) + partitions(0).length should be > 180 + partitions(1).length should be > 180 + partitions(2).length should be > 180 + partitions(3).length should be > 180 + partitions(0).last should be > partitions(1).head + partitions(1).last should be > partitions(2).head + partitions(2).last should be > partitions(3).head + } +} + diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 94df282b28..94f66c94c6 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -24,7 +24,7 @@ import org.scalatest.BeforeAndAfter import org.apache.spark.LocalSparkContext import org.apache.spark.MapOutputTracker -import org.apache.spark.RDD +import org.apache.spark.rdd.RDD import org.apache.spark.SparkContext import org.apache.spark.Partition import org.apache.spark.TaskContext diff --git a/core/src/test/scala/org/apache/spark/scheduler/JobLoggerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/JobLoggerSuite.scala index f5b3e97222..cece60dda7 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/JobLoggerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/JobLoggerSuite.scala @@ -19,11 +19,15 @@ package org.apache.spark.scheduler import java.util.Properties import java.util.concurrent.LinkedBlockingQueue + +import scala.collection.mutable + import org.scalatest.FunSuite import org.scalatest.matchers.ShouldMatchers -import scala.collection.mutable + import org.apache.spark._ import org.apache.spark.SparkContext._ +import org.apache.spark.rdd.RDD class JobLoggerSuite extends FunSuite with LocalSparkContext with ShouldMatchers { diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala index 0347cc02d7..e31a116a75 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.scheduler import org.scalatest.FunSuite import org.scalatest.BeforeAndAfter import org.apache.spark.TaskContext -import org.apache.spark.RDD +import org.apache.spark.rdd.RDD import org.apache.spark.SparkContext import org.apache.spark.Partition import org.apache.spark.LocalSparkContext diff --git a/core/src/test/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManagerSuite.scala index a4f63baf3d..ff70a2cdf0 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManagerSuite.scala @@ -26,7 +26,7 @@ import org.apache.spark._ import org.apache.spark.scheduler._ import org.apache.spark.executor.TaskMetrics import java.nio.ByteBuffer -import org.apache.spark.util.FakeClock +import org.apache.spark.util.{Utils, FakeClock} /** * A mock ClusterScheduler implementation that just remembers information about tasks started and diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala new file mode 100644 index 0000000000..0164dda0ba --- /dev/null +++ b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala @@ -0,0 +1,208 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.serializer + +import scala.collection.mutable + +import com.esotericsoftware.kryo.Kryo + +import org.scalatest.FunSuite +import org.apache.spark.SharedSparkContext +import org.apache.spark.serializer.KryoTest._ + +class KryoSerializerSuite extends FunSuite with SharedSparkContext { + test("basic types") { + val ser = (new KryoSerializer).newInstance() + def check[T](t: T) { + assert(ser.deserialize[T](ser.serialize(t)) === t) + } + check(1) + check(1L) + check(1.0f) + check(1.0) + check(1.toByte) + check(1.toShort) + check("") + check("hello") + check(Integer.MAX_VALUE) + check(Integer.MIN_VALUE) + check(java.lang.Long.MAX_VALUE) + check(java.lang.Long.MIN_VALUE) + check[String](null) + check(Array(1, 2, 3)) + check(Array(1L, 2L, 3L)) + check(Array(1.0, 2.0, 3.0)) + check(Array(1.0f, 2.9f, 3.9f)) + check(Array("aaa", "bbb", "ccc")) + check(Array("aaa", "bbb", null)) + check(Array(true, false, true)) + check(Array('a', 'b', 'c')) + check(Array[Int]()) + check(Array(Array("1", "2"), Array("1", "2", "3", "4"))) + } + + test("pairs") { + val ser = (new KryoSerializer).newInstance() + def check[T](t: T) { + assert(ser.deserialize[T](ser.serialize(t)) === t) + } + check((1, 1)) + check((1, 1L)) + check((1L, 1)) + check((1L, 1L)) + check((1.0, 1)) + check((1, 1.0)) + check((1.0, 1.0)) + check((1.0, 1L)) + check((1L, 1.0)) + check((1.0, 1L)) + check(("x", 1)) + check(("x", 1.0)) + check(("x", 1L)) + check((1, "x")) + check((1.0, "x")) + check((1L, "x")) + check(("x", "x")) + } + + test("Scala data structures") { + val ser = (new KryoSerializer).newInstance() + def check[T](t: T) { + assert(ser.deserialize[T](ser.serialize(t)) === t) + } + check(List[Int]()) + check(List[Int](1, 2, 3)) + check(List[String]()) + check(List[String]("x", "y", "z")) + check(None) + check(Some(1)) + check(Some("hi")) + check(mutable.ArrayBuffer(1, 2, 3)) + check(mutable.ArrayBuffer("1", "2", "3")) + check(mutable.Map()) + check(mutable.Map(1 -> "one", 2 -> "two")) + check(mutable.Map("one" -> 1, "two" -> 2)) + check(mutable.HashMap(1 -> "one", 2 -> "two")) + check(mutable.HashMap("one" -> 1, "two" -> 2)) + check(List(Some(mutable.HashMap(1->1, 2->2)), None, Some(mutable.HashMap(3->4)))) + check(List(mutable.HashMap("one" -> 1, "two" -> 2),mutable.HashMap(1->"one",2->"two",3->"three"))) + } + + test("custom registrator") { + System.setProperty("spark.kryo.registrator", classOf[MyRegistrator].getName) + + val ser = (new KryoSerializer).newInstance() + def check[T](t: T) { + assert(ser.deserialize[T](ser.serialize(t)) === t) + } + + check(CaseClass(17, "hello")) + + val c1 = new ClassWithNoArgConstructor + c1.x = 32 + check(c1) + + val c2 = new ClassWithoutNoArgConstructor(47) + check(c2) + + val hashMap = new java.util.HashMap[String, String] + hashMap.put("foo", "bar") + check(hashMap) + + System.clearProperty("spark.kryo.registrator") + } + + test("kryo with collect") { + val control = 1 :: 2 :: Nil + val result = sc.parallelize(control, 2).map(new ClassWithoutNoArgConstructor(_)).collect().map(_.x) + assert(control === result.toSeq) + } + + test("kryo with parallelize") { + val control = 1 :: 2 :: Nil + val result = sc.parallelize(control.map(new ClassWithoutNoArgConstructor(_))).map(_.x).collect() + assert (control === result.toSeq) + } + + test("kryo with parallelize for specialized tuples") { + assert (sc.parallelize( Array((1, 11), (2, 22), (3, 33)) ).count === 3) + } + + test("kryo with parallelize for primitive arrays") { + assert (sc.parallelize( Array(1, 2, 3) ).count === 3) + } + + test("kryo with collect for specialized tuples") { + assert (sc.parallelize( Array((1, 11), (2, 22), (3, 33)) ).collect().head === (1, 11)) + } + + test("kryo with reduce") { + val control = 1 :: 2 :: Nil + val result = sc.parallelize(control, 2).map(new ClassWithoutNoArgConstructor(_)) + .reduce((t1, t2) => new ClassWithoutNoArgConstructor(t1.x + t2.x)).x + assert(control.sum === result) + } + + // TODO: this still doesn't work + ignore("kryo with fold") { + val control = 1 :: 2 :: Nil + val result = sc.parallelize(control, 2).map(new ClassWithoutNoArgConstructor(_)) + .fold(new ClassWithoutNoArgConstructor(10))((t1, t2) => new ClassWithoutNoArgConstructor(t1.x + t2.x)).x + assert(10 + control.sum === result) + } + + override def beforeAll() { + System.setProperty("spark.serializer", "org.apache.spark.serializer.KryoSerializer") + System.setProperty("spark.kryo.registrator", classOf[MyRegistrator].getName) + super.beforeAll() + } + + override def afterAll() { + super.afterAll() + System.clearProperty("spark.kryo.registrator") + System.clearProperty("spark.serializer") + } +} + +object KryoTest { + case class CaseClass(i: Int, s: String) {} + + class ClassWithNoArgConstructor { + var x: Int = 0 + override def equals(other: Any) = other match { + case c: ClassWithNoArgConstructor => x == c.x + case _ => false + } + } + + class ClassWithoutNoArgConstructor(val x: Int) { + override def equals(other: Any) = other match { + case c: ClassWithoutNoArgConstructor => x == c.x + case _ => false + } + } + + class MyRegistrator extends KryoRegistrator { + override def registerClasses(k: Kryo) { + k.register(classOf[CaseClass]) + k.register(classOf[ClassWithNoArgConstructor]) + k.register(classOf[ClassWithoutNoArgConstructor]) + k.register(classOf[java.util.HashMap[_, _]]) + } + } +} diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index 88ba10f2f2..038a9acb85 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -29,12 +29,8 @@ import org.scalatest.concurrent.Timeouts._ import org.scalatest.matchers.ShouldMatchers._ import org.scalatest.time.SpanSugar._ -import org.apache.spark.JavaSerializer -import org.apache.spark.KryoSerializer -import org.apache.spark.SizeEstimator -import org.apache.spark.Utils -import org.apache.spark.util.AkkaUtils -import org.apache.spark.util.ByteBufferInputStream +import org.apache.spark.util.{SizeEstimator, Utils, AkkaUtils, ByteBufferInputStream} +import org.apache.spark.serializer.{JavaSerializer, KryoSerializer} class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodTester { diff --git a/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala new file mode 100644 index 0000000000..0ed366fb70 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala @@ -0,0 +1,147 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +import java.io.NotSerializableException + +import org.scalatest.FunSuite + +import org.apache.spark.SparkContext +import org.apache.spark.LocalSparkContext._ + +class ClosureCleanerSuite extends FunSuite { + test("closures inside an object") { + assert(TestObject.run() === 30) // 6 + 7 + 8 + 9 + } + + test("closures inside a class") { + val obj = new TestClass + assert(obj.run() === 30) // 6 + 7 + 8 + 9 + } + + test("closures inside a class with no default constructor") { + val obj = new TestClassWithoutDefaultConstructor(5) + assert(obj.run() === 30) // 6 + 7 + 8 + 9 + } + + test("closures that don't use fields of the outer class") { + val obj = new TestClassWithoutFieldAccess + assert(obj.run() === 30) // 6 + 7 + 8 + 9 + } + + test("nested closures inside an object") { + assert(TestObjectWithNesting.run() === 96) // 4 * (1+2+3+4) + 4 * (1+2+3+4) + 16 * 1 + } + + test("nested closures inside a class") { + val obj = new TestClassWithNesting(1) + assert(obj.run() === 96) // 4 * (1+2+3+4) + 4 * (1+2+3+4) + 16 * 1 + } +} + +// A non-serializable class we create in closures to make sure that we aren't +// keeping references to unneeded variables from our outer closures. +class NonSerializable {} + +object TestObject { + def run(): Int = { + var nonSer = new NonSerializable + var x = 5 + return withSpark(new SparkContext("local", "test")) { sc => + val nums = sc.parallelize(Array(1, 2, 3, 4)) + nums.map(_ + x).reduce(_ + _) + } + } +} + +class TestClass extends Serializable { + var x = 5 + + def getX = x + + def run(): Int = { + var nonSer = new NonSerializable + return withSpark(new SparkContext("local", "test")) { sc => + val nums = sc.parallelize(Array(1, 2, 3, 4)) + nums.map(_ + getX).reduce(_ + _) + } + } +} + +class TestClassWithoutDefaultConstructor(x: Int) extends Serializable { + def getX = x + + def run(): Int = { + var nonSer = new NonSerializable + return withSpark(new SparkContext("local", "test")) { sc => + val nums = sc.parallelize(Array(1, 2, 3, 4)) + nums.map(_ + getX).reduce(_ + _) + } + } +} + +// This class is not serializable, but we aren't using any of its fields in our +// closures, so they won't have a $outer pointing to it and should still work. +class TestClassWithoutFieldAccess { + var nonSer = new NonSerializable + + def run(): Int = { + var nonSer2 = new NonSerializable + var x = 5 + return withSpark(new SparkContext("local", "test")) { sc => + val nums = sc.parallelize(Array(1, 2, 3, 4)) + nums.map(_ + x).reduce(_ + _) + } + } +} + + +object TestObjectWithNesting { + def run(): Int = { + var nonSer = new NonSerializable + var answer = 0 + return withSpark(new SparkContext("local", "test")) { sc => + val nums = sc.parallelize(Array(1, 2, 3, 4)) + var y = 1 + for (i <- 1 to 4) { + var nonSer2 = new NonSerializable + var x = i + answer += nums.map(_ + x + y).reduce(_ + _) + } + answer + } + } +} + +class TestClassWithNesting(val y: Int) extends Serializable { + def getY = y + + def run(): Int = { + var nonSer = new NonSerializable + var answer = 0 + return withSpark(new SparkContext("local", "test")) { sc => + val nums = sc.parallelize(Array(1, 2, 3, 4)) + for (i <- 1 to 4) { + var nonSer2 = new NonSerializable + var x = i + answer += nums.map(_ + x + getY).reduce(_ + _) + } + answer + } + } +} diff --git a/core/src/test/scala/org/apache/spark/util/SizeEstimatorSuite.scala b/core/src/test/scala/org/apache/spark/util/SizeEstimatorSuite.scala new file mode 100644 index 0000000000..4e40dcbdee --- /dev/null +++ b/core/src/test/scala/org/apache/spark/util/SizeEstimatorSuite.scala @@ -0,0 +1,164 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +import org.scalatest.FunSuite +import org.scalatest.BeforeAndAfterAll +import org.scalatest.PrivateMethodTester + +class DummyClass1 {} + +class DummyClass2 { + val x: Int = 0 +} + +class DummyClass3 { + val x: Int = 0 + val y: Double = 0.0 +} + +class DummyClass4(val d: DummyClass3) { + val x: Int = 0 +} + +object DummyString { + def apply(str: String) : DummyString = new DummyString(str.toArray) +} +class DummyString(val arr: Array[Char]) { + override val hashCode: Int = 0 + // JDK-7 has an extra hash32 field http://hg.openjdk.java.net/jdk7u/jdk7u6/jdk/rev/11987e85555f + @transient val hash32: Int = 0 +} + +class SizeEstimatorSuite + extends FunSuite with BeforeAndAfterAll with PrivateMethodTester { + + var oldArch: String = _ + var oldOops: String = _ + + override def beforeAll() { + // Set the arch to 64-bit and compressedOops to true to get a deterministic test-case + oldArch = System.setProperty("os.arch", "amd64") + oldOops = System.setProperty("spark.test.useCompressedOops", "true") + } + + override def afterAll() { + resetOrClear("os.arch", oldArch) + resetOrClear("spark.test.useCompressedOops", oldOops) + } + + test("simple classes") { + assert(SizeEstimator.estimate(new DummyClass1) === 16) + assert(SizeEstimator.estimate(new DummyClass2) === 16) + assert(SizeEstimator.estimate(new DummyClass3) === 24) + assert(SizeEstimator.estimate(new DummyClass4(null)) === 24) + assert(SizeEstimator.estimate(new DummyClass4(new DummyClass3)) === 48) + } + + // NOTE: The String class definition varies across JDK versions (1.6 vs. 1.7) and vendors + // (Sun vs IBM). Use a DummyString class to make tests deterministic. + test("strings") { + assert(SizeEstimator.estimate(DummyString("")) === 40) + assert(SizeEstimator.estimate(DummyString("a")) === 48) + assert(SizeEstimator.estimate(DummyString("ab")) === 48) + assert(SizeEstimator.estimate(DummyString("abcdefgh")) === 56) + } + + test("primitive arrays") { + assert(SizeEstimator.estimate(new Array[Byte](10)) === 32) + assert(SizeEstimator.estimate(new Array[Char](10)) === 40) + assert(SizeEstimator.estimate(new Array[Short](10)) === 40) + assert(SizeEstimator.estimate(new Array[Int](10)) === 56) + assert(SizeEstimator.estimate(new Array[Long](10)) === 96) + assert(SizeEstimator.estimate(new Array[Float](10)) === 56) + assert(SizeEstimator.estimate(new Array[Double](10)) === 96) + assert(SizeEstimator.estimate(new Array[Int](1000)) === 4016) + assert(SizeEstimator.estimate(new Array[Long](1000)) === 8016) + } + + test("object arrays") { + // Arrays containing nulls should just have one pointer per element + assert(SizeEstimator.estimate(new Array[String](10)) === 56) + assert(SizeEstimator.estimate(new Array[AnyRef](10)) === 56) + + // For object arrays with non-null elements, each object should take one pointer plus + // however many bytes that class takes. (Note that Array.fill calls the code in its + // second parameter separately for each object, so we get distinct objects.) + assert(SizeEstimator.estimate(Array.fill(10)(new DummyClass1)) === 216) + assert(SizeEstimator.estimate(Array.fill(10)(new DummyClass2)) === 216) + assert(SizeEstimator.estimate(Array.fill(10)(new DummyClass3)) === 296) + assert(SizeEstimator.estimate(Array(new DummyClass1, new DummyClass2)) === 56) + + // Past size 100, our samples 100 elements, but we should still get the right size. + assert(SizeEstimator.estimate(Array.fill(1000)(new DummyClass3)) === 28016) + + // If an array contains the *same* element many times, we should only count it once. + val d1 = new DummyClass1 + assert(SizeEstimator.estimate(Array.fill(10)(d1)) === 72) // 10 pointers plus 8-byte object + assert(SizeEstimator.estimate(Array.fill(100)(d1)) === 432) // 100 pointers plus 8-byte object + + // Same thing with huge array containing the same element many times. Note that this won't + // return exactly 4032 because it can't tell that *all* the elements will equal the first + // one it samples, but it should be close to that. + + // TODO: If we sample 100 elements, this should always be 4176 ? + val estimatedSize = SizeEstimator.estimate(Array.fill(1000)(d1)) + assert(estimatedSize >= 4000, "Estimated size " + estimatedSize + " should be more than 4000") + assert(estimatedSize <= 4200, "Estimated size " + estimatedSize + " should be less than 4100") + } + + test("32-bit arch") { + val arch = System.setProperty("os.arch", "x86") + + val initialize = PrivateMethod[Unit]('initialize) + SizeEstimator invokePrivate initialize() + + assert(SizeEstimator.estimate(DummyString("")) === 40) + assert(SizeEstimator.estimate(DummyString("a")) === 48) + assert(SizeEstimator.estimate(DummyString("ab")) === 48) + assert(SizeEstimator.estimate(DummyString("abcdefgh")) === 56) + + resetOrClear("os.arch", arch) + } + + // NOTE: The String class definition varies across JDK versions (1.6 vs. 1.7) and vendors + // (Sun vs IBM). Use a DummyString class to make tests deterministic. + test("64-bit arch with no compressed oops") { + val arch = System.setProperty("os.arch", "amd64") + val oops = System.setProperty("spark.test.useCompressedOops", "false") + + val initialize = PrivateMethod[Unit]('initialize) + SizeEstimator invokePrivate initialize() + + assert(SizeEstimator.estimate(DummyString("")) === 56) + assert(SizeEstimator.estimate(DummyString("a")) === 64) + assert(SizeEstimator.estimate(DummyString("ab")) === 64) + assert(SizeEstimator.estimate(DummyString("abcdefgh")) === 72) + + resetOrClear("os.arch", arch) + resetOrClear("spark.test.useCompressedOops", oops) + } + + def resetOrClear(prop: String, oldValue: String) { + if (oldValue != null) { + System.setProperty(prop, oldValue) + } else { + System.clearProperty(prop) + } + } +} diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala new file mode 100644 index 0000000000..e2859caf58 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala @@ -0,0 +1,139 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +import com.google.common.base.Charsets +import com.google.common.io.Files +import java.io.{ByteArrayOutputStream, ByteArrayInputStream, FileOutputStream, File} +import org.scalatest.FunSuite +import org.apache.commons.io.FileUtils +import scala.util.Random + +class UtilsSuite extends FunSuite { + + test("bytesToString") { + assert(Utils.bytesToString(10) === "10.0 B") + assert(Utils.bytesToString(1500) === "1500.0 B") + assert(Utils.bytesToString(2000000) === "1953.1 KB") + assert(Utils.bytesToString(2097152) === "2.0 MB") + assert(Utils.bytesToString(2306867) === "2.2 MB") + assert(Utils.bytesToString(5368709120L) === "5.0 GB") + assert(Utils.bytesToString(5L * 1024L * 1024L * 1024L * 1024L) === "5.0 TB") + } + + test("copyStream") { + //input array initialization + val bytes = Array.ofDim[Byte](9000) + Random.nextBytes(bytes) + + val os = new ByteArrayOutputStream() + Utils.copyStream(new ByteArrayInputStream(bytes), os) + + assert(os.toByteArray.toList.equals(bytes.toList)) + } + + test("memoryStringToMb") { + assert(Utils.memoryStringToMb("1") === 0) + assert(Utils.memoryStringToMb("1048575") === 0) + assert(Utils.memoryStringToMb("3145728") === 3) + + assert(Utils.memoryStringToMb("1024k") === 1) + assert(Utils.memoryStringToMb("5000k") === 4) + assert(Utils.memoryStringToMb("4024k") === Utils.memoryStringToMb("4024K")) + + assert(Utils.memoryStringToMb("1024m") === 1024) + assert(Utils.memoryStringToMb("5000m") === 5000) + assert(Utils.memoryStringToMb("4024m") === Utils.memoryStringToMb("4024M")) + + assert(Utils.memoryStringToMb("2g") === 2048) + assert(Utils.memoryStringToMb("3g") === Utils.memoryStringToMb("3G")) + + assert(Utils.memoryStringToMb("2t") === 2097152) + assert(Utils.memoryStringToMb("3t") === Utils.memoryStringToMb("3T")) + } + + test("splitCommandString") { + assert(Utils.splitCommandString("") === Seq()) + assert(Utils.splitCommandString("a") === Seq("a")) + assert(Utils.splitCommandString("aaa") === Seq("aaa")) + assert(Utils.splitCommandString("a b c") === Seq("a", "b", "c")) + assert(Utils.splitCommandString(" a b\t c ") === Seq("a", "b", "c")) + assert(Utils.splitCommandString("a 'b c'") === Seq("a", "b c")) + assert(Utils.splitCommandString("a 'b c' d") === Seq("a", "b c", "d")) + assert(Utils.splitCommandString("'b c'") === Seq("b c")) + assert(Utils.splitCommandString("a \"b c\"") === Seq("a", "b c")) + assert(Utils.splitCommandString("a \"b c\" d") === Seq("a", "b c", "d")) + assert(Utils.splitCommandString("\"b c\"") === Seq("b c")) + assert(Utils.splitCommandString("a 'b\" c' \"d' e\"") === Seq("a", "b\" c", "d' e")) + assert(Utils.splitCommandString("a\t'b\nc'\nd") === Seq("a", "b\nc", "d")) + assert(Utils.splitCommandString("a \"b\\\\c\"") === Seq("a", "b\\c")) + assert(Utils.splitCommandString("a \"b\\\"c\"") === Seq("a", "b\"c")) + assert(Utils.splitCommandString("a 'b\\\"c'") === Seq("a", "b\\\"c")) + assert(Utils.splitCommandString("'a'b") === Seq("ab")) + assert(Utils.splitCommandString("'a''b'") === Seq("ab")) + assert(Utils.splitCommandString("\"a\"b") === Seq("ab")) + assert(Utils.splitCommandString("\"a\"\"b\"") === Seq("ab")) + assert(Utils.splitCommandString("''") === Seq("")) + assert(Utils.splitCommandString("\"\"") === Seq("")) + } + + test("string formatting of time durations") { + val second = 1000 + val minute = second * 60 + val hour = minute * 60 + def str = Utils.msDurationToString(_) + + assert(str(123) === "123 ms") + assert(str(second) === "1.0 s") + assert(str(second + 462) === "1.5 s") + assert(str(hour) === "1.00 h") + assert(str(minute) === "1.0 m") + assert(str(minute + 4 * second + 34) === "1.1 m") + assert(str(10 * hour + minute + 4 * second) === "10.02 h") + assert(str(10 * hour + 59 * minute + 59 * second + 999) === "11.00 h") + } + + test("reading offset bytes of a file") { + val tmpDir2 = Files.createTempDir() + val f1Path = tmpDir2 + "/f1" + val f1 = new FileOutputStream(f1Path) + f1.write("1\n2\n3\n4\n5\n6\n7\n8\n9\n".getBytes(Charsets.UTF_8)) + f1.close() + + // Read first few bytes + assert(Utils.offsetBytes(f1Path, 0, 5) === "1\n2\n3") + + // Read some middle bytes + assert(Utils.offsetBytes(f1Path, 4, 11) === "3\n4\n5\n6") + + // Read last few bytes + assert(Utils.offsetBytes(f1Path, 12, 18) === "7\n8\n9\n") + + // Read some nonexistent bytes in the beginning + assert(Utils.offsetBytes(f1Path, -5, 5) === "1\n2\n3") + + // Read some nonexistent bytes at the end + assert(Utils.offsetBytes(f1Path, 12, 22) === "7\n8\n9\n") + + // Read some nonexistent bytes on both ends + assert(Utils.offsetBytes(f1Path, -3, 25) === "1\n2\n3\n4\n5\n6\n7\n8\n9\n") + + FileUtils.deleteDirectory(tmpDir2) + } +} + diff --git a/docs/configuration.md b/docs/configuration.md index 55df18b6fb..58e9434bdc 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -36,13 +36,13 @@ there are at least five properties that you will commonly want to control: spark.serializer - org.apache.spark.JavaSerializer + org.apache.spark.serializer.
JavaSerializer Class to use for serializing objects that will be sent over the network or need to be cached in serialized form. The default of Java serialization works with any Serializable Java object but is - quite slow, so we recommend using org.apache.spark.KryoSerializer + quite slow, so we recommend using org.apache.spark.serializer.KryoSerializer and configuring Kryo serialization when speed is necessary. Can be any subclass of - org.apache.spark.Serializer. + org.apache.spark.Serializer. @@ -51,7 +51,7 @@ there are at least five properties that you will commonly want to control: If you use Kryo serialization, set this class to register your custom classes with Kryo. It should be set to a class that extends - KryoRegistrator. + KryoRegistrator. See the tuning guide for more details. @@ -171,7 +171,7 @@ Apart from these, the following properties are also available, and may be useful spark.closure.serializer - org.apache.spark.JavaSerializer + org.apache.spark.serializer.
JavaSerializer Serializer class to use for closures. Generally Java is fine unless your distributed functions (e.g. map functions) reference large objects in the driver program. diff --git a/docs/quick-start.md b/docs/quick-start.md index 8cbc55bed2..70c3df8095 100644 --- a/docs/quick-start.md +++ b/docs/quick-start.md @@ -109,7 +109,7 @@ We'll create a very simple Spark job in Scala. So simple, in fact, that it's nam {% highlight scala %} /*** SimpleJob.scala ***/ import org.apache.spark.SparkContext -import SparkContext._ +import org.apache.spark.SparkContext._ object SimpleJob { def main(args: Array[String]) { diff --git a/docs/scala-programming-guide.md b/docs/scala-programming-guide.md index 2cf319a263..f7768e55fc 100644 --- a/docs/scala-programming-guide.md +++ b/docs/scala-programming-guide.md @@ -37,7 +37,7 @@ Finally, you need to import some Spark classes and implicit conversions into you {% highlight scala %} import org.apache.spark.SparkContext -import SparkContext._ +import org.apache.spark.SparkContext._ {% endhighlight %} # Initializing Spark diff --git a/docs/tuning.md b/docs/tuning.md index 3563d110c9..28d88a2659 100644 --- a/docs/tuning.md +++ b/docs/tuning.md @@ -38,17 +38,17 @@ in your operations) and performance. It provides two serialization libraries: `Serializable` types and requires you to *register* the classes you'll use in the program in advance for best performance. -You can switch to using Kryo by calling `System.setProperty("spark.serializer", "org.apache.spark.KryoSerializer")` +You can switch to using Kryo by calling `System.setProperty("spark.serializer", "org.apache.spark.serializer.KryoSerializer")` *before* creating your SparkContext. The only reason it is not the default is because of the custom registration requirement, but we recommend trying it in any network-intensive application. Finally, to register your classes with Kryo, create a public class that extends -[`org.apache.spark.KryoRegistrator`](api/core/index.html#org.apache.spark.KryoRegistrator) and set the +[`org.apache.spark.serializer.KryoRegistrator`](api/core/index.html#org.apache.spark.serializer.KryoRegistrator) and set the `spark.kryo.registrator` system property to point to it, as follows: {% highlight scala %} import com.esotericsoftware.kryo.Kryo -import org.apache.spark.KryoRegistrator +import org.apache.spark.serializer.KryoRegistrator class MyRegistrator extends KryoRegistrator { override def registerClasses(kryo: Kryo) { @@ -58,7 +58,7 @@ class MyRegistrator extends KryoRegistrator { } // Make sure to set these properties *before* creating a SparkContext! -System.setProperty("spark.serializer", "org.apache.spark.KryoSerializer") +System.setProperty("spark.serializer", "org.apache.spark.serializer.KryoSerializer") System.setProperty("spark.kryo.registrator", "mypackage.MyRegistrator") val sc = new SparkContext(...) {% endhighlight %} @@ -217,7 +217,7 @@ enough. Spark automatically sets the number of "map" tasks to run on each file a (though you can control it through optional parameters to `SparkContext.textFile`, etc), and for distributed "reduce" operations, such as `groupByKey` and `reduceByKey`, it uses the largest parent RDD's number of partitions. You can pass the level of parallelism as a second argument -(see the [`spark.PairRDDFunctions`](api/core/index.html#org.apache.spark.PairRDDFunctions) documentation), +(see the [`spark.PairRDDFunctions`](api/core/index.html#org.apache.spark.rdd.PairRDDFunctions) documentation), or set the system property `spark.default.parallelism` to change the default. In general, we recommend 2-3 tasks per CPU core in your cluster. diff --git a/examples/src/main/scala/org/apache/spark/examples/bagel/PageRankUtils.scala b/examples/src/main/scala/org/apache/spark/examples/bagel/PageRankUtils.scala index b190e83c4d..cfafbaf23e 100644 --- a/examples/src/main/scala/org/apache/spark/examples/bagel/PageRankUtils.scala +++ b/examples/src/main/scala/org/apache/spark/examples/bagel/PageRankUtils.scala @@ -19,6 +19,7 @@ package org.apache.spark.examples.bagel import org.apache.spark._ import org.apache.spark.SparkContext._ +import org.apache.spark.serializer.KryoRegistrator import org.apache.spark.bagel._ import org.apache.spark.bagel.Bagel._ diff --git a/examples/src/main/scala/org/apache/spark/examples/bagel/WikipediaPageRank.scala b/examples/src/main/scala/org/apache/spark/examples/bagel/WikipediaPageRank.scala index b1f606e48e..72b5c7b88e 100644 --- a/examples/src/main/scala/org/apache/spark/examples/bagel/WikipediaPageRank.scala +++ b/examples/src/main/scala/org/apache/spark/examples/bagel/WikipediaPageRank.scala @@ -37,7 +37,7 @@ object WikipediaPageRank { System.exit(-1) } - System.setProperty("spark.serializer", "org.apache.spark.KryoSerializer") + System.setProperty("spark.serializer", "org.apache.spark.serializer.KryoSerializer") System.setProperty("spark.kryo.registrator", classOf[PRKryoRegistrator].getName) val inputFile = args(0) diff --git a/examples/src/main/scala/org/apache/spark/examples/bagel/WikipediaPageRankStandalone.scala b/examples/src/main/scala/org/apache/spark/examples/bagel/WikipediaPageRankStandalone.scala index 3bfa48eaf3..ddf6855325 100644 --- a/examples/src/main/scala/org/apache/spark/examples/bagel/WikipediaPageRankStandalone.scala +++ b/examples/src/main/scala/org/apache/spark/examples/bagel/WikipediaPageRankStandalone.scala @@ -17,19 +17,16 @@ package org.apache.spark.examples.bagel -import org.apache.spark._ -import serializer.{DeserializationStream, SerializationStream, SerializerInstance} -import org.apache.spark.SparkContext._ - -import org.apache.spark.bagel._ -import org.apache.spark.bagel.Bagel._ - -import scala.xml.{XML,NodeSeq} +import java.io.{InputStream, OutputStream, DataInputStream, DataOutputStream} +import java.nio.ByteBuffer import scala.collection.mutable.ArrayBuffer +import scala.xml.{XML, NodeSeq} -import java.io.{InputStream, OutputStream, DataInputStream, DataOutputStream} -import java.nio.ByteBuffer +import org.apache.spark._ +import org.apache.spark.serializer.{DeserializationStream, SerializationStream, SerializerInstance} +import org.apache.spark.SparkContext._ +import org.apache.spark.rdd.RDD object WikipediaPageRankStandalone { def main(args: Array[String]) { diff --git a/examples/src/main/scala/org/apache/spark/streaming/examples/QueueStream.scala b/examples/src/main/scala/org/apache/spark/streaming/examples/QueueStream.scala index 822da8c9b5..fad512eeba 100644 --- a/examples/src/main/scala/org/apache/spark/streaming/examples/QueueStream.scala +++ b/examples/src/main/scala/org/apache/spark/streaming/examples/QueueStream.scala @@ -17,7 +17,7 @@ package org.apache.spark.streaming.examples -import org.apache.spark.RDD +import org.apache.spark.rdd.RDD import org.apache.spark.streaming.{Seconds, StreamingContext} import org.apache.spark.streaming.StreamingContext._ diff --git a/examples/src/main/scala/org/apache/spark/streaming/examples/RawNetworkGrep.scala b/examples/src/main/scala/org/apache/spark/streaming/examples/RawNetworkGrep.scala index 2e3d9ccf00..0b45c30d20 100644 --- a/examples/src/main/scala/org/apache/spark/streaming/examples/RawNetworkGrep.scala +++ b/examples/src/main/scala/org/apache/spark/streaming/examples/RawNetworkGrep.scala @@ -28,7 +28,7 @@ import org.apache.spark.streaming.util.RawTextHelper * lines have the word 'the' in them. This is useful for benchmarking purposes. This * will only work with spark.streaming.util.RawTextSender running on all worker nodes * and with Spark using Kryo serialization (set Java property "spark.serializer" to - * "org.apache.spark.KryoSerializer"). + * "org.apache.spark.serializer.KryoSerializer"). * Usage: RawNetworkGrep * is the Spark master URL * is the number rawNetworkStreams, which should be same as number diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationModel.scala index 4f4a7f5296..60cb44ce89 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationModel.scala @@ -1,6 +1,6 @@ package org.apache.spark.mllib.classification -import org.apache.spark.RDD +import org.apache.spark.rdd.RDD trait ClassificationModel extends Serializable { /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala index 91bb50c829..50aede9c07 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala @@ -19,7 +19,8 @@ package org.apache.spark.mllib.classification import scala.math.round -import org.apache.spark.{Logging, RDD, SparkContext} +import org.apache.spark.SparkContext +import org.apache.spark.rdd.RDD import org.apache.spark.mllib.optimization._ import org.apache.spark.mllib.regression._ import org.apache.spark.mllib.util.MLUtils diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala index c92c7cc3f3..3511e24bce 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala @@ -19,7 +19,8 @@ package org.apache.spark.mllib.classification import scala.math.signum -import org.apache.spark.{Logging, RDD, SparkContext} +import org.apache.spark.SparkContext +import org.apache.spark.rdd.RDD import org.apache.spark.mllib.optimization._ import org.apache.spark.mllib.regression._ import org.apache.spark.mllib.util.MLUtils diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala index 2c3db099fa..edbf77dbcc 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala @@ -20,8 +20,9 @@ package org.apache.spark.mllib.clustering import scala.collection.mutable.ArrayBuffer import scala.util.Random -import org.apache.spark.{SparkContext, RDD} +import org.apache.spark.SparkContext import org.apache.spark.SparkContext._ +import org.apache.spark.rdd.RDD import org.apache.spark.Logging import org.apache.spark.mllib.util.MLUtils diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala index d1fe5d138d..cfc81c985a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala @@ -17,7 +17,7 @@ package org.apache.spark.mllib.clustering -import org.apache.spark.RDD +import org.apache.spark.rdd.RDD import org.apache.spark.SparkContext._ import org.apache.spark.mllib.util.MLUtils diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala index b62c9b3340..b77364e08d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala @@ -17,7 +17,8 @@ package org.apache.spark.mllib.optimization -import org.apache.spark.{Logging, RDD, SparkContext} +import org.apache.spark.{Logging, SparkContext} +import org.apache.spark.rdd.RDD import org.apache.spark.SparkContext._ import org.jblas.DoubleMatrix @@ -29,8 +30,9 @@ import scala.collection.mutable.ArrayBuffer * @param gradient Gradient function to be used. * @param updater Updater to be used to update weights after every iteration. */ -class GradientDescent(var gradient: Gradient, var updater: Updater) extends Optimizer { - +class GradientDescent(var gradient: Gradient, var updater: Updater) + extends Optimizer with Logging +{ private var stepSize: Double = 1.0 private var numIterations: Int = 100 private var regParam: Double = 0.0 diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/Optimizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/Optimizer.scala index 50059d385d..94d30b56f2 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/Optimizer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/Optimizer.scala @@ -17,7 +17,7 @@ package org.apache.spark.mllib.optimization -import org.apache.spark.RDD +import org.apache.spark.rdd.RDD trait Optimizer { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala index 218217acfe..be002d02bc 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala @@ -21,9 +21,10 @@ import scala.collection.mutable.{ArrayBuffer, BitSet} import scala.util.Random import scala.util.Sorting -import org.apache.spark.{HashPartitioner, Partitioner, SparkContext, RDD} +import org.apache.spark.{HashPartitioner, Partitioner, SparkContext} import org.apache.spark.storage.StorageLevel -import org.apache.spark.KryoRegistrator +import org.apache.spark.rdd.RDD +import org.apache.spark.serializer.KryoRegistrator import org.apache.spark.SparkContext._ import com.esotericsoftware.kryo.Kryo @@ -432,7 +433,7 @@ object ALS { val (master, ratingsFile, rank, iters, outputDir) = (args(0), args(1), args(2).toInt, args(3).toInt, args(4)) val blocks = if (args.length == 6) args(5).toInt else -1 - System.setProperty("spark.serializer", "org.apache.spark.KryoSerializer") + System.setProperty("spark.serializer", "org.apache.spark.serializer.KryoSerializer") System.setProperty("spark.kryo.registrator", classOf[ALSRegistrator].getName) System.setProperty("spark.kryo.referenceTracking", "false") System.setProperty("spark.kryoserializer.buffer.mb", "8") diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala index ae9fe48aec..af43d89c70 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala @@ -17,7 +17,7 @@ package org.apache.spark.mllib.recommendation -import org.apache.spark.RDD +import org.apache.spark.rdd.RDD import org.apache.spark.SparkContext._ import org.jblas._ diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala index 06015110ac..f98b0b536d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala @@ -17,7 +17,8 @@ package org.apache.spark.mllib.regression -import org.apache.spark.{Logging, RDD, SparkException} +import org.apache.spark.{Logging, SparkException} +import org.apache.spark.rdd.RDD import org.apache.spark.mllib.optimization._ import org.jblas.DoubleMatrix diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala index df3beb1959..d959695325 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala @@ -17,7 +17,8 @@ package org.apache.spark.mllib.regression -import org.apache.spark.{Logging, RDD, SparkContext} +import org.apache.spark.{Logging, SparkContext} +import org.apache.spark.rdd.RDD import org.apache.spark.mllib.optimization._ import org.apache.spark.mllib.util.MLUtils diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala index 71f968471c..ae95ea24fc 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala @@ -17,7 +17,8 @@ package org.apache.spark.mllib.regression -import org.apache.spark.{Logging, RDD, SparkContext} +import org.apache.spark.{Logging, SparkContext} +import org.apache.spark.rdd.RDD import org.apache.spark.mllib.optimization._ import org.apache.spark.mllib.util.MLUtils diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/RegressionModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/RegressionModel.scala index 8dd325efc0..423afc32d6 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/RegressionModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/RegressionModel.scala @@ -17,7 +17,7 @@ package org.apache.spark.mllib.regression -import org.apache.spark.RDD +import org.apache.spark.rdd.RDD trait RegressionModel extends Serializable { /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala index 228ab9e4e8..b29508d2b9 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala @@ -17,7 +17,8 @@ package org.apache.spark.mllib.regression -import org.apache.spark.{Logging, RDD, SparkContext} +import org.apache.spark.{Logging, SparkContext} +import org.apache.spark.rdd.RDD import org.apache.spark.mllib.optimization._ import org.apache.spark.mllib.util.MLUtils diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/DataValidators.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/DataValidators.scala index 7fd4623071..8b55bce7c4 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/DataValidators.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/DataValidators.scala @@ -17,7 +17,8 @@ package org.apache.spark.mllib.util -import org.apache.spark.{RDD, Logging} +import org.apache.spark.Logging +import org.apache.spark.rdd.RDD import org.apache.spark.mllib.regression.LabeledPoint /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/KMeansDataGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/KMeansDataGenerator.scala index 6500d47183..9109189dff 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/KMeansDataGenerator.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/KMeansDataGenerator.scala @@ -19,7 +19,8 @@ package org.apache.spark.mllib.util import scala.util.Random -import org.apache.spark.{RDD, SparkContext} +import org.apache.spark.SparkContext +import org.apache.spark.rdd.RDD /** * Generate test data for KMeans. This class first chooses k cluster centers diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala index 4c49d484b4..bc5045fb05 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala @@ -22,7 +22,8 @@ import scala.util.Random import org.jblas.DoubleMatrix -import org.apache.spark.{RDD, SparkContext} +import org.apache.spark.SparkContext +import org.apache.spark.rdd.RDD import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.regression.LabeledPoint diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/LogisticRegressionDataGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/LogisticRegressionDataGenerator.scala index f553298fc5..52c4a71d62 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/LogisticRegressionDataGenerator.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/LogisticRegressionDataGenerator.scala @@ -19,7 +19,8 @@ package org.apache.spark.mllib.util import scala.util.Random -import org.apache.spark.{RDD, SparkContext} +import org.apache.spark.SparkContext +import org.apache.spark.rdd.RDD import org.apache.spark.mllib.regression.LabeledPoint /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/MFDataGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/MFDataGenerator.scala index 7eb69ae81c..5aec867257 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/MFDataGenerator.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/MFDataGenerator.scala @@ -21,7 +21,8 @@ import scala.util.Random import org.jblas.DoubleMatrix -import org.apache.spark.{RDD, SparkContext} +import org.apache.spark.SparkContext +import org.apache.spark.rdd.RDD import org.apache.spark.mllib.util.MLUtils /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala index 0aeafbe23c..d91b74c3ac 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala @@ -17,7 +17,8 @@ package org.apache.spark.mllib.util -import org.apache.spark.{RDD, SparkContext} +import org.apache.spark.SparkContext +import org.apache.spark.rdd.RDD import org.apache.spark.SparkContext._ import org.jblas.DoubleMatrix diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/SVMDataGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/SVMDataGenerator.scala index d3f191b05b..6e9f667635 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/SVMDataGenerator.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/SVMDataGenerator.scala @@ -2,9 +2,10 @@ package org.apache.spark.mllib.util import scala.util.Random -import org.apache.spark.{RDD, SparkContext} - import org.jblas.DoubleMatrix + +import org.apache.spark.SparkContext +import org.apache.spark.rdd.RDD import org.apache.spark.mllib.regression.LabeledPoint /** diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 906e9221a1..8fbf296509 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -114,9 +114,9 @@ class SparkContext(object): self.addPyFile(path) # Create a temporary directory inside spark.local.dir: - local_dir = self._jvm.org.apache.spark.Utils.getLocalDir() + local_dir = self._jvm.org.apache.spark.util.Utils.getLocalDir() self._temp_dir = \ - self._jvm.org.apache.spark.Utils.createTempDir(local_dir).getAbsolutePath() + self._jvm.org.apache.spark.util.Utils.createTempDir(local_dir).getAbsolutePath() @property def defaultParallelism(self): diff --git a/repl/src/main/scala/org/apache/spark/repl/SparkIMain.scala b/repl/src/main/scala/org/apache/spark/repl/SparkIMain.scala index 7e244e48a2..e6e35c9b5d 100644 --- a/repl/src/main/scala/org/apache/spark/repl/SparkIMain.scala +++ b/repl/src/main/scala/org/apache/spark/repl/SparkIMain.scala @@ -28,7 +28,7 @@ import scala.reflect.NameTransformer import SparkIMain._ import org.apache.spark.HttpServer -import org.apache.spark.Utils +import org.apache.spark.util.Utils import org.apache.spark.SparkEnv /** An interpreter for Scala code. diff --git a/streaming/src/main/scala/org/apache/spark/streaming/DStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/DStream.scala index 362247cc38..80da6bd30b 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/DStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/DStream.scala @@ -23,7 +23,8 @@ import org.apache.spark.util.MetadataCleaner //import Time._ -import org.apache.spark.{RDD, Logging} +import org.apache.spark.Logging +import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel import scala.collection.mutable.ArrayBuffer diff --git a/streaming/src/main/scala/org/apache/spark/streaming/Duration.scala b/streaming/src/main/scala/org/apache/spark/streaming/Duration.scala index 290ad37812..6bf275f5af 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/Duration.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/Duration.scala @@ -17,7 +17,7 @@ package org.apache.spark.streaming -import org.apache.spark.Utils +import org.apache.spark.util.Utils case class Duration (private val millis: Long) { diff --git a/streaming/src/main/scala/org/apache/spark/streaming/PairDStreamFunctions.scala b/streaming/src/main/scala/org/apache/spark/streaming/PairDStreamFunctions.scala index d8a7381e87..757bc98981 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/PairDStreamFunctions.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/PairDStreamFunctions.scala @@ -22,8 +22,9 @@ import org.apache.spark.streaming.dstream.{ReducedWindowedDStream, StateDStream} import org.apache.spark.streaming.dstream.{CoGroupedDStream, ShuffledDStream} import org.apache.spark.streaming.dstream.{MapValuedDStream, FlatMapValuedDStream} -import org.apache.spark.{Manifests, RDD, Partitioner, HashPartitioner} +import org.apache.spark.{Partitioner, HashPartitioner} import org.apache.spark.SparkContext._ +import org.apache.spark.rdd.{Manifests, RDD, PairRDDFunctions} import org.apache.spark.storage.StorageLevel import scala.collection.mutable.ArrayBuffer @@ -101,8 +102,8 @@ extends Serializable { /** * Combine elements of each key in DStream's RDDs using custom functions. This is similar to the - * combineByKey for RDDs. Please refer to combineByKey in [[org.apache.spark.PairRDDFunctions]] for more - * information. + * combineByKey for RDDs. Please refer to combineByKey in + * [[org.apache.spark.rdd.PairRDDFunctions]] for more information. */ def combineByKey[C: ClassManifest]( createCombiner: V => C, @@ -379,7 +380,7 @@ extends Serializable { /** * Return a new "state" DStream where the state for each key is updated by applying * the given function on the previous state of the key and the new values of each key. - * [[org.apache.spark.Paxrtitioner]] is used to control the partitioning of each RDD. + * [[org.apache.spark.Partitioner]] is used to control the partitioning of each RDD. * @param updateFunc State update function. If `this` function returns None, then * corresponding state key-value pair will be eliminated. Note, that * this function may generate a different a tuple with a different key diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala index 3852ac2dab..878725c705 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala @@ -24,6 +24,7 @@ import akka.zeromq.Subscribe import org.apache.spark.streaming.dstream._ import org.apache.spark._ +import org.apache.spark.rdd.RDD import org.apache.spark.streaming.receivers.ActorReceiver import org.apache.spark.streaming.receivers.ReceiverSupervisorStrategy import org.apache.spark.streaming.receivers.ZeroMQReceiver diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStream.scala index f8c8d8ece1..d1932b6b05 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStream.scala @@ -21,7 +21,7 @@ import org.apache.spark.streaming.{Duration, Time, DStream} import org.apache.spark.api.java.function.{Function => JFunction} import org.apache.spark.api.java.JavaRDD import org.apache.spark.storage.StorageLevel -import org.apache.spark.RDD +import org.apache.spark.rdd.RDD /** * A Discretized Stream (DStream), the basic abstraction in Spark Streaming, is a continuous diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala index 2e6fe9a9c4..459695b7ca 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala @@ -26,7 +26,7 @@ import org.apache.spark.streaming._ import org.apache.spark.api.java.{JavaPairRDD, JavaRDDLike, JavaRDD} import org.apache.spark.api.java.function.{Function2 => JFunction2, Function => JFunction, _} import java.util -import org.apache.spark.RDD +import org.apache.spark.rdd.RDD import JavaDStream._ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T, R]] diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala index c203dccd17..978fca33ad 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala @@ -25,14 +25,15 @@ import scala.collection.JavaConversions._ import org.apache.spark.streaming._ import org.apache.spark.streaming.StreamingContext._ import org.apache.spark.api.java.function.{Function => JFunction, Function2 => JFunction2} -import org.apache.spark.{RDD, Partitioner} +import org.apache.spark.Partitioner import org.apache.hadoop.mapred.{JobConf, OutputFormat} import org.apache.hadoop.mapreduce.{OutputFormat => NewOutputFormat} import org.apache.hadoop.conf.Configuration import org.apache.spark.api.java.{JavaUtils, JavaRDD, JavaPairRDD} import org.apache.spark.storage.StorageLevel import com.google.common.base.Optional -import org.apache.spark.RDD +import org.apache.spark.rdd.RDD +import org.apache.spark.rdd.PairRDDFunctions class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( implicit val kManifiest: ClassManifest[K], @@ -147,7 +148,7 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( /** * Combine elements of each key in DStream's RDDs using custom function. This is similar to the - * combineByKey for RDDs. Please refer to combineByKey in [[org.apache.spark.PairRDDFunctions]] for more + * combineByKey for RDDs. Please refer to combineByKey in [[PairRDDFunctions]] for more * information. */ def combineByKey[C](createCombiner: JFunction[V, C], diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala index f10beb1db3..54ba3e6025 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala @@ -17,23 +17,26 @@ package org.apache.spark.streaming.api.java -import org.apache.spark.streaming._ -import receivers.{ActorReceiver, ReceiverSupervisorStrategy} -import org.apache.spark.streaming.dstream._ -import org.apache.spark.storage.StorageLevel -import org.apache.spark.api.java.function.{Function => JFunction, Function2 => JFunction2} -import org.apache.spark.api.java.{JavaSparkContext, JavaRDD} +import java.lang.{Long => JLong, Integer => JInt} +import java.io.InputStream +import java.util.{Map => JMap} + +import scala.collection.JavaConversions._ + import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat} import twitter4j.Status import akka.actor.Props import akka.actor.SupervisorStrategy import akka.zeromq.Subscribe -import scala.collection.JavaConversions._ -import java.lang.{Long => JLong, Integer => JInt} -import java.io.InputStream -import java.util.{Map => JMap} import twitter4j.auth.Authorization -import org.apache.spark.RDD + +import org.apache.spark.rdd.RDD +import org.apache.spark.storage.StorageLevel +import org.apache.spark.api.java.function.{Function => JFunction, Function2 => JFunction2} +import org.apache.spark.api.java.{JavaSparkContext, JavaRDD} +import org.apache.spark.streaming._ +import org.apache.spark.streaming.dstream._ +import org.apache.spark.streaming.receivers.{ActorReceiver, ReceiverSupervisorStrategy} /** * A StreamingContext is the main entry point for Spark Streaming functionality. Besides the basic diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/CoGroupedDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/CoGroupedDStream.scala index 4a9d82211f..4eddc755b9 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/CoGroupedDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/CoGroupedDStream.scala @@ -17,7 +17,8 @@ package org.apache.spark.streaming.dstream -import org.apache.spark.{RDD, Partitioner} +import org.apache.spark.Partitioner +import org.apache.spark.rdd.RDD import org.apache.spark.rdd.CoGroupedRDD import org.apache.spark.streaming.{Time, DStream, Duration} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ConstantInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ConstantInputDStream.scala index 35cc4cb396..a9a05c9981 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ConstantInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ConstantInputDStream.scala @@ -17,7 +17,7 @@ package org.apache.spark.streaming.dstream -import org.apache.spark.RDD +import org.apache.spark.rdd.RDD import org.apache.spark.streaming.{Time, StreamingContext} /** diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala index 1c265ed972..fea0573b77 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala @@ -17,7 +17,7 @@ package org.apache.spark.streaming.dstream -import org.apache.spark.RDD +import org.apache.spark.rdd.RDD import org.apache.spark.rdd.UnionRDD import org.apache.spark.streaming.{DStreamCheckpointData, StreamingContext, Time} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FilteredDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FilteredDStream.scala index 3166c68760..91ee2c1a36 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FilteredDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FilteredDStream.scala @@ -18,7 +18,7 @@ package org.apache.spark.streaming.dstream import org.apache.spark.streaming.{Duration, DStream, Time} -import org.apache.spark.RDD +import org.apache.spark.rdd.RDD private[streaming] class FilteredDStream[T: ClassManifest]( diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FlatMapValuedDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FlatMapValuedDStream.scala index 21950ad6ac..ca7d7ca49e 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FlatMapValuedDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FlatMapValuedDStream.scala @@ -18,7 +18,7 @@ package org.apache.spark.streaming.dstream import org.apache.spark.streaming.{Duration, DStream, Time} -import org.apache.spark.RDD +import org.apache.spark.rdd.RDD import org.apache.spark.SparkContext._ private[streaming] diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FlatMappedDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FlatMappedDStream.scala index 8377cfe60c..b37966f9a7 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FlatMappedDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FlatMappedDStream.scala @@ -18,7 +18,7 @@ package org.apache.spark.streaming.dstream import org.apache.spark.streaming.{Duration, DStream, Time} -import org.apache.spark.RDD +import org.apache.spark.rdd.RDD private[streaming] class FlatMappedDStream[T: ClassManifest, U: ClassManifest]( diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FlumeInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FlumeInputDStream.scala index 3fb443143c..18de772946 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FlumeInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FlumeInputDStream.scala @@ -17,10 +17,11 @@ package org.apache.spark.streaming.dstream -import org.apache.spark.streaming.StreamingContext +import java.net.InetSocketAddress +import java.io.{ObjectInput, ObjectOutput, Externalizable} +import java.nio.ByteBuffer -import org.apache.spark.Utils -import org.apache.spark.storage.StorageLevel +import scala.collection.JavaConversions._ import org.apache.flume.source.avro.AvroSourceProtocol import org.apache.flume.source.avro.AvroFlumeEvent @@ -28,11 +29,9 @@ import org.apache.flume.source.avro.Status import org.apache.avro.ipc.specific.SpecificResponder import org.apache.avro.ipc.NettyServer -import scala.collection.JavaConversions._ - -import java.net.InetSocketAddress -import java.io.{ObjectInput, ObjectOutput, Externalizable} -import java.nio.ByteBuffer +import org.apache.spark.streaming.StreamingContext +import org.apache.spark.util.Utils +import org.apache.spark.storage.StorageLevel private[streaming] class FlumeInputDStream[T: ClassManifest]( diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ForEachDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ForEachDStream.scala index c1f95650c8..e21bac4602 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ForEachDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ForEachDStream.scala @@ -17,7 +17,7 @@ package org.apache.spark.streaming.dstream -import org.apache.spark.RDD +import org.apache.spark.rdd.RDD import org.apache.spark.streaming.{Duration, DStream, Job, Time} private[streaming] diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/GlommedDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/GlommedDStream.scala index 1e4c7e7fde..4294b07d91 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/GlommedDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/GlommedDStream.scala @@ -18,7 +18,7 @@ package org.apache.spark.streaming.dstream import org.apache.spark.streaming.{Duration, DStream, Time} -import org.apache.spark.RDD +import org.apache.spark.rdd.RDD private[streaming] class GlommedDStream[T: ClassManifest](parent: DStream[T]) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/MapPartitionedDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/MapPartitionedDStream.scala index 1d79d707bb..5329601a6f 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/MapPartitionedDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/MapPartitionedDStream.scala @@ -18,7 +18,7 @@ package org.apache.spark.streaming.dstream import org.apache.spark.streaming.{Duration, DStream, Time} -import org.apache.spark.RDD +import org.apache.spark.rdd.RDD private[streaming] class MapPartitionedDStream[T: ClassManifest, U: ClassManifest]( diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/MapValuedDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/MapValuedDStream.scala index 312e0c0567..8290df90a2 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/MapValuedDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/MapValuedDStream.scala @@ -18,7 +18,7 @@ package org.apache.spark.streaming.dstream import org.apache.spark.streaming.{Duration, DStream, Time} -import org.apache.spark.RDD +import org.apache.spark.rdd.RDD import org.apache.spark.SparkContext._ private[streaming] diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/MappedDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/MappedDStream.scala index af688dde5f..b1682afea3 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/MappedDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/MappedDStream.scala @@ -18,7 +18,7 @@ package org.apache.spark.streaming.dstream import org.apache.spark.streaming.{Duration, DStream, Time} -import org.apache.spark.RDD +import org.apache.spark.rdd.RDD private[streaming] class MappedDStream[T: ClassManifest, U: ClassManifest] ( diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/NetworkInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/NetworkInputDStream.scala index 3d68da36a2..31f9891560 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/NetworkInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/NetworkInputDStream.scala @@ -17,22 +17,21 @@ package org.apache.spark.streaming.dstream -import org.apache.spark.streaming.{Time, StreamingContext, AddBlocks, RegisterReceiver, DeregisterReceiver} - -import org.apache.spark.{Logging, SparkEnv, RDD} -import org.apache.spark.rdd.BlockRDD -import org.apache.spark.storage.StorageLevel +import java.util.concurrent.ArrayBlockingQueue +import java.nio.ByteBuffer import scala.collection.mutable.ArrayBuffer -import java.nio.ByteBuffer - import akka.actor.{Props, Actor} import akka.pattern.ask import akka.dispatch.Await import akka.util.duration._ + import org.apache.spark.streaming.util.{RecurringTimer, SystemClock} -import java.util.concurrent.ArrayBlockingQueue +import org.apache.spark.streaming._ +import org.apache.spark.{Logging, SparkEnv} +import org.apache.spark.rdd.{RDD, BlockRDD} +import org.apache.spark.storage.StorageLevel /** * Abstract class for defining any InputDStream that has to start a receiver on worker diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/QueueInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/QueueInputDStream.scala index b43ecaeebe..7d9f3521b1 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/QueueInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/QueueInputDStream.scala @@ -17,7 +17,7 @@ package org.apache.spark.streaming.dstream -import org.apache.spark.RDD +import org.apache.spark.rdd.RDD import org.apache.spark.rdd.UnionRDD import scala.collection.mutable.Queue diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReducedWindowedDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReducedWindowedDStream.scala index b6c672f899..b88a4db959 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReducedWindowedDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReducedWindowedDStream.scala @@ -19,7 +19,7 @@ package org.apache.spark.streaming.dstream import org.apache.spark.streaming.StreamingContext._ -import org.apache.spark.RDD +import org.apache.spark.rdd.RDD import org.apache.spark.rdd.{CoGroupedRDD, MapPartitionsRDD} import org.apache.spark.Partitioner import org.apache.spark.SparkContext._ diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ShuffledDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ShuffledDStream.scala index 3a0bd2acd7..a95e66d761 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ShuffledDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ShuffledDStream.scala @@ -17,7 +17,8 @@ package org.apache.spark.streaming.dstream -import org.apache.spark.{RDD, Partitioner} +import org.apache.spark.Partitioner +import org.apache.spark.rdd.RDD import org.apache.spark.SparkContext._ import org.apache.spark.streaming.{Duration, DStream, Time} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/StateDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/StateDStream.scala index c1c9f808f0..362a6bf4cc 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/StateDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/StateDStream.scala @@ -17,7 +17,7 @@ package org.apache.spark.streaming.dstream -import org.apache.spark.RDD +import org.apache.spark.rdd.RDD import org.apache.spark.Partitioner import org.apache.spark.SparkContext._ import org.apache.spark.storage.StorageLevel diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/TransformedDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/TransformedDStream.scala index edba2032b4..60485adef9 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/TransformedDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/TransformedDStream.scala @@ -17,7 +17,7 @@ package org.apache.spark.streaming.dstream -import org.apache.spark.RDD +import org.apache.spark.rdd.RDD import org.apache.spark.streaming.{Duration, DStream, Time} private[streaming] diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/UnionDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/UnionDStream.scala index 97eab97b2f..c696bb70a8 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/UnionDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/UnionDStream.scala @@ -18,7 +18,7 @@ package org.apache.spark.streaming.dstream import org.apache.spark.streaming.{Duration, DStream, Time} -import org.apache.spark.RDD +import org.apache.spark.rdd.RDD import collection.mutable.ArrayBuffer import org.apache.spark.rdd.UnionRDD diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/WindowedDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/WindowedDStream.scala index dbbea39e81..3c57294269 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/WindowedDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/WindowedDStream.scala @@ -17,7 +17,7 @@ package org.apache.spark.streaming.dstream -import org.apache.spark.RDD +import org.apache.spark.rdd.RDD import org.apache.spark.rdd.UnionRDD import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.{Duration, Interval, Time, DStream} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/MasterFailureTest.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/MasterFailureTest.scala index 50d72298e4..6977957126 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/MasterFailureTest.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/MasterFailureTest.scala @@ -17,7 +17,8 @@ package org.apache.spark.streaming.util -import org.apache.spark.{Logging, RDD} +import org.apache.spark.Logging +import org.apache.spark.rdd.RDD import org.apache.spark.streaming._ import org.apache.spark.streaming.dstream.ForEachDStream import StreamingContext._ diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/RawTextSender.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/RawTextSender.scala index 249f6a22ae..fc8655a083 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/RawTextSender.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/RawTextSender.scala @@ -20,10 +20,11 @@ package org.apache.spark.streaming.util import java.nio.ByteBuffer import org.apache.spark.util.{RateLimitedOutputStream, IntParam} import java.net.ServerSocket -import org.apache.spark.{Logging, KryoSerializer} +import org.apache.spark.{Logging} import it.unimi.dsi.fastutil.io.FastByteArrayOutputStream import scala.io.Source import java.io.IOException +import org.apache.spark.serializer.KryoSerializer /** * A helper program that sends blocks of Kryo-serialized text strings out on a socket at a diff --git a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala index 31c2fa0208..37dd9c4cc6 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala @@ -20,8 +20,6 @@ package org.apache.spark.streaming import org.apache.spark.streaming.dstream.{InputDStream, ForEachDStream} import org.apache.spark.streaming.util.ManualClock -import org.apache.spark.{RDD, Logging} - import collection.mutable.ArrayBuffer import collection.mutable.SynchronizedBuffer @@ -29,6 +27,9 @@ import java.io.{ObjectInputStream, IOException} import org.scalatest.{BeforeAndAfter, FunSuite} +import org.apache.spark.Logging +import org.apache.spark.rdd.RDD + /** * This is a input stream just for the testsuites. This is equivalent to a checkpointable, * replayable, reliable message queue like Kafka. It requires a sequence as input, and diff --git a/tools/src/main/scala/org/apache/spark/tools/JavaAPICompletenessChecker.scala b/tools/src/main/scala/org/apache/spark/tools/JavaAPICompletenessChecker.scala index 50335e5736..f824c472ae 100644 --- a/tools/src/main/scala/org/apache/spark/tools/JavaAPICompletenessChecker.scala +++ b/tools/src/main/scala/org/apache/spark/tools/JavaAPICompletenessChecker.scala @@ -23,7 +23,7 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark._ import org.apache.spark.api.java._ -import org.apache.spark.rdd.OrderedRDDFunctions +import org.apache.spark.rdd.{RDD, DoubleRDDFunctions, PairRDDFunctions, OrderedRDDFunctions} import org.apache.spark.streaming.{PairDStreamFunctions, DStream, StreamingContext} import org.apache.spark.streaming.api.java.{JavaPairDStream, JavaDStream, JavaStreamingContext} @@ -139,7 +139,7 @@ object JavaAPICompletenessChecker { scalaType match { case ParameterizedType(name, parameters, typebounds) => name match { - case "org.apache.spark.RDD" => + case "org.apache.spark.rdd.RDD" => if (parameters(0).name == classOf[Tuple2[_, _]].getName) { val tupleParams = parameters(0).asInstanceOf[ParameterizedType].parameters.map(applySubs) @@ -211,23 +211,23 @@ object JavaAPICompletenessChecker { // This list also includes a few methods that are only used by the web UI or other // internal Spark components. val excludedNames = Seq( - "org.apache.spark.RDD.origin", - "org.apache.spark.RDD.elementClassManifest", - "org.apache.spark.RDD.checkpointData", - "org.apache.spark.RDD.partitioner", - "org.apache.spark.RDD.partitions", - "org.apache.spark.RDD.firstParent", - "org.apache.spark.RDD.doCheckpoint", - "org.apache.spark.RDD.markCheckpointed", - "org.apache.spark.RDD.clearDependencies", - "org.apache.spark.RDD.getDependencies", - "org.apache.spark.RDD.getPartitions", - "org.apache.spark.RDD.dependencies", - "org.apache.spark.RDD.getPreferredLocations", - "org.apache.spark.RDD.collectPartitions", - "org.apache.spark.RDD.computeOrReadCheckpoint", - "org.apache.spark.PairRDDFunctions.getKeyClass", - "org.apache.spark.PairRDDFunctions.getValueClass", + "org.apache.spark.rdd.RDD.origin", + "org.apache.spark.rdd.RDD.elementClassManifest", + "org.apache.spark.rdd.RDD.checkpointData", + "org.apache.spark.rdd.RDD.partitioner", + "org.apache.spark.rdd.RDD.partitions", + "org.apache.spark.rdd.RDD.firstParent", + "org.apache.spark.rdd.RDD.doCheckpoint", + "org.apache.spark.rdd.RDD.markCheckpointed", + "org.apache.spark.rdd.RDD.clearDependencies", + "org.apache.spark.rdd.RDD.getDependencies", + "org.apache.spark.rdd.RDD.getPartitions", + "org.apache.spark.rdd.RDD.dependencies", + "org.apache.spark.rdd.RDD.getPreferredLocations", + "org.apache.spark.rdd.RDD.collectPartitions", + "org.apache.spark.rdd.RDD.computeOrReadCheckpoint", + "org.apache.spark.rdd.PairRDDFunctions.getKeyClass", + "org.apache.spark.rdd.PairRDDFunctions.getValueClass", "org.apache.spark.SparkContext.stringToText", "org.apache.spark.SparkContext.makeRDD", "org.apache.spark.SparkContext.runJob", -- cgit v1.2.3