aboutsummaryrefslogtreecommitdiff
path: root/core
diff options
context:
space:
mode:
Diffstat (limited to 'core')
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala7
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockManager.scala15
-rw-r--r--core/src/test/scala/org/apache/spark/DistributedSuite.scala6
4 files changed, 17 insertions, 13 deletions
diff --git a/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala b/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala
index 63d1d1767a..d47b75544f 100644
--- a/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala
@@ -44,7 +44,7 @@ class BlockRDD[T: ClassTag](sc: SparkContext, @transient val blockIds: Array[Blo
assertValid()
val blockManager = SparkEnv.get.blockManager
val blockId = split.asInstanceOf[BlockRDDPartition].blockId
- blockManager.get(blockId) match {
+ blockManager.get[T](blockId) match {
case Some(block) => block.data.asInstanceOf[Iterator[T]]
case None =>
throw new Exception("Could not compute split, block " + blockId + " not found")
diff --git a/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala b/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala
index 7b1ec6fcbb..2156d576f1 100644
--- a/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala
+++ b/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala
@@ -180,11 +180,12 @@ private[spark] class SerializerManager(defaultSerializer: Serializer, conf: Spar
* Deserializes an InputStream into an iterator of values and disposes of it when the end of
* the iterator is reached.
*/
- def dataDeserializeStream[T: ClassTag](
+ def dataDeserializeStream[T](
blockId: BlockId,
- inputStream: InputStream): Iterator[T] = {
+ inputStream: InputStream)
+ (classTag: ClassTag[T]): Iterator[T] = {
val stream = new BufferedInputStream(inputStream)
- getSerializer(implicitly[ClassTag[T]])
+ getSerializer(classTag)
.newInstance()
.deserializeStream(wrapStream(blockId, stream))
.asIterator.asInstanceOf[Iterator[T]]
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
index c72f28e00c..0614646771 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
@@ -520,10 +520,11 @@ private[spark] class BlockManager(
*
* This does not acquire a lock on this block in this JVM.
*/
- private def getRemoteValues(blockId: BlockId): Option[BlockResult] = {
+ private def getRemoteValues[T: ClassTag](blockId: BlockId): Option[BlockResult] = {
+ val ct = implicitly[ClassTag[T]]
getRemoteBytes(blockId).map { data =>
val values =
- serializerManager.dataDeserializeStream(blockId, data.toInputStream(dispose = true))
+ serializerManager.dataDeserializeStream(blockId, data.toInputStream(dispose = true))(ct)
new BlockResult(values, DataReadMethod.Network, data.size)
}
}
@@ -602,13 +603,13 @@ private[spark] class BlockManager(
* any locks if the block was fetched from a remote block manager. The read lock will
* automatically be freed once the result's `data` iterator is fully consumed.
*/
- def get(blockId: BlockId): Option[BlockResult] = {
+ def get[T: ClassTag](blockId: BlockId): Option[BlockResult] = {
val local = getLocalValues(blockId)
if (local.isDefined) {
logInfo(s"Found block $blockId locally")
return local
}
- val remote = getRemoteValues(blockId)
+ val remote = getRemoteValues[T](blockId)
if (remote.isDefined) {
logInfo(s"Found block $blockId remotely")
return remote
@@ -660,7 +661,7 @@ private[spark] class BlockManager(
makeIterator: () => Iterator[T]): Either[BlockResult, Iterator[T]] = {
// Attempt to read the block from local or remote storage. If it's present, then we don't need
// to go through the local-get-or-put path.
- get(blockId) match {
+ get[T](blockId)(classTag) match {
case Some(block) =>
return Left(block)
case _ =>
@@ -1204,8 +1205,8 @@ private[spark] class BlockManager(
/**
* Read a block consisting of a single object.
*/
- def getSingle(blockId: BlockId): Option[Any] = {
- get(blockId).map(_.data.next())
+ def getSingle[T: ClassTag](blockId: BlockId): Option[T] = {
+ get[T](blockId).map(_.data.next().asInstanceOf[T])
}
/**
diff --git a/core/src/test/scala/org/apache/spark/DistributedSuite.scala b/core/src/test/scala/org/apache/spark/DistributedSuite.scala
index 4ee0e00fde..4e36adc8ba 100644
--- a/core/src/test/scala/org/apache/spark/DistributedSuite.scala
+++ b/core/src/test/scala/org/apache/spark/DistributedSuite.scala
@@ -170,10 +170,12 @@ class DistributedSuite extends SparkFunSuite with Matchers with LocalSparkContex
blockManager.master.getLocations(blockId).foreach { cmId =>
val bytes = blockTransfer.fetchBlockSync(cmId.host, cmId.port, cmId.executorId,
blockId.toString)
- val deserialized = serializerManager.dataDeserializeStream[Int](blockId,
- new ChunkedByteBuffer(bytes.nioByteBuffer()).toInputStream()).toList
+ val deserialized = serializerManager.dataDeserializeStream(blockId,
+ new ChunkedByteBuffer(bytes.nioByteBuffer()).toInputStream())(data.elementClassTag).toList
assert(deserialized === (1 to 100).toList)
}
+ // This will exercise the getRemoteBytes / getRemoteValues code paths:
+ assert(blockIds.flatMap(id => blockManager.get[Int](id).get.data).toSet === (1 to 1000).toSet)
}
Seq(