diff options
author | Shixiong Zhu <shixiong@databricks.com> | 2015-12-04 17:02:04 -0800 |
---|---|---|
committer | Marcelo Vanzin <vanzin@cloudera.com> | 2015-12-04 17:02:04 -0800 |
commit | 3af53e61fd604fe8000e1fdf656d60b79c842d1c (patch) | |
tree | b12b7a4cdb05361813fc81f93bbed924ba961822 /core/src/main | |
parent | f30373f5ee60f9892c28771e34b208e4f1f675a6 (diff) | |
download | spark-3af53e61fd604fe8000e1fdf656d60b79c842d1c.tar.gz spark-3af53e61fd604fe8000e1fdf656d60b79c842d1c.tar.bz2 spark-3af53e61fd604fe8000e1fdf656d60b79c842d1c.zip |
[SPARK-12084][CORE] Fix codes that uses ByteBuffer.array incorrectly
`ByteBuffer` doesn't guarantee all contents in `ByteBuffer.array` are valid. E.g, a ByteBuffer returned by `ByteBuffer.slice`. We should not use the whole content of `ByteBuffer` unless we know that's correct.
This patch fixed all places that use `ByteBuffer.array` incorrectly.
Author: Shixiong Zhu <shixiong@databricks.com>
Closes #10083 from zsxwing/bytebuffer-array.
Diffstat (limited to 'core/src/main')
7 files changed, 30 insertions, 18 deletions
diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala index 82c16e855b..40604a4da1 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala @@ -30,6 +30,7 @@ import org.apache.spark.network.sasl.{SaslClientBootstrap, SaslServerBootstrap} import org.apache.spark.network.server._ import org.apache.spark.network.shuffle.{RetryingBlockFetcher, BlockFetchingListener, OneForOneBlockFetcher} import org.apache.spark.network.shuffle.protocol.UploadBlock +import org.apache.spark.network.util.JavaUtils import org.apache.spark.serializer.JavaSerializer import org.apache.spark.storage.{BlockId, StorageLevel} import org.apache.spark.util.Utils @@ -123,17 +124,10 @@ class NettyBlockTransferService(conf: SparkConf, securityManager: SecurityManage // StorageLevel is serialized as bytes using our JavaSerializer. Everything else is encoded // using our binary protocol. - val levelBytes = serializer.newInstance().serialize(level).array() + val levelBytes = JavaUtils.bufferToArray(serializer.newInstance().serialize(level)) // Convert or copy nio buffer into array in order to serialize it. - val nioBuffer = blockData.nioByteBuffer() - val array = if (nioBuffer.hasArray) { - nioBuffer.array() - } else { - val data = new Array[Byte](nioBuffer.remaining()) - nioBuffer.get(data) - data - } + val array = JavaUtils.bufferToArray(blockData.nioByteBuffer()) client.sendRpc(new UploadBlock(appId, execId, blockId.toString, levelBytes, array).toByteBuffer, new RpcResponseCallback { diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index e01a9609b9..5582720bbc 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -34,6 +34,7 @@ import org.apache.commons.lang3.SerializationUtils import org.apache.spark._ import org.apache.spark.broadcast.Broadcast import org.apache.spark.executor.TaskMetrics +import org.apache.spark.network.util.JavaUtils import org.apache.spark.partial.{ApproximateActionListener, ApproximateEvaluator, PartialResult} import org.apache.spark.rdd.RDD import org.apache.spark.rpc.RpcTimeout @@ -997,9 +998,10 @@ class DAGScheduler( // For ResultTask, serialize and broadcast (rdd, func). val taskBinaryBytes: Array[Byte] = stage match { case stage: ShuffleMapStage => - closureSerializer.serialize((stage.rdd, stage.shuffleDep): AnyRef).array() + JavaUtils.bufferToArray( + closureSerializer.serialize((stage.rdd, stage.shuffleDep): AnyRef)) case stage: ResultStage => - closureSerializer.serialize((stage.rdd, stage.func): AnyRef).array() + JavaUtils.bufferToArray(closureSerializer.serialize((stage.rdd, stage.func): AnyRef)) } taskBinary = sc.broadcast(taskBinaryBytes) diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala index 2fcd5aa57d..5fe5ae8c45 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala @@ -191,8 +191,8 @@ private[spark] object Task { // Write the task itself and finish dataOut.flush() - val taskBytes = serializer.serialize(task).array() - out.write(taskBytes) + val taskBytes = serializer.serialize(task) + Utils.writeByteBuffer(taskBytes, out) ByteBuffer.wrap(out.toByteArray) } diff --git a/core/src/main/scala/org/apache/spark/serializer/GenericAvroSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/GenericAvroSerializer.scala index 62f8aae7f2..8d6af9cae8 100644 --- a/core/src/main/scala/org/apache/spark/serializer/GenericAvroSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/GenericAvroSerializer.scala @@ -81,7 +81,10 @@ private[serializer] class GenericAvroSerializer(schemas: Map[Long, String]) * seen values so to limit the number of times that decompression has to be done. */ def decompress(schemaBytes: ByteBuffer): Schema = decompressCache.getOrElseUpdate(schemaBytes, { - val bis = new ByteArrayInputStream(schemaBytes.array()) + val bis = new ByteArrayInputStream( + schemaBytes.array(), + schemaBytes.arrayOffset() + schemaBytes.position(), + schemaBytes.remaining()) val bytes = IOUtils.toByteArray(codec.compressedInputStream(bis)) new Schema.Parser().parse(new String(bytes, "UTF-8")) }) diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala index 7b77f78ce6..62d445f3d7 100644 --- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala @@ -309,7 +309,7 @@ private[spark] class KryoSerializerInstance(ks: KryoSerializer) extends Serializ override def deserialize[T: ClassTag](bytes: ByteBuffer): T = { val kryo = borrowKryo() try { - input.setBuffer(bytes.array) + input.setBuffer(bytes.array(), bytes.arrayOffset() + bytes.position(), bytes.remaining()) kryo.readClassAndObject(input).asInstanceOf[T] } finally { releaseKryo(kryo) @@ -321,7 +321,7 @@ private[spark] class KryoSerializerInstance(ks: KryoSerializer) extends Serializ val oldClassLoader = kryo.getClassLoader try { kryo.setClassLoader(loader) - input.setBuffer(bytes.array) + input.setBuffer(bytes.array(), bytes.arrayOffset() + bytes.position(), bytes.remaining()) kryo.readClassAndObject(input).asInstanceOf[T] } finally { kryo.setClassLoader(oldClassLoader) diff --git a/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala index 22878783fc..d14fe46135 100644 --- a/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala @@ -103,7 +103,7 @@ private[spark] class TachyonBlockManager() extends ExternalBlockManager with Log val file = getFile(blockId) val os = file.getOutStream(WriteType.TRY_CACHE) try { - os.write(bytes.array()) + Utils.writeByteBuffer(bytes, os) } catch { case NonFatal(e) => logWarning(s"Failed to put bytes of block $blockId into Tachyon", e) diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index af632349c9..9dbe66e7ee 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -178,7 +178,20 @@ private[spark] object Utils extends Logging { /** * Primitive often used when writing [[java.nio.ByteBuffer]] to [[java.io.DataOutput]] */ - def writeByteBuffer(bb: ByteBuffer, out: ObjectOutput): Unit = { + def writeByteBuffer(bb: ByteBuffer, out: DataOutput): Unit = { + if (bb.hasArray) { + out.write(bb.array(), bb.arrayOffset() + bb.position(), bb.remaining()) + } else { + val bbval = new Array[Byte](bb.remaining()) + bb.get(bbval) + out.write(bbval) + } + } + + /** + * Primitive often used when writing [[java.nio.ByteBuffer]] to [[java.io.OutputStream]] + */ + def writeByteBuffer(bb: ByteBuffer, out: OutputStream): Unit = { if (bb.hasArray) { out.write(bb.array(), bb.arrayOffset() + bb.position(), bb.remaining()) } else { |