From d0ecca6075d86bedebf8bc2278085a2cd6cb0a43 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Mon, 20 Feb 2017 09:02:09 -0800 Subject: [SPARK-19646][CORE][STREAMING] binaryRecords replicates records in scala API ## What changes were proposed in this pull request? Use `BytesWritable.copyBytes`, not `getBytes`, because `getBytes` returns the underlying array, which may be reused when repeated reads don't need a different size, as is the case with binaryRecords APIs ## How was this patch tested? Existing tests Author: Sean Owen Closes #16974 from srowen/SPARK-19646. --- .../apache/spark/streaming/StreamingContext.scala | 5 ++--- .../apache/spark/streaming/InputStreamsSuite.scala | 21 +++++++++++---------- 2 files changed, 13 insertions(+), 13 deletions(-) (limited to 'streaming') diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala index 0a4c141e5b..a34f6c73fe 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala @@ -435,13 +435,12 @@ class StreamingContext private[streaming] ( conf.setInt(FixedLengthBinaryInputFormat.RECORD_LENGTH_PROPERTY, recordLength) val br = fileStream[LongWritable, BytesWritable, FixedLengthBinaryInputFormat]( directory, FileInputDStream.defaultFilter: Path => Boolean, newFilesOnly = true, conf) - val data = br.map { case (k, v) => - val bytes = v.getBytes + br.map { case (k, v) => + val bytes = v.copyBytes() require(bytes.length == recordLength, "Byte array does not have correct length. " + s"${bytes.length} did not equal recordLength: $recordLength") bytes } - data } /** diff --git a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala index 6fb50a4052..b5d36a3651 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala @@ -84,7 +84,7 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { // Verify whether all the elements received are as expected // (whether the elements were received one in each interval is not verified) - val output: Array[String] = outputQueue.asScala.flatMap(x => x).toArray + val output = outputQueue.asScala.flatten.toArray assert(output.length === expectedOutput.size) for (i <- output.indices) { assert(output(i) === expectedOutput(i)) @@ -155,14 +155,15 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { // not enough to trigger a batch clock.advance(batchDuration.milliseconds / 2) - val input = Seq(1, 2, 3, 4, 5) - input.foreach { i => + val numCopies = 3 + val input = Array[Byte](1, 2, 3, 4, 5) + for (i <- 0 until numCopies) { Thread.sleep(batchDuration.milliseconds) val file = new File(testDir, i.toString) - Files.write(Array[Byte](i.toByte), file) + Files.write(input.map(b => (b + i).toByte), file) assert(file.setLastModified(clock.getTimeMillis())) assert(file.lastModified === clock.getTimeMillis()) - logInfo("Created file " + file) + logInfo(s"Created file $file") // Advance the clock after creating the file to avoid a race when // setting its modification time clock.advance(batchDuration.milliseconds) @@ -170,10 +171,10 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { assert(batchCounter.getNumCompletedBatches === i) } } - - val expectedOutput = input.map(i => i.toByte) - val obtainedOutput = outputQueue.asScala.flatten.toList.map(i => i(0).toByte) - assert(obtainedOutput.toSeq === expectedOutput) + val obtainedOutput = outputQueue.asScala.map(_.flatten).toSeq + for (i <- obtainedOutput.indices) { + assert(obtainedOutput(i) === input.map(b => (b + i).toByte)) + } } } finally { if (testDir != null) Utils.deleteRecursively(testDir) @@ -258,7 +259,7 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { val testReceiver = new MultiThreadTestReceiver(numThreads, numRecordsPerThread) MultiThreadTestReceiver.haveAllThreadsFinished = false val outputQueue = new ConcurrentLinkedQueue[Seq[Long]] - def output: Iterable[Long] = outputQueue.asScala.flatMap(x => x) + def output: Iterable[Long] = outputQueue.asScala.flatten // set up the network stream using the test receiver withStreamingContext(new StreamingContext(conf, batchDuration)) { ssc => -- cgit v1.2.3