aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/sql/tests.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/sql/tests.py')
-rw-r--r--python/pyspark/sql/tests.py24
1 files changed, 24 insertions, 0 deletions
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index a8250281da..73a5df65e0 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -435,6 +435,30 @@ class SQLTests(ReusedPySparkTestCase):
row = self.spark.read.json(filePath).select(sourceFile(input_file_name())).first()
self.assertTrue(row[0].find("people1.json") != -1)
+ def test_udf_with_input_file_name_for_hadooprdd(self):
+ from pyspark.sql.functions import udf, input_file_name
+ from pyspark.sql.types import StringType
+
+ def filename(path):
+ return path
+
+ sameText = udf(filename, StringType())
+
+ rdd = self.sc.textFile('python/test_support/sql/people.json')
+ df = self.spark.read.json(rdd).select(input_file_name().alias('file'))
+ row = df.select(sameText(df['file'])).first()
+ self.assertTrue(row[0].find("people.json") != -1)
+
+ rdd2 = self.sc.newAPIHadoopFile(
+ 'python/test_support/sql/people.json',
+ 'org.apache.hadoop.mapreduce.lib.input.TextInputFormat',
+ 'org.apache.hadoop.io.LongWritable',
+ 'org.apache.hadoop.io.Text')
+
+ df2 = self.spark.read.json(rdd2).select(input_file_name().alias('file'))
+ row2 = df2.select(sameText(df2['file'])).first()
+ self.assertTrue(row2[0].find("people.json") != -1)
+
def test_basic_functions(self):
rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}'])
df = self.spark.read.json(rdd)