aboutsummaryrefslogtreecommitdiff
path: root/core/src/main/scala/org
diff options
context:
space:
mode:
authorShixiong Zhu <shixiong@databricks.com>2017-01-27 15:07:57 -0800
committerShixiong Zhu <shixiong@databricks.com>2017-01-27 15:07:57 -0800
commit21aa8c32ba7a29aafc000ecce2e6c802ced6a009 (patch)
tree27cbf9ae131d63e2632b255088f82509493efb96 /core/src/main/scala/org
parenta7ab6f9a8fdfb927f0bcefdc87a92cc82fac4223 (diff)
downloadspark-21aa8c32ba7a29aafc000ecce2e6c802ced6a009.tar.gz
spark-21aa8c32ba7a29aafc000ecce2e6c802ced6a009.tar.bz2
spark-21aa8c32ba7a29aafc000ecce2e6c802ced6a009.zip
[SPARK-19365][CORE] Optimize RequestMessage serialization
## What changes were proposed in this pull request? Right now Netty PRC serializes `RequestMessage` using Java serialization, and the size of a single message (e.g., RequestMessage(..., "hello")`) is almost 1KB. This PR optimizes it by serializing `RequestMessage` manually (eliminate unnecessary information from most messages, e.g., class names of `RequestMessage`, `NettyRpcEndpointRef`, ...), and reduces the above message size to 100+ bytes. ## How was this patch tested? Jenkins I did a simple test to measure the improvement: Before ``` $ bin/spark-shell --master local-cluster[1,4,1024] ... scala> for (i <- 1 to 10) { | val start = System.nanoTime | val s = sc.parallelize(1 to 1000000, 10 * 1000).count() | val end = System.nanoTime | println(s"$i\t" + ((end - start)/1000/1000)) | } 1 6830 2 4353 3 3322 4 3107 5 3235 6 3139 7 3156 8 3166 9 3091 10 3029 ``` After: ``` $ bin/spark-shell --master local-cluster[1,4,1024] ... scala> for (i <- 1 to 10) { | val start = System.nanoTime | val s = sc.parallelize(1 to 1000000, 10 * 1000).count() | val end = System.nanoTime | println(s"$i\t" + ((end - start)/1000/1000)) | } 1 6431 2 3643 3 2913 4 2679 5 2760 6 2710 7 2747 8 2793 9 2679 10 2651 ``` I also captured the TCP packets for this test. Before this patch, the total size of TCP packets is ~1.5GB. After it, it reduces to ~1.2GB. Author: Shixiong Zhu <shixiong@databricks.com> Closes #16706 from zsxwing/rpc-opt.
Diffstat (limited to 'core/src/main/scala/org')
-rw-r--r--core/src/main/scala/org/apache/spark/rpc/RpcEndpointAddress.scala5
-rw-r--r--core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala119
2 files changed, 99 insertions, 25 deletions
diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEndpointAddress.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEndpointAddress.scala
index b9db60a779..fdbccc9e74 100644
--- a/core/src/main/scala/org/apache/spark/rpc/RpcEndpointAddress.scala
+++ b/core/src/main/scala/org/apache/spark/rpc/RpcEndpointAddress.scala
@@ -25,10 +25,11 @@ import org.apache.spark.SparkException
* The `rpcAddress` may be null, in which case the endpoint is registered via a client-only
* connection and can only be reached via the client that sent the endpoint reference.
*
- * @param rpcAddress The socket address of the endpoint.
+ * @param rpcAddress The socket address of the endpoint. It's `null` when this address pointing to
+ * an endpoint in a client `NettyRpcEnv`.
* @param name Name of the endpoint.
*/
-private[spark] case class RpcEndpointAddress(val rpcAddress: RpcAddress, val name: String) {
+private[spark] case class RpcEndpointAddress(rpcAddress: RpcAddress, name: String) {
require(name != null, "RpcEndpoint name must be provided.")
diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala
index 1e448b2f1a..ff5e39a8dc 100644
--- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala
+++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala
@@ -37,8 +37,8 @@ import org.apache.spark.network.crypto.{AuthClientBootstrap, AuthServerBootstrap
import org.apache.spark.network.netty.SparkTransportConf
import org.apache.spark.network.server._
import org.apache.spark.rpc._
-import org.apache.spark.serializer.{JavaSerializer, JavaSerializerInstance}
-import org.apache.spark.util.{ThreadUtils, Utils}
+import org.apache.spark.serializer.{JavaSerializer, JavaSerializerInstance, SerializationStream}
+import org.apache.spark.util.{ByteBufferInputStream, ByteBufferOutputStream, ThreadUtils, Utils}
private[netty] class NettyRpcEnv(
val conf: SparkConf,
@@ -189,7 +189,7 @@ private[netty] class NettyRpcEnv(
}
} else {
// Message to a remote RPC endpoint.
- postToOutbox(message.receiver, OneWayOutboxMessage(serialize(message)))
+ postToOutbox(message.receiver, OneWayOutboxMessage(message.serialize(this)))
}
}
@@ -224,7 +224,7 @@ private[netty] class NettyRpcEnv(
}(ThreadUtils.sameThread)
dispatcher.postLocalMessage(message, p)
} else {
- val rpcMessage = RpcOutboxMessage(serialize(message),
+ val rpcMessage = RpcOutboxMessage(message.serialize(this),
onFailure,
(client, response) => onSuccess(deserialize[Any](client, response)))
postToOutbox(message.receiver, rpcMessage)
@@ -253,6 +253,13 @@ private[netty] class NettyRpcEnv(
javaSerializerInstance.serialize(content)
}
+ /**
+ * Returns [[SerializationStream]] that forwards the serialized bytes to `out`.
+ */
+ private[netty] def serializeStream(out: OutputStream): SerializationStream = {
+ javaSerializerInstance.serializeStream(out)
+ }
+
private[netty] def deserialize[T: ClassTag](client: TransportClient, bytes: ByteBuffer): T = {
NettyRpcEnv.currentClient.withValue(client) {
deserialize { () =>
@@ -480,16 +487,13 @@ private[rpc] class NettyRpcEnvFactory extends RpcEnvFactory with Logging {
*/
private[netty] class NettyRpcEndpointRef(
@transient private val conf: SparkConf,
- endpointAddress: RpcEndpointAddress,
- @transient @volatile private var nettyEnv: NettyRpcEnv)
- extends RpcEndpointRef(conf) with Serializable with Logging {
+ private val endpointAddress: RpcEndpointAddress,
+ @transient @volatile private var nettyEnv: NettyRpcEnv) extends RpcEndpointRef(conf) {
@transient @volatile var client: TransportClient = _
- private val _address = if (endpointAddress.rpcAddress != null) endpointAddress else null
- private val _name = endpointAddress.name
-
- override def address: RpcAddress = if (_address != null) _address.rpcAddress else null
+ override def address: RpcAddress =
+ if (endpointAddress.rpcAddress != null) endpointAddress.rpcAddress else null
private def readObject(in: ObjectInputStream): Unit = {
in.defaultReadObject()
@@ -501,34 +505,103 @@ private[netty] class NettyRpcEndpointRef(
out.defaultWriteObject()
}
- override def name: String = _name
+ override def name: String = endpointAddress.name
override def ask[T: ClassTag](message: Any, timeout: RpcTimeout): Future[T] = {
- nettyEnv.ask(RequestMessage(nettyEnv.address, this, message), timeout)
+ nettyEnv.ask(new RequestMessage(nettyEnv.address, this, message), timeout)
}
override def send(message: Any): Unit = {
require(message != null, "Message is null")
- nettyEnv.send(RequestMessage(nettyEnv.address, this, message))
+ nettyEnv.send(new RequestMessage(nettyEnv.address, this, message))
}
- override def toString: String = s"NettyRpcEndpointRef(${_address})"
-
- def toURI: URI = new URI(_address.toString)
+ override def toString: String = s"NettyRpcEndpointRef(${endpointAddress})"
final override def equals(that: Any): Boolean = that match {
- case other: NettyRpcEndpointRef => _address == other._address
+ case other: NettyRpcEndpointRef => endpointAddress == other.endpointAddress
case _ => false
}
- final override def hashCode(): Int = if (_address == null) 0 else _address.hashCode()
+ final override def hashCode(): Int =
+ if (endpointAddress == null) 0 else endpointAddress.hashCode()
}
/**
* The message that is sent from the sender to the receiver.
+ *
+ * @param senderAddress the sender address. It's `null` if this message is from a client
+ * `NettyRpcEnv`.
+ * @param receiver the receiver of this message.
+ * @param content the message content.
*/
-private[netty] case class RequestMessage(
- senderAddress: RpcAddress, receiver: NettyRpcEndpointRef, content: Any)
+private[netty] class RequestMessage(
+ val senderAddress: RpcAddress,
+ val receiver: NettyRpcEndpointRef,
+ val content: Any) {
+
+ /** Manually serialize [[RequestMessage]] to minimize the size. */
+ def serialize(nettyEnv: NettyRpcEnv): ByteBuffer = {
+ val bos = new ByteBufferOutputStream()
+ val out = new DataOutputStream(bos)
+ try {
+ writeRpcAddress(out, senderAddress)
+ writeRpcAddress(out, receiver.address)
+ out.writeUTF(receiver.name)
+ val s = nettyEnv.serializeStream(out)
+ try {
+ s.writeObject(content)
+ } finally {
+ s.close()
+ }
+ } finally {
+ out.close()
+ }
+ bos.toByteBuffer
+ }
+
+ private def writeRpcAddress(out: DataOutputStream, rpcAddress: RpcAddress): Unit = {
+ if (rpcAddress == null) {
+ out.writeBoolean(false)
+ } else {
+ out.writeBoolean(true)
+ out.writeUTF(rpcAddress.host)
+ out.writeInt(rpcAddress.port)
+ }
+ }
+
+ override def toString: String = s"RequestMessage($senderAddress, $receiver, $content)"
+}
+
+private[netty] object RequestMessage {
+
+ private def readRpcAddress(in: DataInputStream): RpcAddress = {
+ val hasRpcAddress = in.readBoolean()
+ if (hasRpcAddress) {
+ RpcAddress(in.readUTF(), in.readInt())
+ } else {
+ null
+ }
+ }
+
+ def apply(nettyEnv: NettyRpcEnv, client: TransportClient, bytes: ByteBuffer): RequestMessage = {
+ val bis = new ByteBufferInputStream(bytes)
+ val in = new DataInputStream(bis)
+ try {
+ val senderAddress = readRpcAddress(in)
+ val endpointAddress = RpcEndpointAddress(readRpcAddress(in), in.readUTF())
+ val ref = new NettyRpcEndpointRef(nettyEnv.conf, endpointAddress, nettyEnv)
+ ref.client = client
+ new RequestMessage(
+ senderAddress,
+ ref,
+ // The remaining bytes in `bytes` are the message content.
+ nettyEnv.deserialize(client, bytes))
+ } finally {
+ in.close()
+ }
+ }
+}
/**
* A response that indicates some failure happens in the receiver side.
@@ -574,10 +647,10 @@ private[netty] class NettyRpcHandler(
val addr = client.getChannel().remoteAddress().asInstanceOf[InetSocketAddress]
assert(addr != null)
val clientAddr = RpcAddress(addr.getHostString, addr.getPort)
- val requestMessage = nettyEnv.deserialize[RequestMessage](client, message)
+ val requestMessage = RequestMessage(nettyEnv, client, message)
if (requestMessage.senderAddress == null) {
// Create a new message with the socket address of the client as the sender.
- RequestMessage(clientAddr, requestMessage.receiver, requestMessage.content)
+ new RequestMessage(clientAddr, requestMessage.receiver, requestMessage.content)
} else {
// The remote RpcEnv listens to some port, we should also fire a RemoteProcessConnected for
// the listening address