aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala34
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/InputFileBlockHolder.scala77
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/InputFileNameHolder.scala49
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala43
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala49
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/inputFileBlock.scala94
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala9
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala61
10 files changed, 268 insertions, 152 deletions
diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
index ae4320d458..e3d81a6be5 100644
--- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
@@ -131,9 +131,9 @@ class HadoopRDD[K, V](
minPartitions)
}
- protected val jobConfCacheKey = "rdd_%d_job_conf".format(id)
+ protected val jobConfCacheKey: String = "rdd_%d_job_conf".format(id)
- protected val inputFormatCacheKey = "rdd_%d_input_format".format(id)
+ protected val inputFormatCacheKey: String = "rdd_%d_input_format".format(id)
// used to build JobTracker ID
private val createTime = new Date()
@@ -210,22 +210,24 @@ class HadoopRDD[K, V](
override def compute(theSplit: Partition, context: TaskContext): InterruptibleIterator[(K, V)] = {
val iter = new NextIterator[(K, V)] {
- val split = theSplit.asInstanceOf[HadoopPartition]
+ private val split = theSplit.asInstanceOf[HadoopPartition]
logInfo("Input split: " + split.inputSplit)
- val jobConf = getJobConf()
+ private val jobConf = getJobConf()
- val inputMetrics = context.taskMetrics().inputMetrics
- val existingBytesRead = inputMetrics.bytesRead
+ private val inputMetrics = context.taskMetrics().inputMetrics
+ private val existingBytesRead = inputMetrics.bytesRead
- // Sets the thread local variable for the file's name
+ // Sets InputFileBlockHolder for the file block's information
split.inputSplit.value match {
- case fs: FileSplit => InputFileNameHolder.setInputFileName(fs.getPath.toString)
- case _ => InputFileNameHolder.unsetInputFileName()
+ case fs: FileSplit =>
+ InputFileBlockHolder.set(fs.getPath.toString, fs.getStart, fs.getLength)
+ case _ =>
+ InputFileBlockHolder.unset()
}
// Find a function that will return the FileSystem bytes read by this thread. Do this before
// creating RecordReader, because RecordReader's constructor might read some bytes
- val getBytesReadCallback: Option[() => Long] = split.inputSplit.value match {
+ private val getBytesReadCallback: Option[() => Long] = split.inputSplit.value match {
case _: FileSplit | _: CombineFileSplit =>
SparkHadoopUtil.get.getFSBytesReadOnThreadCallback()
case _ => None
@@ -235,14 +237,14 @@ class HadoopRDD[K, V](
// If we do a coalesce, however, we are likely to compute multiple partitions in the same
// task and in the same thread, in which case we need to avoid override values written by
// previous partitions (SPARK-13071).
- def updateBytesRead(): Unit = {
+ private def updateBytesRead(): Unit = {
getBytesReadCallback.foreach { getBytesRead =>
inputMetrics.setBytesRead(existingBytesRead + getBytesRead())
}
}
- var reader: RecordReader[K, V] = null
- val inputFormat = getInputFormat(jobConf)
+ private var reader: RecordReader[K, V] = null
+ private val inputFormat = getInputFormat(jobConf)
HadoopRDD.addLocalConfiguration(
new SimpleDateFormat("yyyyMMddHHmmss", Locale.US).format(createTime),
context.stageId, theSplit.index, context.attemptNumber, jobConf)
@@ -250,8 +252,8 @@ class HadoopRDD[K, V](
// Register an on-task-completion callback to close the input stream.
context.addTaskCompletionListener{ context => closeIfNeeded() }
- val key: K = reader.createKey()
- val value: V = reader.createValue()
+ private val key: K = reader.createKey()
+ private val value: V = reader.createValue()
override def getNext(): (K, V) = {
try {
@@ -270,7 +272,7 @@ class HadoopRDD[K, V](
override def close() {
if (reader != null) {
- InputFileNameHolder.unsetInputFileName()
+ InputFileBlockHolder.unset()
// Close the reader and release it. Note: it's very important that we don't close the
// reader more than once, since that exposes us to MAPREDUCE-5918 when running against
// Hadoop 1.x and older Hadoop 2.x releases. That bug can lead to non-deterministic
diff --git a/core/src/main/scala/org/apache/spark/rdd/InputFileBlockHolder.scala b/core/src/main/scala/org/apache/spark/rdd/InputFileBlockHolder.scala
new file mode 100644
index 0000000000..9ba476d2ba
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/rdd/InputFileBlockHolder.scala
@@ -0,0 +1,77 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.rdd
+
+import org.apache.spark.unsafe.types.UTF8String
+
+/**
+ * This holds file names of the current Spark task. This is used in HadoopRDD,
+ * FileScanRDD, NewHadoopRDD and InputFileName function in Spark SQL.
+ */
+private[spark] object InputFileBlockHolder {
+ /**
+ * A wrapper around some input file information.
+ *
+ * @param filePath path of the file read, or empty string if not available.
+ * @param startOffset starting offset, in bytes, or -1 if not available.
+ * @param length size of the block, in bytes, or -1 if not available.
+ */
+ private class FileBlock(val filePath: UTF8String, val startOffset: Long, val length: Long) {
+ def this() {
+ this(UTF8String.fromString(""), -1, -1)
+ }
+ }
+
+ /**
+ * The thread variable for the name of the current file being read. This is used by
+ * the InputFileName function in Spark SQL.
+ */
+ private[this] val inputBlock: ThreadLocal[FileBlock] = new ThreadLocal[FileBlock] {
+ override protected def initialValue(): FileBlock = new FileBlock
+ }
+
+ /**
+ * Returns the holding file name or empty string if it is unknown.
+ */
+ def getInputFilePath: UTF8String = inputBlock.get().filePath
+
+ /**
+ * Returns the starting offset of the block currently being read, or -1 if it is unknown.
+ */
+ def getStartOffset: Long = inputBlock.get().startOffset
+
+ /**
+ * Returns the length of the block being read, or -1 if it is unknown.
+ */
+ def getLength: Long = inputBlock.get().length
+
+ /**
+ * Sets the thread-local input block.
+ */
+ def set(filePath: String, startOffset: Long, length: Long): Unit = {
+ require(filePath != null, "filePath cannot be null")
+ require(startOffset >= 0, s"startOffset ($startOffset) cannot be negative")
+ require(length >= 0, s"length ($length) cannot be negative")
+ inputBlock.set(new FileBlock(UTF8String.fromString(filePath), startOffset, length))
+ }
+
+ /**
+ * Clears the input file block to default value.
+ */
+ def unset(): Unit = inputBlock.remove()
+}
diff --git a/core/src/main/scala/org/apache/spark/rdd/InputFileNameHolder.scala b/core/src/main/scala/org/apache/spark/rdd/InputFileNameHolder.scala
deleted file mode 100644
index 960c91a154..0000000000
--- a/core/src/main/scala/org/apache/spark/rdd/InputFileNameHolder.scala
+++ /dev/null
@@ -1,49 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.rdd
-
-import org.apache.spark.unsafe.types.UTF8String
-
-/**
- * This holds file names of the current Spark task. This is used in HadoopRDD,
- * FileScanRDD, NewHadoopRDD and InputFileName function in Spark SQL.
- *
- * The returned value should never be null but empty string if it is unknown.
- */
-private[spark] object InputFileNameHolder {
- /**
- * The thread variable for the name of the current file being read. This is used by
- * the InputFileName function in Spark SQL.
- */
- private[this] val inputFileName: ThreadLocal[UTF8String] = new ThreadLocal[UTF8String] {
- override protected def initialValue(): UTF8String = UTF8String.fromString("")
- }
-
- /**
- * Returns the holding file name or empty string if it is unknown.
- */
- def getInputFileName(): UTF8String = inputFileName.get()
-
- private[spark] def setInputFileName(file: String) = {
- require(file != null, "The input file name cannot be null")
- inputFileName.set(UTF8String.fromString(file))
- }
-
- private[spark] def unsetInputFileName(): Unit = inputFileName.remove()
-
-}
diff --git a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala
index c783e13752..e90e84c459 100644
--- a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala
@@ -132,54 +132,57 @@ class NewHadoopRDD[K, V](
override def compute(theSplit: Partition, context: TaskContext): InterruptibleIterator[(K, V)] = {
val iter = new Iterator[(K, V)] {
- val split = theSplit.asInstanceOf[NewHadoopPartition]
+ private val split = theSplit.asInstanceOf[NewHadoopPartition]
logInfo("Input split: " + split.serializableHadoopSplit)
- val conf = getConf
+ private val conf = getConf
- val inputMetrics = context.taskMetrics().inputMetrics
- val existingBytesRead = inputMetrics.bytesRead
+ private val inputMetrics = context.taskMetrics().inputMetrics
+ private val existingBytesRead = inputMetrics.bytesRead
- // Sets the thread local variable for the file's name
+ // Sets InputFileBlockHolder for the file block's information
split.serializableHadoopSplit.value match {
- case fs: FileSplit => InputFileNameHolder.setInputFileName(fs.getPath.toString)
- case _ => InputFileNameHolder.unsetInputFileName()
+ case fs: FileSplit =>
+ InputFileBlockHolder.set(fs.getPath.toString, fs.getStart, fs.getLength)
+ case _ =>
+ InputFileBlockHolder.unset()
}
// Find a function that will return the FileSystem bytes read by this thread. Do this before
// creating RecordReader, because RecordReader's constructor might read some bytes
- val getBytesReadCallback: Option[() => Long] = split.serializableHadoopSplit.value match {
- case _: FileSplit | _: CombineFileSplit =>
- SparkHadoopUtil.get.getFSBytesReadOnThreadCallback()
- case _ => None
- }
+ private val getBytesReadCallback: Option[() => Long] =
+ split.serializableHadoopSplit.value match {
+ case _: FileSplit | _: CombineFileSplit =>
+ SparkHadoopUtil.get.getFSBytesReadOnThreadCallback()
+ case _ => None
+ }
// For Hadoop 2.5+, we get our input bytes from thread-local Hadoop FileSystem statistics.
// If we do a coalesce, however, we are likely to compute multiple partitions in the same
// task and in the same thread, in which case we need to avoid override values written by
// previous partitions (SPARK-13071).
- def updateBytesRead(): Unit = {
+ private def updateBytesRead(): Unit = {
getBytesReadCallback.foreach { getBytesRead =>
inputMetrics.setBytesRead(existingBytesRead + getBytesRead())
}
}
- val format = inputFormatClass.newInstance
+ private val format = inputFormatClass.newInstance
format match {
case configurable: Configurable =>
configurable.setConf(conf)
case _ =>
}
- val attemptId = new TaskAttemptID(jobTrackerId, id, TaskType.MAP, split.index, 0)
- val hadoopAttemptContext = new TaskAttemptContextImpl(conf, attemptId)
+ private val attemptId = new TaskAttemptID(jobTrackerId, id, TaskType.MAP, split.index, 0)
+ private val hadoopAttemptContext = new TaskAttemptContextImpl(conf, attemptId)
private var reader = format.createRecordReader(
split.serializableHadoopSplit.value, hadoopAttemptContext)
reader.initialize(split.serializableHadoopSplit.value, hadoopAttemptContext)
// Register an on-task-completion callback to close the input stream.
context.addTaskCompletionListener(context => close())
- var havePair = false
- var finished = false
- var recordsSinceMetricsUpdate = 0
+ private var havePair = false
+ private var finished = false
+ private var recordsSinceMetricsUpdate = 0
override def hasNext: Boolean = {
if (!finished && !havePair) {
@@ -215,7 +218,7 @@ class NewHadoopRDD[K, V](
private def close() {
if (reader != null) {
- InputFileNameHolder.unsetInputFileName()
+ InputFileBlockHolder.unset()
// Close the reader and release it. Note: it's very important that we don't close the
// reader more than once, since that exposes us to MAPREDUCE-5918 when running against
// Hadoop 1.x and older Hadoop 2.x releases. That bug can lead to non-deterministic
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index e41f1cad93..5d065d736e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -371,6 +371,8 @@ object FunctionRegistry {
expression[Sha2]("sha2"),
expression[SparkPartitionID]("spark_partition_id"),
expression[InputFileName]("input_file_name"),
+ expression[InputFileBlockStart]("input_file_block_start"),
+ expression[InputFileBlockLength]("input_file_block_length"),
expression[MonotonicallyIncreasingID]("monotonically_increasing_id"),
expression[CurrentDatabase]("current_database"),
expression[CallMethodViaReflection]("reflect"),
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala
deleted file mode 100644
index d412336699..0000000000
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala
+++ /dev/null
@@ -1,49 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.catalyst.expressions
-
-import org.apache.spark.rdd.InputFileNameHolder
-import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
-import org.apache.spark.sql.types.{DataType, StringType}
-import org.apache.spark.unsafe.types.UTF8String
-
-/**
- * Expression that returns the name of the current file being read.
- */
-@ExpressionDescription(
- usage = "_FUNC_() - Returns the name of the current file being read if available.")
-case class InputFileName() extends LeafExpression with Nondeterministic {
-
- override def nullable: Boolean = false
-
- override def dataType: DataType = StringType
-
- override def prettyName: String = "input_file_name"
-
- override protected def initializeInternal(partitionIndex: Int): Unit = {}
-
- override protected def evalInternal(input: InternalRow): UTF8String = {
- InputFileNameHolder.getInputFileName()
- }
-
- override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
- ev.copy(code = s"final ${ctx.javaType(dataType)} ${ev.value} = " +
- "org.apache.spark.rdd.InputFileNameHolder.getInputFileName();", isNull = "false")
- }
-}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/inputFileBlock.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/inputFileBlock.scala
new file mode 100644
index 0000000000..7a8edabed1
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/inputFileBlock.scala
@@ -0,0 +1,94 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions
+
+import org.apache.spark.rdd.InputFileBlockHolder
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
+import org.apache.spark.sql.types.{DataType, LongType, StringType}
+import org.apache.spark.unsafe.types.UTF8String
+
+
+@ExpressionDescription(
+ usage = "_FUNC_() - Returns the name of the file being read, or empty string if not available.")
+case class InputFileName() extends LeafExpression with Nondeterministic {
+
+ override def nullable: Boolean = false
+
+ override def dataType: DataType = StringType
+
+ override def prettyName: String = "input_file_name"
+
+ override protected def initializeInternal(partitionIndex: Int): Unit = {}
+
+ override protected def evalInternal(input: InternalRow): UTF8String = {
+ InputFileBlockHolder.getInputFilePath
+ }
+
+ override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
+ val className = InputFileBlockHolder.getClass.getName.stripSuffix("$")
+ ev.copy(code = s"final ${ctx.javaType(dataType)} ${ev.value} = " +
+ s"$className.getInputFilePath();", isNull = "false")
+ }
+}
+
+
+@ExpressionDescription(
+ usage = "_FUNC_() - Returns the start offset of the block being read, or -1 if not available.")
+case class InputFileBlockStart() extends LeafExpression with Nondeterministic {
+ override def nullable: Boolean = false
+
+ override def dataType: DataType = LongType
+
+ override def prettyName: String = "input_file_block_start"
+
+ override protected def initializeInternal(partitionIndex: Int): Unit = {}
+
+ override protected def evalInternal(input: InternalRow): Long = {
+ InputFileBlockHolder.getStartOffset
+ }
+
+ override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
+ val className = InputFileBlockHolder.getClass.getName.stripSuffix("$")
+ ev.copy(code = s"final ${ctx.javaType(dataType)} ${ev.value} = " +
+ s"$className.getStartOffset();", isNull = "false")
+ }
+}
+
+
+@ExpressionDescription(
+ usage = "_FUNC_() - Returns the length of the block being read, or -1 if not available.")
+case class InputFileBlockLength() extends LeafExpression with Nondeterministic {
+ override def nullable: Boolean = false
+
+ override def dataType: DataType = LongType
+
+ override def prettyName: String = "input_file_block_length"
+
+ override protected def initializeInternal(partitionIndex: Int): Unit = {}
+
+ override protected def evalInternal(input: InternalRow): Long = {
+ InputFileBlockHolder.getLength
+ }
+
+ override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
+ val className = InputFileBlockHolder.getClass.getName.stripSuffix("$")
+ ev.copy(code = s"final ${ctx.javaType(dataType)} ${ev.value} = " +
+ s"$className.getLength();", isNull = "false")
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala
index f47eb84df0..eaccdf27da 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala
@@ -84,7 +84,7 @@ case class DataSource(
case class SourceInfo(name: String, schema: StructType, partitionColumns: Seq[String])
lazy val providingClass: Class[_] = DataSource.lookupDataSource(className)
- lazy val sourceInfo = sourceSchema()
+ lazy val sourceInfo: SourceInfo = sourceSchema()
private val caseInsensitiveOptions = new CaseInsensitiveMap(options)
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala
index 89944570df..306dc6527e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala
@@ -23,7 +23,7 @@ import scala.collection.mutable
import org.apache.spark.{Partition => RDDPartition, TaskContext}
import org.apache.spark.deploy.SparkHadoopUtil
-import org.apache.spark.rdd.{InputFileNameHolder, RDD}
+import org.apache.spark.rdd.{InputFileBlockHolder, RDD}
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.execution.vectorized.ColumnarBatch
@@ -121,7 +121,8 @@ class FileScanRDD(
if (files.hasNext) {
currentFile = files.next()
logInfo(s"Reading File $currentFile")
- InputFileNameHolder.setInputFileName(currentFile.filePath)
+ // Sets InputFileBlockHolder for the file block's information
+ InputFileBlockHolder.set(currentFile.filePath, currentFile.start, currentFile.length)
try {
if (ignoreCorruptFiles) {
@@ -162,7 +163,7 @@ class FileScanRDD(
hasNext
} else {
currentFile = null
- InputFileNameHolder.unsetInputFileName()
+ InputFileBlockHolder.unset()
false
}
}
@@ -170,7 +171,7 @@ class FileScanRDD(
override def close(): Unit = {
updateBytesRead()
updateBytesReadWithFileSize()
- InputFileNameHolder.unsetInputFileName()
+ InputFileBlockHolder.unset()
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
index 26e1a9f75d..b0339a88fb 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
@@ -533,31 +533,54 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext {
)
}
- test("input_file_name - FileScanRDD") {
+ test("input_file_name, input_file_block_start, input_file_block_length - FileScanRDD") {
withTempPath { dir =>
val data = sparkContext.parallelize(0 to 10).toDF("id")
data.write.parquet(dir.getCanonicalPath)
- val answer = spark.read.parquet(dir.getCanonicalPath).select(input_file_name())
- .head.getString(0)
- assert(answer.contains(dir.getCanonicalPath))
- checkAnswer(data.select(input_file_name()).limit(1), Row(""))
+ // Test the 3 expressions when reading from files
+ val q = spark.read.parquet(dir.getCanonicalPath).select(
+ input_file_name(), expr("input_file_block_start()"), expr("input_file_block_length()"))
+ val firstRow = q.head()
+ assert(firstRow.getString(0).contains(dir.getCanonicalPath))
+ assert(firstRow.getLong(1) == 0)
+ assert(firstRow.getLong(2) > 0)
+
+ // Now read directly from the original RDD without going through any files to make sure
+ // we are returning empty string, -1, and -1.
+ checkAnswer(
+ data.select(
+ input_file_name(), expr("input_file_block_start()"), expr("input_file_block_length()")
+ ).limit(1),
+ Row("", -1L, -1L))
}
}
- test("input_file_name - HadoopRDD") {
+ test("input_file_name, input_file_block_start, input_file_block_length - HadoopRDD") {
withTempPath { dir =>
val data = sparkContext.parallelize((0 to 10).map(_.toString)).toDF()
data.write.text(dir.getCanonicalPath)
val df = spark.sparkContext.textFile(dir.getCanonicalPath).toDF()
- val answer = df.select(input_file_name()).head.getString(0)
- assert(answer.contains(dir.getCanonicalPath))
- checkAnswer(data.select(input_file_name()).limit(1), Row(""))
+ // Test the 3 expressions when reading from files
+ val q = df.select(
+ input_file_name(), expr("input_file_block_start()"), expr("input_file_block_length()"))
+ val firstRow = q.head()
+ assert(firstRow.getString(0).contains(dir.getCanonicalPath))
+ assert(firstRow.getLong(1) == 0)
+ assert(firstRow.getLong(2) > 0)
+
+ // Now read directly from the original RDD without going through any files to make sure
+ // we are returning empty string, -1, and -1.
+ checkAnswer(
+ data.select(
+ input_file_name(), expr("input_file_block_start()"), expr("input_file_block_length()")
+ ).limit(1),
+ Row("", -1L, -1L))
}
}
- test("input_file_name - NewHadoopRDD") {
+ test("input_file_name, input_file_block_start, input_file_block_length - NewHadoopRDD") {
withTempPath { dir =>
val data = sparkContext.parallelize((0 to 10).map(_.toString)).toDF()
data.write.text(dir.getCanonicalPath)
@@ -567,10 +590,22 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext {
classOf[LongWritable],
classOf[Text])
val df = rdd.map(pair => pair._2.toString).toDF()
- val answer = df.select(input_file_name()).head.getString(0)
- assert(answer.contains(dir.getCanonicalPath))
- checkAnswer(data.select(input_file_name()).limit(1), Row(""))
+ // Test the 3 expressions when reading from files
+ val q = df.select(
+ input_file_name(), expr("input_file_block_start()"), expr("input_file_block_length()"))
+ val firstRow = q.head()
+ assert(firstRow.getString(0).contains(dir.getCanonicalPath))
+ assert(firstRow.getLong(1) == 0)
+ assert(firstRow.getLong(2) > 0)
+
+ // Now read directly from the original RDD without going through any files to make sure
+ // we are returning empty string, -1, and -1.
+ checkAnswer(
+ data.select(
+ input_file_name(), expr("input_file_block_start()"), expr("input_file_block_length()")
+ ).limit(1),
+ Row("", -1L, -1L))
}
}