aboutsummaryrefslogtreecommitdiff
path: root/streaming
diff options
context:
space:
mode:
authorShixiong Zhu <shixiong@databricks.com>2015-12-04 17:02:04 -0800
committerMarcelo Vanzin <vanzin@cloudera.com>2015-12-04 17:02:04 -0800
commit3af53e61fd604fe8000e1fdf656d60b79c842d1c (patch)
treeb12b7a4cdb05361813fc81f93bbed924ba961822 /streaming
parentf30373f5ee60f9892c28771e34b208e4f1f675a6 (diff)
downloadspark-3af53e61fd604fe8000e1fdf656d60b79c842d1c.tar.gz
spark-3af53e61fd604fe8000e1fdf656d60b79c842d1c.tar.bz2
spark-3af53e61fd604fe8000e1fdf656d60b79c842d1c.zip
[SPARK-12084][CORE] Fix codes that uses ByteBuffer.array incorrectly
`ByteBuffer` doesn't guarantee all contents in `ByteBuffer.array` are valid. E.g, a ByteBuffer returned by `ByteBuffer.slice`. We should not use the whole content of `ByteBuffer` unless we know that's correct. This patch fixed all places that use `ByteBuffer.array` incorrectly. Author: Shixiong Zhu <shixiong@databricks.com> Closes #10083 from zsxwing/bytebuffer-array.
Diffstat (limited to 'streaming')
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala5
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/util/BatchedWriteAheadLog.scala15
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLogWriter.scala14
-rw-r--r--streaming/src/test/java/org/apache/spark/streaming/JavaWriteAheadLogSuite.java15
4 files changed, 19 insertions, 30 deletions
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala
index 500dc70c98..4dab64d696 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala
@@ -27,6 +27,7 @@ import scala.util.control.NonFatal
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.Path
+import org.apache.spark.network.util.JavaUtils
import org.apache.spark.streaming.Time
import org.apache.spark.streaming.util.{BatchedWriteAheadLog, WriteAheadLog, WriteAheadLogUtils}
import org.apache.spark.util.{Clock, Utils}
@@ -210,9 +211,9 @@ private[streaming] class ReceivedBlockTracker(
writeAheadLogOption.foreach { writeAheadLog =>
logInfo(s"Recovering from write ahead logs in ${checkpointDirOption.get}")
writeAheadLog.readAll().asScala.foreach { byteBuffer =>
- logTrace("Recovering record " + byteBuffer)
+ logInfo("Recovering record " + byteBuffer)
Utils.deserialize[ReceivedBlockTrackerLogEvent](
- byteBuffer.array, Thread.currentThread().getContextClassLoader) match {
+ JavaUtils.bufferToArray(byteBuffer), Thread.currentThread().getContextClassLoader) match {
case BlockAdditionEvent(receivedBlockInfo) =>
insertAddedBlock(receivedBlockInfo)
case BatchAllocationEvent(time, allocatedBlocks) =>
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/BatchedWriteAheadLog.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/BatchedWriteAheadLog.scala
index 6e6ed8d819..7158abc088 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/util/BatchedWriteAheadLog.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/util/BatchedWriteAheadLog.scala
@@ -28,6 +28,7 @@ import scala.concurrent.duration._
import scala.util.control.NonFatal
import org.apache.spark.{Logging, SparkConf}
+import org.apache.spark.network.util.JavaUtils
import org.apache.spark.util.Utils
/**
@@ -197,17 +198,10 @@ private[util] object BatchedWriteAheadLog {
*/
case class Record(data: ByteBuffer, time: Long, promise: Promise[WriteAheadLogRecordHandle])
- /** Copies the byte array of a ByteBuffer. */
- private def getByteArray(buffer: ByteBuffer): Array[Byte] = {
- val byteArray = new Array[Byte](buffer.remaining())
- buffer.get(byteArray)
- byteArray
- }
-
/** Aggregate multiple serialized ReceivedBlockTrackerLogEvents in a single ByteBuffer. */
def aggregate(records: Seq[Record]): ByteBuffer = {
ByteBuffer.wrap(Utils.serialize[Array[Array[Byte]]](
- records.map(record => getByteArray(record.data)).toArray))
+ records.map(record => JavaUtils.bufferToArray(record.data)).toArray))
}
/**
@@ -216,10 +210,13 @@ private[util] object BatchedWriteAheadLog {
* method therefore needs to be backwards compatible.
*/
def deaggregate(buffer: ByteBuffer): Array[ByteBuffer] = {
+ val prevPosition = buffer.position()
try {
- Utils.deserialize[Array[Array[Byte]]](getByteArray(buffer)).map(ByteBuffer.wrap)
+ Utils.deserialize[Array[Array[Byte]]](JavaUtils.bufferToArray(buffer)).map(ByteBuffer.wrap)
} catch {
case _: ClassCastException => // users may restart a stream with batching enabled
+ // Restore `position` so that the user can read `buffer` later
+ buffer.position(prevPosition)
Array(buffer)
}
}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLogWriter.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLogWriter.scala
index e146bec32a..1185f30265 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLogWriter.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLogWriter.scala
@@ -24,6 +24,8 @@ import scala.util.Try
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.FSDataOutputStream
+import org.apache.spark.util.Utils
+
/**
* A writer for writing byte-buffers to a write ahead log file.
*/
@@ -48,17 +50,7 @@ private[streaming] class FileBasedWriteAheadLogWriter(path: String, hadoopConf:
val lengthToWrite = data.remaining()
val segment = new FileBasedWriteAheadLogSegment(path, nextOffset, lengthToWrite)
stream.writeInt(lengthToWrite)
- if (data.hasArray) {
- stream.write(data.array())
- } else {
- // If the buffer is not backed by an array, we transfer using temp array
- // Note that despite the extra array copy, this should be faster than byte-by-byte copy
- while (data.hasRemaining) {
- val array = new Array[Byte](data.remaining)
- data.get(array)
- stream.write(array)
- }
- }
+ Utils.writeByteBuffer(data, stream: OutputStream)
flush()
nextOffset = stream.getPos()
segment
diff --git a/streaming/src/test/java/org/apache/spark/streaming/JavaWriteAheadLogSuite.java b/streaming/src/test/java/org/apache/spark/streaming/JavaWriteAheadLogSuite.java
index 09b5f8ed03..f02fa87f61 100644
--- a/streaming/src/test/java/org/apache/spark/streaming/JavaWriteAheadLogSuite.java
+++ b/streaming/src/test/java/org/apache/spark/streaming/JavaWriteAheadLogSuite.java
@@ -17,7 +17,6 @@
package org.apache.spark.streaming;
-import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.nio.ByteBuffer;
import java.util.Arrays;
@@ -27,6 +26,7 @@ import java.util.List;
import com.google.common.base.Function;
import com.google.common.collect.Iterators;
import org.apache.spark.SparkConf;
+import org.apache.spark.network.util.JavaUtils;
import org.apache.spark.streaming.util.WriteAheadLog;
import org.apache.spark.streaming.util.WriteAheadLogRecordHandle;
import org.apache.spark.streaming.util.WriteAheadLogUtils;
@@ -112,20 +112,19 @@ public class JavaWriteAheadLogSuite extends WriteAheadLog {
WriteAheadLog wal = WriteAheadLogUtils.createLogForDriver(conf, null, null);
String data1 = "data1";
- WriteAheadLogRecordHandle handle =
- wal.write(ByteBuffer.wrap(data1.getBytes(StandardCharsets.UTF_8)), 1234);
+ WriteAheadLogRecordHandle handle = wal.write(JavaUtils.stringToBytes(data1), 1234);
Assert.assertTrue(handle instanceof JavaWriteAheadLogSuiteHandle);
- Assert.assertEquals(new String(wal.read(handle).array(), StandardCharsets.UTF_8), data1);
+ Assert.assertEquals(JavaUtils.bytesToString(wal.read(handle)), data1);
- wal.write(ByteBuffer.wrap("data2".getBytes(StandardCharsets.UTF_8)), 1235);
- wal.write(ByteBuffer.wrap("data3".getBytes(StandardCharsets.UTF_8)), 1236);
- wal.write(ByteBuffer.wrap("data4".getBytes(StandardCharsets.UTF_8)), 1237);
+ wal.write(JavaUtils.stringToBytes("data2"), 1235);
+ wal.write(JavaUtils.stringToBytes("data3"), 1236);
+ wal.write(JavaUtils.stringToBytes("data4"), 1237);
wal.clean(1236, false);
Iterator<ByteBuffer> dataIterator = wal.readAll();
List<String> readData = new ArrayList<>();
while (dataIterator.hasNext()) {
- readData.add(new String(dataIterator.next().array(), StandardCharsets.UTF_8));
+ readData.add(JavaUtils.bytesToString(dataIterator.next()));
}
Assert.assertEquals(readData, Arrays.asList("data3", "data4"));
}