aboutsummaryrefslogtreecommitdiff
path: root/core
diff options
context:
space:
mode:
authorReynold Xin <rxin@cs.berkeley.edu>2012-04-17 16:40:56 -0700
committerReynold Xin <rxin@cs.berkeley.edu>2012-04-17 16:40:56 -0700
commite601b3b9e56d6ce978c09506ca07fd3e252e4673 (patch)
treea63b16330c93ef411e47f407be31b6474628d40d /core
parent3b745176e0ec5fda8c7afef04aec1040e1c649a9 (diff)
downloadspark-e601b3b9e56d6ce978c09506ca07fd3e252e4673.tar.gz
spark-e601b3b9e56d6ce978c09506ca07fd3e252e4673.tar.bz2
spark-e601b3b9e56d6ce978c09506ca07fd3e252e4673.zip
Added the ability to set environmental variables in piped rdd.
Diffstat (limited to 'core')
-rw-r--r--core/src/main/scala/spark/PipedRDD.scala14
-rw-r--r--core/src/main/scala/spark/RDD.scala5
-rw-r--r--core/src/test/scala/spark/PipedRDDSuite.scala37
3 files changed, 52 insertions, 4 deletions
diff --git a/core/src/main/scala/spark/PipedRDD.scala b/core/src/main/scala/spark/PipedRDD.scala
index 3f993d895a..8a5de3d7e9 100644
--- a/core/src/main/scala/spark/PipedRDD.scala
+++ b/core/src/main/scala/spark/PipedRDD.scala
@@ -3,6 +3,7 @@ package spark
import java.io.PrintWriter
import java.util.StringTokenizer
+import scala.collection.JavaConversions._
import scala.collection.mutable.ArrayBuffer
import scala.io.Source
@@ -10,8 +11,12 @@ import scala.io.Source
* An RDD that pipes the contents of each parent partition through an external command
* (printing them one per line) and returns the output as a collection of strings.
*/
-class PipedRDD[T: ClassManifest](parent: RDD[T], command: Seq[String])
+class PipedRDD[T: ClassManifest](
+ parent: RDD[T], command: Seq[String], envVars: Map[String, String])
extends RDD[String](parent.context) {
+
+ def this(parent: RDD[T], command: Seq[String]) = this(parent, command, Map())
+
// Similar to Runtime.exec(), if we are given a single string, split it into words
// using a standard StringTokenizer (i.e. by spaces)
def this(parent: RDD[T], command: String) = this(parent, PipedRDD.tokenize(command))
@@ -21,7 +26,12 @@ class PipedRDD[T: ClassManifest](parent: RDD[T], command: Seq[String])
override val dependencies = List(new OneToOneDependency(parent))
override def compute(split: Split): Iterator[String] = {
- val proc = Runtime.getRuntime.exec(command.toArray)
+ val pb = new ProcessBuilder(command)
+ // Add the environmental variables to the process.
+ val currentEnvVars = pb.environment()
+ envVars.foreach { case(variable, value) => currentEnvVars.put(variable, value) }
+
+ val proc = pb.start()
val env = SparkEnv.get
// Start a thread to print the process's stderr to ours
diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala
index 1160de5fd1..7fe6633f1b 100644
--- a/core/src/main/scala/spark/RDD.scala
+++ b/core/src/main/scala/spark/RDD.scala
@@ -9,8 +9,6 @@ import java.util.Random
import java.util.Date
import scala.collection.mutable.ArrayBuffer
-import scala.collection.mutable.Map
-import scala.collection.mutable.HashMap
import org.apache.hadoop.io.BytesWritable
import org.apache.hadoop.io.NullWritable
@@ -146,6 +144,9 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial
def pipe(command: Seq[String]): RDD[String] = new PipedRDD(this, command)
+ def pipe(command: Seq[String], env: Map[String, String]): RDD[String] =
+ new PipedRDD(this, command, env)
+
def mapPartitions[U: ClassManifest](f: Iterator[T] => Iterator[U]): RDD[U] =
new MapPartitionsRDD(this, sc.clean(f))
diff --git a/core/src/test/scala/spark/PipedRDDSuite.scala b/core/src/test/scala/spark/PipedRDDSuite.scala
new file mode 100644
index 0000000000..d5dc2efd91
--- /dev/null
+++ b/core/src/test/scala/spark/PipedRDDSuite.scala
@@ -0,0 +1,37 @@
+package spark
+
+import org.scalatest.FunSuite
+import SparkContext._
+
+class PipedRDDSuite extends FunSuite {
+
+ test("basic pipe") {
+ val sc = new SparkContext("local", "test")
+ val nums = sc.makeRDD(Array(1, 2, 3, 4), 2)
+
+ val piped = nums.pipe(Seq("cat"))
+
+ val c = piped.collect()
+ println(c.toSeq)
+ assert(c.size === 4)
+ assert(c(0) === "1")
+ assert(c(1) === "2")
+ assert(c(2) === "3")
+ assert(c(3) === "4")
+ sc.stop()
+ }
+
+ test("pipe with env variable") {
+ val sc = new SparkContext("local", "test")
+ val nums = sc.makeRDD(Array(1, 2, 3, 4), 2)
+ val piped = nums.pipe(Seq("printenv", "MY_TEST_ENV"), Map("MY_TEST_ENV" -> "LALALA"))
+ val c = piped.collect()
+ assert(c.size === 2)
+ assert(c(0) === "LALALA")
+ assert(c(1) === "LALALA")
+ sc.stop()
+ }
+
+}
+
+