aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorShixiong Zhu <shixiong@databricks.com>2015-12-04 17:02:04 -0800
committerMarcelo Vanzin <vanzin@cloudera.com>2015-12-04 17:02:04 -0800
commit3af53e61fd604fe8000e1fdf656d60b79c842d1c (patch)
treeb12b7a4cdb05361813fc81f93bbed924ba961822
parentf30373f5ee60f9892c28771e34b208e4f1f675a6 (diff)
downloadspark-3af53e61fd604fe8000e1fdf656d60b79c842d1c.tar.gz
spark-3af53e61fd604fe8000e1fdf656d60b79c842d1c.tar.bz2
spark-3af53e61fd604fe8000e1fdf656d60b79c842d1c.zip
[SPARK-12084][CORE] Fix codes that uses ByteBuffer.array incorrectly
`ByteBuffer` doesn't guarantee all contents in `ByteBuffer.array` are valid. E.g, a ByteBuffer returned by `ByteBuffer.slice`. We should not use the whole content of `ByteBuffer` unless we know that's correct. This patch fixed all places that use `ByteBuffer.array` incorrectly. Author: Shixiong Zhu <shixiong@databricks.com> Closes #10083 from zsxwing/bytebuffer-array.
-rw-r--r--core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala12
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala6
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/Task.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/serializer/GenericAvroSerializer.scala5
-rw-r--r--core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/util/Utils.scala15
-rw-r--r--core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java5
-rw-r--r--core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala3
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/pythonconverters/AvroConverters.scala5
-rw-r--r--external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala6
-rw-r--r--external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala4
-rw-r--r--external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala4
-rw-r--r--extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala3
-rw-r--r--sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java7
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala3
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarTableScan.scala5
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala8
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala5
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/util/BatchedWriteAheadLog.scala15
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLogWriter.scala14
-rw-r--r--streaming/src/test/java/org/apache/spark/streaming/JavaWriteAheadLogSuite.java15
22 files changed, 81 insertions, 69 deletions
diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala
index 82c16e855b..40604a4da1 100644
--- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala
+++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala
@@ -30,6 +30,7 @@ import org.apache.spark.network.sasl.{SaslClientBootstrap, SaslServerBootstrap}
import org.apache.spark.network.server._
import org.apache.spark.network.shuffle.{RetryingBlockFetcher, BlockFetchingListener, OneForOneBlockFetcher}
import org.apache.spark.network.shuffle.protocol.UploadBlock
+import org.apache.spark.network.util.JavaUtils
import org.apache.spark.serializer.JavaSerializer
import org.apache.spark.storage.{BlockId, StorageLevel}
import org.apache.spark.util.Utils
@@ -123,17 +124,10 @@ class NettyBlockTransferService(conf: SparkConf, securityManager: SecurityManage
// StorageLevel is serialized as bytes using our JavaSerializer. Everything else is encoded
// using our binary protocol.
- val levelBytes = serializer.newInstance().serialize(level).array()
+ val levelBytes = JavaUtils.bufferToArray(serializer.newInstance().serialize(level))
// Convert or copy nio buffer into array in order to serialize it.
- val nioBuffer = blockData.nioByteBuffer()
- val array = if (nioBuffer.hasArray) {
- nioBuffer.array()
- } else {
- val data = new Array[Byte](nioBuffer.remaining())
- nioBuffer.get(data)
- data
- }
+ val array = JavaUtils.bufferToArray(blockData.nioByteBuffer())
client.sendRpc(new UploadBlock(appId, execId, blockId.toString, levelBytes, array).toByteBuffer,
new RpcResponseCallback {
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index e01a9609b9..5582720bbc 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -34,6 +34,7 @@ import org.apache.commons.lang3.SerializationUtils
import org.apache.spark._
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.executor.TaskMetrics
+import org.apache.spark.network.util.JavaUtils
import org.apache.spark.partial.{ApproximateActionListener, ApproximateEvaluator, PartialResult}
import org.apache.spark.rdd.RDD
import org.apache.spark.rpc.RpcTimeout
@@ -997,9 +998,10 @@ class DAGScheduler(
// For ResultTask, serialize and broadcast (rdd, func).
val taskBinaryBytes: Array[Byte] = stage match {
case stage: ShuffleMapStage =>
- closureSerializer.serialize((stage.rdd, stage.shuffleDep): AnyRef).array()
+ JavaUtils.bufferToArray(
+ closureSerializer.serialize((stage.rdd, stage.shuffleDep): AnyRef))
case stage: ResultStage =>
- closureSerializer.serialize((stage.rdd, stage.func): AnyRef).array()
+ JavaUtils.bufferToArray(closureSerializer.serialize((stage.rdd, stage.func): AnyRef))
}
taskBinary = sc.broadcast(taskBinaryBytes)
diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala
index 2fcd5aa57d..5fe5ae8c45 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala
@@ -191,8 +191,8 @@ private[spark] object Task {
// Write the task itself and finish
dataOut.flush()
- val taskBytes = serializer.serialize(task).array()
- out.write(taskBytes)
+ val taskBytes = serializer.serialize(task)
+ Utils.writeByteBuffer(taskBytes, out)
ByteBuffer.wrap(out.toByteArray)
}
diff --git a/core/src/main/scala/org/apache/spark/serializer/GenericAvroSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/GenericAvroSerializer.scala
index 62f8aae7f2..8d6af9cae8 100644
--- a/core/src/main/scala/org/apache/spark/serializer/GenericAvroSerializer.scala
+++ b/core/src/main/scala/org/apache/spark/serializer/GenericAvroSerializer.scala
@@ -81,7 +81,10 @@ private[serializer] class GenericAvroSerializer(schemas: Map[Long, String])
* seen values so to limit the number of times that decompression has to be done.
*/
def decompress(schemaBytes: ByteBuffer): Schema = decompressCache.getOrElseUpdate(schemaBytes, {
- val bis = new ByteArrayInputStream(schemaBytes.array())
+ val bis = new ByteArrayInputStream(
+ schemaBytes.array(),
+ schemaBytes.arrayOffset() + schemaBytes.position(),
+ schemaBytes.remaining())
val bytes = IOUtils.toByteArray(codec.compressedInputStream(bis))
new Schema.Parser().parse(new String(bytes, "UTF-8"))
})
diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala
index 7b77f78ce6..62d445f3d7 100644
--- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala
+++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala
@@ -309,7 +309,7 @@ private[spark] class KryoSerializerInstance(ks: KryoSerializer) extends Serializ
override def deserialize[T: ClassTag](bytes: ByteBuffer): T = {
val kryo = borrowKryo()
try {
- input.setBuffer(bytes.array)
+ input.setBuffer(bytes.array(), bytes.arrayOffset() + bytes.position(), bytes.remaining())
kryo.readClassAndObject(input).asInstanceOf[T]
} finally {
releaseKryo(kryo)
@@ -321,7 +321,7 @@ private[spark] class KryoSerializerInstance(ks: KryoSerializer) extends Serializ
val oldClassLoader = kryo.getClassLoader
try {
kryo.setClassLoader(loader)
- input.setBuffer(bytes.array)
+ input.setBuffer(bytes.array(), bytes.arrayOffset() + bytes.position(), bytes.remaining())
kryo.readClassAndObject(input).asInstanceOf[T]
} finally {
kryo.setClassLoader(oldClassLoader)
diff --git a/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala
index 22878783fc..d14fe46135 100644
--- a/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala
@@ -103,7 +103,7 @@ private[spark] class TachyonBlockManager() extends ExternalBlockManager with Log
val file = getFile(blockId)
val os = file.getOutStream(WriteType.TRY_CACHE)
try {
- os.write(bytes.array())
+ Utils.writeByteBuffer(bytes, os)
} catch {
case NonFatal(e) =>
logWarning(s"Failed to put bytes of block $blockId into Tachyon", e)
diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala
index af632349c9..9dbe66e7ee 100644
--- a/core/src/main/scala/org/apache/spark/util/Utils.scala
+++ b/core/src/main/scala/org/apache/spark/util/Utils.scala
@@ -178,7 +178,20 @@ private[spark] object Utils extends Logging {
/**
* Primitive often used when writing [[java.nio.ByteBuffer]] to [[java.io.DataOutput]]
*/
- def writeByteBuffer(bb: ByteBuffer, out: ObjectOutput): Unit = {
+ def writeByteBuffer(bb: ByteBuffer, out: DataOutput): Unit = {
+ if (bb.hasArray) {
+ out.write(bb.array(), bb.arrayOffset() + bb.position(), bb.remaining())
+ } else {
+ val bbval = new Array[Byte](bb.remaining())
+ bb.get(bbval)
+ out.write(bbval)
+ }
+ }
+
+ /**
+ * Primitive often used when writing [[java.nio.ByteBuffer]] to [[java.io.OutputStream]]
+ */
+ def writeByteBuffer(bb: ByteBuffer, out: OutputStream): Unit = {
if (bb.hasArray) {
out.write(bb.array(), bb.arrayOffset() + bb.position(), bb.remaining())
} else {
diff --git a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java
index a5c583f9f2..8724a34988 100644
--- a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java
+++ b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java
@@ -41,6 +41,7 @@ import org.apache.spark.SparkConf;
import org.apache.spark.executor.ShuffleWriteMetrics;
import org.apache.spark.memory.TestMemoryManager;
import org.apache.spark.memory.TaskMemoryManager;
+import org.apache.spark.network.util.JavaUtils;
import org.apache.spark.serializer.SerializerInstance;
import org.apache.spark.storage.*;
import org.apache.spark.unsafe.Platform;
@@ -430,7 +431,7 @@ public abstract class AbstractBytesToBytesMapSuite {
}
for (Map.Entry<ByteBuffer, byte[]> entry : expected.entrySet()) {
- final byte[] key = entry.getKey().array();
+ final byte[] key = JavaUtils.bufferToArray(entry.getKey());
final byte[] value = entry.getValue();
final BytesToBytesMap.Location loc =
map.lookup(key, Platform.BYTE_ARRAY_OFFSET, key.length);
@@ -480,7 +481,7 @@ public abstract class AbstractBytesToBytesMapSuite {
}
}
for (Map.Entry<ByteBuffer, byte[]> entry : expected.entrySet()) {
- final byte[] key = entry.getKey().array();
+ final byte[] key = JavaUtils.bufferToArray(entry.getKey());
final byte[] value = entry.getValue();
final BytesToBytesMap.Location loc =
map.lookup(key, Platform.BYTE_ARRAY_OFFSET, key.length);
diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala
index 450ab7b9fe..d83d0aee42 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala
@@ -23,6 +23,7 @@ import org.mockito.Matchers.any
import org.scalatest.BeforeAndAfter
import org.apache.spark._
+import org.apache.spark.network.util.JavaUtils
import org.apache.spark.rdd.RDD
import org.apache.spark.util.{TaskCompletionListener, TaskCompletionListenerException}
import org.apache.spark.metrics.source.JvmSource
@@ -57,7 +58,7 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark
}
val closureSerializer = SparkEnv.get.closureSerializer.newInstance()
val func = (c: TaskContext, i: Iterator[String]) => i.next()
- val taskBinary = sc.broadcast(closureSerializer.serialize((rdd, func)).array)
+ val taskBinary = sc.broadcast(JavaUtils.bufferToArray(closureSerializer.serialize((rdd, func))))
val task = new ResultTask[String, String](
0, 0, taskBinary, rdd.partitions(0), Seq.empty, 0, Seq.empty)
intercept[RuntimeException] {
diff --git a/examples/src/main/scala/org/apache/spark/examples/pythonconverters/AvroConverters.scala b/examples/src/main/scala/org/apache/spark/examples/pythonconverters/AvroConverters.scala
index 805184e740..cf12c98b4a 100644
--- a/examples/src/main/scala/org/apache/spark/examples/pythonconverters/AvroConverters.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/pythonconverters/AvroConverters.scala
@@ -79,7 +79,10 @@ object AvroConversionUtil extends Serializable {
def unpackBytes(obj: Any): Array[Byte] = {
val bytes: Array[Byte] = obj match {
- case buf: java.nio.ByteBuffer => buf.array()
+ case buf: java.nio.ByteBuffer =>
+ val arr = new Array[Byte](buf.remaining())
+ buf.get(arr)
+ arr
case arr: Array[Byte] => arr
case other => throw new SparkException(
s"Unknown BYTES type ${other.getClass.getName}")
diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala
index c8780aa83b..2b9116eb3c 100644
--- a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala
+++ b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala
@@ -93,9 +93,9 @@ class SparkFlumeEvent() extends Externalizable {
/* Serialize to bytes. */
def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException {
- val body = event.getBody.array()
- out.writeInt(body.length)
- out.write(body)
+ val body = event.getBody
+ out.writeInt(body.remaining())
+ Utils.writeByteBuffer(body, out)
val numHeaders = event.getHeaders.size()
out.writeInt(numHeaders)
diff --git a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala
index 5fd2711f5f..bb951a6ef1 100644
--- a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala
+++ b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala
@@ -24,11 +24,11 @@ import scala.collection.mutable.{SynchronizedBuffer, ArrayBuffer}
import scala.concurrent.duration._
import scala.language.postfixOps
-import com.google.common.base.Charsets.UTF_8
import org.scalatest.BeforeAndAfter
import org.scalatest.concurrent.Eventually._
import org.apache.spark.{Logging, SparkConf, SparkFunSuite}
+import org.apache.spark.network.util.JavaUtils
import org.apache.spark.storage.StorageLevel
import org.apache.spark.streaming.dstream.ReceiverInputDStream
import org.apache.spark.streaming.{Seconds, TestOutputStream, StreamingContext}
@@ -119,7 +119,7 @@ class FlumePollingStreamSuite extends SparkFunSuite with BeforeAndAfter with Log
val headers = flattenOutputBuffer.map(_.event.getHeaders.asScala.map {
case (key, value) => (key.toString, value.toString)
}).map(_.asJava)
- val bodies = flattenOutputBuffer.map(e => new String(e.event.getBody.array(), UTF_8))
+ val bodies = flattenOutputBuffer.map(e => JavaUtils.bytesToString(e.event.getBody))
utils.assertOutput(headers.asJava, bodies.asJava)
}
} finally {
diff --git a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala
index f315e0a7ca..b29e591c07 100644
--- a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala
+++ b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala
@@ -22,7 +22,6 @@ import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer}
import scala.concurrent.duration._
import scala.language.postfixOps
-import com.google.common.base.Charsets
import org.jboss.netty.channel.ChannelPipeline
import org.jboss.netty.channel.socket.SocketChannel
import org.jboss.netty.channel.socket.nio.NioClientSocketChannelFactory
@@ -31,6 +30,7 @@ import org.scalatest.{BeforeAndAfter, Matchers}
import org.scalatest.concurrent.Eventually._
import org.apache.spark.{Logging, SparkConf, SparkFunSuite}
+import org.apache.spark.network.util.JavaUtils
import org.apache.spark.storage.StorageLevel
import org.apache.spark.streaming.{Milliseconds, StreamingContext, TestOutputStream}
@@ -63,7 +63,7 @@ class FlumeStreamSuite extends SparkFunSuite with BeforeAndAfter with Matchers w
event =>
event.getHeaders.get("test") should be("header")
}
- val output = outputEvents.map(event => new String(event.getBody.array(), Charsets.UTF_8))
+ val output = outputEvents.map(event => JavaUtils.bytesToString(event.getBody))
output should be (input)
}
} finally {
diff --git a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala
index 78cec021b7..6fe24fe811 100644
--- a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala
+++ b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala
@@ -29,6 +29,7 @@ import org.scalatest.Matchers._
import org.scalatest.concurrent.Eventually
import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll}
+import org.apache.spark.network.util.JavaUtils
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.{StorageLevel, StreamBlockId}
import org.apache.spark.streaming._
@@ -196,7 +197,7 @@ abstract class KinesisStreamTests(aggregateTestData: Boolean) extends KinesisFun
testIfEnabled("custom message handling") {
val awsCredentials = KinesisTestUtils.getAWSCredentials()
- def addFive(r: Record): Int = new String(r.getData.array()).toInt + 5
+ def addFive(r: Record): Int = JavaUtils.bytesToString(r.getData).toInt + 5
val stream = KinesisUtils.createStream(ssc, appName, testUtils.streamName,
testUtils.endpointUrl, testUtils.regionName, InitialPositionInStream.LATEST,
Seconds(10), StorageLevel.MEMORY_ONLY, addFive,
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java
index dade488ca2..0cc4566c9c 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java
@@ -332,12 +332,13 @@ public class UnsafeRowParquetRecordReader extends SpecificParquetRecordReaderBas
for (int n = 0; n < num; ++n) {
if (columnReaders[col].next()) {
ByteBuffer bytes = columnReaders[col].nextBinary().toByteBuffer();
- int len = bytes.limit() - bytes.position();
+ int len = bytes.remaining();
if (originalTypes[col] == OriginalType.UTF8) {
- UTF8String str = UTF8String.fromBytes(bytes.array(), bytes.position(), len);
+ UTF8String str =
+ UTF8String.fromBytes(bytes.array(), bytes.arrayOffset() + bytes.position(), len);
rowWriters[n].write(col, str);
} else {
- rowWriters[n].write(col, bytes.array(), bytes.position(), len);
+ rowWriters[n].write(col, bytes.array(), bytes.arrayOffset() + bytes.position(), len);
}
rows[n].setNotNullAt(col);
} else {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala
index 8317f648cc..45a8e03248 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala
@@ -26,6 +26,7 @@ import com.esotericsoftware.kryo.io.{Input, Output}
import com.esotericsoftware.kryo.{Kryo, Serializer}
import com.twitter.chill.ResourcePool
+import org.apache.spark.network.util.JavaUtils
import org.apache.spark.serializer.{KryoSerializer, SerializerInstance}
import org.apache.spark.sql.types.Decimal
import org.apache.spark.util.MutablePair
@@ -76,7 +77,7 @@ private[sql] object SparkSqlSerializer {
def serialize[T: ClassTag](o: T): Array[Byte] =
acquireRelease { k =>
- k.serialize(o).array()
+ JavaUtils.bufferToArray(k.serialize(o))
}
def deserialize[T: ClassTag](bytes: Array[Byte]): T =
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarTableScan.scala
index ce701fb3a7..3c5a8cb2aa 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarTableScan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarTableScan.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.columnar
import scala.collection.mutable.ArrayBuffer
+import org.apache.spark.network.util.JavaUtils
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation
@@ -163,7 +164,9 @@ private[sql] case class InMemoryRelation(
.flatMap(_.values))
batchStats += stats
- CachedBatch(rowCount, columnBuilders.map(_.build().array()), stats)
+ CachedBatch(rowCount, columnBuilders.map { builder =>
+ JavaUtils.bufferToArray(builder.build())
+ }, stats)
}
def hasNext: Boolean = rowIterator.hasNext
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala
index 94298fae2d..8851bc23cd 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala
@@ -327,8 +327,8 @@ private[parquet] class CatalystRowConverter(
// are using `Binary.toByteBuffer.array()` to steal the underlying byte array without copying
// it.
val buffer = value.toByteBuffer
- val offset = buffer.position()
- val numBytes = buffer.limit() - buffer.position()
+ val offset = buffer.arrayOffset() + buffer.position()
+ val numBytes = buffer.remaining()
updater.set(UTF8String.fromBytes(buffer.array(), offset, numBytes))
}
}
@@ -644,8 +644,8 @@ private[parquet] object CatalystRowConverter {
// copying it.
val buffer = binary.toByteBuffer
val bytes = buffer.array()
- val start = buffer.position()
- val end = buffer.limit()
+ val start = buffer.arrayOffset() + buffer.position()
+ val end = buffer.arrayOffset() + buffer.limit()
var unscaled = 0L
var i = start
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala
index 500dc70c98..4dab64d696 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala
@@ -27,6 +27,7 @@ import scala.util.control.NonFatal
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.Path
+import org.apache.spark.network.util.JavaUtils
import org.apache.spark.streaming.Time
import org.apache.spark.streaming.util.{BatchedWriteAheadLog, WriteAheadLog, WriteAheadLogUtils}
import org.apache.spark.util.{Clock, Utils}
@@ -210,9 +211,9 @@ private[streaming] class ReceivedBlockTracker(
writeAheadLogOption.foreach { writeAheadLog =>
logInfo(s"Recovering from write ahead logs in ${checkpointDirOption.get}")
writeAheadLog.readAll().asScala.foreach { byteBuffer =>
- logTrace("Recovering record " + byteBuffer)
+ logInfo("Recovering record " + byteBuffer)
Utils.deserialize[ReceivedBlockTrackerLogEvent](
- byteBuffer.array, Thread.currentThread().getContextClassLoader) match {
+ JavaUtils.bufferToArray(byteBuffer), Thread.currentThread().getContextClassLoader) match {
case BlockAdditionEvent(receivedBlockInfo) =>
insertAddedBlock(receivedBlockInfo)
case BatchAllocationEvent(time, allocatedBlocks) =>
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/BatchedWriteAheadLog.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/BatchedWriteAheadLog.scala
index 6e6ed8d819..7158abc088 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/util/BatchedWriteAheadLog.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/util/BatchedWriteAheadLog.scala
@@ -28,6 +28,7 @@ import scala.concurrent.duration._
import scala.util.control.NonFatal
import org.apache.spark.{Logging, SparkConf}
+import org.apache.spark.network.util.JavaUtils
import org.apache.spark.util.Utils
/**
@@ -197,17 +198,10 @@ private[util] object BatchedWriteAheadLog {
*/
case class Record(data: ByteBuffer, time: Long, promise: Promise[WriteAheadLogRecordHandle])
- /** Copies the byte array of a ByteBuffer. */
- private def getByteArray(buffer: ByteBuffer): Array[Byte] = {
- val byteArray = new Array[Byte](buffer.remaining())
- buffer.get(byteArray)
- byteArray
- }
-
/** Aggregate multiple serialized ReceivedBlockTrackerLogEvents in a single ByteBuffer. */
def aggregate(records: Seq[Record]): ByteBuffer = {
ByteBuffer.wrap(Utils.serialize[Array[Array[Byte]]](
- records.map(record => getByteArray(record.data)).toArray))
+ records.map(record => JavaUtils.bufferToArray(record.data)).toArray))
}
/**
@@ -216,10 +210,13 @@ private[util] object BatchedWriteAheadLog {
* method therefore needs to be backwards compatible.
*/
def deaggregate(buffer: ByteBuffer): Array[ByteBuffer] = {
+ val prevPosition = buffer.position()
try {
- Utils.deserialize[Array[Array[Byte]]](getByteArray(buffer)).map(ByteBuffer.wrap)
+ Utils.deserialize[Array[Array[Byte]]](JavaUtils.bufferToArray(buffer)).map(ByteBuffer.wrap)
} catch {
case _: ClassCastException => // users may restart a stream with batching enabled
+ // Restore `position` so that the user can read `buffer` later
+ buffer.position(prevPosition)
Array(buffer)
}
}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLogWriter.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLogWriter.scala
index e146bec32a..1185f30265 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLogWriter.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLogWriter.scala
@@ -24,6 +24,8 @@ import scala.util.Try
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.FSDataOutputStream
+import org.apache.spark.util.Utils
+
/**
* A writer for writing byte-buffers to a write ahead log file.
*/
@@ -48,17 +50,7 @@ private[streaming] class FileBasedWriteAheadLogWriter(path: String, hadoopConf:
val lengthToWrite = data.remaining()
val segment = new FileBasedWriteAheadLogSegment(path, nextOffset, lengthToWrite)
stream.writeInt(lengthToWrite)
- if (data.hasArray) {
- stream.write(data.array())
- } else {
- // If the buffer is not backed by an array, we transfer using temp array
- // Note that despite the extra array copy, this should be faster than byte-by-byte copy
- while (data.hasRemaining) {
- val array = new Array[Byte](data.remaining)
- data.get(array)
- stream.write(array)
- }
- }
+ Utils.writeByteBuffer(data, stream: OutputStream)
flush()
nextOffset = stream.getPos()
segment
diff --git a/streaming/src/test/java/org/apache/spark/streaming/JavaWriteAheadLogSuite.java b/streaming/src/test/java/org/apache/spark/streaming/JavaWriteAheadLogSuite.java
index 09b5f8ed03..f02fa87f61 100644
--- a/streaming/src/test/java/org/apache/spark/streaming/JavaWriteAheadLogSuite.java
+++ b/streaming/src/test/java/org/apache/spark/streaming/JavaWriteAheadLogSuite.java
@@ -17,7 +17,6 @@
package org.apache.spark.streaming;
-import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.nio.ByteBuffer;
import java.util.Arrays;
@@ -27,6 +26,7 @@ import java.util.List;
import com.google.common.base.Function;
import com.google.common.collect.Iterators;
import org.apache.spark.SparkConf;
+import org.apache.spark.network.util.JavaUtils;
import org.apache.spark.streaming.util.WriteAheadLog;
import org.apache.spark.streaming.util.WriteAheadLogRecordHandle;
import org.apache.spark.streaming.util.WriteAheadLogUtils;
@@ -112,20 +112,19 @@ public class JavaWriteAheadLogSuite extends WriteAheadLog {
WriteAheadLog wal = WriteAheadLogUtils.createLogForDriver(conf, null, null);
String data1 = "data1";
- WriteAheadLogRecordHandle handle =
- wal.write(ByteBuffer.wrap(data1.getBytes(StandardCharsets.UTF_8)), 1234);
+ WriteAheadLogRecordHandle handle = wal.write(JavaUtils.stringToBytes(data1), 1234);
Assert.assertTrue(handle instanceof JavaWriteAheadLogSuiteHandle);
- Assert.assertEquals(new String(wal.read(handle).array(), StandardCharsets.UTF_8), data1);
+ Assert.assertEquals(JavaUtils.bytesToString(wal.read(handle)), data1);
- wal.write(ByteBuffer.wrap("data2".getBytes(StandardCharsets.UTF_8)), 1235);
- wal.write(ByteBuffer.wrap("data3".getBytes(StandardCharsets.UTF_8)), 1236);
- wal.write(ByteBuffer.wrap("data4".getBytes(StandardCharsets.UTF_8)), 1237);
+ wal.write(JavaUtils.stringToBytes("data2"), 1235);
+ wal.write(JavaUtils.stringToBytes("data3"), 1236);
+ wal.write(JavaUtils.stringToBytes("data4"), 1237);
wal.clean(1236, false);
Iterator<ByteBuffer> dataIterator = wal.readAll();
List<String> readData = new ArrayList<>();
while (dataIterator.hasNext()) {
- readData.add(new String(dataIterator.next().array(), StandardCharsets.UTF_8));
+ readData.add(JavaUtils.bytesToString(dataIterator.next()));
}
Assert.assertEquals(readData, Arrays.asList("data3", "data4"));
}