aboutsummaryrefslogtreecommitdiff
path: root/core/src/test/scala/org/apache/spark/deploy/master/CustomRecoveryModeFactory.scala
diff options
context:
space:
mode:
Diffstat (limited to 'core/src/test/scala/org/apache/spark/deploy/master/CustomRecoveryModeFactory.scala')
-rw-r--r--core/src/test/scala/org/apache/spark/deploy/master/CustomRecoveryModeFactory.scala31
1 files changed, 13 insertions, 18 deletions
diff --git a/core/src/test/scala/org/apache/spark/deploy/master/CustomRecoveryModeFactory.scala b/core/src/test/scala/org/apache/spark/deploy/master/CustomRecoveryModeFactory.scala
index f4e56632e4..8c96b0e71d 100644
--- a/core/src/test/scala/org/apache/spark/deploy/master/CustomRecoveryModeFactory.scala
+++ b/core/src/test/scala/org/apache/spark/deploy/master/CustomRecoveryModeFactory.scala
@@ -19,18 +19,19 @@
// when they are outside of org.apache.spark.
package other.supplier
+import java.nio.ByteBuffer
+
import scala.collection.mutable
import scala.reflect.ClassTag
-import akka.serialization.Serialization
-
import org.apache.spark.SparkConf
import org.apache.spark.deploy.master._
+import org.apache.spark.serializer.Serializer
class CustomRecoveryModeFactory(
conf: SparkConf,
- serialization: Serialization
-) extends StandaloneRecoveryModeFactory(conf, serialization) {
+ serializer: Serializer
+) extends StandaloneRecoveryModeFactory(conf, serializer) {
CustomRecoveryModeFactory.instantiationAttempts += 1
@@ -40,7 +41,7 @@ class CustomRecoveryModeFactory(
*
*/
override def createPersistenceEngine(): PersistenceEngine =
- new CustomPersistenceEngine(serialization)
+ new CustomPersistenceEngine(serializer)
/**
* Create an instance of LeaderAgent that decides who gets elected as master.
@@ -53,7 +54,7 @@ object CustomRecoveryModeFactory {
@volatile var instantiationAttempts = 0
}
-class CustomPersistenceEngine(serialization: Serialization) extends PersistenceEngine {
+class CustomPersistenceEngine(serializer: Serializer) extends PersistenceEngine {
val data = mutable.HashMap[String, Array[Byte]]()
CustomPersistenceEngine.lastInstance = Some(this)
@@ -64,10 +65,10 @@ class CustomPersistenceEngine(serialization: Serialization) extends PersistenceE
*/
override def persist(name: String, obj: Object): Unit = {
CustomPersistenceEngine.persistAttempts += 1
- serialization.serialize(obj) match {
- case util.Success(bytes) => data += name -> bytes
- case util.Failure(cause) => throw new RuntimeException(cause)
- }
+ val serialized = serializer.newInstance().serialize(obj)
+ val bytes = new Array[Byte](serialized.remaining())
+ serialized.get(bytes)
+ data += name -> bytes
}
/**
@@ -84,15 +85,9 @@ class CustomPersistenceEngine(serialization: Serialization) extends PersistenceE
*/
override def read[T: ClassTag](prefix: String): Seq[T] = {
CustomPersistenceEngine.readAttempts += 1
- val clazz = implicitly[ClassTag[T]].runtimeClass.asInstanceOf[Class[T]]
val results = for ((name, bytes) <- data; if name.startsWith(prefix))
- yield serialization.deserialize(bytes, clazz)
-
- results.find(_.isFailure).foreach {
- case util.Failure(cause) => throw new RuntimeException(cause)
- }
-
- results.flatMap(_.toOption).toSeq
+ yield serializer.newInstance().deserialize[T](ByteBuffer.wrap(bytes))
+ results.toSeq
}
}