aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorHossein <hossein@databricks.com>2016-10-30 16:17:23 -0700
committerFelix Cheung <felixcheung@apache.org>2016-10-30 16:17:23 -0700
commit2881a2d1d1a650a91df2c6a01275eba14a43b42a (patch)
tree1083f14a8b284f1ebdb9e69a0b842edf6b14116d
parent8ae2da0b2551011e2f6cf02907a1e20c138a4b2f (diff)
downloadspark-2881a2d1d1a650a91df2c6a01275eba14a43b42a.tar.gz
spark-2881a2d1d1a650a91df2c6a01275eba14a43b42a.tar.bz2
spark-2881a2d1d1a650a91df2c6a01275eba14a43b42a.zip
[SPARK-17919] Make timeout to RBackend configurable in SparkR
## What changes were proposed in this pull request? This patch makes RBackend connection timeout configurable by user. ## How was this patch tested? N/A Author: Hossein <hossein@databricks.com> Closes #15471 from falaki/SPARK-17919.
-rw-r--r--R/pkg/R/backend.R20
-rw-r--r--R/pkg/R/client.R2
-rw-r--r--R/pkg/R/sparkR.R8
-rw-r--r--R/pkg/inst/worker/daemon.R4
-rw-r--r--R/pkg/inst/worker/worker.R7
-rw-r--r--core/src/main/scala/org/apache/spark/api/r/RBackend.scala15
-rw-r--r--core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala39
-rw-r--r--core/src/main/scala/org/apache/spark/api/r/RRunner.scala3
-rw-r--r--core/src/main/scala/org/apache/spark/api/r/SparkRDefaults.scala30
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/RRunner.scala7
-rw-r--r--docs/configuration.md15
11 files changed, 134 insertions, 16 deletions
diff --git a/R/pkg/R/backend.R b/R/pkg/R/backend.R
index 03e70bb2cb..0a789e6c37 100644
--- a/R/pkg/R/backend.R
+++ b/R/pkg/R/backend.R
@@ -108,13 +108,27 @@ invokeJava <- function(isStatic, objId, methodName, ...) {
conn <- get(".sparkRCon", .sparkREnv)
writeBin(requestMessage, conn)
- # TODO: check the status code to output error information
returnStatus <- readInt(conn)
+ handleErrors(returnStatus, conn)
+
+ # Backend will send +1 as keep alive value to prevent various connection timeouts
+ # on very long running jobs. See spark.r.heartBeatInterval
+ while (returnStatus == 1) {
+ returnStatus <- readInt(conn)
+ handleErrors(returnStatus, conn)
+ }
+
+ readObject(conn)
+}
+
+# Helper function to check for returned errors and print appropriate error message to user
+handleErrors <- function(returnStatus, conn) {
if (length(returnStatus) == 0) {
stop("No status is returned. Java SparkR backend might have failed.")
}
- if (returnStatus != 0) {
+
+ # 0 is success and +1 is reserved for heartbeats. Other negative values indicate errors.
+ if (returnStatus < 0) {
stop(readString(conn))
}
- readObject(conn)
}
diff --git a/R/pkg/R/client.R b/R/pkg/R/client.R
index 2d341d836c..9d82814211 100644
--- a/R/pkg/R/client.R
+++ b/R/pkg/R/client.R
@@ -19,7 +19,7 @@
# Creates a SparkR client connection object
# if one doesn't already exist
-connectBackend <- function(hostname, port, timeout = 6000) {
+connectBackend <- function(hostname, port, timeout) {
if (exists(".sparkRcon", envir = .sparkREnv)) {
if (isOpen(.sparkREnv[[".sparkRCon"]])) {
cat("SparkRBackend client connection already exists\n")
diff --git a/R/pkg/R/sparkR.R b/R/pkg/R/sparkR.R
index cc6d591bb2..6b4a2f2fdc 100644
--- a/R/pkg/R/sparkR.R
+++ b/R/pkg/R/sparkR.R
@@ -154,6 +154,7 @@ sparkR.sparkContext <- function(
packages <- processSparkPackages(sparkPackages)
existingPort <- Sys.getenv("EXISTING_SPARKR_BACKEND_PORT", "")
+ connectionTimeout <- as.numeric(Sys.getenv("SPARKR_BACKEND_CONNECTION_TIMEOUT", "6000"))
if (existingPort != "") {
if (length(packages) != 0) {
warning(paste("sparkPackages has no effect when using spark-submit or sparkR shell",
@@ -187,6 +188,7 @@ sparkR.sparkContext <- function(
backendPort <- readInt(f)
monitorPort <- readInt(f)
rLibPath <- readString(f)
+ connectionTimeout <- readInt(f)
close(f)
file.remove(path)
if (length(backendPort) == 0 || backendPort == 0 ||
@@ -194,7 +196,9 @@ sparkR.sparkContext <- function(
length(rLibPath) != 1) {
stop("JVM failed to launch")
}
- assign(".monitorConn", socketConnection(port = monitorPort), envir = .sparkREnv)
+ assign(".monitorConn",
+ socketConnection(port = monitorPort, timeout = connectionTimeout),
+ envir = .sparkREnv)
assign(".backendLaunched", 1, envir = .sparkREnv)
if (rLibPath != "") {
assign(".libPath", rLibPath, envir = .sparkREnv)
@@ -204,7 +208,7 @@ sparkR.sparkContext <- function(
.sparkREnv$backendPort <- backendPort
tryCatch({
- connectBackend("localhost", backendPort)
+ connectBackend("localhost", backendPort, timeout = connectionTimeout)
},
error = function(err) {
stop("Failed to connect JVM\n")
diff --git a/R/pkg/inst/worker/daemon.R b/R/pkg/inst/worker/daemon.R
index b92e6be995..3a318b71ea 100644
--- a/R/pkg/inst/worker/daemon.R
+++ b/R/pkg/inst/worker/daemon.R
@@ -18,6 +18,7 @@
# Worker daemon
rLibDir <- Sys.getenv("SPARKR_RLIBDIR")
+connectionTimeout <- as.integer(Sys.getenv("SPARKR_BACKEND_CONNECTION_TIMEOUT", "6000"))
dirs <- strsplit(rLibDir, ",")[[1]]
script <- file.path(dirs[[1]], "SparkR", "worker", "worker.R")
@@ -26,7 +27,8 @@ script <- file.path(dirs[[1]], "SparkR", "worker", "worker.R")
suppressPackageStartupMessages(library(SparkR))
port <- as.integer(Sys.getenv("SPARKR_WORKER_PORT"))
-inputCon <- socketConnection(port = port, open = "rb", blocking = TRUE, timeout = 3600)
+inputCon <- socketConnection(
+ port = port, open = "rb", blocking = TRUE, timeout = connectionTimeout)
while (TRUE) {
ready <- socketSelect(list(inputCon))
diff --git a/R/pkg/inst/worker/worker.R b/R/pkg/inst/worker/worker.R
index cfe41ded20..03e7450147 100644
--- a/R/pkg/inst/worker/worker.R
+++ b/R/pkg/inst/worker/worker.R
@@ -90,6 +90,7 @@ bootTime <- currentTimeSecs()
bootElap <- elapsedSecs()
rLibDir <- Sys.getenv("SPARKR_RLIBDIR")
+connectionTimeout <- as.integer(Sys.getenv("SPARKR_BACKEND_CONNECTION_TIMEOUT", "6000"))
dirs <- strsplit(rLibDir, ",")[[1]]
# Set libPaths to include SparkR package as loadNamespace needs this
# TODO: Figure out if we can avoid this by not loading any objects that require
@@ -98,8 +99,10 @@ dirs <- strsplit(rLibDir, ",")[[1]]
suppressPackageStartupMessages(library(SparkR))
port <- as.integer(Sys.getenv("SPARKR_WORKER_PORT"))
-inputCon <- socketConnection(port = port, blocking = TRUE, open = "rb")
-outputCon <- socketConnection(port = port, blocking = TRUE, open = "wb")
+inputCon <- socketConnection(
+ port = port, blocking = TRUE, open = "rb", timeout = connectionTimeout)
+outputCon <- socketConnection(
+ port = port, blocking = TRUE, open = "wb", timeout = connectionTimeout)
# read the index of the current partition inside the RDD
partition <- SparkR:::readInt(inputCon)
diff --git a/core/src/main/scala/org/apache/spark/api/r/RBackend.scala b/core/src/main/scala/org/apache/spark/api/r/RBackend.scala
index 41d0a85ee3..550746c552 100644
--- a/core/src/main/scala/org/apache/spark/api/r/RBackend.scala
+++ b/core/src/main/scala/org/apache/spark/api/r/RBackend.scala
@@ -22,12 +22,13 @@ import java.net.{InetAddress, InetSocketAddress, ServerSocket}
import java.util.concurrent.TimeUnit
import io.netty.bootstrap.ServerBootstrap
-import io.netty.channel.{ChannelFuture, ChannelInitializer, EventLoopGroup}
+import io.netty.channel.{ChannelFuture, ChannelInitializer, ChannelOption, EventLoopGroup}
import io.netty.channel.nio.NioEventLoopGroup
import io.netty.channel.socket.SocketChannel
import io.netty.channel.socket.nio.NioServerSocketChannel
import io.netty.handler.codec.LengthFieldBasedFrameDecoder
import io.netty.handler.codec.bytes.{ByteArrayDecoder, ByteArrayEncoder}
+import io.netty.handler.timeout.ReadTimeoutHandler
import org.apache.spark.SparkConf
import org.apache.spark.internal.Logging
@@ -43,7 +44,10 @@ private[spark] class RBackend {
def init(): Int = {
val conf = new SparkConf()
- bossGroup = new NioEventLoopGroup(conf.getInt("spark.r.numRBackendThreads", 2))
+ val backendConnectionTimeout = conf.getInt(
+ "spark.r.backendConnectionTimeout", SparkRDefaults.DEFAULT_CONNECTION_TIMEOUT)
+ bossGroup = new NioEventLoopGroup(
+ conf.getInt("spark.r.numRBackendThreads", SparkRDefaults.DEFAULT_NUM_RBACKEND_THREADS))
val workerGroup = bossGroup
val handler = new RBackendHandler(this)
@@ -63,6 +67,7 @@ private[spark] class RBackend {
// initialBytesToStrip = 4, i.e. strip out the length field itself
new LengthFieldBasedFrameDecoder(Integer.MAX_VALUE, 0, 4, 0, 4))
.addLast("decoder", new ByteArrayDecoder())
+ .addLast("readTimeoutHandler", new ReadTimeoutHandler(backendConnectionTimeout))
.addLast("handler", handler)
}
})
@@ -110,6 +115,11 @@ private[spark] object RBackend extends Logging {
val boundPort = sparkRBackend.init()
val serverSocket = new ServerSocket(0, 1, InetAddress.getByName("localhost"))
val listenPort = serverSocket.getLocalPort()
+ // Connection timeout is set by socket client. To make it configurable we will pass the
+ // timeout value to client inside the temp file
+ val conf = new SparkConf()
+ val backendConnectionTimeout = conf.getInt(
+ "spark.r.backendConnectionTimeout", SparkRDefaults.DEFAULT_CONNECTION_TIMEOUT)
// tell the R process via temporary file
val path = args(0)
@@ -118,6 +128,7 @@ private[spark] object RBackend extends Logging {
dos.writeInt(boundPort)
dos.writeInt(listenPort)
SerDe.writeString(dos, RUtils.rPackages.getOrElse(""))
+ dos.writeInt(backendConnectionTimeout)
dos.close()
f.renameTo(new File(path))
diff --git a/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala b/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala
index 1422ef888f..9f5afa29d6 100644
--- a/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala
+++ b/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala
@@ -18,16 +18,19 @@
package org.apache.spark.api.r
import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream}
+import java.util.concurrent.TimeUnit
import scala.collection.mutable.HashMap
import scala.language.existentials
import io.netty.channel.{ChannelHandlerContext, SimpleChannelInboundHandler}
import io.netty.channel.ChannelHandler.Sharable
+import io.netty.handler.timeout.ReadTimeoutException
import org.apache.spark.api.r.SerDe._
import org.apache.spark.internal.Logging
-import org.apache.spark.util.Utils
+import org.apache.spark.SparkConf
+import org.apache.spark.util.{ThreadUtils, Utils}
/**
* Handler for RBackend
@@ -83,7 +86,29 @@ private[r] class RBackendHandler(server: RBackend)
writeString(dos, s"Error: unknown method $methodName")
}
} else {
+ // To avoid timeouts when reading results in SparkR driver, we will be regularly sending
+ // heartbeat responses. We use special code +1 to signal the client that backend is
+ // alive and it should continue blocking for result.
+ val execService = ThreadUtils.newDaemonSingleThreadScheduledExecutor("SparkRKeepAliveThread")
+ val pingRunner = new Runnable {
+ override def run(): Unit = {
+ val pingBaos = new ByteArrayOutputStream()
+ val pingDaos = new DataOutputStream(pingBaos)
+ writeInt(pingDaos, +1)
+ ctx.write(pingBaos.toByteArray)
+ }
+ }
+ val conf = new SparkConf()
+ val heartBeatInterval = conf.getInt(
+ "spark.r.heartBeatInterval", SparkRDefaults.DEFAULT_HEARTBEAT_INTERVAL)
+ val backendConnectionTimeout = conf.getInt(
+ "spark.r.backendConnectionTimeout", SparkRDefaults.DEFAULT_CONNECTION_TIMEOUT)
+ val interval = Math.min(heartBeatInterval, backendConnectionTimeout - 1)
+
+ execService.scheduleAtFixedRate(pingRunner, interval, interval, TimeUnit.SECONDS)
handleMethodCall(isStatic, objId, methodName, numArgs, dis, dos)
+ execService.shutdown()
+ execService.awaitTermination(1, TimeUnit.SECONDS)
}
val reply = bos.toByteArray
@@ -95,9 +120,15 @@ private[r] class RBackendHandler(server: RBackend)
}
override def exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable): Unit = {
- // Close the connection when an exception is raised.
- cause.printStackTrace()
- ctx.close()
+ cause match {
+ case timeout: ReadTimeoutException =>
+ // Do nothing. We don't want to timeout on read
+ logWarning("Ignoring read timeout in RBackendHandler")
+ case _ =>
+ // Close the connection when an exception is raised.
+ cause.printStackTrace()
+ ctx.close()
+ }
}
def handleMethodCall(
diff --git a/core/src/main/scala/org/apache/spark/api/r/RRunner.scala b/core/src/main/scala/org/apache/spark/api/r/RRunner.scala
index 496fdf851f..7ef64723d9 100644
--- a/core/src/main/scala/org/apache/spark/api/r/RRunner.scala
+++ b/core/src/main/scala/org/apache/spark/api/r/RRunner.scala
@@ -333,6 +333,8 @@ private[r] object RRunner {
var rCommand = sparkConf.get("spark.sparkr.r.command", "Rscript")
rCommand = sparkConf.get("spark.r.command", rCommand)
+ val rConnectionTimeout = sparkConf.getInt(
+ "spark.r.backendConnectionTimeout", SparkRDefaults.DEFAULT_CONNECTION_TIMEOUT)
val rOptions = "--vanilla"
val rLibDir = RUtils.sparkRPackagePath(isDriver = false)
val rExecScript = rLibDir(0) + "/SparkR/worker/" + script
@@ -344,6 +346,7 @@ private[r] object RRunner {
pb.environment().put("R_TESTS", "")
pb.environment().put("SPARKR_RLIBDIR", rLibDir.mkString(","))
pb.environment().put("SPARKR_WORKER_PORT", port.toString)
+ pb.environment().put("SPARKR_BACKEND_CONNECTION_TIMEOUT", rConnectionTimeout.toString)
pb.redirectErrorStream(true) // redirect stderr into stdout
val proc = pb.start()
val errThread = startStdoutThread(proc)
diff --git a/core/src/main/scala/org/apache/spark/api/r/SparkRDefaults.scala b/core/src/main/scala/org/apache/spark/api/r/SparkRDefaults.scala
new file mode 100644
index 0000000000..af67cbbce4
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/api/r/SparkRDefaults.scala
@@ -0,0 +1,30 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.api.r
+
+private[spark] object SparkRDefaults {
+
+ // Default value for spark.r.backendConnectionTimeout config
+ val DEFAULT_CONNECTION_TIMEOUT: Int = 6000
+
+ // Default value for spark.r.heartBeatInterval config
+ val DEFAULT_HEARTBEAT_INTERVAL: Int = 100
+
+ // Default value for spark.r.numRBackendThreads config
+ val DEFAULT_NUM_RBACKEND_THREADS = 2
+}
diff --git a/core/src/main/scala/org/apache/spark/deploy/RRunner.scala b/core/src/main/scala/org/apache/spark/deploy/RRunner.scala
index d0466830b2..6eb53a8252 100644
--- a/core/src/main/scala/org/apache/spark/deploy/RRunner.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/RRunner.scala
@@ -25,7 +25,7 @@ import scala.collection.JavaConverters._
import org.apache.hadoop.fs.Path
import org.apache.spark.{SparkException, SparkUserAppException}
-import org.apache.spark.api.r.{RBackend, RUtils}
+import org.apache.spark.api.r.{RBackend, RUtils, SparkRDefaults}
import org.apache.spark.util.RedirectThread
/**
@@ -51,6 +51,10 @@ object RRunner {
cmd
}
+ // Connection timeout set by R process on its connection to RBackend in seconds.
+ val backendConnectionTimeout = sys.props.getOrElse(
+ "spark.r.backendConnectionTimeout", SparkRDefaults.DEFAULT_CONNECTION_TIMEOUT.toString)
+
// Check if the file path exists.
// If not, change directory to current working directory for YARN cluster mode
val rF = new File(rFile)
@@ -81,6 +85,7 @@ object RRunner {
val builder = new ProcessBuilder((Seq(rCommand, rFileNormalized) ++ otherArgs).asJava)
val env = builder.environment()
env.put("EXISTING_SPARKR_BACKEND_PORT", sparkRBackendPort.toString)
+ env.put("SPARKR_BACKEND_CONNECTION_TIMEOUT", backendConnectionTimeout)
val rPackageDir = RUtils.sparkRPackagePath(isDriver = true)
// Put the R package directories into an env variable of comma-separated paths
env.put("SPARKR_PACKAGE_DIR", rPackageDir.mkString(","))
diff --git a/docs/configuration.md b/docs/configuration.md
index 6600cb6c0a..780fc94908 100644
--- a/docs/configuration.md
+++ b/docs/configuration.md
@@ -1890,6 +1890,21 @@ showDF(properties, numRows = 200, truncate = FALSE)
<code>spark.r.shell.command</code> is used for sparkR shell while <code>spark.r.driver.command</code> is used for running R script.
</td>
</tr>
+<tr>
+ <td><code>spark.r.backendConnectionTimeout</code></td>
+ <td>6000</td>
+ <td>
+ Connection timeout set by R process on its connection to RBackend in seconds.
+ </td>
+</tr>
+<tr>
+ <td><code>spark.r.heartBeatInterval</code></td>
+ <td>100</td>
+ <td>
+ Interval for heartbeats sents from SparkR backend to R process to prevent connection timeout.
+ </td>
+</tr>
+
</table>
#### Deploy