aboutsummaryrefslogtreecommitdiff
path: root/core/src/main/scala/org/apache/spark/network/SecurityMessage.scala
diff options
context:
space:
mode:
Diffstat (limited to 'core/src/main/scala/org/apache/spark/network/SecurityMessage.scala')
-rw-r--r--core/src/main/scala/org/apache/spark/network/SecurityMessage.scala163
1 files changed, 163 insertions, 0 deletions
diff --git a/core/src/main/scala/org/apache/spark/network/SecurityMessage.scala b/core/src/main/scala/org/apache/spark/network/SecurityMessage.scala
new file mode 100644
index 0000000000..0d9f743b36
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/network/SecurityMessage.scala
@@ -0,0 +1,163 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network
+
+import java.nio.ByteBuffer
+
+import scala.collection.mutable.ArrayBuffer
+import scala.collection.mutable.StringBuilder
+
+import org.apache.spark._
+import org.apache.spark.network._
+
+/**
+ * SecurityMessage is class that contains the connectionId and sasl token
+ * used in SASL negotiation. SecurityMessage has routines for converting
+ * it to and from a BufferMessage so that it can be sent by the ConnectionManager
+ * and easily consumed by users when received.
+ * The api was modeled after BlockMessage.
+ *
+ * The connectionId is the connectionId of the client side. Since
+ * message passing is asynchronous and its possible for the server side (receiving)
+ * to get multiple different types of messages on the same connection the connectionId
+ * is used to know which connnection the security message is intended for.
+ *
+ * For instance, lets say we are node_0. We need to send data to node_1. The node_0 side
+ * is acting as a client and connecting to node_1. SASL negotiation has to occur
+ * between node_0 and node_1 before node_1 trusts node_0 so node_0 sends a security message.
+ * node_1 receives the message from node_0 but before it can process it and send a response,
+ * some thread on node_1 decides it needs to send data to node_0 so it connects to node_0
+ * and sends a security message of its own to authenticate as a client. Now node_0 gets
+ * the message and it needs to decide if this message is in response to it being a client
+ * (from the first send) or if its just node_1 trying to connect to it to send data. This
+ * is where the connectionId field is used. node_0 can lookup the connectionId to see if
+ * it is in response to it being a client or if its in response to someone sending other data.
+ *
+ * The format of a SecurityMessage as its sent is:
+ * - Length of the ConnectionId
+ * - ConnectionId
+ * - Length of the token
+ * - Token
+ */
+private[spark] class SecurityMessage() extends Logging {
+
+ private var connectionId: String = null
+ private var token: Array[Byte] = null
+
+ def set(byteArr: Array[Byte], newconnectionId: String) {
+ if (byteArr == null) {
+ token = new Array[Byte](0)
+ } else {
+ token = byteArr
+ }
+ connectionId = newconnectionId
+ }
+
+ /**
+ * Read the given buffer and set the members of this class.
+ */
+ def set(buffer: ByteBuffer) {
+ val idLength = buffer.getInt()
+ val idBuilder = new StringBuilder(idLength)
+ for (i <- 1 to idLength) {
+ idBuilder += buffer.getChar()
+ }
+ connectionId = idBuilder.toString()
+
+ val tokenLength = buffer.getInt()
+ token = new Array[Byte](tokenLength)
+ if (tokenLength > 0) {
+ buffer.get(token, 0, tokenLength)
+ }
+ }
+
+ def set(bufferMsg: BufferMessage) {
+ val buffer = bufferMsg.buffers.apply(0)
+ buffer.clear()
+ set(buffer)
+ }
+
+ def getConnectionId: String = {
+ return connectionId
+ }
+
+ def getToken: Array[Byte] = {
+ return token
+ }
+
+ /**
+ * Create a BufferMessage that can be sent by the ConnectionManager containing
+ * the security information from this class.
+ * @return BufferMessage
+ */
+ def toBufferMessage: BufferMessage = {
+ val startTime = System.currentTimeMillis
+ val buffers = new ArrayBuffer[ByteBuffer]()
+
+ // 4 bytes for the length of the connectionId
+ // connectionId is of type char so multiple the length by 2 to get number of bytes
+ // 4 bytes for the length of token
+ // token is a byte buffer so just take the length
+ var buffer = ByteBuffer.allocate(4 + connectionId.length() * 2 + 4 + token.length)
+ buffer.putInt(connectionId.length())
+ connectionId.foreach((x: Char) => buffer.putChar(x))
+ buffer.putInt(token.length)
+
+ if (token.length > 0) {
+ buffer.put(token)
+ }
+ buffer.flip()
+ buffers += buffer
+
+ var message = Message.createBufferMessage(buffers)
+ logDebug("message total size is : " + message.size)
+ message.isSecurityNeg = true
+ return message
+ }
+
+ override def toString: String = {
+ "SecurityMessage [connId= " + connectionId + ", Token = " + token + "]"
+ }
+}
+
+private[spark] object SecurityMessage {
+
+ /**
+ * Convert the given BufferMessage to a SecurityMessage by parsing the contents
+ * of the BufferMessage and populating the SecurityMessage fields.
+ * @param bufferMessage is a BufferMessage that was received
+ * @return new SecurityMessage
+ */
+ def fromBufferMessage(bufferMessage: BufferMessage): SecurityMessage = {
+ val newSecurityMessage = new SecurityMessage()
+ newSecurityMessage.set(bufferMessage)
+ newSecurityMessage
+ }
+
+ /**
+ * Create a SecurityMessage to send from a given saslResponse.
+ * @param response is the response to a challenge from the SaslClient or Saslserver
+ * @param connectionId the client connectionId we are negotiation authentication for
+ * @return a new SecurityMessage
+ */
+ def fromResponse(response : Array[Byte], connectionId : String) : SecurityMessage = {
+ val newSecurityMessage = new SecurityMessage()
+ newSecurityMessage.set(response, connectionId)
+ newSecurityMessage
+ }
+}