aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorYanbo Liang <ybliang8@gmail.com>2016-09-21 01:37:03 -0700
committerYanbo Liang <ybliang8@gmail.com>2016-09-21 01:37:03 -0700
commitd3b88697638dcf32854fe21a6c53dfb3782773b9 (patch)
tree8d5e243a53fb22c3e092556cc87ccd1c549329d6
parent61876a42793bde0da90f54b44255148ed54b7f61 (diff)
downloadspark-d3b88697638dcf32854fe21a6c53dfb3782773b9.tar.gz
spark-d3b88697638dcf32854fe21a6c53dfb3782773b9.tar.bz2
spark-d3b88697638dcf32854fe21a6c53dfb3782773b9.zip
[SPARK-17585][PYSPARK][CORE] PySpark SparkContext.addFile supports adding files recursively
## What changes were proposed in this pull request? Users would like to add a directory as dependency in some cases, they can use ```SparkContext.addFile``` with argument ```recursive=true``` to recursively add all files under the directory by using Scala. But Python users can only add file not directory, we should also make it supported. ## How was this patch tested? Unit test. Author: Yanbo Liang <ybliang8@gmail.com> Closes #15140 from yanboliang/spark-17585.
-rw-r--r--core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala13
-rw-r--r--python/pyspark/context.py7
-rw-r--r--python/pyspark/tests.py20
-rwxr-xr-xpython/test_support/hello/hello.txt (renamed from python/test_support/hello.txt)0
-rw-r--r--python/test_support/hello/sub_hello/sub_hello.txt1
5 files changed, 34 insertions, 7 deletions
diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala
index 131f36f547..4e50c2686d 100644
--- a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala
@@ -670,6 +670,19 @@ class JavaSparkContext(val sc: SparkContext)
}
/**
+ * Add a file to be downloaded with this Spark job on every node.
+ * The `path` passed can be either a local file, a file in HDFS (or other Hadoop-supported
+ * filesystems), or an HTTP, HTTPS or FTP URI. To access the file in Spark jobs,
+ * use `SparkFiles.get(fileName)` to find its download location.
+ *
+ * A directory can be given if the recursive option is set to true. Currently directories are only
+ * supported for Hadoop-supported filesystems.
+ */
+ def addFile(path: String, recursive: Boolean): Unit = {
+ sc.addFile(path, recursive)
+ }
+
+ /**
* Adds a JAR dependency for all tasks to be executed on this SparkContext in the future.
* The `path` passed can be either a local file, a file in HDFS (or other Hadoop-supported
* filesystems), or an HTTP, HTTPS or FTP URI.
diff --git a/python/pyspark/context.py b/python/pyspark/context.py
index 5c32f8ea1d..7a7f59cb50 100644
--- a/python/pyspark/context.py
+++ b/python/pyspark/context.py
@@ -767,7 +767,7 @@ class SparkContext(object):
SparkContext._next_accum_id += 1
return Accumulator(SparkContext._next_accum_id - 1, value, accum_param)
- def addFile(self, path):
+ def addFile(self, path, recursive=False):
"""
Add a file to be downloaded with this Spark job on every node.
The C{path} passed can be either a local file, a file in HDFS
@@ -778,6 +778,9 @@ class SparkContext(object):
L{SparkFiles.get(fileName)<pyspark.files.SparkFiles.get>} with the
filename to find its download location.
+ A directory can be given if the recursive option is set to True.
+ Currently directories are only supported for Hadoop-supported filesystems.
+
>>> from pyspark import SparkFiles
>>> path = os.path.join(tempdir, "test.txt")
>>> with open(path, "w") as testFile:
@@ -790,7 +793,7 @@ class SparkContext(object):
>>> sc.parallelize([1, 2, 3, 4]).mapPartitions(func).collect()
[100, 200, 300, 400]
"""
- self._jsc.sc().addFile(path)
+ self._jsc.sc().addFile(path, recursive)
def addPyFile(self, path):
"""
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index 0a029b6e74..b0756911bf 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -409,13 +409,23 @@ class AddFileTests(PySparkTestCase):
self.assertEqual("Hello World!", res)
def test_add_file_locally(self):
- path = os.path.join(SPARK_HOME, "python/test_support/hello.txt")
+ path = os.path.join(SPARK_HOME, "python/test_support/hello/hello.txt")
self.sc.addFile(path)
download_path = SparkFiles.get("hello.txt")
self.assertNotEqual(path, download_path)
with open(download_path) as test_file:
self.assertEqual("Hello World!\n", test_file.readline())
+ def test_add_file_recursively_locally(self):
+ path = os.path.join(SPARK_HOME, "python/test_support/hello")
+ self.sc.addFile(path, True)
+ download_path = SparkFiles.get("hello")
+ self.assertNotEqual(path, download_path)
+ with open(download_path + "/hello.txt") as test_file:
+ self.assertEqual("Hello World!\n", test_file.readline())
+ with open(download_path + "/sub_hello/sub_hello.txt") as test_file:
+ self.assertEqual("Sub Hello World!\n", test_file.readline())
+
def test_add_py_file_locally(self):
# To ensure that we're actually testing addPyFile's effects, check that
# this fails due to `userlibrary` not being on the Python path:
@@ -514,7 +524,7 @@ class RDDTests(ReusedPySparkTestCase):
def test_cartesian_on_textfile(self):
# Regression test for
- path = os.path.join(SPARK_HOME, "python/test_support/hello.txt")
+ path = os.path.join(SPARK_HOME, "python/test_support/hello/hello.txt")
a = self.sc.textFile(path)
result = a.cartesian(a).collect()
(x, y) = result[0]
@@ -751,7 +761,7 @@ class RDDTests(ReusedPySparkTestCase):
b = b._reserialize(MarshalSerializer())
self.assertEqual(a.zip(b).collect(), [(0, 100), (1, 101), (2, 102), (3, 103), (4, 104)])
# regression test for SPARK-4841
- path = os.path.join(SPARK_HOME, "python/test_support/hello.txt")
+ path = os.path.join(SPARK_HOME, "python/test_support/hello/hello.txt")
t = self.sc.textFile(path)
cnt = t.count()
self.assertEqual(cnt, t.zip(t).count())
@@ -1214,7 +1224,7 @@ class InputFormatTests(ReusedPySparkTestCase):
ei = [(1, u'aa'), (1, u'aa'), (2, u'aa'), (2, u'bb'), (2, u'bb'), (3, u'cc')]
self.assertEqual(ints, ei)
- hellopath = os.path.join(SPARK_HOME, "python/test_support/hello.txt")
+ hellopath = os.path.join(SPARK_HOME, "python/test_support/hello/hello.txt")
oldconf = {"mapred.input.dir": hellopath}
hello = self.sc.hadoopRDD("org.apache.hadoop.mapred.TextInputFormat",
"org.apache.hadoop.io.LongWritable",
@@ -1233,7 +1243,7 @@ class InputFormatTests(ReusedPySparkTestCase):
ei = [(1, u'aa'), (1, u'aa'), (2, u'aa'), (2, u'bb'), (2, u'bb'), (3, u'cc')]
self.assertEqual(ints, ei)
- hellopath = os.path.join(SPARK_HOME, "python/test_support/hello.txt")
+ hellopath = os.path.join(SPARK_HOME, "python/test_support/hello/hello.txt")
newconf = {"mapred.input.dir": hellopath}
hello = self.sc.newAPIHadoopRDD("org.apache.hadoop.mapreduce.lib.input.TextInputFormat",
"org.apache.hadoop.io.LongWritable",
diff --git a/python/test_support/hello.txt b/python/test_support/hello/hello.txt
index 980a0d5f19..980a0d5f19 100755
--- a/python/test_support/hello.txt
+++ b/python/test_support/hello/hello.txt
diff --git a/python/test_support/hello/sub_hello/sub_hello.txt b/python/test_support/hello/sub_hello/sub_hello.txt
new file mode 100644
index 0000000000..ce2d435b8c
--- /dev/null
+++ b/python/test_support/hello/sub_hello/sub_hello.txt
@@ -0,0 +1 @@
+Sub Hello World!