diff options
Diffstat (limited to 'core')
-rw-r--r-- | core/src/main/scala/org/apache/spark/SparkContext.scala | 14 | ||||
-rw-r--r-- | core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala | 78 |
2 files changed, 71 insertions, 21 deletions
diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 72540c712a..894cc67acf 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -332,7 +332,7 @@ class SparkContext( valueClass: Class[V], minSplits: Int = defaultMinSplits ): RDD[(K, V)] = { - new HadoopRDD(this, conf, inputFormatClass, keyClass, valueClass, minSplits) + new HadoopDatasetRDD(this, conf, inputFormatClass, keyClass, valueClass, minSplits) } /** Get an RDD for a Hadoop file with an arbitrary InputFormat */ @@ -343,9 +343,15 @@ class SparkContext( valueClass: Class[V], minSplits: Int = defaultMinSplits ) : RDD[(K, V)] = { - val conf = new JobConf(hadoopConfiguration) - FileInputFormat.setInputPaths(conf, path) - new HadoopRDD(this, conf, inputFormatClass, keyClass, valueClass, minSplits) + val broadcastHadoopConfiguration = broadcast(new SerializableWritable(hadoopConfiguration)) + new HadoopFileRDD( + this, + path, + broadcastHadoopConfiguration, + inputFormatClass, + keyClass, + valueClass, + minSplits) } /** diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala index 2cb6734e41..e259ef52a9 100644 --- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala @@ -19,6 +19,7 @@ package org.apache.spark.rdd import java.io.EOFException +import org.apache.hadoop.mapred.FileInputFormat import org.apache.hadoop.mapred.InputFormat import org.apache.hadoop.mapred.InputSplit import org.apache.hadoop.mapred.JobConf @@ -26,10 +27,55 @@ import org.apache.hadoop.mapred.RecordReader import org.apache.hadoop.mapred.Reporter import org.apache.hadoop.util.ReflectionUtils -import org.apache.spark.{Logging, Partition, SerializableWritable, SparkContext, SparkEnv, TaskContext} +import org.apache.spark.{Logging, Partition, SerializableWritable, SparkContext, SparkEnv, + TaskContext} +import org.apache.spark.broadcast.Broadcast import org.apache.spark.util.NextIterator import org.apache.hadoop.conf.{Configuration, Configurable} +/** + * An RDD that reads a file (or multiple files) from Hadoop (e.g. files in HDFS, the local file + * system, or S3). + * This accepts a general, broadcasted Hadoop Configuration because those tend to remain the same + * across multiple reads; the 'path' is the only variable that is different acrodd new JobConfs + * created from the Configuration. + */ +class HadoopFileRDD[K, V]( + sc: SparkContext, + path: String, + hadoopConfBroadcast: Broadcast[SerializableWritable[Configuration]], + inputFormatClass: Class[_ <: InputFormat[K, V]], + keyClass: Class[K], + valueClass: Class[V], + minSplits: Int) + extends HadoopRDD[K, V](sc, inputFormatClass, keyClass, valueClass, minSplits) { + + private val localJobConf: JobConf = { + val jobConf = new JobConf(hadoopConfBroadcast.value.value) + FileInputFormat.setInputPaths(jobConf, path) + jobConf + } + + override def getJobConf: JobConf = localJobConf +} + +/** + * An RDD that reads a Hadoop dataset as specified by a JobConf (e.g. tables in HBase). + */ +class HadoopDatasetRDD[K, V]( + sc: SparkContext, + @transient conf: JobConf, + inputFormatClass: Class[_ <: InputFormat[K, V]], + keyClass: Class[K], + valueClass: Class[V], + minSplits: Int) + extends HadoopRDD[K, V](sc, inputFormatClass, keyClass, valueClass, minSplits) { + + // A Hadoop JobConf can be about 10 KB, which is pretty big, so broadcast it. + private val confBroadcast = sc.broadcast(new SerializableWritable(conf)) + + override def getJobConf: JobConf = confBroadcast.value.value +} /** * A Spark split class that wraps around a Hadoop InputSplit. @@ -45,29 +91,30 @@ private[spark] class HadoopPartition(rddId: Int, idx: Int, @transient s: InputSp } /** - * An RDD that reads a Hadoop dataset as specified by a JobConf (e.g. files in HDFS, the local file - * system, or S3, tables in HBase, etc). + * A base class that provides core functionality for reading data partitions stored in Hadoop. */ -class HadoopRDD[K, V]( +abstract class HadoopRDD[K, V]( sc: SparkContext, - @transient conf: JobConf, inputFormatClass: Class[_ <: InputFormat[K, V]], keyClass: Class[K], valueClass: Class[V], minSplits: Int) extends RDD[(K, V)](sc, Nil) with Logging { - // A Hadoop JobConf can be about 10 KB, which is pretty big, so broadcast it - private val confBroadcast = sc.broadcast(new SerializableWritable(conf)) + // The JobConf used to obtain input splits for Hadoop reads. The subclass is responsible for + // determining how the JobConf is initialized. + protected def getJobConf: JobConf + + def getConf: Configuration = getJobConf override def getPartitions: Array[Partition] = { val env = SparkEnv.get - env.hadoop.addCredentials(conf) - val inputFormat = createInputFormat(conf) + env.hadoop.addCredentials(getJobConf) + val inputFormat = createInputFormat(getJobConf) if (inputFormat.isInstanceOf[Configurable]) { - inputFormat.asInstanceOf[Configurable].setConf(conf) + inputFormat.asInstanceOf[Configurable].setConf(getJobConf) } - val inputSplits = inputFormat.getSplits(conf, minSplits) + val inputSplits = inputFormat.getSplits(getJobConf, minSplits) val array = new Array[Partition](inputSplits.size) for (i <- 0 until inputSplits.size) { array(i) = new HadoopPartition(id, i, inputSplits(i)) @@ -85,12 +132,11 @@ class HadoopRDD[K, V]( logInfo("Input split: " + split.inputSplit) var reader: RecordReader[K, V] = null - val conf = confBroadcast.value.value - val fmt = createInputFormat(conf) + val fmt = createInputFormat(getJobConf) if (fmt.isInstanceOf[Configurable]) { - fmt.asInstanceOf[Configurable].setConf(conf) + fmt.asInstanceOf[Configurable].setConf(getJobConf) } - reader = fmt.getRecordReader(split.inputSplit.value, conf, Reporter.NULL) + reader = fmt.getRecordReader(split.inputSplit.value, getJobConf, Reporter.NULL) // Register an on-task-completion callback to close the input stream. context.addOnCompleteCallback{ () => closeIfNeeded() } @@ -126,6 +172,4 @@ class HadoopRDD[K, V]( override def checkpoint() { // Do nothing. Hadoop RDD should not be checkpointed. } - - def getConf: Configuration = confBroadcast.value.value } |