diff options
Diffstat (limited to 'python/pyspark/sql/tests.py')
-rw-r--r-- | python/pyspark/sql/tests.py | 24 |
1 files changed, 24 insertions, 0 deletions
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index a8250281da..73a5df65e0 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -435,6 +435,30 @@ class SQLTests(ReusedPySparkTestCase): row = self.spark.read.json(filePath).select(sourceFile(input_file_name())).first() self.assertTrue(row[0].find("people1.json") != -1) + def test_udf_with_input_file_name_for_hadooprdd(self): + from pyspark.sql.functions import udf, input_file_name + from pyspark.sql.types import StringType + + def filename(path): + return path + + sameText = udf(filename, StringType()) + + rdd = self.sc.textFile('python/test_support/sql/people.json') + df = self.spark.read.json(rdd).select(input_file_name().alias('file')) + row = df.select(sameText(df['file'])).first() + self.assertTrue(row[0].find("people.json") != -1) + + rdd2 = self.sc.newAPIHadoopFile( + 'python/test_support/sql/people.json', + 'org.apache.hadoop.mapreduce.lib.input.TextInputFormat', + 'org.apache.hadoop.io.LongWritable', + 'org.apache.hadoop.io.Text') + + df2 = self.spark.read.json(rdd2).select(input_file_name().alias('file')) + row2 = df2.select(sameText(df2['file'])).first() + self.assertTrue(row2[0].find("people.json") != -1) + def test_basic_functions(self): rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}']) df = self.spark.read.json(rdd) |