/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.network.netty
import java.io.InputStreamReader
import java.nio._
import java.nio.charset.StandardCharsets
import java.util.concurrent.TimeUnit
import scala.concurrent.{Await, Promise}
import scala.concurrent.duration._
import scala.util.{Failure, Success, Try}
import com.google.common.io.CharStreams
import org.mockito.Mockito._
import org.scalatest.mock.MockitoSugar
import org.scalatest.ShouldMatchers
import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite}
import org.apache.spark.network.{BlockDataManager, BlockTransferService}
import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer}
import org.apache.spark.network.shuffle.BlockFetchingListener
import org.apache.spark.storage.{BlockId, ShuffleBlockId}
class NettyBlockTransferSecuritySuite extends SparkFunSuite with MockitoSugar with ShouldMatchers {
test("security default off") {
val conf = new SparkConf()
.set("spark.app.id", "app-id")
testConnection(conf, conf) match {
case Success(_) => // expected
case Failure(t) => fail(t)
}
}
test("security on same password") {
val conf = new SparkConf()
.set("spark.authenticate", "true")
.set("spark.authenticate.secret", "good")
.set("spark.app.id", "app-id")
testConnection(conf, conf) match {
case Success(_) => // expected
case Failure(t) => fail(t)
}
}
test("security on mismatch password") {
val conf0 = new SparkConf()
.set("spark.authenticate", "true")
.set("spark.authenticate.secret", "good")
.set("spark.app.id", "app-id")
val conf1 = conf0.clone.set("spark.authenticate.secret", "bad")
testConnection(conf0, conf1) match {
case Success(_) => fail("Should have failed")
case Failure(t) => t.getMessage should include ("Mismatched response")
}
}
test("security mismatch auth off on server") {
val conf0 = new SparkConf()
.set("spark.authenticate", "true")
.set("spark.authenticate.secret", "good")
.set("spark.app.id", "app-id")
val conf1 = conf0.clone.set("spark.authenticate", "false")
testConnection(conf0, conf1) match {
case Success(_) => fail("Should have failed")
case Failure(t) => // any funny error may occur, sever will interpret SASL token as RPC
}
}
test("security mismatch auth off on client") {
val conf0 = new SparkConf()
.set("spark.authenticate", "false")
.set("spark.authenticate.secret", "good")
.set("spark.app.id", "app-id")
val conf1 = conf0.clone.set("spark.authenticate", "true")
testConnection(conf0, conf1) match {
case Success(_) => fail("Should have failed")
case Failure(t) => t.getMessage should include ("Expected SaslMessage")
}
}
/**
* Creates two servers with different configurations and sees if they can talk.
* Returns Success() if they can transfer a block, and Failure() if the block transfer was failed
* properly. We will throw an out-of-band exception if something other than that goes wrong.
*/
private def testConnection(conf0: SparkConf, conf1: SparkConf): Try[Unit] = {
val blockManager = mock[BlockDataManager]
val blockId = ShuffleBlockId(0, 1, 2)
val blockString = "Hello, world!"
val blockBuffer = new NioManagedBuffer(ByteBuffer.wrap(
blockString.getBytes(StandardCharsets.UTF_8)))
when(blockManager.getBlockData(blockId)).thenReturn(blockBuffer)
val securityManager0 = new SecurityManager(conf0)
val exec0 = new NettyBlockTransferService(conf0, securityManager0, numCores = 1)
exec0.init(blockManager)
val securityManager1 = new SecurityManager(conf1)
val exec1 = new NettyBlockTransferService(conf1, securityManager1, numCores = 1)
exec1.init(blockManager)
val result = fetchBlock(exec0, exec1, "1", blockId) match {
case Success(buf) =>
val actualString = CharStreams.toString(
new InputStreamReader(buf.createInputStream(), StandardCharsets.UTF_8))
actualString should equal(blockString)
buf.release()
Success(())
case Failure(t) =>
Failure(t)
}
exec0.close()
exec1.close()
result
}
/** Synchronously fetches a single block, acting as the given executor fetching from another. */
private def fetchBlock(
self: BlockTransferService,
from: BlockTransferService,
execId: String,
blockId: BlockId): Try[ManagedBuffer] = {
val promise = Promise[ManagedBuffer]()
self.fetchBlocks(from.hostName, from.port, execId, Array(blockId.toString),
new BlockFetchingListener {
override def onBlockFetchFailure(blockId: String, exception: Throwable): Unit = {
promise.failure(exception)
}
override def onBlockFetchSuccess(blockId: String, data: ManagedBuffer): Unit = {
promise.success(data.retain())
}
})
Await.ready(promise.future, FiniteDuration(10, TimeUnit.SECONDS))
promise.future.value.get
}
}