aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/tests.py
diff options
context:
space:
mode:
authorJosh Rosen <joshrosen@eecs.berkeley.edu>2013-01-23 10:36:18 -0800
committerJosh Rosen <joshrosen@eecs.berkeley.edu>2013-01-23 10:58:50 -0800
commitae2ed2947d43860c74a8d40767e289ca78073977 (patch)
treec72a585db233832ae53e52c4e015b8d52813643c /python/pyspark/tests.py
parent35168d9c89904f0dc0bb470c1799f5ca3b04221f (diff)
downloadspark-ae2ed2947d43860c74a8d40767e289ca78073977.tar.gz
spark-ae2ed2947d43860c74a8d40767e289ca78073977.tar.bz2
spark-ae2ed2947d43860c74a8d40767e289ca78073977.zip
Allow PySpark's SparkFiles to be used from driver
Fix minor documentation formatting issues.
Diffstat (limited to 'python/pyspark/tests.py')
-rw-r--r--python/pyspark/tests.py23
1 files changed, 23 insertions, 0 deletions
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index 4d70ee4f12..46ab34f063 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -4,22 +4,26 @@ individual modules.
"""
import os
import shutil
+import sys
from tempfile import NamedTemporaryFile
import time
import unittest
from pyspark.context import SparkContext
+from pyspark.files import SparkFiles
from pyspark.java_gateway import SPARK_HOME
class PySparkTestCase(unittest.TestCase):
def setUp(self):
+ self._old_sys_path = list(sys.path)
class_name = self.__class__.__name__
self.sc = SparkContext('local[4]', class_name , batchSize=2)
def tearDown(self):
self.sc.stop()
+ sys.path = self._old_sys_path
# To avoid Akka rebinding to the same port, since it doesn't unbind
# immediately on shutdown
self.sc.jvm.System.clearProperty("spark.master.port")
@@ -84,6 +88,25 @@ class TestAddFile(PySparkTestCase):
res = self.sc.parallelize(range(2)).map(func).first()
self.assertEqual("Hello World!", res)
+ def test_add_file_locally(self):
+ path = os.path.join(SPARK_HOME, "python/test_support/hello.txt")
+ self.sc.addFile(path)
+ download_path = SparkFiles.get("hello.txt")
+ self.assertNotEqual(path, download_path)
+ with open(download_path) as test_file:
+ self.assertEquals("Hello World!\n", test_file.readline())
+
+ def test_add_py_file_locally(self):
+ # To ensure that we're actually testing addPyFile's effects, check that
+ # this fails due to `userlibrary` not being on the Python path:
+ def func():
+ from userlibrary import UserClass
+ self.assertRaises(ImportError, func)
+ path = os.path.join(SPARK_HOME, "python/test_support/userlibrary.py")
+ self.sc.addFile(path)
+ from userlibrary import UserClass
+ self.assertEqual("Hello World!", UserClass().hello())
+
if __name__ == "__main__":
unittest.main()