aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala6
-rw-r--r--core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala75
2 files changed, 80 insertions, 1 deletions
diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
index 39c3a4996c..d29a1a9881 100644
--- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
@@ -29,7 +29,7 @@ import scala.collection.mutable.ArrayBuffer
import scala.reflect.ClassTag
import com.clearspring.analytics.stream.cardinality.HyperLogLog
-import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.conf.{Configurable, Configuration}
import org.apache.hadoop.fs.Path
import org.apache.hadoop.io.SequenceFile.CompressionType
import org.apache.hadoop.io.compress.CompressionCodec
@@ -618,6 +618,10 @@ class PairRDDFunctions[K: ClassTag, V: ClassTag](self: RDD[(K, V)])
attemptNumber)
val hadoopContext = newTaskAttemptContext(wrappedConf.value, attemptId)
val format = outputFormatClass.newInstance
+ format match {
+ case c: Configurable => c.setConf(wrappedConf.value)
+ case _ => ()
+ }
val committer = format.getOutputCommitter(hadoopContext)
committer.setupTask(hadoopContext)
val writer = format.getRecordWriter(hadoopContext).asInstanceOf[NewRecordWriter[K,V]]
diff --git a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala
index fa5c9b10fe..e3e23775f0 100644
--- a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala
@@ -23,6 +23,8 @@ import scala.util.Random
import org.scalatest.FunSuite
import com.google.common.io.Files
+import org.apache.hadoop.mapreduce._
+import org.apache.hadoop.conf.{Configuration, Configurable}
import org.apache.spark.SparkContext._
import org.apache.spark.{Partitioner, SharedSparkContext}
@@ -330,4 +332,77 @@ class PairRDDFunctionsSuite extends FunSuite with SharedSparkContext {
(1, ArrayBuffer(1)),
(2, ArrayBuffer(1))))
}
+
+ test("saveNewAPIHadoopFile should call setConf if format is configurable") {
+ val pairs = sc.parallelize(Array((new Integer(1), new Integer(1))))
+
+ // No error, non-configurable formats still work
+ pairs.saveAsNewAPIHadoopFile[FakeFormat]("ignored")
+
+ /*
+ Check that configurable formats get configured:
+ ConfigTestFormat throws an exception if we try to write
+ to it when setConf hasn't been called first.
+ Assertion is in ConfigTestFormat.getRecordWriter.
+ */
+ pairs.saveAsNewAPIHadoopFile[ConfigTestFormat]("ignored")
+ }
}
+
+/*
+ These classes are fakes for testing
+ "saveNewAPIHadoopFile should call setConf if format is configurable".
+ Unfortunately, they have to be top level classes, and not defined in
+ the test method, because otherwise Scala won't generate no-args constructors
+ and the test will therefore throw InstantiationException when saveAsNewAPIHadoopFile
+ tries to instantiate them with Class.newInstance.
+ */
+class FakeWriter extends RecordWriter[Integer, Integer] {
+
+ def close(p1: TaskAttemptContext) = ()
+
+ def write(p1: Integer, p2: Integer) = ()
+
+}
+
+class FakeCommitter extends OutputCommitter {
+ def setupJob(p1: JobContext) = ()
+
+ def needsTaskCommit(p1: TaskAttemptContext): Boolean = false
+
+ def setupTask(p1: TaskAttemptContext) = ()
+
+ def commitTask(p1: TaskAttemptContext) = ()
+
+ def abortTask(p1: TaskAttemptContext) = ()
+}
+
+class FakeFormat() extends OutputFormat[Integer, Integer]() {
+
+ def checkOutputSpecs(p1: JobContext) = ()
+
+ def getRecordWriter(p1: TaskAttemptContext): RecordWriter[Integer, Integer] = {
+ new FakeWriter()
+ }
+
+ def getOutputCommitter(p1: TaskAttemptContext): OutputCommitter = {
+ new FakeCommitter()
+ }
+}
+
+class ConfigTestFormat() extends FakeFormat() with Configurable {
+
+ var setConfCalled = false
+ def setConf(p1: Configuration) = {
+ setConfCalled = true
+ ()
+ }
+
+ def getConf: Configuration = null
+
+ override def getRecordWriter(p1: TaskAttemptContext): RecordWriter[Integer, Integer] = {
+ assert(setConfCalled, "setConf was never called")
+ super.getRecordWriter(p1)
+ }
+}
+