diff options
Diffstat (limited to 'sql/hive')
4 files changed, 24 insertions, 34 deletions
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala index 1af3280e18..1ceacb458a 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala @@ -83,11 +83,11 @@ class OrcFileFormat extends FileFormat with DataSourceRegister with Serializable new OutputWriterFactory { override def newInstance( - path: String, - bucketId: Option[Int], + stagingDir: String, + fileNamePrefix: String, dataSchema: StructType, context: TaskAttemptContext): OutputWriter = { - new OrcOutputWriter(path, bucketId, dataSchema, context) + new OrcOutputWriter(stagingDir, fileNamePrefix, dataSchema, context) } } } @@ -210,8 +210,8 @@ private[orc] class OrcSerializer(dataSchema: StructType, conf: Configuration) } private[orc] class OrcOutputWriter( - path: String, - bucketId: Option[Int], + stagingDir: String, + fileNamePrefix: String, dataSchema: StructType, context: TaskAttemptContext) extends OutputWriter { @@ -226,10 +226,7 @@ private[orc] class OrcOutputWriter( private lazy val recordWriter: RecordWriter[NullWritable, Writable] = { recordWriterInstantiated = true - val uniqueWriteJobId = conf.get(WriterContainer.DATASOURCE_WRITEJOBUUID) - val taskAttemptId = context.getTaskAttemptID - val partition = taskAttemptId.getTaskID.getId - val bucketString = bucketId.map(BucketingUtils.bucketIdToString).getOrElse("") + val compressionExtension = { val name = conf.get(OrcRelation.ORC_COMPRESSION) OrcRelation.extensionsForCompressionCodecNames.getOrElse(name, "") @@ -237,12 +234,12 @@ private[orc] class OrcOutputWriter( // It has the `.orc` extension at the end because (de)compression tools // such as gunzip would not be able to decompress this as the compression // is not applied on this whole file but on each "stream" in ORC format. - val filename = f"part-r-$partition%05d-$uniqueWriteJobId$bucketString$compressionExtension.orc" + val filename = s"$fileNamePrefix$compressionExtension.orc" new OrcOutputFormat().getRecordWriter( - new Path(path, filename).getFileSystem(conf), + new Path(stagingDir, filename).getFileSystem(conf), conf.asInstanceOf[JobConf], - new Path(path, filename).toString, + new Path(stagingDir, filename).toString, Reporter.NULL ).asInstanceOf[RecordWriter[NullWritable, Writable]] } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala index 997445114b..2eafe18b85 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala @@ -54,11 +54,6 @@ class BucketedWriteSuite extends QueryTest with SQLTestUtils with TestHiveSingle intercept[AnalysisException](df.write.bucketBy(2, "i").sortBy("j").saveAsTable("tt")) } - test("write bucketed data to unsupported data source") { - val df = Seq(Tuple1("a"), Tuple1("b")).toDF("i") - intercept[SparkException](df.write.bucketBy(3, "i").format("text").saveAsTable("tt")) - } - test("write bucketed data using save()") { val df = Seq(1 -> "a", 2 -> "b").toDF("i", "j") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/CommitFailureTestSource.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/CommitFailureTestSource.scala index 5a8a7f0ab5..d504468402 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/CommitFailureTestSource.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/CommitFailureTestSource.scala @@ -39,11 +39,11 @@ class CommitFailureTestSource extends SimpleTextSource { dataSchema: StructType): OutputWriterFactory = new OutputWriterFactory { override def newInstance( - path: String, - bucketId: Option[Int], + stagingDir: String, + fileNamePrefix: String, dataSchema: StructType, context: TaskAttemptContext): OutputWriter = { - new SimpleTextOutputWriter(path, context) { + new SimpleTextOutputWriter(stagingDir, fileNamePrefix, context) { var failed = false TaskContext.get().addTaskFailureListener { (t: TaskContext, e: Throwable) => failed = true diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala index 906de6bbcb..9e13b217ec 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala @@ -23,7 +23,7 @@ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.hadoop.io.{NullWritable, Text} import org.apache.hadoop.mapreduce.{Job, RecordWriter, TaskAttemptContext} -import org.apache.hadoop.mapreduce.lib.output.{FileOutputFormat, TextOutputFormat} +import org.apache.hadoop.mapreduce.lib.output.TextOutputFormat import org.apache.spark.sql.{sources, Row, SparkSession} import org.apache.spark.sql.catalyst.{expressions, InternalRow} @@ -51,11 +51,11 @@ class SimpleTextSource extends TextBasedFileFormat with DataSourceRegister { SimpleTextRelation.lastHadoopConf = Option(job.getConfiguration) new OutputWriterFactory { override def newInstance( - path: String, - bucketId: Option[Int], + stagingDir: String, + fileNamePrefix: String, dataSchema: StructType, context: TaskAttemptContext): OutputWriter = { - new SimpleTextOutputWriter(path, context) + new SimpleTextOutputWriter(stagingDir, fileNamePrefix, context) } } } @@ -120,9 +120,11 @@ class SimpleTextSource extends TextBasedFileFormat with DataSourceRegister { } } -class SimpleTextOutputWriter(path: String, context: TaskAttemptContext) extends OutputWriter { +class SimpleTextOutputWriter( + stagingDir: String, fileNamePrefix: String, context: TaskAttemptContext) + extends OutputWriter { private val recordWriter: RecordWriter[NullWritable, Text] = - new AppendingTextOutputFormat(new Path(path)).getRecordWriter(context) + new AppendingTextOutputFormat(new Path(stagingDir), fileNamePrefix).getRecordWriter(context) override def write(row: Row): Unit = { val serialized = row.toSeq.map { v => @@ -136,19 +138,15 @@ class SimpleTextOutputWriter(path: String, context: TaskAttemptContext) extends } } -class AppendingTextOutputFormat(outputFile: Path) extends TextOutputFormat[NullWritable, Text] { - val numberFormat = NumberFormat.getInstance() +class AppendingTextOutputFormat(stagingDir: Path, fileNamePrefix: String) + extends TextOutputFormat[NullWritable, Text] { + val numberFormat = NumberFormat.getInstance() numberFormat.setMinimumIntegerDigits(5) numberFormat.setGroupingUsed(false) override def getDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = { - val configuration = context.getConfiguration - val uniqueWriteJobId = configuration.get(WriterContainer.DATASOURCE_WRITEJOBUUID) - val taskAttemptId = context.getTaskAttemptID - val split = taskAttemptId.getTaskID.getId - val name = FileOutputFormat.getOutputName(context) - new Path(outputFile, s"$name-${numberFormat.format(split)}-$uniqueWriteJobId") + new Path(stagingDir, fileNamePrefix) } } |