diff options
author | Shixiong Zhu <shixiong@databricks.com> | 2015-12-04 17:02:04 -0800 |
---|---|---|
committer | Marcelo Vanzin <vanzin@cloudera.com> | 2015-12-04 17:02:04 -0800 |
commit | 3af53e61fd604fe8000e1fdf656d60b79c842d1c (patch) | |
tree | b12b7a4cdb05361813fc81f93bbed924ba961822 /streaming/src | |
parent | f30373f5ee60f9892c28771e34b208e4f1f675a6 (diff) | |
download | spark-3af53e61fd604fe8000e1fdf656d60b79c842d1c.tar.gz spark-3af53e61fd604fe8000e1fdf656d60b79c842d1c.tar.bz2 spark-3af53e61fd604fe8000e1fdf656d60b79c842d1c.zip |
[SPARK-12084][CORE] Fix codes that uses ByteBuffer.array incorrectly
`ByteBuffer` doesn't guarantee all contents in `ByteBuffer.array` are valid. E.g, a ByteBuffer returned by `ByteBuffer.slice`. We should not use the whole content of `ByteBuffer` unless we know that's correct.
This patch fixed all places that use `ByteBuffer.array` incorrectly.
Author: Shixiong Zhu <shixiong@databricks.com>
Closes #10083 from zsxwing/bytebuffer-array.
Diffstat (limited to 'streaming/src')
4 files changed, 19 insertions, 30 deletions
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala index 500dc70c98..4dab64d696 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala @@ -27,6 +27,7 @@ import scala.util.control.NonFatal import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path +import org.apache.spark.network.util.JavaUtils import org.apache.spark.streaming.Time import org.apache.spark.streaming.util.{BatchedWriteAheadLog, WriteAheadLog, WriteAheadLogUtils} import org.apache.spark.util.{Clock, Utils} @@ -210,9 +211,9 @@ private[streaming] class ReceivedBlockTracker( writeAheadLogOption.foreach { writeAheadLog => logInfo(s"Recovering from write ahead logs in ${checkpointDirOption.get}") writeAheadLog.readAll().asScala.foreach { byteBuffer => - logTrace("Recovering record " + byteBuffer) + logInfo("Recovering record " + byteBuffer) Utils.deserialize[ReceivedBlockTrackerLogEvent]( - byteBuffer.array, Thread.currentThread().getContextClassLoader) match { + JavaUtils.bufferToArray(byteBuffer), Thread.currentThread().getContextClassLoader) match { case BlockAdditionEvent(receivedBlockInfo) => insertAddedBlock(receivedBlockInfo) case BatchAllocationEvent(time, allocatedBlocks) => diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/BatchedWriteAheadLog.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/BatchedWriteAheadLog.scala index 6e6ed8d819..7158abc088 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/BatchedWriteAheadLog.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/BatchedWriteAheadLog.scala @@ -28,6 +28,7 @@ import scala.concurrent.duration._ import scala.util.control.NonFatal import org.apache.spark.{Logging, SparkConf} +import org.apache.spark.network.util.JavaUtils import org.apache.spark.util.Utils /** @@ -197,17 +198,10 @@ private[util] object BatchedWriteAheadLog { */ case class Record(data: ByteBuffer, time: Long, promise: Promise[WriteAheadLogRecordHandle]) - /** Copies the byte array of a ByteBuffer. */ - private def getByteArray(buffer: ByteBuffer): Array[Byte] = { - val byteArray = new Array[Byte](buffer.remaining()) - buffer.get(byteArray) - byteArray - } - /** Aggregate multiple serialized ReceivedBlockTrackerLogEvents in a single ByteBuffer. */ def aggregate(records: Seq[Record]): ByteBuffer = { ByteBuffer.wrap(Utils.serialize[Array[Array[Byte]]]( - records.map(record => getByteArray(record.data)).toArray)) + records.map(record => JavaUtils.bufferToArray(record.data)).toArray)) } /** @@ -216,10 +210,13 @@ private[util] object BatchedWriteAheadLog { * method therefore needs to be backwards compatible. */ def deaggregate(buffer: ByteBuffer): Array[ByteBuffer] = { + val prevPosition = buffer.position() try { - Utils.deserialize[Array[Array[Byte]]](getByteArray(buffer)).map(ByteBuffer.wrap) + Utils.deserialize[Array[Array[Byte]]](JavaUtils.bufferToArray(buffer)).map(ByteBuffer.wrap) } catch { case _: ClassCastException => // users may restart a stream with batching enabled + // Restore `position` so that the user can read `buffer` later + buffer.position(prevPosition) Array(buffer) } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLogWriter.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLogWriter.scala index e146bec32a..1185f30265 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLogWriter.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLogWriter.scala @@ -24,6 +24,8 @@ import scala.util.Try import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.FSDataOutputStream +import org.apache.spark.util.Utils + /** * A writer for writing byte-buffers to a write ahead log file. */ @@ -48,17 +50,7 @@ private[streaming] class FileBasedWriteAheadLogWriter(path: String, hadoopConf: val lengthToWrite = data.remaining() val segment = new FileBasedWriteAheadLogSegment(path, nextOffset, lengthToWrite) stream.writeInt(lengthToWrite) - if (data.hasArray) { - stream.write(data.array()) - } else { - // If the buffer is not backed by an array, we transfer using temp array - // Note that despite the extra array copy, this should be faster than byte-by-byte copy - while (data.hasRemaining) { - val array = new Array[Byte](data.remaining) - data.get(array) - stream.write(array) - } - } + Utils.writeByteBuffer(data, stream: OutputStream) flush() nextOffset = stream.getPos() segment diff --git a/streaming/src/test/java/org/apache/spark/streaming/JavaWriteAheadLogSuite.java b/streaming/src/test/java/org/apache/spark/streaming/JavaWriteAheadLogSuite.java index 09b5f8ed03..f02fa87f61 100644 --- a/streaming/src/test/java/org/apache/spark/streaming/JavaWriteAheadLogSuite.java +++ b/streaming/src/test/java/org/apache/spark/streaming/JavaWriteAheadLogSuite.java @@ -17,7 +17,6 @@ package org.apache.spark.streaming; -import java.nio.charset.StandardCharsets; import java.util.ArrayList; import java.nio.ByteBuffer; import java.util.Arrays; @@ -27,6 +26,7 @@ import java.util.List; import com.google.common.base.Function; import com.google.common.collect.Iterators; import org.apache.spark.SparkConf; +import org.apache.spark.network.util.JavaUtils; import org.apache.spark.streaming.util.WriteAheadLog; import org.apache.spark.streaming.util.WriteAheadLogRecordHandle; import org.apache.spark.streaming.util.WriteAheadLogUtils; @@ -112,20 +112,19 @@ public class JavaWriteAheadLogSuite extends WriteAheadLog { WriteAheadLog wal = WriteAheadLogUtils.createLogForDriver(conf, null, null); String data1 = "data1"; - WriteAheadLogRecordHandle handle = - wal.write(ByteBuffer.wrap(data1.getBytes(StandardCharsets.UTF_8)), 1234); + WriteAheadLogRecordHandle handle = wal.write(JavaUtils.stringToBytes(data1), 1234); Assert.assertTrue(handle instanceof JavaWriteAheadLogSuiteHandle); - Assert.assertEquals(new String(wal.read(handle).array(), StandardCharsets.UTF_8), data1); + Assert.assertEquals(JavaUtils.bytesToString(wal.read(handle)), data1); - wal.write(ByteBuffer.wrap("data2".getBytes(StandardCharsets.UTF_8)), 1235); - wal.write(ByteBuffer.wrap("data3".getBytes(StandardCharsets.UTF_8)), 1236); - wal.write(ByteBuffer.wrap("data4".getBytes(StandardCharsets.UTF_8)), 1237); + wal.write(JavaUtils.stringToBytes("data2"), 1235); + wal.write(JavaUtils.stringToBytes("data3"), 1236); + wal.write(JavaUtils.stringToBytes("data4"), 1237); wal.clean(1236, false); Iterator<ByteBuffer> dataIterator = wal.readAll(); List<String> readData = new ArrayList<>(); while (dataIterator.hasNext()) { - readData.add(new String(dataIterator.next().array(), StandardCharsets.UTF_8)); + readData.add(JavaUtils.bytesToString(dataIterator.next())); } Assert.assertEquals(readData, Arrays.asList("data3", "data4")); } |