diff options
-rw-r--r-- | core/src/main/scala/spark/NewHadoopRDD.scala | 88 | ||||
-rw-r--r-- | core/src/main/scala/spark/PairRDDFunctions.scala | 55 | ||||
-rw-r--r-- | core/src/main/scala/spark/SparkContext.scala | 29 | ||||
-rw-r--r-- | core/src/test/scala/spark/FileSuite.scala | 26 |
4 files changed, 198 insertions, 0 deletions
diff --git a/core/src/main/scala/spark/NewHadoopRDD.scala b/core/src/main/scala/spark/NewHadoopRDD.scala new file mode 100644 index 0000000000..c40a39cbe0 --- /dev/null +++ b/core/src/main/scala/spark/NewHadoopRDD.scala @@ -0,0 +1,88 @@ +package spark + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.io.Writable +import org.apache.hadoop.mapreduce.InputFormat +import org.apache.hadoop.mapreduce.InputSplit +import org.apache.hadoop.mapreduce.JobContext +import org.apache.hadoop.mapreduce.JobID +import org.apache.hadoop.mapreduce.RecordReader +import org.apache.hadoop.mapreduce.TaskAttemptContext +import org.apache.hadoop.mapreduce.TaskAttemptID + +import java.util.Date +import java.text.SimpleDateFormat + +class NewHadoopSplit(rddId: Int, val index: Int, @transient rawSplit: InputSplit with Writable) +extends Split { + val serializableHadoopSplit = new SerializableWritable(rawSplit) + + override def hashCode(): Int = (41 * (41 + rddId) + index).toInt +} + +class NewHadoopRDD[K, V]( + sc: SparkContext, + inputFormatClass: Class[_ <: InputFormat[K, V]], + keyClass: Class[K], valueClass: Class[V], + @transient conf: Configuration) +extends RDD[(K, V)](sc) { + private val serializableConf = new SerializableWritable(conf) + + private val jobtrackerId: String = { + val formatter = new SimpleDateFormat("yyyyMMddHHmm") + formatter.format(new Date()) + } + + @transient private val jobId = new JobID(jobtrackerId, id) + + @transient private val splits_ : Array[Split] = { + val inputFormat = inputFormatClass.newInstance + val jobContext = new JobContext(serializableConf.value, jobId) + val rawSplits = inputFormat.getSplits(jobContext).toArray + val result = new Array[Split](rawSplits.size) + for (i <- 0 until rawSplits.size) + result(i) = new NewHadoopSplit(id, i, rawSplits(i).asInstanceOf[InputSplit with Writable]) + result + } + + override def splits = splits_ + + override def compute(theSplit: Split) = new Iterator[(K, V)] { + val split = theSplit.asInstanceOf[NewHadoopSplit] + val conf = serializableConf.value + val attemptId = new TaskAttemptID(jobtrackerId, id, true, split.index, 0) + val context = new TaskAttemptContext(serializableConf.value, attemptId) + val format = inputFormatClass.newInstance + val reader = format.createRecordReader(split.serializableHadoopSplit.value, context) + reader.initialize(split.serializableHadoopSplit.value, context) + + var havePair = false + var finished = false + + override def hasNext: Boolean = { + if (!finished && !havePair) { + finished = !reader.nextKeyValue + havePair = !finished + if (finished) { + reader.close + } + } + !finished + } + + override def next: (K, V) = { + if (!hasNext) { + throw new java.util.NoSuchElementException("End of stream") + } + havePair = false + return (reader.getCurrentKey, reader.getCurrentValue) + } + } + + override def preferredLocations(split: Split) = { + val theSplit = split.asInstanceOf[NewHadoopSplit] + theSplit.serializableHadoopSplit.value.getLocations.filter(_ != "localhost") + } + + override val dependencies: List[Dependency[_]] = Nil +} diff --git a/core/src/main/scala/spark/PairRDDFunctions.scala b/core/src/main/scala/spark/PairRDDFunctions.scala index 6a743bf096..074d5abb38 100644 --- a/core/src/main/scala/spark/PairRDDFunctions.scala +++ b/core/src/main/scala/spark/PairRDDFunctions.scala @@ -7,11 +7,13 @@ import java.util.concurrent.atomic.AtomicLong import java.util.HashSet import java.util.Random import java.util.Date +import java.text.SimpleDateFormat import scala.collection.mutable.ArrayBuffer import scala.collection.mutable.Map import scala.collection.mutable.HashMap +import org.apache.hadoop.fs.Path import org.apache.hadoop.io.BytesWritable import org.apache.hadoop.io.NullWritable import org.apache.hadoop.io.Text @@ -25,6 +27,13 @@ import org.apache.hadoop.mapred.OutputFormat import org.apache.hadoop.mapred.SequenceFileOutputFormat import org.apache.hadoop.mapred.TextOutputFormat +import org.apache.hadoop.mapreduce.lib.output.{FileOutputFormat => NewFileOutputFormat} +import org.apache.hadoop.mapreduce.{OutputFormat => NewOutputFormat} +import org.apache.hadoop.mapreduce.{RecordWriter => NewRecordWriter} +import org.apache.hadoop.mapreduce.{Job => NewAPIHadoopJob} +import org.apache.hadoop.mapreduce.TaskAttemptID +import org.apache.hadoop.mapreduce.TaskAttemptContext + import SparkContext._ /** @@ -239,6 +248,52 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](self: RDD[(K, V)]) ex def saveAsHadoopFile [F <: OutputFormat[K, V]] (path: String) (implicit fm: ClassManifest[F]) { saveAsHadoopFile(path, getKeyClass, getValueClass, fm.erasure.asInstanceOf[Class[F]]) } + + def saveAsNewAPIHadoopFile [F <: NewOutputFormat[K, V]] (path: String) (implicit fm: ClassManifest[F]) { + saveAsNewAPIHadoopFile(path, getKeyClass, getValueClass, fm.erasure.asInstanceOf[Class[F]]) + } + + def saveAsNewAPIHadoopFile(path: String, + keyClass: Class[_], + valueClass: Class[_], + outputFormatClass: Class[_ <: NewOutputFormat[_, _]]) { + val job = new NewAPIHadoopJob + job.setOutputKeyClass(keyClass) + job.setOutputValueClass(valueClass) + val wrappedConf = new SerializableWritable(job.getConfiguration) + NewFileOutputFormat.setOutputPath(job, new Path(path)) + val formatter = new SimpleDateFormat("yyyyMMddHHmm") + val jobtrackerID = formatter.format(new Date()) + val stageId = self.id + def writeShard(context: spark.TaskContext, iter: Iterator[(K,V)]): Int = { + /* "reduce task" <split #> <attempt # = spark task #> */ + val attemptId = new TaskAttemptID(jobtrackerID, + stageId, false, context.splitId, context.attemptId) + val hadoopContext = new TaskAttemptContext(wrappedConf.value, attemptId) + val format = outputFormatClass.newInstance + val committer = format.getOutputCommitter(hadoopContext) + committer.setupTask(hadoopContext) + val writer = format.getRecordWriter(hadoopContext).asInstanceOf[NewRecordWriter[K,V]] + while (iter.hasNext) { + val (k, v) = iter.next + writer.write(k, v) + } + writer.close(hadoopContext) + committer.commitTask(hadoopContext) + return 1 + } + val jobFormat = outputFormatClass.newInstance + /* apparently we need a TaskAttemptID to construct an OutputCommitter; + * however we're only going to use this local OutputCommitter for + * setupJob/commitJob, so we just use a dummy "map" task. + */ + val jobAttemptId = new TaskAttemptID(jobtrackerID, stageId, true, 0, 0) + val jobTaskContext = new TaskAttemptContext(wrappedConf.value, jobAttemptId) + val jobCommitter = jobFormat.getOutputCommitter(jobTaskContext) + jobCommitter.setupJob(jobTaskContext) + val count = self.context.runJob(self, writeShard _).sum + jobCommitter.cleanupJob(jobTaskContext) + } def saveAsHadoopFile(path: String, keyClass: Class[_], diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index f044da1e21..25b879ba96 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -6,6 +6,8 @@ import java.util.concurrent.atomic.AtomicInteger import scala.actors.remote.RemoteActor import scala.collection.mutable.ArrayBuffer +import org.apache.hadoop.fs.Path +import org.apache.hadoop.conf.Configuration import org.apache.hadoop.mapred.InputFormat import org.apache.hadoop.mapred.SequenceFileInputFormat import org.apache.hadoop.io.Writable @@ -22,6 +24,10 @@ import org.apache.hadoop.mapred.FileInputFormat import org.apache.hadoop.mapred.JobConf import org.apache.hadoop.mapred.TextInputFormat +import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat} +import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat => NewFileInputFormat} +import org.apache.hadoop.mapreduce.{Job => NewHadoopJob} + import spark.broadcast._ class SparkContext( @@ -123,6 +129,29 @@ extends Logging { (implicit km: ClassManifest[K], vm: ClassManifest[V], fm: ClassManifest[F]): RDD[(K, V)] = hadoopFile[K, V, F](path, defaultMinSplits) + /** Get an RDD for a Hadoop file with an arbitrary new API InputFormat. */ + def newAPIHadoopFile[K, V, F <: NewInputFormat[K, V]](path: String) + (implicit km: ClassManifest[K], vm: ClassManifest[V], fm: ClassManifest[F]): RDD[(K, V)] = { + val job = new NewHadoopJob + NewFileInputFormat.addInputPath(job, new Path(path)) + val conf = job.getConfiguration + newAPIHadoopFile(path, + fm.erasure.asInstanceOf[Class[F]], + km.erasure.asInstanceOf[Class[K]], + vm.erasure.asInstanceOf[Class[V]], + conf) + } + + /** Get an RDD for a given Hadoop file with an arbitrary new API InputFormat and extra + * configuration options to pass to the input format. + */ + def newAPIHadoopFile[K, V, F <: NewInputFormat[K, V]](path: String, + fClass: Class[F], + kClass: Class[K], + vClass: Class[V], + conf: Configuration): RDD[(K, V)] = + new NewHadoopRDD(this, fClass, kClass, vClass, conf) + /** Get an RDD for a Hadoop SequenceFile with given key and value types */ def sequenceFile[K, V](path: String, keyClass: Class[K], diff --git a/core/src/test/scala/spark/FileSuite.scala b/core/src/test/scala/spark/FileSuite.scala index d21de34e72..b12014e6be 100644 --- a/core/src/test/scala/spark/FileSuite.scala +++ b/core/src/test/scala/spark/FileSuite.scala @@ -115,4 +115,30 @@ class FileSuite extends FunSuite { assert(output.collect().toList === List((1, "a"), (2, "aa"), (3, "aaa"))) sc.stop() } + + test("write SequenceFile using new Hadoop API") { + import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat + val sc = new SparkContext("local", "test") + val tempDir = Files.createTempDir() + val outputDir = new File(tempDir, "output").getAbsolutePath + val nums = sc.makeRDD(1 to 3).map(x => (new IntWritable(x), new Text("a" * x))) + nums.saveAsNewAPIHadoopFile[SequenceFileOutputFormat[IntWritable, Text]]( + outputDir) + val output = sc.sequenceFile[IntWritable, Text](outputDir) + assert(output.map(_.toString).collect().toList === List("(1,a)", "(2,aa)", "(3,aaa)")) + sc.stop() + } + + test("read SequenceFile using new Hadoop API") { + import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat + val sc = new SparkContext("local", "test") + val tempDir = Files.createTempDir() + val outputDir = new File(tempDir, "output").getAbsolutePath + val nums = sc.makeRDD(1 to 3).map(x => (new IntWritable(x), new Text("a" * x))) + nums.saveAsSequenceFile(outputDir) + val output = + sc.newAPIHadoopFile[IntWritable, Text, SequenceFileInputFormat[IntWritable, Text]](outputDir) + assert(output.map(_.toString).collect().toList === List("(1,a)", "(2,aa)", "(3,aaa)")) + sc.stop() + } } |