aboutsummaryrefslogtreecommitdiff
path: root/sql/hive/src/main/scala/org/apache/hadoop/mapred/SparkHadoopWriter.scala
diff options
context:
space:
mode:
Diffstat (limited to 'sql/hive/src/main/scala/org/apache/hadoop/mapred/SparkHadoopWriter.scala')
-rw-r--r--sql/hive/src/main/scala/org/apache/hadoop/mapred/SparkHadoopWriter.scala198
1 files changed, 198 insertions, 0 deletions
diff --git a/sql/hive/src/main/scala/org/apache/hadoop/mapred/SparkHadoopWriter.scala b/sql/hive/src/main/scala/org/apache/hadoop/mapred/SparkHadoopWriter.scala
new file mode 100644
index 0000000000..08d390e887
--- /dev/null
+++ b/sql/hive/src/main/scala/org/apache/hadoop/mapred/SparkHadoopWriter.scala
@@ -0,0 +1,198 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.mapred
+
+import java.io.IOException
+import java.text.NumberFormat
+import java.util.Date
+
+import org.apache.hadoop.fs.Path
+import org.apache.hadoop.io.Writable
+
+import org.apache.spark.Logging
+import org.apache.spark.SerializableWritable
+
+import org.apache.hadoop.hive.ql.exec.{Utilities, FileSinkOperator}
+import org.apache.hadoop.hive.ql.io.{HiveFileFormatUtils, HiveOutputFormat}
+import org.apache.hadoop.hive.ql.plan.FileSinkDesc
+
+/**
+ * Internal helper class that saves an RDD using a Hive OutputFormat.
+ * It is based on [[SparkHadoopWriter]].
+ */
+protected[apache]
+class SparkHiveHadoopWriter(
+ @transient jobConf: JobConf,
+ fileSinkConf: FileSinkDesc)
+ extends Logging
+ with SparkHadoopMapRedUtil
+ with Serializable {
+
+ private val now = new Date()
+ private val conf = new SerializableWritable(jobConf)
+
+ private var jobID = 0
+ private var splitID = 0
+ private var attemptID = 0
+ private var jID: SerializableWritable[JobID] = null
+ private var taID: SerializableWritable[TaskAttemptID] = null
+
+ @transient private var writer: FileSinkOperator.RecordWriter = null
+ @transient private var format: HiveOutputFormat[AnyRef, Writable] = null
+ @transient private var committer: OutputCommitter = null
+ @transient private var jobContext: JobContext = null
+ @transient private var taskContext: TaskAttemptContext = null
+
+ def preSetup() {
+ setIDs(0, 0, 0)
+ setConfParams()
+
+ val jCtxt = getJobContext()
+ getOutputCommitter().setupJob(jCtxt)
+ }
+
+
+ def setup(jobid: Int, splitid: Int, attemptid: Int) {
+ setIDs(jobid, splitid, attemptid)
+ setConfParams()
+ }
+
+ def open() {
+ val numfmt = NumberFormat.getInstance()
+ numfmt.setMinimumIntegerDigits(5)
+ numfmt.setGroupingUsed(false)
+
+ val extension = Utilities.getFileExtension(
+ conf.value,
+ fileSinkConf.getCompressed,
+ getOutputFormat())
+
+ val outputName = "part-" + numfmt.format(splitID) + extension
+ val path = FileOutputFormat.getTaskOutputPath(conf.value, outputName)
+
+ getOutputCommitter().setupTask(getTaskContext())
+ writer = HiveFileFormatUtils.getHiveRecordWriter(
+ conf.value,
+ fileSinkConf.getTableInfo,
+ conf.value.getOutputValueClass.asInstanceOf[Class[Writable]],
+ fileSinkConf,
+ path,
+ null)
+ }
+
+ def write(value: Writable) {
+ if (writer != null) {
+ writer.write(value)
+ } else {
+ throw new IOException("Writer is null, open() has not been called")
+ }
+ }
+
+ def close() {
+ // Seems the boolean value passed into close does not matter.
+ writer.close(false)
+ }
+
+ def commit() {
+ val taCtxt = getTaskContext()
+ val cmtr = getOutputCommitter()
+ if (cmtr.needsTaskCommit(taCtxt)) {
+ try {
+ cmtr.commitTask(taCtxt)
+ logInfo (taID + ": Committed")
+ } catch {
+ case e: IOException => {
+ logError("Error committing the output of task: " + taID.value, e)
+ cmtr.abortTask(taCtxt)
+ throw e
+ }
+ }
+ } else {
+ logWarning ("No need to commit output of task: " + taID.value)
+ }
+ }
+
+ def commitJob() {
+ // always ? Or if cmtr.needsTaskCommit ?
+ val cmtr = getOutputCommitter()
+ cmtr.commitJob(getJobContext())
+ }
+
+ // ********* Private Functions *********
+
+ private def getOutputFormat(): HiveOutputFormat[AnyRef,Writable] = {
+ if (format == null) {
+ format = conf.value.getOutputFormat()
+ .asInstanceOf[HiveOutputFormat[AnyRef,Writable]]
+ }
+ format
+ }
+
+ private def getOutputCommitter(): OutputCommitter = {
+ if (committer == null) {
+ committer = conf.value.getOutputCommitter
+ }
+ committer
+ }
+
+ private def getJobContext(): JobContext = {
+ if (jobContext == null) {
+ jobContext = newJobContext(conf.value, jID.value)
+ }
+ jobContext
+ }
+
+ private def getTaskContext(): TaskAttemptContext = {
+ if (taskContext == null) {
+ taskContext = newTaskAttemptContext(conf.value, taID.value)
+ }
+ taskContext
+ }
+
+ private def setIDs(jobid: Int, splitid: Int, attemptid: Int) {
+ jobID = jobid
+ splitID = splitid
+ attemptID = attemptid
+
+ jID = new SerializableWritable[JobID](SparkHadoopWriter.createJobID(now, jobid))
+ taID = new SerializableWritable[TaskAttemptID](
+ new TaskAttemptID(new TaskID(jID.value, true, splitID), attemptID))
+ }
+
+ private def setConfParams() {
+ conf.value.set("mapred.job.id", jID.value.toString)
+ conf.value.set("mapred.tip.id", taID.value.getTaskID.toString)
+ conf.value.set("mapred.task.id", taID.value.toString)
+ conf.value.setBoolean("mapred.task.is.map", true)
+ conf.value.setInt("mapred.task.partition", splitID)
+ }
+}
+
+object SparkHiveHadoopWriter {
+ def createPathFromString(path: String, conf: JobConf): Path = {
+ if (path == null) {
+ throw new IllegalArgumentException("Output path is null")
+ }
+ val outputPath = new Path(path)
+ val fs = outputPath.getFileSystem(conf)
+ if (outputPath == null || fs == null) {
+ throw new IllegalArgumentException("Incorrectly formatted output path")
+ }
+ outputPath.makeQualified(fs)
+ }
+}