aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/tests.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/tests.py')
-rw-r--r--python/pyspark/tests.py121
1 files changed, 121 insertions, 0 deletions
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
new file mode 100644
index 0000000000..6a1962d267
--- /dev/null
+++ b/python/pyspark/tests.py
@@ -0,0 +1,121 @@
+"""
+Unit tests for PySpark; additional tests are implemented as doctests in
+individual modules.
+"""
+import os
+import shutil
+import sys
+from tempfile import NamedTemporaryFile
+import time
+import unittest
+
+from pyspark.context import SparkContext
+from pyspark.files import SparkFiles
+from pyspark.java_gateway import SPARK_HOME
+
+
+class PySparkTestCase(unittest.TestCase):
+
+ def setUp(self):
+ self._old_sys_path = list(sys.path)
+ class_name = self.__class__.__name__
+ self.sc = SparkContext('local[4]', class_name , batchSize=2)
+
+ def tearDown(self):
+ self.sc.stop()
+ sys.path = self._old_sys_path
+ # To avoid Akka rebinding to the same port, since it doesn't unbind
+ # immediately on shutdown
+ self.sc._jvm.System.clearProperty("spark.driver.port")
+
+
+class TestCheckpoint(PySparkTestCase):
+
+ def setUp(self):
+ PySparkTestCase.setUp(self)
+ self.checkpointDir = NamedTemporaryFile(delete=False)
+ os.unlink(self.checkpointDir.name)
+ self.sc.setCheckpointDir(self.checkpointDir.name)
+
+ def tearDown(self):
+ PySparkTestCase.tearDown(self)
+ shutil.rmtree(self.checkpointDir.name)
+
+ def test_basic_checkpointing(self):
+ parCollection = self.sc.parallelize([1, 2, 3, 4])
+ flatMappedRDD = parCollection.flatMap(lambda x: range(1, x + 1))
+
+ self.assertFalse(flatMappedRDD.isCheckpointed())
+ self.assertIsNone(flatMappedRDD.getCheckpointFile())
+
+ flatMappedRDD.checkpoint()
+ result = flatMappedRDD.collect()
+ time.sleep(1) # 1 second
+ self.assertTrue(flatMappedRDD.isCheckpointed())
+ self.assertEqual(flatMappedRDD.collect(), result)
+ self.assertEqual(self.checkpointDir.name,
+ os.path.dirname(flatMappedRDD.getCheckpointFile()))
+
+ def test_checkpoint_and_restore(self):
+ parCollection = self.sc.parallelize([1, 2, 3, 4])
+ flatMappedRDD = parCollection.flatMap(lambda x: [x])
+
+ self.assertFalse(flatMappedRDD.isCheckpointed())
+ self.assertIsNone(flatMappedRDD.getCheckpointFile())
+
+ flatMappedRDD.checkpoint()
+ flatMappedRDD.count() # forces a checkpoint to be computed
+ time.sleep(1) # 1 second
+
+ self.assertIsNotNone(flatMappedRDD.getCheckpointFile())
+ recovered = self.sc._checkpointFile(flatMappedRDD.getCheckpointFile())
+ self.assertEquals([1, 2, 3, 4], recovered.collect())
+
+
+class TestAddFile(PySparkTestCase):
+
+ def test_add_py_file(self):
+ # To ensure that we're actually testing addPyFile's effects, check that
+ # this job fails due to `userlibrary` not being on the Python path:
+ def func(x):
+ from userlibrary import UserClass
+ return UserClass().hello()
+ self.assertRaises(Exception,
+ self.sc.parallelize(range(2)).map(func).first)
+ # Add the file, so the job should now succeed:
+ path = os.path.join(SPARK_HOME, "python/test_support/userlibrary.py")
+ self.sc.addPyFile(path)
+ res = self.sc.parallelize(range(2)).map(func).first()
+ self.assertEqual("Hello World!", res)
+
+ def test_add_file_locally(self):
+ path = os.path.join(SPARK_HOME, "python/test_support/hello.txt")
+ self.sc.addFile(path)
+ download_path = SparkFiles.get("hello.txt")
+ self.assertNotEqual(path, download_path)
+ with open(download_path) as test_file:
+ self.assertEquals("Hello World!\n", test_file.readline())
+
+ def test_add_py_file_locally(self):
+ # To ensure that we're actually testing addPyFile's effects, check that
+ # this fails due to `userlibrary` not being on the Python path:
+ def func():
+ from userlibrary import UserClass
+ self.assertRaises(ImportError, func)
+ path = os.path.join(SPARK_HOME, "python/test_support/userlibrary.py")
+ self.sc.addFile(path)
+ from userlibrary import UserClass
+ self.assertEqual("Hello World!", UserClass().hello())
+
+
+class TestIO(PySparkTestCase):
+
+ def test_stdout_redirection(self):
+ import subprocess
+ def func(x):
+ subprocess.check_call('ls', shell=True)
+ self.sc.parallelize([1]).foreach(func)
+
+
+if __name__ == "__main__":
+ unittest.main()