aboutsummaryrefslogtreecommitdiff
path: root/core/src/main/scala/spark/rdd/NewHadoopRDD.scala
diff options
context:
space:
mode:
Diffstat (limited to 'core/src/main/scala/spark/rdd/NewHadoopRDD.scala')
-rw-r--r--core/src/main/scala/spark/rdd/NewHadoopRDD.scala34
1 files changed, 16 insertions, 18 deletions
diff --git a/core/src/main/scala/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/spark/rdd/NewHadoopRDD.scala
index 7a1a0fb87d..197ed5ea17 100644
--- a/core/src/main/scala/spark/rdd/NewHadoopRDD.scala
+++ b/core/src/main/scala/spark/rdd/NewHadoopRDD.scala
@@ -1,22 +1,19 @@
package spark.rdd
+import java.text.SimpleDateFormat
+import java.util.Date
+
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.io.Writable
import org.apache.hadoop.mapreduce._
-import java.util.Date
-import java.text.SimpleDateFormat
+import spark.{Dependency, RDD, SerializableWritable, SparkContext, Split, TaskContext}
-import spark.Dependency
-import spark.RDD
-import spark.SerializableWritable
-import spark.SparkContext
-import spark.Split
-private[spark]
+private[spark]
class NewHadoopSplit(rddId: Int, val index: Int, @transient rawSplit: InputSplit with Writable)
extends Split {
-
+
val serializableHadoopSplit = new SerializableWritable(rawSplit)
override def hashCode(): Int = (41 * (41 + rddId) + index)
@@ -29,7 +26,7 @@ class NewHadoopRDD[K, V](
@transient conf: Configuration)
extends RDD[(K, V)](sc)
with HadoopMapReduceUtil {
-
+
// A Hadoop Configuration can be about 10 KB, which is pretty big, so broadcast it
val confBroadcast = sc.broadcast(new SerializableWritable(conf))
// private val serializableConf = new SerializableWritable(conf)
@@ -56,15 +53,19 @@ class NewHadoopRDD[K, V](
override def splits = splits_
- override def compute(theSplit: Split) = new Iterator[(K, V)] {
+ override def compute(theSplit: Split, context: TaskContext) = new Iterator[(K, V)] {
val split = theSplit.asInstanceOf[NewHadoopSplit]
val conf = confBroadcast.value.value
val attemptId = new TaskAttemptID(jobtrackerId, id, true, split.index, 0)
- val context = newTaskAttemptContext(conf, attemptId)
+ val hadoopAttemptContext = newTaskAttemptContext(conf, attemptId)
val format = inputFormatClass.newInstance
- val reader = format.createRecordReader(split.serializableHadoopSplit.value, context)
- reader.initialize(split.serializableHadoopSplit.value, context)
-
+ val reader = format.createRecordReader(
+ split.serializableHadoopSplit.value, hadoopAttemptContext)
+ reader.initialize(split.serializableHadoopSplit.value, hadoopAttemptContext)
+
+ // Register an on-task-completion callback to close the input stream.
+ context.addOnCompleteCallback(() => reader.close())
+
var havePair = false
var finished = false
@@ -72,9 +73,6 @@ class NewHadoopRDD[K, V](
if (!finished && !havePair) {
finished = !reader.nextKeyValue
havePair = !finished
- if (finished) {
- reader.close()
- }
}
!finished
}