aboutsummaryrefslogtreecommitdiff
path: root/core
diff options
context:
space:
mode:
Diffstat (limited to 'core')
-rw-r--r--core/pom.xml51
-rw-r--r--core/src/main/scala/org/apache/spark/api/r/RBackend.scala145
-rw-r--r--core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala223
-rw-r--r--core/src/main/scala/org/apache/spark/api/r/RRDD.scala450
-rw-r--r--core/src/main/scala/org/apache/spark/api/r/SerDe.scala340
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/RRunner.scala92
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala73
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala8
8 files changed, 1364 insertions, 18 deletions
diff --git a/core/pom.xml b/core/pom.xml
index 6cd1965ec3..e80829b7a7 100644
--- a/core/pom.xml
+++ b/core/pom.xml
@@ -442,4 +442,55 @@
</resources>
</build>
+ <profiles>
+ <profile>
+ <id>Windows</id>
+ <activation>
+ <os>
+ <family>Windows</family>
+ </os>
+ </activation>
+ <properties>
+ <path.separator>\</path.separator>
+ <script.extension>.bat</script.extension>
+ </properties>
+ </profile>
+ <profile>
+ <id>unix</id>
+ <activation>
+ <os>
+ <family>unix</family>
+ </os>
+ </activation>
+ <properties>
+ <path.separator>/</path.separator>
+ <script.extension>.sh</script.extension>
+ </properties>
+ </profile>
+ <profile>
+ <id>sparkr</id>
+ <build>
+ <plugins>
+ <plugin>
+ <groupId>org.codehaus.mojo</groupId>
+ <artifactId>exec-maven-plugin</artifactId>
+ <version>1.3.2</version>
+ <executions>
+ <execution>
+ <id>sparkr-pkg</id>
+ <phase>compile</phase>
+ <goals>
+ <goal>exec</goal>
+ </goals>
+ </execution>
+ </executions>
+ <configuration>
+ <executable>..${path.separator}R${path.separator}install-dev${script.extension}</executable>
+ </configuration>
+ </plugin>
+ </plugins>
+ </build>
+ </profile>
+ </profiles>
+
</project>
diff --git a/core/src/main/scala/org/apache/spark/api/r/RBackend.scala b/core/src/main/scala/org/apache/spark/api/r/RBackend.scala
new file mode 100644
index 0000000000..3a2c94bd9d
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/api/r/RBackend.scala
@@ -0,0 +1,145 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.api.r
+
+import java.io.{DataOutputStream, File, FileOutputStream, IOException}
+import java.net.{InetSocketAddress, ServerSocket}
+import java.util.concurrent.TimeUnit
+
+import io.netty.bootstrap.ServerBootstrap
+import io.netty.channel.{ChannelFuture, ChannelInitializer, EventLoopGroup}
+import io.netty.channel.nio.NioEventLoopGroup
+import io.netty.channel.socket.SocketChannel
+import io.netty.channel.socket.nio.NioServerSocketChannel
+import io.netty.handler.codec.LengthFieldBasedFrameDecoder
+import io.netty.handler.codec.bytes.{ByteArrayDecoder, ByteArrayEncoder}
+
+import org.apache.spark.Logging
+
+/**
+ * Netty-based backend server that is used to communicate between R and Java.
+ */
+private[spark] class RBackend {
+
+ private[this] var channelFuture: ChannelFuture = null
+ private[this] var bootstrap: ServerBootstrap = null
+ private[this] var bossGroup: EventLoopGroup = null
+
+ def init(): Int = {
+ bossGroup = new NioEventLoopGroup(2)
+ val workerGroup = bossGroup
+ val handler = new RBackendHandler(this)
+
+ bootstrap = new ServerBootstrap()
+ .group(bossGroup, workerGroup)
+ .channel(classOf[NioServerSocketChannel])
+
+ bootstrap.childHandler(new ChannelInitializer[SocketChannel]() {
+ def initChannel(ch: SocketChannel): Unit = {
+ ch.pipeline()
+ .addLast("encoder", new ByteArrayEncoder())
+ .addLast("frameDecoder",
+ // maxFrameLength = 2G
+ // lengthFieldOffset = 0
+ // lengthFieldLength = 4
+ // lengthAdjustment = 0
+ // initialBytesToStrip = 4, i.e. strip out the length field itself
+ new LengthFieldBasedFrameDecoder(Integer.MAX_VALUE, 0, 4, 0, 4))
+ .addLast("decoder", new ByteArrayDecoder())
+ .addLast("handler", handler)
+ }
+ })
+
+ channelFuture = bootstrap.bind(new InetSocketAddress(0))
+ channelFuture.syncUninterruptibly()
+ channelFuture.channel().localAddress().asInstanceOf[InetSocketAddress].getPort()
+ }
+
+ def run(): Unit = {
+ channelFuture.channel.closeFuture().syncUninterruptibly()
+ }
+
+ def close(): Unit = {
+ if (channelFuture != null) {
+ // close is a local operation and should finish within milliseconds; timeout just to be safe
+ channelFuture.channel().close().awaitUninterruptibly(10, TimeUnit.SECONDS)
+ channelFuture = null
+ }
+ if (bootstrap != null && bootstrap.group() != null) {
+ bootstrap.group().shutdownGracefully()
+ }
+ if (bootstrap != null && bootstrap.childGroup() != null) {
+ bootstrap.childGroup().shutdownGracefully()
+ }
+ bootstrap = null
+ }
+
+}
+
+private[spark] object RBackend extends Logging {
+ def main(args: Array[String]): Unit = {
+ if (args.length < 1) {
+ System.err.println("Usage: RBackend <tempFilePath>")
+ System.exit(-1)
+ }
+ val sparkRBackend = new RBackend()
+ try {
+ // bind to random port
+ val boundPort = sparkRBackend.init()
+ val serverSocket = new ServerSocket(0, 1)
+ val listenPort = serverSocket.getLocalPort()
+
+ // tell the R process via temporary file
+ val path = args(0)
+ val f = new File(path + ".tmp")
+ val dos = new DataOutputStream(new FileOutputStream(f))
+ dos.writeInt(boundPort)
+ dos.writeInt(listenPort)
+ dos.close()
+ f.renameTo(new File(path))
+
+ // wait for the end of stdin, then exit
+ new Thread("wait for socket to close") {
+ setDaemon(true)
+ override def run(): Unit = {
+ // any un-catched exception will also shutdown JVM
+ val buf = new Array[Byte](1024)
+ // shutdown JVM if R does not connect back in 10 seconds
+ serverSocket.setSoTimeout(10000)
+ try {
+ val inSocket = serverSocket.accept()
+ serverSocket.close()
+ // wait for the end of socket, closed if R process die
+ inSocket.getInputStream().read(buf)
+ } finally {
+ sparkRBackend.close()
+ System.exit(0)
+ }
+ }
+ }.start()
+
+ sparkRBackend.run()
+ } catch {
+ case e: IOException =>
+ logError("Server shutting down: failed with exception ", e)
+ sparkRBackend.close()
+ System.exit(1)
+ }
+ System.exit(0)
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala b/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala
new file mode 100644
index 0000000000..0075d96371
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala
@@ -0,0 +1,223 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.api.r
+
+import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream}
+
+import scala.collection.mutable.HashMap
+
+import io.netty.channel.ChannelHandler.Sharable
+import io.netty.channel.{ChannelHandlerContext, SimpleChannelInboundHandler}
+
+import org.apache.spark.Logging
+import org.apache.spark.api.r.SerDe._
+
+/**
+ * Handler for RBackend
+ * TODO: This is marked as sharable to get a handle to RBackend. Is it safe to re-use
+ * this across connections ?
+ */
+@Sharable
+private[r] class RBackendHandler(server: RBackend)
+ extends SimpleChannelInboundHandler[Array[Byte]] with Logging {
+
+ override def channelRead0(ctx: ChannelHandlerContext, msg: Array[Byte]): Unit = {
+ val bis = new ByteArrayInputStream(msg)
+ val dis = new DataInputStream(bis)
+
+ val bos = new ByteArrayOutputStream()
+ val dos = new DataOutputStream(bos)
+
+ // First bit is isStatic
+ val isStatic = readBoolean(dis)
+ val objId = readString(dis)
+ val methodName = readString(dis)
+ val numArgs = readInt(dis)
+
+ if (objId == "SparkRHandler") {
+ methodName match {
+ case "stopBackend" =>
+ writeInt(dos, 0)
+ writeType(dos, "void")
+ server.close()
+ case "rm" =>
+ try {
+ val t = readObjectType(dis)
+ assert(t == 'c')
+ val objToRemove = readString(dis)
+ JVMObjectTracker.remove(objToRemove)
+ writeInt(dos, 0)
+ writeObject(dos, null)
+ } catch {
+ case e: Exception =>
+ logError(s"Removing $objId failed", e)
+ writeInt(dos, -1)
+ }
+ case _ => dos.writeInt(-1)
+ }
+ } else {
+ handleMethodCall(isStatic, objId, methodName, numArgs, dis, dos)
+ }
+
+ val reply = bos.toByteArray
+ ctx.write(reply)
+ }
+
+ override def channelReadComplete(ctx: ChannelHandlerContext): Unit = {
+ ctx.flush()
+ }
+
+ override def exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable): Unit = {
+ // Close the connection when an exception is raised.
+ cause.printStackTrace()
+ ctx.close()
+ }
+
+ def handleMethodCall(
+ isStatic: Boolean,
+ objId: String,
+ methodName: String,
+ numArgs: Int,
+ dis: DataInputStream,
+ dos: DataOutputStream): Unit = {
+ var obj: Object = null
+ try {
+ val cls = if (isStatic) {
+ Class.forName(objId)
+ } else {
+ JVMObjectTracker.get(objId) match {
+ case None => throw new IllegalArgumentException("Object not found " + objId)
+ case Some(o) =>
+ obj = o
+ o.getClass
+ }
+ }
+
+ val args = readArgs(numArgs, dis)
+
+ val methods = cls.getMethods
+ val selectedMethods = methods.filter(m => m.getName == methodName)
+ if (selectedMethods.length > 0) {
+ val methods = selectedMethods.filter { x =>
+ matchMethod(numArgs, args, x.getParameterTypes)
+ }
+ if (methods.isEmpty) {
+ logWarning(s"cannot find matching method ${cls}.$methodName. "
+ + s"Candidates are:")
+ selectedMethods.foreach { method =>
+ logWarning(s"$methodName(${method.getParameterTypes.mkString(",")})")
+ }
+ throw new Exception(s"No matched method found for $cls.$methodName")
+ }
+ val ret = methods.head.invoke(obj, args:_*)
+
+ // Write status bit
+ writeInt(dos, 0)
+ writeObject(dos, ret.asInstanceOf[AnyRef])
+ } else if (methodName == "<init>") {
+ // methodName should be "<init>" for constructor
+ val ctor = cls.getConstructors.filter { x =>
+ matchMethod(numArgs, args, x.getParameterTypes)
+ }.head
+
+ val obj = ctor.newInstance(args:_*)
+
+ writeInt(dos, 0)
+ writeObject(dos, obj.asInstanceOf[AnyRef])
+ } else {
+ throw new IllegalArgumentException("invalid method " + methodName + " for object " + objId)
+ }
+ } catch {
+ case e: Exception =>
+ logError(s"$methodName on $objId failed", e)
+ writeInt(dos, -1)
+ }
+ }
+
+ // Read a number of arguments from the data input stream
+ def readArgs(numArgs: Int, dis: DataInputStream): Array[java.lang.Object] = {
+ (0 until numArgs).map { arg =>
+ readObject(dis)
+ }.toArray
+ }
+
+ // Checks if the arguments passed in args matches the parameter types.
+ // NOTE: Currently we do exact match. We may add type conversions later.
+ def matchMethod(
+ numArgs: Int,
+ args: Array[java.lang.Object],
+ parameterTypes: Array[Class[_]]): Boolean = {
+ if (parameterTypes.length != numArgs) {
+ return false
+ }
+
+ for (i <- 0 to numArgs - 1) {
+ val parameterType = parameterTypes(i)
+ var parameterWrapperType = parameterType
+
+ // Convert native parameters to Object types as args is Array[Object] here
+ if (parameterType.isPrimitive) {
+ parameterWrapperType = parameterType match {
+ case java.lang.Integer.TYPE => classOf[java.lang.Integer]
+ case java.lang.Double.TYPE => classOf[java.lang.Double]
+ case java.lang.Boolean.TYPE => classOf[java.lang.Boolean]
+ case _ => parameterType
+ }
+ }
+ if (!parameterWrapperType.isInstance(args(i))) {
+ return false
+ }
+ }
+ true
+ }
+}
+
+/**
+ * Helper singleton that tracks JVM objects returned to R.
+ * This is useful for referencing these objects in RPC calls.
+ */
+private[r] object JVMObjectTracker {
+
+ // TODO: This map should be thread-safe if we want to support multiple
+ // connections at the same time
+ private[this] val objMap = new HashMap[String, Object]
+
+ // TODO: We support only one connection now, so an integer is fine.
+ // Investigate using use atomic integer in the future.
+ private[this] var objCounter: Int = 0
+
+ def getObject(id: String): Object = {
+ objMap(id)
+ }
+
+ def get(id: String): Option[Object] = {
+ objMap.get(id)
+ }
+
+ def put(obj: Object): String = {
+ val objId = objCounter.toString
+ objCounter = objCounter + 1
+ objMap.put(objId, obj)
+ objId
+ }
+
+ def remove(id: String): Option[Object] = {
+ objMap.remove(id)
+ }
+
+}
diff --git a/core/src/main/scala/org/apache/spark/api/r/RRDD.scala b/core/src/main/scala/org/apache/spark/api/r/RRDD.scala
new file mode 100644
index 0000000000..5fa4d483b8
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/api/r/RRDD.scala
@@ -0,0 +1,450 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.api.r
+
+import java.io._
+import java.net.ServerSocket
+import java.util.{Map => JMap}
+
+import scala.collection.JavaConversions._
+import scala.io.Source
+import scala.reflect.ClassTag
+import scala.util.Try
+
+import org.apache.spark._
+import org.apache.spark.api.java.{JavaPairRDD, JavaRDD, JavaSparkContext}
+import org.apache.spark.broadcast.Broadcast
+import org.apache.spark.rdd.RDD
+import org.apache.spark.util.Utils
+
+private abstract class BaseRRDD[T: ClassTag, U: ClassTag](
+ parent: RDD[T],
+ numPartitions: Int,
+ func: Array[Byte],
+ deserializer: String,
+ serializer: String,
+ packageNames: Array[Byte],
+ rLibDir: String,
+ broadcastVars: Array[Broadcast[Object]])
+ extends RDD[U](parent) with Logging {
+ override def getPartitions: Array[Partition] = parent.partitions
+
+ override def compute(partition: Partition, context: TaskContext): Iterator[U] = {
+
+ // The parent may be also an RRDD, so we should launch it first.
+ val parentIterator = firstParent[T].iterator(partition, context)
+
+ // we expect two connections
+ val serverSocket = new ServerSocket(0, 2)
+ val listenPort = serverSocket.getLocalPort()
+
+ // The stdout/stderr is shared by multiple tasks, because we use one daemon
+ // to launch child process as worker.
+ val errThread = RRDD.createRWorker(rLibDir, listenPort)
+
+ // We use two sockets to separate input and output, then it's easy to manage
+ // the lifecycle of them to avoid deadlock.
+ // TODO: optimize it to use one socket
+
+ // the socket used to send out the input of task
+ serverSocket.setSoTimeout(10000)
+ val inSocket = serverSocket.accept()
+ startStdinThread(inSocket.getOutputStream(), parentIterator, partition.index)
+
+ // the socket used to receive the output of task
+ val outSocket = serverSocket.accept()
+ val inputStream = new BufferedInputStream(outSocket.getInputStream)
+ val dataStream = openDataStream(inputStream)
+ serverSocket.close()
+
+ try {
+
+ return new Iterator[U] {
+ def next(): U = {
+ val obj = _nextObj
+ if (hasNext) {
+ _nextObj = read()
+ }
+ obj
+ }
+
+ var _nextObj = read()
+
+ def hasNext(): Boolean = {
+ val hasMore = (_nextObj != null)
+ if (!hasMore) {
+ dataStream.close()
+ }
+ hasMore
+ }
+ }
+ } catch {
+ case e: Exception =>
+ throw new SparkException("R computation failed with\n " + errThread.getLines())
+ }
+ }
+
+ /**
+ * Start a thread to write RDD data to the R process.
+ */
+ private def startStdinThread[T](
+ output: OutputStream,
+ iter: Iterator[T],
+ partition: Int): Unit = {
+
+ val env = SparkEnv.get
+ val bufferSize = System.getProperty("spark.buffer.size", "65536").toInt
+ val stream = new BufferedOutputStream(output, bufferSize)
+
+ new Thread("writer for R") {
+ override def run(): Unit = {
+ try {
+ SparkEnv.set(env)
+ val dataOut = new DataOutputStream(stream)
+ dataOut.writeInt(partition)
+
+ SerDe.writeString(dataOut, deserializer)
+ SerDe.writeString(dataOut, serializer)
+
+ dataOut.writeInt(packageNames.length)
+ dataOut.write(packageNames)
+
+ dataOut.writeInt(func.length)
+ dataOut.write(func)
+
+ dataOut.writeInt(broadcastVars.length)
+ broadcastVars.foreach { broadcast =>
+ // TODO(shivaram): Read a Long in R to avoid this cast
+ dataOut.writeInt(broadcast.id.toInt)
+ // TODO: Pass a byte array from R to avoid this cast ?
+ val broadcastByteArr = broadcast.value.asInstanceOf[Array[Byte]]
+ dataOut.writeInt(broadcastByteArr.length)
+ dataOut.write(broadcastByteArr)
+ }
+
+ dataOut.writeInt(numPartitions)
+
+ if (!iter.hasNext) {
+ dataOut.writeInt(0)
+ } else {
+ dataOut.writeInt(1)
+ }
+
+ val printOut = new PrintStream(stream)
+
+ def writeElem(elem: Any): Unit = {
+ if (deserializer == SerializationFormats.BYTE) {
+ val elemArr = elem.asInstanceOf[Array[Byte]]
+ dataOut.writeInt(elemArr.length)
+ dataOut.write(elemArr)
+ } else if (deserializer == SerializationFormats.ROW) {
+ dataOut.write(elem.asInstanceOf[Array[Byte]])
+ } else if (deserializer == SerializationFormats.STRING) {
+ printOut.println(elem)
+ }
+ }
+
+ for (elem <- iter) {
+ elem match {
+ case (key, value) =>
+ writeElem(key)
+ writeElem(value)
+ case _ =>
+ writeElem(elem)
+ }
+ }
+ stream.flush()
+ } catch {
+ // TODO: We should propogate this error to the task thread
+ case e: Exception =>
+ logError("R Writer thread got an exception", e)
+ } finally {
+ Try(output.close())
+ }
+ }
+ }.start()
+ }
+
+ protected def openDataStream(input: InputStream): Closeable
+
+ protected def read(): U
+}
+
+/**
+ * Form an RDD[(Int, Array[Byte])] from key-value pairs returned from R.
+ * This is used by SparkR's shuffle operations.
+ */
+private class PairwiseRRDD[T: ClassTag](
+ parent: RDD[T],
+ numPartitions: Int,
+ hashFunc: Array[Byte],
+ deserializer: String,
+ packageNames: Array[Byte],
+ rLibDir: String,
+ broadcastVars: Array[Object])
+ extends BaseRRDD[T, (Int, Array[Byte])](
+ parent, numPartitions, hashFunc, deserializer,
+ SerializationFormats.BYTE, packageNames, rLibDir,
+ broadcastVars.map(x => x.asInstanceOf[Broadcast[Object]])) {
+
+ private var dataStream: DataInputStream = _
+
+ override protected def openDataStream(input: InputStream): Closeable = {
+ dataStream = new DataInputStream(input)
+ dataStream
+ }
+
+ override protected def read(): (Int, Array[Byte]) = {
+ try {
+ val length = dataStream.readInt()
+
+ length match {
+ case length if length == 2 =>
+ val hashedKey = dataStream.readInt()
+ val contentPairsLength = dataStream.readInt()
+ val contentPairs = new Array[Byte](contentPairsLength)
+ dataStream.readFully(contentPairs)
+ (hashedKey, contentPairs)
+ case _ => null // End of input
+ }
+ } catch {
+ case eof: EOFException => {
+ throw new SparkException("R worker exited unexpectedly (crashed)", eof)
+ }
+ }
+ }
+
+ lazy val asJavaPairRDD : JavaPairRDD[Int, Array[Byte]] = JavaPairRDD.fromRDD(this)
+}
+
+/**
+ * An RDD that stores serialized R objects as Array[Byte].
+ */
+private class RRDD[T: ClassTag](
+ parent: RDD[T],
+ func: Array[Byte],
+ deserializer: String,
+ serializer: String,
+ packageNames: Array[Byte],
+ rLibDir: String,
+ broadcastVars: Array[Object])
+ extends BaseRRDD[T, Array[Byte]](
+ parent, -1, func, deserializer, serializer, packageNames, rLibDir,
+ broadcastVars.map(x => x.asInstanceOf[Broadcast[Object]])) {
+
+ private var dataStream: DataInputStream = _
+
+ override protected def openDataStream(input: InputStream): Closeable = {
+ dataStream = new DataInputStream(input)
+ dataStream
+ }
+
+ override protected def read(): Array[Byte] = {
+ try {
+ val length = dataStream.readInt()
+
+ length match {
+ case length if length > 0 =>
+ val obj = new Array[Byte](length)
+ dataStream.readFully(obj, 0, length)
+ obj
+ case _ => null
+ }
+ } catch {
+ case eof: EOFException => {
+ throw new SparkException("R worker exited unexpectedly (crashed)", eof)
+ }
+ }
+ }
+
+ lazy val asJavaRDD : JavaRDD[Array[Byte]] = JavaRDD.fromRDD(this)
+}
+
+/**
+ * An RDD that stores R objects as Array[String].
+ */
+private class StringRRDD[T: ClassTag](
+ parent: RDD[T],
+ func: Array[Byte],
+ deserializer: String,
+ packageNames: Array[Byte],
+ rLibDir: String,
+ broadcastVars: Array[Object])
+ extends BaseRRDD[T, String](
+ parent, -1, func, deserializer, SerializationFormats.STRING, packageNames, rLibDir,
+ broadcastVars.map(x => x.asInstanceOf[Broadcast[Object]])) {
+
+ private var dataStream: BufferedReader = _
+
+ override protected def openDataStream(input: InputStream): Closeable = {
+ dataStream = new BufferedReader(new InputStreamReader(input))
+ dataStream
+ }
+
+ override protected def read(): String = {
+ try {
+ dataStream.readLine()
+ } catch {
+ case e: IOException => {
+ throw new SparkException("R worker exited unexpectedly (crashed)", e)
+ }
+ }
+ }
+
+ lazy val asJavaRDD : JavaRDD[String] = JavaRDD.fromRDD(this)
+}
+
+private[r] class BufferedStreamThread(
+ in: InputStream,
+ name: String,
+ errBufferSize: Int) extends Thread(name) with Logging {
+ val lines = new Array[String](errBufferSize)
+ var lineIdx = 0
+ override def run() {
+ for (line <- Source.fromInputStream(in).getLines) {
+ synchronized {
+ lines(lineIdx) = line
+ lineIdx = (lineIdx + 1) % errBufferSize
+ }
+ logInfo(line)
+ }
+ }
+
+ def getLines(): String = synchronized {
+ (0 until errBufferSize).filter { x =>
+ lines((x + lineIdx) % errBufferSize) != null
+ }.map { x =>
+ lines((x + lineIdx) % errBufferSize)
+ }.mkString("\n")
+ }
+}
+
+private[r] object RRDD {
+ // Because forking processes from Java is expensive, we prefer to launch
+ // a single R daemon (daemon.R) and tell it to fork new workers for our tasks.
+ // This daemon currently only works on UNIX-based systems now, so we should
+ // also fall back to launching workers (worker.R) directly.
+ private[this] var errThread: BufferedStreamThread = _
+ private[this] var daemonChannel: DataOutputStream = _
+
+ def createSparkContext(
+ master: String,
+ appName: String,
+ sparkHome: String,
+ jars: Array[String],
+ sparkEnvirMap: JMap[Object, Object],
+ sparkExecutorEnvMap: JMap[Object, Object]): JavaSparkContext = {
+
+ val sparkConf = new SparkConf().setAppName(appName)
+ .setSparkHome(sparkHome)
+ .setJars(jars)
+
+ // Override `master` if we have a user-specified value
+ if (master != "") {
+ sparkConf.setMaster(master)
+ } else {
+ // If conf has no master set it to "local" to maintain
+ // backwards compatibility
+ sparkConf.setIfMissing("spark.master", "local")
+ }
+
+ for ((name, value) <- sparkEnvirMap) {
+ sparkConf.set(name.asInstanceOf[String], value.asInstanceOf[String])
+ }
+ for ((name, value) <- sparkExecutorEnvMap) {
+ sparkConf.setExecutorEnv(name.asInstanceOf[String], value.asInstanceOf[String])
+ }
+
+ new JavaSparkContext(sparkConf)
+ }
+
+ /**
+ * Start a thread to print the process's stderr to ours
+ */
+ private def startStdoutThread(proc: Process): BufferedStreamThread = {
+ val BUFFER_SIZE = 100
+ val thread = new BufferedStreamThread(proc.getInputStream, "stdout reader for R", BUFFER_SIZE)
+ thread.setDaemon(true)
+ thread.start()
+ thread
+ }
+
+ private def createRProcess(rLibDir: String, port: Int, script: String): BufferedStreamThread = {
+ val rCommand = "Rscript"
+ val rOptions = "--vanilla"
+ val rExecScript = rLibDir + "/SparkR/worker/" + script
+ val pb = new ProcessBuilder(List(rCommand, rOptions, rExecScript))
+ // Unset the R_TESTS environment variable for workers.
+ // This is set by R CMD check as startup.Rs
+ // (http://svn.r-project.org/R/trunk/src/library/tools/R/testing.R)
+ // and confuses worker script which tries to load a non-existent file
+ pb.environment().put("R_TESTS", "")
+ pb.environment().put("SPARKR_RLIBDIR", rLibDir)
+ pb.environment().put("SPARKR_WORKER_PORT", port.toString)
+ pb.redirectErrorStream(true) // redirect stderr into stdout
+ val proc = pb.start()
+ val errThread = startStdoutThread(proc)
+ errThread
+ }
+
+ /**
+ * ProcessBuilder used to launch worker R processes.
+ */
+ def createRWorker(rLibDir: String, port: Int): BufferedStreamThread = {
+ val useDaemon = SparkEnv.get.conf.getBoolean("spark.sparkr.use.daemon", true)
+ if (!Utils.isWindows && useDaemon) {
+ synchronized {
+ if (daemonChannel == null) {
+ // we expect one connections
+ val serverSocket = new ServerSocket(0, 1)
+ val daemonPort = serverSocket.getLocalPort
+ errThread = createRProcess(rLibDir, daemonPort, "daemon.R")
+ // the socket used to send out the input of task
+ serverSocket.setSoTimeout(10000)
+ val sock = serverSocket.accept()
+ daemonChannel = new DataOutputStream(new BufferedOutputStream(sock.getOutputStream))
+ serverSocket.close()
+ }
+ try {
+ daemonChannel.writeInt(port)
+ daemonChannel.flush()
+ } catch {
+ case e: IOException =>
+ // daemon process died
+ daemonChannel.close()
+ daemonChannel = null
+ errThread = null
+ // fail the current task, retry by scheduler
+ throw e
+ }
+ errThread
+ }
+ } else {
+ createRProcess(rLibDir, port, "worker.R")
+ }
+ }
+
+ /**
+ * Create an RRDD given a sequence of byte arrays. Used to create RRDD when `parallelize` is
+ * called from R.
+ */
+ def createRDDFromArray(jsc: JavaSparkContext, arr: Array[Array[Byte]]): JavaRDD[Array[Byte]] = {
+ JavaRDD.fromRDD(jsc.sc.parallelize(arr, arr.length))
+ }
+
+}
diff --git a/core/src/main/scala/org/apache/spark/api/r/SerDe.scala b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala
new file mode 100644
index 0000000000..ccb2a371f4
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala
@@ -0,0 +1,340 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.api.r
+
+import java.io.{DataInputStream, DataOutputStream}
+import java.sql.{Date, Time}
+
+import scala.collection.JavaConversions._
+
+/**
+ * Utility functions to serialize, deserialize objects to / from R
+ */
+private[spark] object SerDe {
+
+ // Type mapping from R to Java
+ //
+ // NULL -> void
+ // integer -> Int
+ // character -> String
+ // logical -> Boolean
+ // double, numeric -> Double
+ // raw -> Array[Byte]
+ // Date -> Date
+ // POSIXlt/POSIXct -> Time
+ //
+ // list[T] -> Array[T], where T is one of above mentioned types
+ // environment -> Map[String, T], where T is a native type
+ // jobj -> Object, where jobj is an object created in the backend
+
+ def readObjectType(dis: DataInputStream): Char = {
+ dis.readByte().toChar
+ }
+
+ def readObject(dis: DataInputStream): Object = {
+ val dataType = readObjectType(dis)
+ readTypedObject(dis, dataType)
+ }
+
+ def readTypedObject(
+ dis: DataInputStream,
+ dataType: Char): Object = {
+ dataType match {
+ case 'n' => null
+ case 'i' => new java.lang.Integer(readInt(dis))
+ case 'd' => new java.lang.Double(readDouble(dis))
+ case 'b' => new java.lang.Boolean(readBoolean(dis))
+ case 'c' => readString(dis)
+ case 'e' => readMap(dis)
+ case 'r' => readBytes(dis)
+ case 'l' => readList(dis)
+ case 'D' => readDate(dis)
+ case 't' => readTime(dis)
+ case 'j' => JVMObjectTracker.getObject(readString(dis))
+ case _ => throw new IllegalArgumentException(s"Invalid type $dataType")
+ }
+ }
+
+ def readBytes(in: DataInputStream): Array[Byte] = {
+ val len = readInt(in)
+ val out = new Array[Byte](len)
+ val bytesRead = in.readFully(out)
+ out
+ }
+
+ def readInt(in: DataInputStream): Int = {
+ in.readInt()
+ }
+
+ def readDouble(in: DataInputStream): Double = {
+ in.readDouble()
+ }
+
+ def readString(in: DataInputStream): String = {
+ val len = in.readInt()
+ val asciiBytes = new Array[Byte](len)
+ in.readFully(asciiBytes)
+ assert(asciiBytes(len - 1) == 0)
+ val str = new String(asciiBytes.dropRight(1).map(_.toChar))
+ str
+ }
+
+ def readBoolean(in: DataInputStream): Boolean = {
+ val intVal = in.readInt()
+ if (intVal == 0) false else true
+ }
+
+ def readDate(in: DataInputStream): Date = {
+ Date.valueOf(readString(in))
+ }
+
+ def readTime(in: DataInputStream): Time = {
+ val t = in.readDouble()
+ new Time((t * 1000L).toLong)
+ }
+
+ def readBytesArr(in: DataInputStream): Array[Array[Byte]] = {
+ val len = readInt(in)
+ (0 until len).map(_ => readBytes(in)).toArray
+ }
+
+ def readIntArr(in: DataInputStream): Array[Int] = {
+ val len = readInt(in)
+ (0 until len).map(_ => readInt(in)).toArray
+ }
+
+ def readDoubleArr(in: DataInputStream): Array[Double] = {
+ val len = readInt(in)
+ (0 until len).map(_ => readDouble(in)).toArray
+ }
+
+ def readBooleanArr(in: DataInputStream): Array[Boolean] = {
+ val len = readInt(in)
+ (0 until len).map(_ => readBoolean(in)).toArray
+ }
+
+ def readStringArr(in: DataInputStream): Array[String] = {
+ val len = readInt(in)
+ (0 until len).map(_ => readString(in)).toArray
+ }
+
+ def readList(dis: DataInputStream): Array[_] = {
+ val arrType = readObjectType(dis)
+ arrType match {
+ case 'i' => readIntArr(dis)
+ case 'c' => readStringArr(dis)
+ case 'd' => readDoubleArr(dis)
+ case 'b' => readBooleanArr(dis)
+ case 'j' => readStringArr(dis).map(x => JVMObjectTracker.getObject(x))
+ case 'r' => readBytesArr(dis)
+ case _ => throw new IllegalArgumentException(s"Invalid array type $arrType")
+ }
+ }
+
+ def readMap(in: DataInputStream): java.util.Map[Object, Object] = {
+ val len = readInt(in)
+ if (len > 0) {
+ val keysType = readObjectType(in)
+ val keysLen = readInt(in)
+ val keys = (0 until keysLen).map(_ => readTypedObject(in, keysType))
+
+ val valuesType = readObjectType(in)
+ val valuesLen = readInt(in)
+ val values = (0 until valuesLen).map(_ => readTypedObject(in, valuesType))
+ mapAsJavaMap(keys.zip(values).toMap)
+ } else {
+ new java.util.HashMap[Object, Object]()
+ }
+ }
+
+ // Methods to write out data from Java to R
+ //
+ // Type mapping from Java to R
+ //
+ // void -> NULL
+ // Int -> integer
+ // String -> character
+ // Boolean -> logical
+ // Double -> double
+ // Long -> double
+ // Array[Byte] -> raw
+ // Date -> Date
+ // Time -> POSIXct
+ //
+ // Array[T] -> list()
+ // Object -> jobj
+
+ def writeType(dos: DataOutputStream, typeStr: String): Unit = {
+ typeStr match {
+ case "void" => dos.writeByte('n')
+ case "character" => dos.writeByte('c')
+ case "double" => dos.writeByte('d')
+ case "integer" => dos.writeByte('i')
+ case "logical" => dos.writeByte('b')
+ case "date" => dos.writeByte('D')
+ case "time" => dos.writeByte('t')
+ case "raw" => dos.writeByte('r')
+ case "list" => dos.writeByte('l')
+ case "jobj" => dos.writeByte('j')
+ case _ => throw new IllegalArgumentException(s"Invalid type $typeStr")
+ }
+ }
+
+ def writeObject(dos: DataOutputStream, value: Object): Unit = {
+ if (value == null) {
+ writeType(dos, "void")
+ } else {
+ value.getClass.getName match {
+ case "java.lang.String" =>
+ writeType(dos, "character")
+ writeString(dos, value.asInstanceOf[String])
+ case "long" | "java.lang.Long" =>
+ writeType(dos, "double")
+ writeDouble(dos, value.asInstanceOf[Long].toDouble)
+ case "double" | "java.lang.Double" =>
+ writeType(dos, "double")
+ writeDouble(dos, value.asInstanceOf[Double])
+ case "int" | "java.lang.Integer" =>
+ writeType(dos, "integer")
+ writeInt(dos, value.asInstanceOf[Int])
+ case "boolean" | "java.lang.Boolean" =>
+ writeType(dos, "logical")
+ writeBoolean(dos, value.asInstanceOf[Boolean])
+ case "java.sql.Date" =>
+ writeType(dos, "date")
+ writeDate(dos, value.asInstanceOf[Date])
+ case "java.sql.Time" =>
+ writeType(dos, "time")
+ writeTime(dos, value.asInstanceOf[Time])
+ case "[B" =>
+ writeType(dos, "raw")
+ writeBytes(dos, value.asInstanceOf[Array[Byte]])
+ // TODO: Types not handled right now include
+ // byte, char, short, float
+
+ // Handle arrays
+ case "[Ljava.lang.String;" =>
+ writeType(dos, "list")
+ writeStringArr(dos, value.asInstanceOf[Array[String]])
+ case "[I" =>
+ writeType(dos, "list")
+ writeIntArr(dos, value.asInstanceOf[Array[Int]])
+ case "[J" =>
+ writeType(dos, "list")
+ writeDoubleArr(dos, value.asInstanceOf[Array[Long]].map(_.toDouble))
+ case "[D" =>
+ writeType(dos, "list")
+ writeDoubleArr(dos, value.asInstanceOf[Array[Double]])
+ case "[Z" =>
+ writeType(dos, "list")
+ writeBooleanArr(dos, value.asInstanceOf[Array[Boolean]])
+ case "[[B" =>
+ writeType(dos, "list")
+ writeBytesArr(dos, value.asInstanceOf[Array[Array[Byte]]])
+ case otherName =>
+ // Handle array of objects
+ if (otherName.startsWith("[L")) {
+ val objArr = value.asInstanceOf[Array[Object]]
+ writeType(dos, "list")
+ writeType(dos, "jobj")
+ dos.writeInt(objArr.length)
+ objArr.foreach(o => writeJObj(dos, o))
+ } else {
+ writeType(dos, "jobj")
+ writeJObj(dos, value)
+ }
+ }
+ }
+ }
+
+ def writeInt(out: DataOutputStream, value: Int): Unit = {
+ out.writeInt(value)
+ }
+
+ def writeDouble(out: DataOutputStream, value: Double): Unit = {
+ out.writeDouble(value)
+ }
+
+ def writeBoolean(out: DataOutputStream, value: Boolean): Unit = {
+ val intValue = if (value) 1 else 0
+ out.writeInt(intValue)
+ }
+
+ def writeDate(out: DataOutputStream, value: Date): Unit = {
+ writeString(out, value.toString)
+ }
+
+ def writeTime(out: DataOutputStream, value: Time): Unit = {
+ out.writeDouble(value.getTime.toDouble / 1000.0)
+ }
+
+
+ // NOTE: Only works for ASCII right now
+ def writeString(out: DataOutputStream, value: String): Unit = {
+ val len = value.length
+ out.writeInt(len + 1) // For the \0
+ out.writeBytes(value)
+ out.writeByte(0)
+ }
+
+ def writeBytes(out: DataOutputStream, value: Array[Byte]): Unit = {
+ out.writeInt(value.length)
+ out.write(value)
+ }
+
+ def writeJObj(out: DataOutputStream, value: Object): Unit = {
+ val objId = JVMObjectTracker.put(value)
+ writeString(out, objId)
+ }
+
+ def writeIntArr(out: DataOutputStream, value: Array[Int]): Unit = {
+ writeType(out, "integer")
+ out.writeInt(value.length)
+ value.foreach(v => out.writeInt(v))
+ }
+
+ def writeDoubleArr(out: DataOutputStream, value: Array[Double]): Unit = {
+ writeType(out, "double")
+ out.writeInt(value.length)
+ value.foreach(v => out.writeDouble(v))
+ }
+
+ def writeBooleanArr(out: DataOutputStream, value: Array[Boolean]): Unit = {
+ writeType(out, "logical")
+ out.writeInt(value.length)
+ value.foreach(v => writeBoolean(out, v))
+ }
+
+ def writeStringArr(out: DataOutputStream, value: Array[String]): Unit = {
+ writeType(out, "character")
+ out.writeInt(value.length)
+ value.foreach(v => writeString(out, v))
+ }
+
+ def writeBytesArr(out: DataOutputStream, value: Array[Array[Byte]]): Unit = {
+ writeType(out, "raw")
+ out.writeInt(value.length)
+ value.foreach(v => writeBytes(out, v))
+ }
+}
+
+private[r] object SerializationFormats {
+ val BYTE = "byte"
+ val STRING = "string"
+ val ROW = "row"
+}
diff --git a/core/src/main/scala/org/apache/spark/deploy/RRunner.scala b/core/src/main/scala/org/apache/spark/deploy/RRunner.scala
new file mode 100644
index 0000000000..e99779f299
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/deploy/RRunner.scala
@@ -0,0 +1,92 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy
+
+import java.io._
+import java.util.concurrent.{Semaphore, TimeUnit}
+
+import scala.collection.JavaConversions._
+
+import org.apache.hadoop.fs.Path
+
+import org.apache.spark.api.r.RBackend
+import org.apache.spark.util.RedirectThread
+
+/**
+ * Main class used to launch SparkR applications using spark-submit. It executes R as a
+ * subprocess and then has it connect back to the JVM to access system properties etc.
+ */
+object RRunner {
+ def main(args: Array[String]): Unit = {
+ val rFile = PythonRunner.formatPath(args(0))
+
+ val otherArgs = args.slice(1, args.length)
+
+ // Time to wait for SparkR backend to initialize in seconds
+ val backendTimeout = sys.env.getOrElse("SPARKR_BACKEND_TIMEOUT", "120").toInt
+ val rCommand = "Rscript"
+
+ // Check if the file path exists.
+ // If not, change directory to current working directory for YARN cluster mode
+ val rF = new File(rFile)
+ val rFileNormalized = if (!rF.exists()) {
+ new Path(rFile).getName
+ } else {
+ rFile
+ }
+
+ // Launch a SparkR backend server for the R process to connect to; this will let it see our
+ // Java system properties etc.
+ val sparkRBackend = new RBackend()
+ @volatile var sparkRBackendPort = 0
+ val initialized = new Semaphore(0)
+ val sparkRBackendThread = new Thread("SparkR backend") {
+ override def run() {
+ sparkRBackendPort = sparkRBackend.init()
+ initialized.release()
+ sparkRBackend.run()
+ }
+ }
+
+ sparkRBackendThread.start()
+ // Wait for RBackend initialization to finish
+ if (initialized.tryAcquire(backendTimeout, TimeUnit.SECONDS)) {
+ // Launch R
+ val returnCode = try {
+ val builder = new ProcessBuilder(Seq(rCommand, rFileNormalized) ++ otherArgs)
+ val env = builder.environment()
+ env.put("EXISTING_SPARKR_BACKEND_PORT", sparkRBackendPort.toString)
+ val sparkHome = System.getenv("SPARK_HOME")
+ env.put("R_PROFILE_USER",
+ Seq(sparkHome, "R", "lib", "SparkR", "profile", "general.R").mkString(File.separator))
+ builder.redirectErrorStream(true) // Ugly but needed for stdout and stderr to synchronize
+ val process = builder.start()
+
+ new RedirectThread(process.getInputStream, System.out, "redirect R output").start()
+
+ process.waitFor()
+ } finally {
+ sparkRBackend.close()
+ }
+ System.exit(returnCode)
+ } else {
+ System.err.println("SparkR backend did not initialize in " + backendTimeout + " seconds")
+ System.exit(-1)
+ }
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala
index 660307d19e..60bc243ebf 100644
--- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala
@@ -77,6 +77,7 @@ object SparkSubmit {
// Special primary resource names that represent shells rather than application jars.
private val SPARK_SHELL = "spark-shell"
private val PYSPARK_SHELL = "pyspark-shell"
+ private val SPARKR_SHELL = "sparkr-shell"
private val CLASS_NOT_FOUND_EXIT_STATUS = 101
@@ -284,6 +285,13 @@ object SparkSubmit {
}
}
+ // Require all R files to be local
+ if (args.isR && !isYarnCluster) {
+ if (Utils.nonLocalPaths(args.primaryResource).nonEmpty) {
+ printErrorAndExit(s"Only local R files are supported: $args.primaryResource")
+ }
+ }
+
// The following modes are not supported or applicable
(clusterManager, deployMode) match {
case (MESOS, CLUSTER) =>
@@ -291,6 +299,9 @@ object SparkSubmit {
case (STANDALONE, CLUSTER) if args.isPython =>
printErrorAndExit("Cluster deploy mode is currently not supported for python " +
"applications on standalone clusters.")
+ case (STANDALONE, CLUSTER) if args.isR =>
+ printErrorAndExit("Cluster deploy mode is currently not supported for R " +
+ "applications on standalone clusters.")
case (_, CLUSTER) if isShell(args.primaryResource) =>
printErrorAndExit("Cluster deploy mode is not applicable to Spark shells.")
case (_, CLUSTER) if isSqlShell(args.mainClass) =>
@@ -317,11 +328,32 @@ object SparkSubmit {
}
}
- // In yarn-cluster mode for a python app, add primary resource and pyFiles to files
- // that can be distributed with the job
- if (args.isPython && isYarnCluster) {
- args.files = mergeFileLists(args.files, args.primaryResource)
- args.files = mergeFileLists(args.files, args.pyFiles)
+ // If we're running a R app, set the main class to our specific R runner
+ if (args.isR && deployMode == CLIENT) {
+ if (args.primaryResource == SPARKR_SHELL) {
+ args.mainClass = "org.apache.spark.api.r.RBackend"
+ } else {
+ // If a R file is provided, add it to the child arguments and list of files to deploy.
+ // Usage: RRunner <main R file> [app arguments]
+ args.mainClass = "org.apache.spark.deploy.RRunner"
+ args.childArgs = ArrayBuffer(args.primaryResource) ++ args.childArgs
+ args.files = mergeFileLists(args.files, args.primaryResource)
+ }
+ }
+
+ if (isYarnCluster) {
+ // In yarn-cluster mode for a python app, add primary resource and pyFiles to files
+ // that can be distributed with the job
+ if (args.isPython) {
+ args.files = mergeFileLists(args.files, args.primaryResource)
+ args.files = mergeFileLists(args.files, args.pyFiles)
+ }
+
+ // In yarn-cluster mode for a R app, add primary resource to files
+ // that can be distributed with the job
+ if (args.isR) {
+ args.files = mergeFileLists(args.files, args.primaryResource)
+ }
}
// Special flag to avoid deprecation warnings at the client
@@ -405,8 +437,8 @@ object SparkSubmit {
// Add the application jar automatically so the user doesn't have to call sc.addJar
// For YARN cluster mode, the jar is already distributed on each node as "app.jar"
- // For python files, the primary resource is already distributed as a regular file
- if (!isYarnCluster && !args.isPython) {
+ // For python and R files, the primary resource is already distributed as a regular file
+ if (!isYarnCluster && !args.isPython && !args.isR) {
var jars = sysProps.get("spark.jars").map(x => x.split(",").toSeq).getOrElse(Seq.empty)
if (isUserJar(args.primaryResource)) {
jars = jars ++ Seq(args.primaryResource)
@@ -447,6 +479,10 @@ object SparkSubmit {
childArgs += ("--py-files", pyFilesNames)
}
childArgs += ("--class", "org.apache.spark.deploy.PythonRunner")
+ } else if (args.isR) {
+ val mainFile = new Path(args.primaryResource).getName
+ childArgs += ("--primary-r-file", mainFile)
+ childArgs += ("--class", "org.apache.spark.deploy.RRunner")
} else {
if (args.primaryResource != SPARK_INTERNAL) {
childArgs += ("--jar", args.primaryResource)
@@ -591,15 +627,15 @@ object SparkSubmit {
/**
* Return whether the given primary resource represents a user jar.
*/
- private def isUserJar(primaryResource: String): Boolean = {
- !isShell(primaryResource) && !isPython(primaryResource) && !isInternal(primaryResource)
+ private[deploy] def isUserJar(res: String): Boolean = {
+ !isShell(res) && !isPython(res) && !isInternal(res) && !isR(res)
}
/**
* Return whether the given primary resource represents a shell.
*/
- private[deploy] def isShell(primaryResource: String): Boolean = {
- primaryResource == SPARK_SHELL || primaryResource == PYSPARK_SHELL
+ private[deploy] def isShell(res: String): Boolean = {
+ (res == SPARK_SHELL || res == PYSPARK_SHELL || res == SPARKR_SHELL)
}
/**
@@ -619,12 +655,19 @@ object SparkSubmit {
/**
* Return whether the given primary resource requires running python.
*/
- private[deploy] def isPython(primaryResource: String): Boolean = {
- primaryResource.endsWith(".py") || primaryResource == PYSPARK_SHELL
+ private[deploy] def isPython(res: String): Boolean = {
+ res != null && res.endsWith(".py") || res == PYSPARK_SHELL
+ }
+
+ /**
+ * Return whether the given primary resource requires running R.
+ */
+ private[deploy] def isR(res: String): Boolean = {
+ res != null && res.endsWith(".R") || res == SPARKR_SHELL
}
- private[deploy] def isInternal(primaryResource: String): Boolean = {
- primaryResource == SPARK_INTERNAL
+ private[deploy] def isInternal(res: String): Boolean = {
+ res == SPARK_INTERNAL
}
/**
diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala
index 6eb73c4347..03ecf3fd99 100644
--- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala
@@ -59,6 +59,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S
var verbose: Boolean = false
var isPython: Boolean = false
var pyFiles: String = null
+ var isR: Boolean = false
var action: SparkSubmitAction = null
val sparkProperties: HashMap[String, String] = new HashMap[String, String]()
var proxyUser: String = null
@@ -158,7 +159,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S
.getOrElse(sparkProperties.get("spark.executor.instances").orNull)
// Try to set main class from JAR if no --class argument is given
- if (mainClass == null && !isPython && primaryResource != null) {
+ if (mainClass == null && !isPython && !isR && primaryResource != null) {
val uri = new URI(primaryResource)
val uriScheme = uri.getScheme()
@@ -211,9 +212,9 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S
printUsageAndExit(-1)
}
if (primaryResource == null) {
- SparkSubmit.printErrorAndExit("Must specify a primary resource (JAR or Python file)")
+ SparkSubmit.printErrorAndExit("Must specify a primary resource (JAR or Python or R file)")
}
- if (mainClass == null && !isPython) {
+ if (mainClass == null && SparkSubmit.isUserJar(primaryResource)) {
SparkSubmit.printErrorAndExit("No main class set in JAR; please specify one with --class")
}
if (pyFiles != null && !isPython) {
@@ -414,6 +415,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S
opt
}
isPython = SparkSubmit.isPython(opt)
+ isR = SparkSubmit.isR(opt)
false
}