From 2881a2d1d1a650a91df2c6a01275eba14a43b42a Mon Sep 17 00:00:00 2001 From: Hossein Date: Sun, 30 Oct 2016 16:17:23 -0700 Subject: [SPARK-17919] Make timeout to RBackend configurable in SparkR ## What changes were proposed in this pull request? This patch makes RBackend connection timeout configurable by user. ## How was this patch tested? N/A Author: Hossein Closes #15471 from falaki/SPARK-17919. --- .../scala/org/apache/spark/api/r/RBackend.scala | 15 +++++++-- .../org/apache/spark/api/r/RBackendHandler.scala | 39 +++++++++++++++++++--- .../scala/org/apache/spark/api/r/RRunner.scala | 3 ++ .../org/apache/spark/api/r/SparkRDefaults.scala | 30 +++++++++++++++++ .../scala/org/apache/spark/deploy/RRunner.scala | 7 +++- 5 files changed, 87 insertions(+), 7 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/api/r/SparkRDefaults.scala (limited to 'core/src') diff --git a/core/src/main/scala/org/apache/spark/api/r/RBackend.scala b/core/src/main/scala/org/apache/spark/api/r/RBackend.scala index 41d0a85ee3..550746c552 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RBackend.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RBackend.scala @@ -22,12 +22,13 @@ import java.net.{InetAddress, InetSocketAddress, ServerSocket} import java.util.concurrent.TimeUnit import io.netty.bootstrap.ServerBootstrap -import io.netty.channel.{ChannelFuture, ChannelInitializer, EventLoopGroup} +import io.netty.channel.{ChannelFuture, ChannelInitializer, ChannelOption, EventLoopGroup} import io.netty.channel.nio.NioEventLoopGroup import io.netty.channel.socket.SocketChannel import io.netty.channel.socket.nio.NioServerSocketChannel import io.netty.handler.codec.LengthFieldBasedFrameDecoder import io.netty.handler.codec.bytes.{ByteArrayDecoder, ByteArrayEncoder} +import io.netty.handler.timeout.ReadTimeoutHandler import org.apache.spark.SparkConf import org.apache.spark.internal.Logging @@ -43,7 +44,10 @@ private[spark] class RBackend { def init(): Int = { val conf = new SparkConf() - bossGroup = new NioEventLoopGroup(conf.getInt("spark.r.numRBackendThreads", 2)) + val backendConnectionTimeout = conf.getInt( + "spark.r.backendConnectionTimeout", SparkRDefaults.DEFAULT_CONNECTION_TIMEOUT) + bossGroup = new NioEventLoopGroup( + conf.getInt("spark.r.numRBackendThreads", SparkRDefaults.DEFAULT_NUM_RBACKEND_THREADS)) val workerGroup = bossGroup val handler = new RBackendHandler(this) @@ -63,6 +67,7 @@ private[spark] class RBackend { // initialBytesToStrip = 4, i.e. strip out the length field itself new LengthFieldBasedFrameDecoder(Integer.MAX_VALUE, 0, 4, 0, 4)) .addLast("decoder", new ByteArrayDecoder()) + .addLast("readTimeoutHandler", new ReadTimeoutHandler(backendConnectionTimeout)) .addLast("handler", handler) } }) @@ -110,6 +115,11 @@ private[spark] object RBackend extends Logging { val boundPort = sparkRBackend.init() val serverSocket = new ServerSocket(0, 1, InetAddress.getByName("localhost")) val listenPort = serverSocket.getLocalPort() + // Connection timeout is set by socket client. To make it configurable we will pass the + // timeout value to client inside the temp file + val conf = new SparkConf() + val backendConnectionTimeout = conf.getInt( + "spark.r.backendConnectionTimeout", SparkRDefaults.DEFAULT_CONNECTION_TIMEOUT) // tell the R process via temporary file val path = args(0) @@ -118,6 +128,7 @@ private[spark] object RBackend extends Logging { dos.writeInt(boundPort) dos.writeInt(listenPort) SerDe.writeString(dos, RUtils.rPackages.getOrElse("")) + dos.writeInt(backendConnectionTimeout) dos.close() f.renameTo(new File(path)) diff --git a/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala b/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala index 1422ef888f..9f5afa29d6 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala @@ -18,16 +18,19 @@ package org.apache.spark.api.r import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream} +import java.util.concurrent.TimeUnit import scala.collection.mutable.HashMap import scala.language.existentials import io.netty.channel.{ChannelHandlerContext, SimpleChannelInboundHandler} import io.netty.channel.ChannelHandler.Sharable +import io.netty.handler.timeout.ReadTimeoutException import org.apache.spark.api.r.SerDe._ import org.apache.spark.internal.Logging -import org.apache.spark.util.Utils +import org.apache.spark.SparkConf +import org.apache.spark.util.{ThreadUtils, Utils} /** * Handler for RBackend @@ -83,7 +86,29 @@ private[r] class RBackendHandler(server: RBackend) writeString(dos, s"Error: unknown method $methodName") } } else { + // To avoid timeouts when reading results in SparkR driver, we will be regularly sending + // heartbeat responses. We use special code +1 to signal the client that backend is + // alive and it should continue blocking for result. + val execService = ThreadUtils.newDaemonSingleThreadScheduledExecutor("SparkRKeepAliveThread") + val pingRunner = new Runnable { + override def run(): Unit = { + val pingBaos = new ByteArrayOutputStream() + val pingDaos = new DataOutputStream(pingBaos) + writeInt(pingDaos, +1) + ctx.write(pingBaos.toByteArray) + } + } + val conf = new SparkConf() + val heartBeatInterval = conf.getInt( + "spark.r.heartBeatInterval", SparkRDefaults.DEFAULT_HEARTBEAT_INTERVAL) + val backendConnectionTimeout = conf.getInt( + "spark.r.backendConnectionTimeout", SparkRDefaults.DEFAULT_CONNECTION_TIMEOUT) + val interval = Math.min(heartBeatInterval, backendConnectionTimeout - 1) + + execService.scheduleAtFixedRate(pingRunner, interval, interval, TimeUnit.SECONDS) handleMethodCall(isStatic, objId, methodName, numArgs, dis, dos) + execService.shutdown() + execService.awaitTermination(1, TimeUnit.SECONDS) } val reply = bos.toByteArray @@ -95,9 +120,15 @@ private[r] class RBackendHandler(server: RBackend) } override def exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable): Unit = { - // Close the connection when an exception is raised. - cause.printStackTrace() - ctx.close() + cause match { + case timeout: ReadTimeoutException => + // Do nothing. We don't want to timeout on read + logWarning("Ignoring read timeout in RBackendHandler") + case _ => + // Close the connection when an exception is raised. + cause.printStackTrace() + ctx.close() + } } def handleMethodCall( diff --git a/core/src/main/scala/org/apache/spark/api/r/RRunner.scala b/core/src/main/scala/org/apache/spark/api/r/RRunner.scala index 496fdf851f..7ef64723d9 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RRunner.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RRunner.scala @@ -333,6 +333,8 @@ private[r] object RRunner { var rCommand = sparkConf.get("spark.sparkr.r.command", "Rscript") rCommand = sparkConf.get("spark.r.command", rCommand) + val rConnectionTimeout = sparkConf.getInt( + "spark.r.backendConnectionTimeout", SparkRDefaults.DEFAULT_CONNECTION_TIMEOUT) val rOptions = "--vanilla" val rLibDir = RUtils.sparkRPackagePath(isDriver = false) val rExecScript = rLibDir(0) + "/SparkR/worker/" + script @@ -344,6 +346,7 @@ private[r] object RRunner { pb.environment().put("R_TESTS", "") pb.environment().put("SPARKR_RLIBDIR", rLibDir.mkString(",")) pb.environment().put("SPARKR_WORKER_PORT", port.toString) + pb.environment().put("SPARKR_BACKEND_CONNECTION_TIMEOUT", rConnectionTimeout.toString) pb.redirectErrorStream(true) // redirect stderr into stdout val proc = pb.start() val errThread = startStdoutThread(proc) diff --git a/core/src/main/scala/org/apache/spark/api/r/SparkRDefaults.scala b/core/src/main/scala/org/apache/spark/api/r/SparkRDefaults.scala new file mode 100644 index 0000000000..af67cbbce4 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/api/r/SparkRDefaults.scala @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.r + +private[spark] object SparkRDefaults { + + // Default value for spark.r.backendConnectionTimeout config + val DEFAULT_CONNECTION_TIMEOUT: Int = 6000 + + // Default value for spark.r.heartBeatInterval config + val DEFAULT_HEARTBEAT_INTERVAL: Int = 100 + + // Default value for spark.r.numRBackendThreads config + val DEFAULT_NUM_RBACKEND_THREADS = 2 +} diff --git a/core/src/main/scala/org/apache/spark/deploy/RRunner.scala b/core/src/main/scala/org/apache/spark/deploy/RRunner.scala index d0466830b2..6eb53a8252 100644 --- a/core/src/main/scala/org/apache/spark/deploy/RRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/RRunner.scala @@ -25,7 +25,7 @@ import scala.collection.JavaConverters._ import org.apache.hadoop.fs.Path import org.apache.spark.{SparkException, SparkUserAppException} -import org.apache.spark.api.r.{RBackend, RUtils} +import org.apache.spark.api.r.{RBackend, RUtils, SparkRDefaults} import org.apache.spark.util.RedirectThread /** @@ -51,6 +51,10 @@ object RRunner { cmd } + // Connection timeout set by R process on its connection to RBackend in seconds. + val backendConnectionTimeout = sys.props.getOrElse( + "spark.r.backendConnectionTimeout", SparkRDefaults.DEFAULT_CONNECTION_TIMEOUT.toString) + // Check if the file path exists. // If not, change directory to current working directory for YARN cluster mode val rF = new File(rFile) @@ -81,6 +85,7 @@ object RRunner { val builder = new ProcessBuilder((Seq(rCommand, rFileNormalized) ++ otherArgs).asJava) val env = builder.environment() env.put("EXISTING_SPARKR_BACKEND_PORT", sparkRBackendPort.toString) + env.put("SPARKR_BACKEND_CONNECTION_TIMEOUT", backendConnectionTimeout) val rPackageDir = RUtils.sparkRPackagePath(isDriver = true) // Put the R package directories into an env variable of comma-separated paths env.put("SPARKR_PACKAGE_DIR", rPackageDir.mkString(",")) -- cgit v1.2.3