aboutsummaryrefslogtreecommitdiff
path: root/core
diff options
context:
space:
mode:
Diffstat (limited to 'core')
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala19
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala8
-rw-r--r--core/src/test/scala/org/apache/spark/PipedRDDSuite.scala184
3 files changed, 158 insertions, 53 deletions
diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
index a374fc4a87..100ddb3607 100644
--- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
@@ -18,8 +18,10 @@
package org.apache.spark.rdd
import java.io.EOFException
+import scala.collection.immutable.Map
import org.apache.hadoop.conf.{Configurable, Configuration}
+import org.apache.hadoop.mapred.FileSplit
import org.apache.hadoop.mapred.InputFormat
import org.apache.hadoop.mapred.InputSplit
import org.apache.hadoop.mapred.JobConf
@@ -43,6 +45,23 @@ private[spark] class HadoopPartition(rddId: Int, idx: Int, @transient s: InputSp
override def hashCode(): Int = 41 * (41 + rddId) + idx
override val index: Int = idx
+
+ /**
+ * Get any environment variables that should be added to the users environment when running pipes
+ * @return a Map with the environment variables and corresponding values, it could be empty
+ */
+ def getPipeEnvVars(): Map[String, String] = {
+ val envVars: Map[String, String] = if (inputSplit.value.isInstanceOf[FileSplit]) {
+ val is: FileSplit = inputSplit.value.asInstanceOf[FileSplit]
+ // map_input_file is deprecated in favor of mapreduce_map_input_file but set both
+ // since its not removed yet
+ Map("map_input_file" -> is.getPath().toString(),
+ "mapreduce_map_input_file" -> is.getPath().toString())
+ } else {
+ Map()
+ }
+ envVars
+ }
}
/**
diff --git a/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala
index abd4414e81..4250a9d02f 100644
--- a/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala
@@ -28,6 +28,7 @@ import scala.reflect.ClassTag
import org.apache.spark.{Partition, SparkEnv, TaskContext}
+
/**
* An RDD that pipes the contents of each parent partition through an external command
* (printing them one per line) and returns the output as a collection of strings.
@@ -59,6 +60,13 @@ class PipedRDD[T: ClassTag](
val currentEnvVars = pb.environment()
envVars.foreach { case (variable, value) => currentEnvVars.put(variable, value) }
+ // for compatibility with Hadoop which sets these env variables
+ // so the user code can access the input filename
+ if (split.isInstanceOf[HadoopPartition]) {
+ val hadoopSplit = split.asInstanceOf[HadoopPartition]
+ currentEnvVars.putAll(hadoopSplit.getPipeEnvVars())
+ }
+
val proc = pb.start()
val env = SparkEnv.get
diff --git a/core/src/test/scala/org/apache/spark/PipedRDDSuite.scala b/core/src/test/scala/org/apache/spark/PipedRDDSuite.scala
index 3a0385a1b0..0bac78d8a6 100644
--- a/core/src/test/scala/org/apache/spark/PipedRDDSuite.scala
+++ b/core/src/test/scala/org/apache/spark/PipedRDDSuite.scala
@@ -19,74 +19,152 @@ package org.apache.spark
import org.scalatest.FunSuite
+
+import org.apache.spark.rdd.{HadoopRDD, PipedRDD, HadoopPartition}
+import org.apache.hadoop.mapred.{JobConf, TextInputFormat, FileSplit}
+import org.apache.hadoop.fs.Path
+
+import scala.collection.Map
+import scala.sys.process._
+import scala.util.Try
+import org.apache.hadoop.io.{Text, LongWritable}
+
class PipedRDDSuite extends FunSuite with SharedSparkContext {
test("basic pipe") {
- val nums = sc.makeRDD(Array(1, 2, 3, 4), 2)
+ if (testCommandAvailable("cat")) {
+ val nums = sc.makeRDD(Array(1, 2, 3, 4), 2)
- val piped = nums.pipe(Seq("cat"))
+ val piped = nums.pipe(Seq("cat"))
- val c = piped.collect()
- assert(c.size === 4)
- assert(c(0) === "1")
- assert(c(1) === "2")
- assert(c(2) === "3")
- assert(c(3) === "4")
+ val c = piped.collect()
+ assert(c.size === 4)
+ assert(c(0) === "1")
+ assert(c(1) === "2")
+ assert(c(2) === "3")
+ assert(c(3) === "4")
+ } else {
+ assert(true)
+ }
}
test("advanced pipe") {
- val nums = sc.makeRDD(Array(1, 2, 3, 4), 2)
- val bl = sc.broadcast(List("0"))
-
- val piped = nums.pipe(Seq("cat"),
- Map[String, String](),
- (f: String => Unit) => {bl.value.map(f(_));f("\u0001")},
- (i:Int, f: String=> Unit) => f(i + "_"))
-
- val c = piped.collect()
-
- assert(c.size === 8)
- assert(c(0) === "0")
- assert(c(1) === "\u0001")
- assert(c(2) === "1_")
- assert(c(3) === "2_")
- assert(c(4) === "0")
- assert(c(5) === "\u0001")
- assert(c(6) === "3_")
- assert(c(7) === "4_")
-
- val nums1 = sc.makeRDD(Array("a\t1", "b\t2", "a\t3", "b\t4"), 2)
- val d = nums1.groupBy(str=>str.split("\t")(0)).
- pipe(Seq("cat"),
- Map[String, String](),
- (f: String => Unit) => {bl.value.map(f(_));f("\u0001")},
- (i:Tuple2[String, Seq[String]], f: String=> Unit) => {for (e <- i._2){ f(e + "_")}}).collect()
- assert(d.size === 8)
- assert(d(0) === "0")
- assert(d(1) === "\u0001")
- assert(d(2) === "b\t2_")
- assert(d(3) === "b\t4_")
- assert(d(4) === "0")
- assert(d(5) === "\u0001")
- assert(d(6) === "a\t1_")
- assert(d(7) === "a\t3_")
+ if (testCommandAvailable("cat")) {
+ val nums = sc.makeRDD(Array(1, 2, 3, 4), 2)
+ val bl = sc.broadcast(List("0"))
+
+ val piped = nums.pipe(Seq("cat"),
+ Map[String, String](),
+ (f: String => Unit) => {
+ bl.value.map(f(_)); f("\u0001")
+ },
+ (i: Int, f: String => Unit) => f(i + "_"))
+
+ val c = piped.collect()
+
+ assert(c.size === 8)
+ assert(c(0) === "0")
+ assert(c(1) === "\u0001")
+ assert(c(2) === "1_")
+ assert(c(3) === "2_")
+ assert(c(4) === "0")
+ assert(c(5) === "\u0001")
+ assert(c(6) === "3_")
+ assert(c(7) === "4_")
+
+ val nums1 = sc.makeRDD(Array("a\t1", "b\t2", "a\t3", "b\t4"), 2)
+ val d = nums1.groupBy(str => str.split("\t")(0)).
+ pipe(Seq("cat"),
+ Map[String, String](),
+ (f: String => Unit) => {
+ bl.value.map(f(_)); f("\u0001")
+ },
+ (i: Tuple2[String, Seq[String]], f: String => Unit) => {
+ for (e <- i._2) {
+ f(e + "_")
+ }
+ }).collect()
+ assert(d.size === 8)
+ assert(d(0) === "0")
+ assert(d(1) === "\u0001")
+ assert(d(2) === "b\t2_")
+ assert(d(3) === "b\t4_")
+ assert(d(4) === "0")
+ assert(d(5) === "\u0001")
+ assert(d(6) === "a\t1_")
+ assert(d(7) === "a\t3_")
+ } else {
+ assert(true)
+ }
}
test("pipe with env variable") {
- val nums = sc.makeRDD(Array(1, 2, 3, 4), 2)
- val piped = nums.pipe(Seq("printenv", "MY_TEST_ENV"), Map("MY_TEST_ENV" -> "LALALA"))
- val c = piped.collect()
- assert(c.size === 2)
- assert(c(0) === "LALALA")
- assert(c(1) === "LALALA")
+ if (testCommandAvailable("printenv")) {
+ val nums = sc.makeRDD(Array(1, 2, 3, 4), 2)
+ val piped = nums.pipe(Seq("printenv", "MY_TEST_ENV"), Map("MY_TEST_ENV" -> "LALALA"))
+ val c = piped.collect()
+ assert(c.size === 2)
+ assert(c(0) === "LALALA")
+ assert(c(1) === "LALALA")
+ } else {
+ assert(true)
+ }
}
test("pipe with non-zero exit status") {
- val nums = sc.makeRDD(Array(1, 2, 3, 4), 2)
- val piped = nums.pipe(Seq("cat nonexistent_file", "2>", "/dev/null"))
- intercept[SparkException] {
- piped.collect()
+ if (testCommandAvailable("cat")) {
+ val nums = sc.makeRDD(Array(1, 2, 3, 4), 2)
+ val piped = nums.pipe(Seq("cat nonexistent_file", "2>", "/dev/null"))
+ intercept[SparkException] {
+ piped.collect()
+ }
+ } else {
+ assert(true)
}
}
+ test("test pipe exports map_input_file") {
+ testExportInputFile("map_input_file")
+ }
+
+ test("test pipe exports mapreduce_map_input_file") {
+ testExportInputFile("mapreduce_map_input_file")
+ }
+
+ def testCommandAvailable(command: String): Boolean = {
+ Try(Process(command) !!).isSuccess
+ }
+
+ def testExportInputFile(varName: String) {
+ if (testCommandAvailable("printenv")) {
+ val nums = new HadoopRDD(sc, new JobConf(), classOf[TextInputFormat], classOf[LongWritable],
+ classOf[Text], 2) {
+ override def getPartitions: Array[Partition] = Array(generateFakeHadoopPartition())
+
+ override val getDependencies = List[Dependency[_]]()
+
+ override def compute(theSplit: Partition, context: TaskContext) = {
+ new InterruptibleIterator[(LongWritable, Text)](context, Iterator((new LongWritable(1),
+ new Text("b"))))
+ }
+ }
+ val hadoopPart1 = generateFakeHadoopPartition()
+ val pipedRdd = new PipedRDD(nums, "printenv " + varName)
+ val tContext = new TaskContext(0, 0, 0, interrupted = false, runningLocally = false,
+ taskMetrics = null)
+ val rddIter = pipedRdd.compute(hadoopPart1, tContext)
+ val arr = rddIter.toArray
+ assert(arr(0) == "/some/path")
+ } else {
+ // printenv isn't available so just pass the test
+ assert(true)
+ }
+ }
+
+ def generateFakeHadoopPartition(): HadoopPartition = {
+ val split = new FileSplit(new Path("/some/path"), 0, 1,
+ Array[String]("loc1", "loc2", "loc3", "loc4", "loc5"))
+ new HadoopPartition(sc.newRddId(), 1, split)
+ }
+
}