aboutsummaryrefslogtreecommitdiff
path: root/core/src/test/scala/org/apache/spark/deploy
diff options
context:
space:
mode:
Diffstat (limited to 'core/src/test/scala/org/apache/spark/deploy')
-rw-r--r--core/src/test/scala/org/apache/spark/deploy/master/CustomRecoveryModeFactory.scala31
-rw-r--r--core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala2
-rw-r--r--core/src/test/scala/org/apache/spark/deploy/master/PersistenceEngineSuite.scala126
3 files changed, 140 insertions, 19 deletions
diff --git a/core/src/test/scala/org/apache/spark/deploy/master/CustomRecoveryModeFactory.scala b/core/src/test/scala/org/apache/spark/deploy/master/CustomRecoveryModeFactory.scala
index f4e56632e4..8c96b0e71d 100644
--- a/core/src/test/scala/org/apache/spark/deploy/master/CustomRecoveryModeFactory.scala
+++ b/core/src/test/scala/org/apache/spark/deploy/master/CustomRecoveryModeFactory.scala
@@ -19,18 +19,19 @@
// when they are outside of org.apache.spark.
package other.supplier
+import java.nio.ByteBuffer
+
import scala.collection.mutable
import scala.reflect.ClassTag
-import akka.serialization.Serialization
-
import org.apache.spark.SparkConf
import org.apache.spark.deploy.master._
+import org.apache.spark.serializer.Serializer
class CustomRecoveryModeFactory(
conf: SparkConf,
- serialization: Serialization
-) extends StandaloneRecoveryModeFactory(conf, serialization) {
+ serializer: Serializer
+) extends StandaloneRecoveryModeFactory(conf, serializer) {
CustomRecoveryModeFactory.instantiationAttempts += 1
@@ -40,7 +41,7 @@ class CustomRecoveryModeFactory(
*
*/
override def createPersistenceEngine(): PersistenceEngine =
- new CustomPersistenceEngine(serialization)
+ new CustomPersistenceEngine(serializer)
/**
* Create an instance of LeaderAgent that decides who gets elected as master.
@@ -53,7 +54,7 @@ object CustomRecoveryModeFactory {
@volatile var instantiationAttempts = 0
}
-class CustomPersistenceEngine(serialization: Serialization) extends PersistenceEngine {
+class CustomPersistenceEngine(serializer: Serializer) extends PersistenceEngine {
val data = mutable.HashMap[String, Array[Byte]]()
CustomPersistenceEngine.lastInstance = Some(this)
@@ -64,10 +65,10 @@ class CustomPersistenceEngine(serialization: Serialization) extends PersistenceE
*/
override def persist(name: String, obj: Object): Unit = {
CustomPersistenceEngine.persistAttempts += 1
- serialization.serialize(obj) match {
- case util.Success(bytes) => data += name -> bytes
- case util.Failure(cause) => throw new RuntimeException(cause)
- }
+ val serialized = serializer.newInstance().serialize(obj)
+ val bytes = new Array[Byte](serialized.remaining())
+ serialized.get(bytes)
+ data += name -> bytes
}
/**
@@ -84,15 +85,9 @@ class CustomPersistenceEngine(serialization: Serialization) extends PersistenceE
*/
override def read[T: ClassTag](prefix: String): Seq[T] = {
CustomPersistenceEngine.readAttempts += 1
- val clazz = implicitly[ClassTag[T]].runtimeClass.asInstanceOf[Class[T]]
val results = for ((name, bytes) <- data; if name.startsWith(prefix))
- yield serialization.deserialize(bytes, clazz)
-
- results.find(_.isFailure).foreach {
- case util.Failure(cause) => throw new RuntimeException(cause)
- }
-
- results.flatMap(_.toOption).toSeq
+ yield serializer.newInstance().deserialize[T](ByteBuffer.wrap(bytes))
+ results.toSeq
}
}
diff --git a/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala b/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala
index 9cb6dd43ba..a8fbaf1d9d 100644
--- a/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala
+++ b/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala
@@ -105,7 +105,7 @@ class MasterSuite extends SparkFunSuite with Matchers with Eventually {
persistenceEngine.addDriver(driverToPersist)
persistenceEngine.addWorker(workerToPersist)
- val (apps, drivers, workers) = persistenceEngine.readPersistedData()
+ val (apps, drivers, workers) = persistenceEngine.readPersistedData(rpcEnv)
apps.map(_.id) should contain(appToPersist.id)
drivers.map(_.id) should contain(driverToPersist.id)
diff --git a/core/src/test/scala/org/apache/spark/deploy/master/PersistenceEngineSuite.scala b/core/src/test/scala/org/apache/spark/deploy/master/PersistenceEngineSuite.scala
new file mode 100644
index 0000000000..11e87bd1dd
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/deploy/master/PersistenceEngineSuite.scala
@@ -0,0 +1,126 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+
+package org.apache.spark.deploy.master
+
+import java.net.ServerSocket
+
+import org.apache.commons.lang3.RandomUtils
+import org.apache.curator.test.TestingServer
+
+import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite}
+import org.apache.spark.rpc.{RpcEndpoint, RpcEnv}
+import org.apache.spark.serializer.{Serializer, JavaSerializer}
+import org.apache.spark.util.Utils
+
+class PersistenceEngineSuite extends SparkFunSuite {
+
+ test("FileSystemPersistenceEngine") {
+ val dir = Utils.createTempDir()
+ try {
+ val conf = new SparkConf()
+ testPersistenceEngine(conf, serializer =>
+ new FileSystemPersistenceEngine(dir.getAbsolutePath, serializer)
+ )
+ } finally {
+ Utils.deleteRecursively(dir)
+ }
+ }
+
+ test("ZooKeeperPersistenceEngine") {
+ val conf = new SparkConf()
+ // TestingServer logs the port conflict exception rather than throwing an exception.
+ // So we have to find a free port by ourselves. This approach cannot guarantee always starting
+ // zkTestServer successfully because there is a time gap between finding a free port and
+ // starting zkTestServer. But the failure possibility should be very low.
+ val zkTestServer = new TestingServer(findFreePort(conf))
+ try {
+ testPersistenceEngine(conf, serializer => {
+ conf.set("spark.deploy.zookeeper.url", zkTestServer.getConnectString)
+ new ZooKeeperPersistenceEngine(conf, serializer)
+ })
+ } finally {
+ zkTestServer.stop()
+ }
+ }
+
+ private def testPersistenceEngine(
+ conf: SparkConf, persistenceEngineCreator: Serializer => PersistenceEngine): Unit = {
+ val serializer = new JavaSerializer(conf)
+ val persistenceEngine = persistenceEngineCreator(serializer)
+ persistenceEngine.persist("test_1", "test_1_value")
+ assert(Seq("test_1_value") === persistenceEngine.read[String]("test_"))
+ persistenceEngine.persist("test_2", "test_2_value")
+ assert(Set("test_1_value", "test_2_value") === persistenceEngine.read[String]("test_").toSet)
+ persistenceEngine.unpersist("test_1")
+ assert(Seq("test_2_value") === persistenceEngine.read[String]("test_"))
+ persistenceEngine.unpersist("test_2")
+ assert(persistenceEngine.read[String]("test_").isEmpty)
+
+ // Test deserializing objects that contain RpcEndpointRef
+ val rpcEnv = RpcEnv.create("test", "localhost", 12345, conf, new SecurityManager(conf))
+ try {
+ // Create a real endpoint so that we can test RpcEndpointRef deserialization
+ val workerEndpoint = rpcEnv.setupEndpoint("worker", new RpcEndpoint {
+ override val rpcEnv: RpcEnv = rpcEnv
+ })
+
+ val workerToPersist = new WorkerInfo(
+ id = "test_worker",
+ host = "127.0.0.1",
+ port = 10000,
+ cores = 0,
+ memory = 0,
+ endpoint = workerEndpoint,
+ webUiPort = 0,
+ publicAddress = ""
+ )
+
+ persistenceEngine.addWorker(workerToPersist)
+
+ val (storedApps, storedDrivers, storedWorkers) = persistenceEngine.readPersistedData(rpcEnv)
+
+ assert(storedApps.isEmpty)
+ assert(storedDrivers.isEmpty)
+
+ // Check deserializing WorkerInfo
+ assert(storedWorkers.size == 1)
+ val recoveryWorkerInfo = storedWorkers.head
+ assert(workerToPersist.id === recoveryWorkerInfo.id)
+ assert(workerToPersist.host === recoveryWorkerInfo.host)
+ assert(workerToPersist.port === recoveryWorkerInfo.port)
+ assert(workerToPersist.cores === recoveryWorkerInfo.cores)
+ assert(workerToPersist.memory === recoveryWorkerInfo.memory)
+ assert(workerToPersist.endpoint === recoveryWorkerInfo.endpoint)
+ assert(workerToPersist.webUiPort === recoveryWorkerInfo.webUiPort)
+ assert(workerToPersist.publicAddress === recoveryWorkerInfo.publicAddress)
+ } finally {
+ rpcEnv.shutdown()
+ rpcEnv.awaitTermination()
+ }
+ }
+
+ private def findFreePort(conf: SparkConf): Int = {
+ val candidatePort = RandomUtils.nextInt(1024, 65536)
+ Utils.startServiceOnPort(candidatePort, (trialPort: Int) => {
+ val socket = new ServerSocket(trialPort)
+ socket.close()
+ (null, trialPort)
+ }, conf)._2
+ }
+}