aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala7
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala93
2 files changed, 98 insertions, 2 deletions
diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
index 0453614f6a..7db5834687 100644
--- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
@@ -213,6 +213,12 @@ class HadoopRDD[K, V](
val inputMetrics = context.taskMetrics.getInputMetricsForReadMethod(DataReadMethod.Hadoop)
+ // Sets the thread local variable for the file's name
+ split.inputSplit.value match {
+ case fs: FileSplit => SqlNewHadoopRDD.setInputFileName(fs.getPath.toString)
+ case _ => SqlNewHadoopRDD.unsetInputFileName()
+ }
+
// Find a function that will return the FileSystem bytes read by this thread. Do this before
// creating RecordReader, because RecordReader's constructor might read some bytes
val bytesReadCallback = inputMetrics.bytesReadCallback.orElse {
@@ -250,6 +256,7 @@ class HadoopRDD[K, V](
override def close() {
if (reader != null) {
+ SqlNewHadoopRDD.unsetInputFileName()
// Close the reader and release it. Note: it's very important that we don't close the
// reader more than once, since that exposes us to MAPREDUCE-5918 when running against
// Hadoop 1.x and older Hadoop 2.x releases. That bug can lead to non-deterministic
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala
index 5ab477efc4..9deb1a6db1 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala
@@ -17,7 +17,7 @@
package org.apache.spark.sql.hive.execution
-import java.io.{DataInput, DataOutput}
+import java.io.{PrintWriter, File, DataInput, DataOutput}
import java.util.{ArrayList, Arrays, Properties}
import org.apache.hadoop.conf.Configuration
@@ -28,6 +28,7 @@ import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectIn
import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspector, ObjectInspectorFactory}
import org.apache.hadoop.hive.serde2.{AbstractSerDe, SerDeStats}
import org.apache.hadoop.io.Writable
+import org.apache.spark.sql.test.SQLTestUtils
import org.apache.spark.sql.{AnalysisException, QueryTest, Row}
import org.apache.spark.sql.hive.test.TestHiveSingleton
import org.apache.spark.util.Utils
@@ -44,7 +45,7 @@ case class ListStringCaseClass(l: Seq[String])
/**
* A test suite for Hive custom UDFs.
*/
-class HiveUDFSuite extends QueryTest with TestHiveSingleton {
+class HiveUDFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils {
import hiveContext.{udf, sql}
import hiveContext.implicits._
@@ -348,6 +349,94 @@ class HiveUDFSuite extends QueryTest with TestHiveSingleton {
sqlContext.dropTempTable("testUDF")
}
+
+ test("SPARK-11522 select input_file_name from non-parquet table"){
+
+ withTempDir { tempDir =>
+
+ // EXTERNAL OpenCSVSerde table pointing to LOCATION
+
+ val file1 = new File(tempDir + "/data1")
+ val writer1 = new PrintWriter(file1)
+ writer1.write("1,2")
+ writer1.close()
+
+ val file2 = new File(tempDir + "/data2")
+ val writer2 = new PrintWriter(file2)
+ writer2.write("1,2")
+ writer2.close()
+
+ sql(
+ s"""CREATE EXTERNAL TABLE csv_table(page_id INT, impressions INT)
+ ROW FORMAT SERDE 'org.apache.hadoop.hive.serde2.OpenCSVSerde'
+ WITH SERDEPROPERTIES (
+ \"separatorChar\" = \",\",
+ \"quoteChar\" = \"\\\"\",
+ \"escapeChar\" = \"\\\\\")
+ LOCATION '$tempDir'
+ """)
+
+ val answer1 =
+ sql("SELECT input_file_name() FROM csv_table").head().getString(0)
+ assert(answer1.contains("data1") || answer1.contains("data2"))
+
+ val count1 = sql("SELECT input_file_name() FROM csv_table").distinct().count()
+ assert(count1 == 2)
+ sql("DROP TABLE csv_table")
+
+ // EXTERNAL pointing to LOCATION
+
+ sql(
+ s"""CREATE EXTERNAL TABLE external_t5 (c1 int, c2 int)
+ ROW FORMAT DELIMITED FIELDS TERMINATED BY ','
+ LOCATION '$tempDir'
+ """)
+
+ val answer2 =
+ sql("SELECT input_file_name() as file FROM external_t5").head().getString(0)
+ assert(answer1.contains("data1") || answer1.contains("data2"))
+
+ val count2 = sql("SELECT input_file_name() as file FROM external_t5").distinct().count
+ assert(count2 == 2)
+ sql("DROP TABLE external_t5")
+ }
+
+ withTempDir { tempDir =>
+
+ // External parquet pointing to LOCATION
+
+ val parquetLocation = tempDir + "/external_parquet"
+ sql("SELECT 1, 2").write.parquet(parquetLocation)
+
+ sql(
+ s"""CREATE EXTERNAL TABLE external_parquet(c1 int, c2 int)
+ STORED AS PARQUET
+ LOCATION '$parquetLocation'
+ """)
+
+ val answer3 =
+ sql("SELECT input_file_name() as file FROM external_parquet").head().getString(0)
+ assert(answer3.contains("external_parquet"))
+
+ val count3 = sql("SELECT input_file_name() as file FROM external_parquet").distinct().count
+ assert(count3 == 1)
+ sql("DROP TABLE external_parquet")
+ }
+
+ // Non-External parquet pointing to /tmp/...
+
+ sql("CREATE TABLE parquet_tmp(c1 int, c2 int) " +
+ " STORED AS parquet " +
+ " AS SELECT 1, 2")
+
+ val answer4 =
+ sql("SELECT input_file_name() as file FROM parquet_tmp").head().getString(0)
+ assert(answer4.contains("parquet_tmp"))
+
+ val count4 = sql("SELECT input_file_name() as file FROM parquet_tmp").distinct().count
+ assert(count4 == 1)
+ sql("DROP TABLE parquet_tmp")
+ }
}
class TestPair(x: Int, y: Int) extends Writable with Serializable {