aboutsummaryrefslogtreecommitdiff
path: root/core/src
diff options
context:
space:
mode:
authorHossein <hossein@databricks.com>2016-10-30 16:17:23 -0700
committerFelix Cheung <felixcheung@apache.org>2016-10-30 16:17:23 -0700
commit2881a2d1d1a650a91df2c6a01275eba14a43b42a (patch)
tree1083f14a8b284f1ebdb9e69a0b842edf6b14116d /core/src
parent8ae2da0b2551011e2f6cf02907a1e20c138a4b2f (diff)
downloadspark-2881a2d1d1a650a91df2c6a01275eba14a43b42a.tar.gz
spark-2881a2d1d1a650a91df2c6a01275eba14a43b42a.tar.bz2
spark-2881a2d1d1a650a91df2c6a01275eba14a43b42a.zip
[SPARK-17919] Make timeout to RBackend configurable in SparkR
## What changes were proposed in this pull request? This patch makes RBackend connection timeout configurable by user. ## How was this patch tested? N/A Author: Hossein <hossein@databricks.com> Closes #15471 from falaki/SPARK-17919.
Diffstat (limited to 'core/src')
-rw-r--r--core/src/main/scala/org/apache/spark/api/r/RBackend.scala15
-rw-r--r--core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala39
-rw-r--r--core/src/main/scala/org/apache/spark/api/r/RRunner.scala3
-rw-r--r--core/src/main/scala/org/apache/spark/api/r/SparkRDefaults.scala30
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/RRunner.scala7
5 files changed, 87 insertions, 7 deletions
diff --git a/core/src/main/scala/org/apache/spark/api/r/RBackend.scala b/core/src/main/scala/org/apache/spark/api/r/RBackend.scala
index 41d0a85ee3..550746c552 100644
--- a/core/src/main/scala/org/apache/spark/api/r/RBackend.scala
+++ b/core/src/main/scala/org/apache/spark/api/r/RBackend.scala
@@ -22,12 +22,13 @@ import java.net.{InetAddress, InetSocketAddress, ServerSocket}
import java.util.concurrent.TimeUnit
import io.netty.bootstrap.ServerBootstrap
-import io.netty.channel.{ChannelFuture, ChannelInitializer, EventLoopGroup}
+import io.netty.channel.{ChannelFuture, ChannelInitializer, ChannelOption, EventLoopGroup}
import io.netty.channel.nio.NioEventLoopGroup
import io.netty.channel.socket.SocketChannel
import io.netty.channel.socket.nio.NioServerSocketChannel
import io.netty.handler.codec.LengthFieldBasedFrameDecoder
import io.netty.handler.codec.bytes.{ByteArrayDecoder, ByteArrayEncoder}
+import io.netty.handler.timeout.ReadTimeoutHandler
import org.apache.spark.SparkConf
import org.apache.spark.internal.Logging
@@ -43,7 +44,10 @@ private[spark] class RBackend {
def init(): Int = {
val conf = new SparkConf()
- bossGroup = new NioEventLoopGroup(conf.getInt("spark.r.numRBackendThreads", 2))
+ val backendConnectionTimeout = conf.getInt(
+ "spark.r.backendConnectionTimeout", SparkRDefaults.DEFAULT_CONNECTION_TIMEOUT)
+ bossGroup = new NioEventLoopGroup(
+ conf.getInt("spark.r.numRBackendThreads", SparkRDefaults.DEFAULT_NUM_RBACKEND_THREADS))
val workerGroup = bossGroup
val handler = new RBackendHandler(this)
@@ -63,6 +67,7 @@ private[spark] class RBackend {
// initialBytesToStrip = 4, i.e. strip out the length field itself
new LengthFieldBasedFrameDecoder(Integer.MAX_VALUE, 0, 4, 0, 4))
.addLast("decoder", new ByteArrayDecoder())
+ .addLast("readTimeoutHandler", new ReadTimeoutHandler(backendConnectionTimeout))
.addLast("handler", handler)
}
})
@@ -110,6 +115,11 @@ private[spark] object RBackend extends Logging {
val boundPort = sparkRBackend.init()
val serverSocket = new ServerSocket(0, 1, InetAddress.getByName("localhost"))
val listenPort = serverSocket.getLocalPort()
+ // Connection timeout is set by socket client. To make it configurable we will pass the
+ // timeout value to client inside the temp file
+ val conf = new SparkConf()
+ val backendConnectionTimeout = conf.getInt(
+ "spark.r.backendConnectionTimeout", SparkRDefaults.DEFAULT_CONNECTION_TIMEOUT)
// tell the R process via temporary file
val path = args(0)
@@ -118,6 +128,7 @@ private[spark] object RBackend extends Logging {
dos.writeInt(boundPort)
dos.writeInt(listenPort)
SerDe.writeString(dos, RUtils.rPackages.getOrElse(""))
+ dos.writeInt(backendConnectionTimeout)
dos.close()
f.renameTo(new File(path))
diff --git a/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala b/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala
index 1422ef888f..9f5afa29d6 100644
--- a/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala
+++ b/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala
@@ -18,16 +18,19 @@
package org.apache.spark.api.r
import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream}
+import java.util.concurrent.TimeUnit
import scala.collection.mutable.HashMap
import scala.language.existentials
import io.netty.channel.{ChannelHandlerContext, SimpleChannelInboundHandler}
import io.netty.channel.ChannelHandler.Sharable
+import io.netty.handler.timeout.ReadTimeoutException
import org.apache.spark.api.r.SerDe._
import org.apache.spark.internal.Logging
-import org.apache.spark.util.Utils
+import org.apache.spark.SparkConf
+import org.apache.spark.util.{ThreadUtils, Utils}
/**
* Handler for RBackend
@@ -83,7 +86,29 @@ private[r] class RBackendHandler(server: RBackend)
writeString(dos, s"Error: unknown method $methodName")
}
} else {
+ // To avoid timeouts when reading results in SparkR driver, we will be regularly sending
+ // heartbeat responses. We use special code +1 to signal the client that backend is
+ // alive and it should continue blocking for result.
+ val execService = ThreadUtils.newDaemonSingleThreadScheduledExecutor("SparkRKeepAliveThread")
+ val pingRunner = new Runnable {
+ override def run(): Unit = {
+ val pingBaos = new ByteArrayOutputStream()
+ val pingDaos = new DataOutputStream(pingBaos)
+ writeInt(pingDaos, +1)
+ ctx.write(pingBaos.toByteArray)
+ }
+ }
+ val conf = new SparkConf()
+ val heartBeatInterval = conf.getInt(
+ "spark.r.heartBeatInterval", SparkRDefaults.DEFAULT_HEARTBEAT_INTERVAL)
+ val backendConnectionTimeout = conf.getInt(
+ "spark.r.backendConnectionTimeout", SparkRDefaults.DEFAULT_CONNECTION_TIMEOUT)
+ val interval = Math.min(heartBeatInterval, backendConnectionTimeout - 1)
+
+ execService.scheduleAtFixedRate(pingRunner, interval, interval, TimeUnit.SECONDS)
handleMethodCall(isStatic, objId, methodName, numArgs, dis, dos)
+ execService.shutdown()
+ execService.awaitTermination(1, TimeUnit.SECONDS)
}
val reply = bos.toByteArray
@@ -95,9 +120,15 @@ private[r] class RBackendHandler(server: RBackend)
}
override def exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable): Unit = {
- // Close the connection when an exception is raised.
- cause.printStackTrace()
- ctx.close()
+ cause match {
+ case timeout: ReadTimeoutException =>
+ // Do nothing. We don't want to timeout on read
+ logWarning("Ignoring read timeout in RBackendHandler")
+ case _ =>
+ // Close the connection when an exception is raised.
+ cause.printStackTrace()
+ ctx.close()
+ }
}
def handleMethodCall(
diff --git a/core/src/main/scala/org/apache/spark/api/r/RRunner.scala b/core/src/main/scala/org/apache/spark/api/r/RRunner.scala
index 496fdf851f..7ef64723d9 100644
--- a/core/src/main/scala/org/apache/spark/api/r/RRunner.scala
+++ b/core/src/main/scala/org/apache/spark/api/r/RRunner.scala
@@ -333,6 +333,8 @@ private[r] object RRunner {
var rCommand = sparkConf.get("spark.sparkr.r.command", "Rscript")
rCommand = sparkConf.get("spark.r.command", rCommand)
+ val rConnectionTimeout = sparkConf.getInt(
+ "spark.r.backendConnectionTimeout", SparkRDefaults.DEFAULT_CONNECTION_TIMEOUT)
val rOptions = "--vanilla"
val rLibDir = RUtils.sparkRPackagePath(isDriver = false)
val rExecScript = rLibDir(0) + "/SparkR/worker/" + script
@@ -344,6 +346,7 @@ private[r] object RRunner {
pb.environment().put("R_TESTS", "")
pb.environment().put("SPARKR_RLIBDIR", rLibDir.mkString(","))
pb.environment().put("SPARKR_WORKER_PORT", port.toString)
+ pb.environment().put("SPARKR_BACKEND_CONNECTION_TIMEOUT", rConnectionTimeout.toString)
pb.redirectErrorStream(true) // redirect stderr into stdout
val proc = pb.start()
val errThread = startStdoutThread(proc)
diff --git a/core/src/main/scala/org/apache/spark/api/r/SparkRDefaults.scala b/core/src/main/scala/org/apache/spark/api/r/SparkRDefaults.scala
new file mode 100644
index 0000000000..af67cbbce4
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/api/r/SparkRDefaults.scala
@@ -0,0 +1,30 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.api.r
+
+private[spark] object SparkRDefaults {
+
+ // Default value for spark.r.backendConnectionTimeout config
+ val DEFAULT_CONNECTION_TIMEOUT: Int = 6000
+
+ // Default value for spark.r.heartBeatInterval config
+ val DEFAULT_HEARTBEAT_INTERVAL: Int = 100
+
+ // Default value for spark.r.numRBackendThreads config
+ val DEFAULT_NUM_RBACKEND_THREADS = 2
+}
diff --git a/core/src/main/scala/org/apache/spark/deploy/RRunner.scala b/core/src/main/scala/org/apache/spark/deploy/RRunner.scala
index d0466830b2..6eb53a8252 100644
--- a/core/src/main/scala/org/apache/spark/deploy/RRunner.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/RRunner.scala
@@ -25,7 +25,7 @@ import scala.collection.JavaConverters._
import org.apache.hadoop.fs.Path
import org.apache.spark.{SparkException, SparkUserAppException}
-import org.apache.spark.api.r.{RBackend, RUtils}
+import org.apache.spark.api.r.{RBackend, RUtils, SparkRDefaults}
import org.apache.spark.util.RedirectThread
/**
@@ -51,6 +51,10 @@ object RRunner {
cmd
}
+ // Connection timeout set by R process on its connection to RBackend in seconds.
+ val backendConnectionTimeout = sys.props.getOrElse(
+ "spark.r.backendConnectionTimeout", SparkRDefaults.DEFAULT_CONNECTION_TIMEOUT.toString)
+
// Check if the file path exists.
// If not, change directory to current working directory for YARN cluster mode
val rF = new File(rFile)
@@ -81,6 +85,7 @@ object RRunner {
val builder = new ProcessBuilder((Seq(rCommand, rFileNormalized) ++ otherArgs).asJava)
val env = builder.environment()
env.put("EXISTING_SPARKR_BACKEND_PORT", sparkRBackendPort.toString)
+ env.put("SPARKR_BACKEND_CONNECTION_TIMEOUT", backendConnectionTimeout)
val rPackageDir = RUtils.sparkRPackagePath(isDriver = true)
// Put the R package directories into an env variable of comma-separated paths
env.put("SPARKR_PACKAGE_DIR", rPackageDir.mkString(","))