aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--core/pom.xml5
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala35
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/master/Master.scala18
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/master/PersistenceEngine.scala8
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/master/RecoveryModeFactory.scala9
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala16
-rw-r--r--core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala6
-rw-r--r--core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala14
-rw-r--r--core/src/test/scala/org/apache/spark/deploy/master/CustomRecoveryModeFactory.scala31
-rw-r--r--core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala2
-rw-r--r--core/src/test/scala/org/apache/spark/deploy/master/PersistenceEngineSuite.scala126
-rw-r--r--pom.xml6
12 files changed, 214 insertions, 62 deletions
diff --git a/core/pom.xml b/core/pom.xml
index 558cc3fb9f..73f7a75cab 100644
--- a/core/pom.xml
+++ b/core/pom.xml
@@ -373,6 +373,11 @@
<scope>test</scope>
</dependency>
<dependency>
+ <groupId>org.apache.curator</groupId>
+ <artifactId>curator-test</artifactId>
+ <scope>test</scope>
+ </dependency>
+ <dependency>
<groupId>net.razorvine</groupId>
<artifactId>pyrolite</artifactId>
<version>4.4</version>
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala b/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala
index f459ed5b3a..aa379d4cd6 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala
@@ -21,9 +21,8 @@ import java.io._
import scala.reflect.ClassTag
-import akka.serialization.Serialization
-
import org.apache.spark.Logging
+import org.apache.spark.serializer.{DeserializationStream, SerializationStream, Serializer}
import org.apache.spark.util.Utils
@@ -32,11 +31,11 @@ import org.apache.spark.util.Utils
* Files are deleted when applications and workers are removed.
*
* @param dir Directory to store files. Created if non-existent (but not recursively).
- * @param serialization Used to serialize our objects.
+ * @param serializer Used to serialize our objects.
*/
private[master] class FileSystemPersistenceEngine(
val dir: String,
- val serialization: Serialization)
+ val serializer: Serializer)
extends PersistenceEngine with Logging {
new File(dir).mkdir()
@@ -57,27 +56,31 @@ private[master] class FileSystemPersistenceEngine(
private def serializeIntoFile(file: File, value: AnyRef) {
val created = file.createNewFile()
if (!created) { throw new IllegalStateException("Could not create file: " + file) }
- val serializer = serialization.findSerializerFor(value)
- val serialized = serializer.toBinary(value)
- val out = new FileOutputStream(file)
+ val fileOut = new FileOutputStream(file)
+ var out: SerializationStream = null
Utils.tryWithSafeFinally {
- out.write(serialized)
+ out = serializer.newInstance().serializeStream(fileOut)
+ out.writeObject(value)
} {
- out.close()
+ fileOut.close()
+ if (out != null) {
+ out.close()
+ }
}
}
private def deserializeFromFile[T](file: File)(implicit m: ClassTag[T]): T = {
- val fileData = new Array[Byte](file.length().asInstanceOf[Int])
- val dis = new DataInputStream(new FileInputStream(file))
+ val fileIn = new FileInputStream(file)
+ var in: DeserializationStream = null
try {
- dis.readFully(fileData)
+ in = serializer.newInstance().deserializeStream(fileIn)
+ in.readObject[T]()
} finally {
- dis.close()
+ fileIn.close()
+ if (in != null) {
+ in.close()
+ }
}
- val clazz = m.runtimeClass.asInstanceOf[Class[T]]
- val serializer = serialization.serializerFor(clazz)
- serializer.fromBinary(fileData).asInstanceOf[T]
}
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala
index 245b047e7d..4615febf17 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala
@@ -27,11 +27,8 @@ import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet}
import scala.language.postfixOps
import scala.util.Random
-import akka.serialization.Serialization
-import akka.serialization.SerializationExtension
import org.apache.hadoop.fs.Path
-import org.apache.spark.rpc.akka.AkkaRpcEnv
import org.apache.spark.rpc._
import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkException}
import org.apache.spark.deploy.{ApplicationDescription, DriverDescription,
@@ -44,6 +41,7 @@ import org.apache.spark.deploy.master.ui.MasterWebUI
import org.apache.spark.deploy.rest.StandaloneRestServer
import org.apache.spark.metrics.MetricsSystem
import org.apache.spark.scheduler.{EventLoggingListener, ReplayListenerBus}
+import org.apache.spark.serializer.{JavaSerializer, Serializer}
import org.apache.spark.ui.SparkUI
import org.apache.spark.util.{ThreadUtils, SignalLogger, Utils}
@@ -58,9 +56,6 @@ private[master] class Master(
private val forwardMessageThread =
ThreadUtils.newDaemonSingleThreadScheduledExecutor("master-forward-message-thread")
- // TODO Remove it once we don't use akka.serialization.Serialization
- private val actorSystem = rpcEnv.asInstanceOf[AkkaRpcEnv].actorSystem
-
private val hadoopConf = SparkHadoopUtil.get.newConfiguration(conf)
private def createDateFormat = new SimpleDateFormat("yyyyMMddHHmmss") // For application IDs
@@ -161,20 +156,21 @@ private[master] class Master(
masterMetricsSystem.getServletHandlers.foreach(webUi.attachHandler)
applicationMetricsSystem.getServletHandlers.foreach(webUi.attachHandler)
+ val serializer = new JavaSerializer(conf)
val (persistenceEngine_, leaderElectionAgent_) = RECOVERY_MODE match {
case "ZOOKEEPER" =>
logInfo("Persisting recovery state to ZooKeeper")
val zkFactory =
- new ZooKeeperRecoveryModeFactory(conf, SerializationExtension(actorSystem))
+ new ZooKeeperRecoveryModeFactory(conf, serializer)
(zkFactory.createPersistenceEngine(), zkFactory.createLeaderElectionAgent(this))
case "FILESYSTEM" =>
val fsFactory =
- new FileSystemRecoveryModeFactory(conf, SerializationExtension(actorSystem))
+ new FileSystemRecoveryModeFactory(conf, serializer)
(fsFactory.createPersistenceEngine(), fsFactory.createLeaderElectionAgent(this))
case "CUSTOM" =>
val clazz = Utils.classForName(conf.get("spark.deploy.recoveryMode.factory"))
- val factory = clazz.getConstructor(classOf[SparkConf], classOf[Serialization])
- .newInstance(conf, SerializationExtension(actorSystem))
+ val factory = clazz.getConstructor(classOf[SparkConf], classOf[Serializer])
+ .newInstance(conf, serializer)
.asInstanceOf[StandaloneRecoveryModeFactory]
(factory.createPersistenceEngine(), factory.createLeaderElectionAgent(this))
case _ =>
@@ -213,7 +209,7 @@ private[master] class Master(
override def receive: PartialFunction[Any, Unit] = {
case ElectedLeader => {
- val (storedApps, storedDrivers, storedWorkers) = persistenceEngine.readPersistedData()
+ val (storedApps, storedDrivers, storedWorkers) = persistenceEngine.readPersistedData(rpcEnv)
state = if (storedApps.isEmpty && storedDrivers.isEmpty && storedWorkers.isEmpty) {
RecoveryState.ALIVE
} else {
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/PersistenceEngine.scala b/core/src/main/scala/org/apache/spark/deploy/master/PersistenceEngine.scala
index a03d460509..58a00bceee 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/PersistenceEngine.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/PersistenceEngine.scala
@@ -18,6 +18,7 @@
package org.apache.spark.deploy.master
import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.rpc.RpcEnv
import scala.reflect.ClassTag
@@ -80,8 +81,11 @@ abstract class PersistenceEngine {
* Returns the persisted data sorted by their respective ids (which implies that they're
* sorted by time of creation).
*/
- final def readPersistedData(): (Seq[ApplicationInfo], Seq[DriverInfo], Seq[WorkerInfo]) = {
- (read[ApplicationInfo]("app_"), read[DriverInfo]("driver_"), read[WorkerInfo]("worker_"))
+ final def readPersistedData(
+ rpcEnv: RpcEnv): (Seq[ApplicationInfo], Seq[DriverInfo], Seq[WorkerInfo]) = {
+ rpcEnv.deserialize { () =>
+ (read[ApplicationInfo]("app_"), read[DriverInfo]("driver_"), read[WorkerInfo]("worker_"))
+ }
}
def close() {}
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/RecoveryModeFactory.scala b/core/src/main/scala/org/apache/spark/deploy/master/RecoveryModeFactory.scala
index 351db8fab2..c4c3283fb7 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/RecoveryModeFactory.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/RecoveryModeFactory.scala
@@ -17,10 +17,9 @@
package org.apache.spark.deploy.master
-import akka.serialization.Serialization
-
import org.apache.spark.{Logging, SparkConf}
import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.serializer.Serializer
/**
* ::DeveloperApi::
@@ -30,7 +29,7 @@ import org.apache.spark.annotation.DeveloperApi
*
*/
@DeveloperApi
-abstract class StandaloneRecoveryModeFactory(conf: SparkConf, serializer: Serialization) {
+abstract class StandaloneRecoveryModeFactory(conf: SparkConf, serializer: Serializer) {
/**
* PersistenceEngine defines how the persistent data(Information about worker, driver etc..)
@@ -49,7 +48,7 @@ abstract class StandaloneRecoveryModeFactory(conf: SparkConf, serializer: Serial
* LeaderAgent in this case is a no-op. Since leader is forever leader as the actual
* recovery is made by restoring from filesystem.
*/
-private[master] class FileSystemRecoveryModeFactory(conf: SparkConf, serializer: Serialization)
+private[master] class FileSystemRecoveryModeFactory(conf: SparkConf, serializer: Serializer)
extends StandaloneRecoveryModeFactory(conf, serializer) with Logging {
val RECOVERY_DIR = conf.get("spark.deploy.recoveryDirectory", "")
@@ -64,7 +63,7 @@ private[master] class FileSystemRecoveryModeFactory(conf: SparkConf, serializer:
}
}
-private[master] class ZooKeeperRecoveryModeFactory(conf: SparkConf, serializer: Serialization)
+private[master] class ZooKeeperRecoveryModeFactory(conf: SparkConf, serializer: Serializer)
extends StandaloneRecoveryModeFactory(conf, serializer) {
def createPersistenceEngine(): PersistenceEngine = {
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala
index 328d95a7a0..563831cc6b 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala
@@ -17,7 +17,7 @@
package org.apache.spark.deploy.master
-import akka.serialization.Serialization
+import java.nio.ByteBuffer
import scala.collection.JavaConversions._
import scala.reflect.ClassTag
@@ -27,9 +27,10 @@ import org.apache.zookeeper.CreateMode
import org.apache.spark.{Logging, SparkConf}
import org.apache.spark.deploy.SparkCuratorUtil
+import org.apache.spark.serializer.Serializer
-private[master] class ZooKeeperPersistenceEngine(conf: SparkConf, val serialization: Serialization)
+private[master] class ZooKeeperPersistenceEngine(conf: SparkConf, val serializer: Serializer)
extends PersistenceEngine
with Logging {
@@ -57,17 +58,16 @@ private[master] class ZooKeeperPersistenceEngine(conf: SparkConf, val serializat
}
private def serializeIntoFile(path: String, value: AnyRef) {
- val serializer = serialization.findSerializerFor(value)
- val serialized = serializer.toBinary(value)
- zk.create().withMode(CreateMode.PERSISTENT).forPath(path, serialized)
+ val serialized = serializer.newInstance().serialize(value)
+ val bytes = new Array[Byte](serialized.remaining())
+ serialized.get(bytes)
+ zk.create().withMode(CreateMode.PERSISTENT).forPath(path, bytes)
}
private def deserializeFromFile[T](filename: String)(implicit m: ClassTag[T]): Option[T] = {
val fileData = zk.getData().forPath(WORKING_DIR + "/" + filename)
- val clazz = m.runtimeClass.asInstanceOf[Class[T]]
- val serializer = serialization.serializerFor(clazz)
try {
- Some(serializer.fromBinary(fileData).asInstanceOf[T])
+ Some(serializer.newInstance().deserialize[T](ByteBuffer.wrap(fileData)))
} catch {
case e: Exception => {
logWarning("Exception while reading persisted file, deleting", e)
diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala
index c9fcc7a36c..29debe8081 100644
--- a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala
+++ b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala
@@ -139,6 +139,12 @@ private[spark] abstract class RpcEnv(conf: SparkConf) {
* creating it manually because different [[RpcEnv]] may have different formats.
*/
def uriOf(systemName: String, address: RpcAddress, endpointName: String): String
+
+ /**
+ * [[RpcEndpointRef]] cannot be deserialized without [[RpcEnv]]. So when deserializing any object
+ * that contains [[RpcEndpointRef]]s, the deserialization codes should be wrapped by this method.
+ */
+ def deserialize[T](deserializationAction: () => T): T
}
diff --git a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala
index f2d87f6834..fc17542abf 100644
--- a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala
+++ b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala
@@ -28,7 +28,7 @@ import akka.actor.{ActorSystem, ExtendedActorSystem, Actor, ActorRef, Props, Add
import akka.event.Logging.Error
import akka.pattern.{ask => akkaAsk}
import akka.remote.{AssociationEvent, AssociatedEvent, DisassociatedEvent, AssociationErrorEvent}
-import com.google.common.util.concurrent.MoreExecutors
+import akka.serialization.JavaSerializer
import org.apache.spark.{SparkException, Logging, SparkConf}
import org.apache.spark.rpc._
@@ -239,6 +239,12 @@ private[spark] class AkkaRpcEnv private[akka] (
}
override def toString: String = s"${getClass.getSimpleName}($actorSystem)"
+
+ override def deserialize[T](deserializationAction: () => T): T = {
+ JavaSerializer.currentSystem.withValue(actorSystem.asInstanceOf[ExtendedActorSystem]) {
+ deserializationAction()
+ }
+ }
}
private[spark] class AkkaRpcEnvFactory extends RpcEnvFactory {
@@ -315,6 +321,12 @@ private[akka] class AkkaRpcEndpointRef(
override def toString: String = s"${getClass.getSimpleName}($actorRef)"
+ final override def equals(that: Any): Boolean = that match {
+ case other: AkkaRpcEndpointRef => actorRef == other.actorRef
+ case _ => false
+ }
+
+ final override def hashCode(): Int = if (actorRef == null) 0 else actorRef.hashCode()
}
/**
diff --git a/core/src/test/scala/org/apache/spark/deploy/master/CustomRecoveryModeFactory.scala b/core/src/test/scala/org/apache/spark/deploy/master/CustomRecoveryModeFactory.scala
index f4e56632e4..8c96b0e71d 100644
--- a/core/src/test/scala/org/apache/spark/deploy/master/CustomRecoveryModeFactory.scala
+++ b/core/src/test/scala/org/apache/spark/deploy/master/CustomRecoveryModeFactory.scala
@@ -19,18 +19,19 @@
// when they are outside of org.apache.spark.
package other.supplier
+import java.nio.ByteBuffer
+
import scala.collection.mutable
import scala.reflect.ClassTag
-import akka.serialization.Serialization
-
import org.apache.spark.SparkConf
import org.apache.spark.deploy.master._
+import org.apache.spark.serializer.Serializer
class CustomRecoveryModeFactory(
conf: SparkConf,
- serialization: Serialization
-) extends StandaloneRecoveryModeFactory(conf, serialization) {
+ serializer: Serializer
+) extends StandaloneRecoveryModeFactory(conf, serializer) {
CustomRecoveryModeFactory.instantiationAttempts += 1
@@ -40,7 +41,7 @@ class CustomRecoveryModeFactory(
*
*/
override def createPersistenceEngine(): PersistenceEngine =
- new CustomPersistenceEngine(serialization)
+ new CustomPersistenceEngine(serializer)
/**
* Create an instance of LeaderAgent that decides who gets elected as master.
@@ -53,7 +54,7 @@ object CustomRecoveryModeFactory {
@volatile var instantiationAttempts = 0
}
-class CustomPersistenceEngine(serialization: Serialization) extends PersistenceEngine {
+class CustomPersistenceEngine(serializer: Serializer) extends PersistenceEngine {
val data = mutable.HashMap[String, Array[Byte]]()
CustomPersistenceEngine.lastInstance = Some(this)
@@ -64,10 +65,10 @@ class CustomPersistenceEngine(serialization: Serialization) extends PersistenceE
*/
override def persist(name: String, obj: Object): Unit = {
CustomPersistenceEngine.persistAttempts += 1
- serialization.serialize(obj) match {
- case util.Success(bytes) => data += name -> bytes
- case util.Failure(cause) => throw new RuntimeException(cause)
- }
+ val serialized = serializer.newInstance().serialize(obj)
+ val bytes = new Array[Byte](serialized.remaining())
+ serialized.get(bytes)
+ data += name -> bytes
}
/**
@@ -84,15 +85,9 @@ class CustomPersistenceEngine(serialization: Serialization) extends PersistenceE
*/
override def read[T: ClassTag](prefix: String): Seq[T] = {
CustomPersistenceEngine.readAttempts += 1
- val clazz = implicitly[ClassTag[T]].runtimeClass.asInstanceOf[Class[T]]
val results = for ((name, bytes) <- data; if name.startsWith(prefix))
- yield serialization.deserialize(bytes, clazz)
-
- results.find(_.isFailure).foreach {
- case util.Failure(cause) => throw new RuntimeException(cause)
- }
-
- results.flatMap(_.toOption).toSeq
+ yield serializer.newInstance().deserialize[T](ByteBuffer.wrap(bytes))
+ results.toSeq
}
}
diff --git a/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala b/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala
index 9cb6dd43ba..a8fbaf1d9d 100644
--- a/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala
+++ b/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala
@@ -105,7 +105,7 @@ class MasterSuite extends SparkFunSuite with Matchers with Eventually {
persistenceEngine.addDriver(driverToPersist)
persistenceEngine.addWorker(workerToPersist)
- val (apps, drivers, workers) = persistenceEngine.readPersistedData()
+ val (apps, drivers, workers) = persistenceEngine.readPersistedData(rpcEnv)
apps.map(_.id) should contain(appToPersist.id)
drivers.map(_.id) should contain(driverToPersist.id)
diff --git a/core/src/test/scala/org/apache/spark/deploy/master/PersistenceEngineSuite.scala b/core/src/test/scala/org/apache/spark/deploy/master/PersistenceEngineSuite.scala
new file mode 100644
index 0000000000..11e87bd1dd
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/deploy/master/PersistenceEngineSuite.scala
@@ -0,0 +1,126 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+
+package org.apache.spark.deploy.master
+
+import java.net.ServerSocket
+
+import org.apache.commons.lang3.RandomUtils
+import org.apache.curator.test.TestingServer
+
+import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite}
+import org.apache.spark.rpc.{RpcEndpoint, RpcEnv}
+import org.apache.spark.serializer.{Serializer, JavaSerializer}
+import org.apache.spark.util.Utils
+
+class PersistenceEngineSuite extends SparkFunSuite {
+
+ test("FileSystemPersistenceEngine") {
+ val dir = Utils.createTempDir()
+ try {
+ val conf = new SparkConf()
+ testPersistenceEngine(conf, serializer =>
+ new FileSystemPersistenceEngine(dir.getAbsolutePath, serializer)
+ )
+ } finally {
+ Utils.deleteRecursively(dir)
+ }
+ }
+
+ test("ZooKeeperPersistenceEngine") {
+ val conf = new SparkConf()
+ // TestingServer logs the port conflict exception rather than throwing an exception.
+ // So we have to find a free port by ourselves. This approach cannot guarantee always starting
+ // zkTestServer successfully because there is a time gap between finding a free port and
+ // starting zkTestServer. But the failure possibility should be very low.
+ val zkTestServer = new TestingServer(findFreePort(conf))
+ try {
+ testPersistenceEngine(conf, serializer => {
+ conf.set("spark.deploy.zookeeper.url", zkTestServer.getConnectString)
+ new ZooKeeperPersistenceEngine(conf, serializer)
+ })
+ } finally {
+ zkTestServer.stop()
+ }
+ }
+
+ private def testPersistenceEngine(
+ conf: SparkConf, persistenceEngineCreator: Serializer => PersistenceEngine): Unit = {
+ val serializer = new JavaSerializer(conf)
+ val persistenceEngine = persistenceEngineCreator(serializer)
+ persistenceEngine.persist("test_1", "test_1_value")
+ assert(Seq("test_1_value") === persistenceEngine.read[String]("test_"))
+ persistenceEngine.persist("test_2", "test_2_value")
+ assert(Set("test_1_value", "test_2_value") === persistenceEngine.read[String]("test_").toSet)
+ persistenceEngine.unpersist("test_1")
+ assert(Seq("test_2_value") === persistenceEngine.read[String]("test_"))
+ persistenceEngine.unpersist("test_2")
+ assert(persistenceEngine.read[String]("test_").isEmpty)
+
+ // Test deserializing objects that contain RpcEndpointRef
+ val rpcEnv = RpcEnv.create("test", "localhost", 12345, conf, new SecurityManager(conf))
+ try {
+ // Create a real endpoint so that we can test RpcEndpointRef deserialization
+ val workerEndpoint = rpcEnv.setupEndpoint("worker", new RpcEndpoint {
+ override val rpcEnv: RpcEnv = rpcEnv
+ })
+
+ val workerToPersist = new WorkerInfo(
+ id = "test_worker",
+ host = "127.0.0.1",
+ port = 10000,
+ cores = 0,
+ memory = 0,
+ endpoint = workerEndpoint,
+ webUiPort = 0,
+ publicAddress = ""
+ )
+
+ persistenceEngine.addWorker(workerToPersist)
+
+ val (storedApps, storedDrivers, storedWorkers) = persistenceEngine.readPersistedData(rpcEnv)
+
+ assert(storedApps.isEmpty)
+ assert(storedDrivers.isEmpty)
+
+ // Check deserializing WorkerInfo
+ assert(storedWorkers.size == 1)
+ val recoveryWorkerInfo = storedWorkers.head
+ assert(workerToPersist.id === recoveryWorkerInfo.id)
+ assert(workerToPersist.host === recoveryWorkerInfo.host)
+ assert(workerToPersist.port === recoveryWorkerInfo.port)
+ assert(workerToPersist.cores === recoveryWorkerInfo.cores)
+ assert(workerToPersist.memory === recoveryWorkerInfo.memory)
+ assert(workerToPersist.endpoint === recoveryWorkerInfo.endpoint)
+ assert(workerToPersist.webUiPort === recoveryWorkerInfo.webUiPort)
+ assert(workerToPersist.publicAddress === recoveryWorkerInfo.publicAddress)
+ } finally {
+ rpcEnv.shutdown()
+ rpcEnv.awaitTermination()
+ }
+ }
+
+ private def findFreePort(conf: SparkConf): Int = {
+ val candidatePort = RandomUtils.nextInt(1024, 65536)
+ Utils.startServiceOnPort(candidatePort, (trialPort: Int) => {
+ val socket = new ServerSocket(trialPort)
+ socket.close()
+ (null, trialPort)
+ }, conf)._2
+ }
+}
diff --git a/pom.xml b/pom.xml
index 370c95dd03..aa49e2ab72 100644
--- a/pom.xml
+++ b/pom.xml
@@ -749,6 +749,12 @@
<version>${curator.version}</version>
</dependency>
<dependency>
+ <groupId>org.apache.curator</groupId>
+ <artifactId>curator-test</artifactId>
+ <version>${curator.version}</version>
+ <scope>test</scope>
+ </dependency>
+ <dependency>
<groupId>org.apache.hadoop</groupId>
<artifactId>hadoop-client</artifactId>
<version>${hadoop.version}</version>