aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala7
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockManager.scala15
-rw-r--r--core/src/test/scala/org/apache/spark/DistributedSuite.scala6
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala5
-rw-r--r--streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala3
6 files changed, 22 insertions, 16 deletions
diff --git a/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala b/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala
index 63d1d1767a..d47b75544f 100644
--- a/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala
@@ -44,7 +44,7 @@ class BlockRDD[T: ClassTag](sc: SparkContext, @transient val blockIds: Array[Blo
assertValid()
val blockManager = SparkEnv.get.blockManager
val blockId = split.asInstanceOf[BlockRDDPartition].blockId
- blockManager.get(blockId) match {
+ blockManager.get[T](blockId) match {
case Some(block) => block.data.asInstanceOf[Iterator[T]]
case None =>
throw new Exception("Could not compute split, block " + blockId + " not found")
diff --git a/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala b/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala
index 7b1ec6fcbb..2156d576f1 100644
--- a/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala
+++ b/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala
@@ -180,11 +180,12 @@ private[spark] class SerializerManager(defaultSerializer: Serializer, conf: Spar
* Deserializes an InputStream into an iterator of values and disposes of it when the end of
* the iterator is reached.
*/
- def dataDeserializeStream[T: ClassTag](
+ def dataDeserializeStream[T](
blockId: BlockId,
- inputStream: InputStream): Iterator[T] = {
+ inputStream: InputStream)
+ (classTag: ClassTag[T]): Iterator[T] = {
val stream = new BufferedInputStream(inputStream)
- getSerializer(implicitly[ClassTag[T]])
+ getSerializer(classTag)
.newInstance()
.deserializeStream(wrapStream(blockId, stream))
.asIterator.asInstanceOf[Iterator[T]]
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
index c72f28e00c..0614646771 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
@@ -520,10 +520,11 @@ private[spark] class BlockManager(
*
* This does not acquire a lock on this block in this JVM.
*/
- private def getRemoteValues(blockId: BlockId): Option[BlockResult] = {
+ private def getRemoteValues[T: ClassTag](blockId: BlockId): Option[BlockResult] = {
+ val ct = implicitly[ClassTag[T]]
getRemoteBytes(blockId).map { data =>
val values =
- serializerManager.dataDeserializeStream(blockId, data.toInputStream(dispose = true))
+ serializerManager.dataDeserializeStream(blockId, data.toInputStream(dispose = true))(ct)
new BlockResult(values, DataReadMethod.Network, data.size)
}
}
@@ -602,13 +603,13 @@ private[spark] class BlockManager(
* any locks if the block was fetched from a remote block manager. The read lock will
* automatically be freed once the result's `data` iterator is fully consumed.
*/
- def get(blockId: BlockId): Option[BlockResult] = {
+ def get[T: ClassTag](blockId: BlockId): Option[BlockResult] = {
val local = getLocalValues(blockId)
if (local.isDefined) {
logInfo(s"Found block $blockId locally")
return local
}
- val remote = getRemoteValues(blockId)
+ val remote = getRemoteValues[T](blockId)
if (remote.isDefined) {
logInfo(s"Found block $blockId remotely")
return remote
@@ -660,7 +661,7 @@ private[spark] class BlockManager(
makeIterator: () => Iterator[T]): Either[BlockResult, Iterator[T]] = {
// Attempt to read the block from local or remote storage. If it's present, then we don't need
// to go through the local-get-or-put path.
- get(blockId) match {
+ get[T](blockId)(classTag) match {
case Some(block) =>
return Left(block)
case _ =>
@@ -1204,8 +1205,8 @@ private[spark] class BlockManager(
/**
* Read a block consisting of a single object.
*/
- def getSingle(blockId: BlockId): Option[Any] = {
- get(blockId).map(_.data.next())
+ def getSingle[T: ClassTag](blockId: BlockId): Option[T] = {
+ get[T](blockId).map(_.data.next().asInstanceOf[T])
}
/**
diff --git a/core/src/test/scala/org/apache/spark/DistributedSuite.scala b/core/src/test/scala/org/apache/spark/DistributedSuite.scala
index 4ee0e00fde..4e36adc8ba 100644
--- a/core/src/test/scala/org/apache/spark/DistributedSuite.scala
+++ b/core/src/test/scala/org/apache/spark/DistributedSuite.scala
@@ -170,10 +170,12 @@ class DistributedSuite extends SparkFunSuite with Matchers with LocalSparkContex
blockManager.master.getLocations(blockId).foreach { cmId =>
val bytes = blockTransfer.fetchBlockSync(cmId.host, cmId.port, cmId.executorId,
blockId.toString)
- val deserialized = serializerManager.dataDeserializeStream[Int](blockId,
- new ChunkedByteBuffer(bytes.nioByteBuffer()).toInputStream()).toList
+ val deserialized = serializerManager.dataDeserializeStream(blockId,
+ new ChunkedByteBuffer(bytes.nioByteBuffer()).toInputStream())(data.elementClassTag).toList
assert(deserialized === (1 to 100).toList)
}
+ // This will exercise the getRemoteBytes / getRemoteValues code paths:
+ assert(blockIds.flatMap(id => blockManager.get[Int](id).get.data).toSet === (1 to 1000).toSet)
}
Seq(
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala b/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala
index 53fccd8d5e..0b2ec29813 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala
@@ -120,7 +120,7 @@ class WriteAheadLogBackedBlockRDD[T: ClassTag](
val blockId = partition.blockId
def getBlockFromBlockManager(): Option[Iterator[T]] = {
- blockManager.get(blockId).map(_.data.asInstanceOf[Iterator[T]])
+ blockManager.get[T](blockId).map(_.data.asInstanceOf[Iterator[T]])
}
def getBlockFromWriteAheadLog(): Iterator[T] = {
@@ -163,7 +163,8 @@ class WriteAheadLogBackedBlockRDD[T: ClassTag](
dataRead.rewind()
}
serializerManager
- .dataDeserializeStream(blockId, new ChunkedByteBuffer(dataRead).toInputStream())
+ .dataDeserializeStream(
+ blockId, new ChunkedByteBuffer(dataRead).toInputStream())(elementClassTag)
.asInstanceOf[Iterator[T]]
}
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala
index feb5c30c6a..7e665454a5 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala
@@ -23,6 +23,7 @@ import java.nio.ByteBuffer
import scala.collection.mutable.ArrayBuffer
import scala.concurrent.duration._
import scala.language.postfixOps
+import scala.reflect.ClassTag
import org.apache.hadoop.conf.Configuration
import org.scalatest.{BeforeAndAfter, Matchers}
@@ -163,7 +164,7 @@ class ReceivedBlockHandlerSuite
val bytes = reader.read(fileSegment)
reader.close()
serializerManager.dataDeserializeStream(
- generateBlockId(), new ChunkedByteBuffer(bytes).toInputStream()).toList
+ generateBlockId(), new ChunkedByteBuffer(bytes).toInputStream())(ClassTag.Any).toList
}
loggedData shouldEqual data
}