aboutsummaryrefslogtreecommitdiff
path: root/core
diff options
context:
space:
mode:
authorCharles Reiss <charles@eecs.berkeley.edu>2011-12-01 14:01:28 -0800
committerCharles Reiss <charles@eecs.berkeley.edu>2011-12-01 14:01:28 -0800
commit02d43e6986edd4cc656e97da69a33777c05ba1af (patch)
treed0944c36b2fa19d8b5758a944d25d9cdf27b5fbd /core
parent07532021fee9e2d27ee954b21c30830687478d8b (diff)
downloadspark-02d43e6986edd4cc656e97da69a33777c05ba1af.tar.gz
spark-02d43e6986edd4cc656e97da69a33777c05ba1af.tar.bz2
spark-02d43e6986edd4cc656e97da69a33777c05ba1af.zip
Add new Hadoop API writing support.
Diffstat (limited to 'core')
-rw-r--r--core/src/main/scala/spark/PairRDDFunctions.scala55
-rw-r--r--core/src/test/scala/spark/FileSuite.scala13
2 files changed, 68 insertions, 0 deletions
diff --git a/core/src/main/scala/spark/PairRDDFunctions.scala b/core/src/main/scala/spark/PairRDDFunctions.scala
index 6a743bf096..074d5abb38 100644
--- a/core/src/main/scala/spark/PairRDDFunctions.scala
+++ b/core/src/main/scala/spark/PairRDDFunctions.scala
@@ -7,11 +7,13 @@ import java.util.concurrent.atomic.AtomicLong
import java.util.HashSet
import java.util.Random
import java.util.Date
+import java.text.SimpleDateFormat
import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.Map
import scala.collection.mutable.HashMap
+import org.apache.hadoop.fs.Path
import org.apache.hadoop.io.BytesWritable
import org.apache.hadoop.io.NullWritable
import org.apache.hadoop.io.Text
@@ -25,6 +27,13 @@ import org.apache.hadoop.mapred.OutputFormat
import org.apache.hadoop.mapred.SequenceFileOutputFormat
import org.apache.hadoop.mapred.TextOutputFormat
+import org.apache.hadoop.mapreduce.lib.output.{FileOutputFormat => NewFileOutputFormat}
+import org.apache.hadoop.mapreduce.{OutputFormat => NewOutputFormat}
+import org.apache.hadoop.mapreduce.{RecordWriter => NewRecordWriter}
+import org.apache.hadoop.mapreduce.{Job => NewAPIHadoopJob}
+import org.apache.hadoop.mapreduce.TaskAttemptID
+import org.apache.hadoop.mapreduce.TaskAttemptContext
+
import SparkContext._
/**
@@ -239,6 +248,52 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](self: RDD[(K, V)]) ex
def saveAsHadoopFile [F <: OutputFormat[K, V]] (path: String) (implicit fm: ClassManifest[F]) {
saveAsHadoopFile(path, getKeyClass, getValueClass, fm.erasure.asInstanceOf[Class[F]])
}
+
+ def saveAsNewAPIHadoopFile [F <: NewOutputFormat[K, V]] (path: String) (implicit fm: ClassManifest[F]) {
+ saveAsNewAPIHadoopFile(path, getKeyClass, getValueClass, fm.erasure.asInstanceOf[Class[F]])
+ }
+
+ def saveAsNewAPIHadoopFile(path: String,
+ keyClass: Class[_],
+ valueClass: Class[_],
+ outputFormatClass: Class[_ <: NewOutputFormat[_, _]]) {
+ val job = new NewAPIHadoopJob
+ job.setOutputKeyClass(keyClass)
+ job.setOutputValueClass(valueClass)
+ val wrappedConf = new SerializableWritable(job.getConfiguration)
+ NewFileOutputFormat.setOutputPath(job, new Path(path))
+ val formatter = new SimpleDateFormat("yyyyMMddHHmm")
+ val jobtrackerID = formatter.format(new Date())
+ val stageId = self.id
+ def writeShard(context: spark.TaskContext, iter: Iterator[(K,V)]): Int = {
+ /* "reduce task" <split #> <attempt # = spark task #> */
+ val attemptId = new TaskAttemptID(jobtrackerID,
+ stageId, false, context.splitId, context.attemptId)
+ val hadoopContext = new TaskAttemptContext(wrappedConf.value, attemptId)
+ val format = outputFormatClass.newInstance
+ val committer = format.getOutputCommitter(hadoopContext)
+ committer.setupTask(hadoopContext)
+ val writer = format.getRecordWriter(hadoopContext).asInstanceOf[NewRecordWriter[K,V]]
+ while (iter.hasNext) {
+ val (k, v) = iter.next
+ writer.write(k, v)
+ }
+ writer.close(hadoopContext)
+ committer.commitTask(hadoopContext)
+ return 1
+ }
+ val jobFormat = outputFormatClass.newInstance
+ /* apparently we need a TaskAttemptID to construct an OutputCommitter;
+ * however we're only going to use this local OutputCommitter for
+ * setupJob/commitJob, so we just use a dummy "map" task.
+ */
+ val jobAttemptId = new TaskAttemptID(jobtrackerID, stageId, true, 0, 0)
+ val jobTaskContext = new TaskAttemptContext(wrappedConf.value, jobAttemptId)
+ val jobCommitter = jobFormat.getOutputCommitter(jobTaskContext)
+ jobCommitter.setupJob(jobTaskContext)
+ val count = self.context.runJob(self, writeShard _).sum
+ jobCommitter.cleanupJob(jobTaskContext)
+ }
def saveAsHadoopFile(path: String,
keyClass: Class[_],
diff --git a/core/src/test/scala/spark/FileSuite.scala b/core/src/test/scala/spark/FileSuite.scala
index d21de34e72..bb2d0c658b 100644
--- a/core/src/test/scala/spark/FileSuite.scala
+++ b/core/src/test/scala/spark/FileSuite.scala
@@ -115,4 +115,17 @@ class FileSuite extends FunSuite {
assert(output.collect().toList === List((1, "a"), (2, "aa"), (3, "aaa")))
sc.stop()
}
+
+ test("write SequenceFile using new Hadoop API") {
+ import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat
+ val sc = new SparkContext("local", "test")
+ val tempDir = Files.createTempDir()
+ val outputDir = new File(tempDir, "output").getAbsolutePath
+ val nums = sc.makeRDD(1 to 3).map(x => (new IntWritable(x), new Text("a" * x)))
+ nums.saveAsNewAPIHadoopFile[SequenceFileOutputFormat[IntWritable, Text]](
+ outputDir)
+ val output = sc.sequenceFile[IntWritable, Text](outputDir)
+ assert(output.map(_.toString).collect().toList === List("(1,a)", "(2,aa)", "(3,aaa)"))
+ sc.stop()
+ }
}