aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/tests.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/tests.py')
-rw-r--r--python/pyspark/tests.py38
1 files changed, 33 insertions, 5 deletions
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index b0a403b580..4d70ee4f12 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -9,21 +9,32 @@ import time
import unittest
from pyspark.context import SparkContext
+from pyspark.java_gateway import SPARK_HOME
-class TestCheckpoint(unittest.TestCase):
+class PySparkTestCase(unittest.TestCase):
def setUp(self):
- self.sc = SparkContext('local[4]', 'TestPartitioning', batchSize=2)
- self.checkpointDir = NamedTemporaryFile(delete=False)
- os.unlink(self.checkpointDir.name)
- self.sc.setCheckpointDir(self.checkpointDir.name)
+ class_name = self.__class__.__name__
+ self.sc = SparkContext('local[4]', class_name , batchSize=2)
def tearDown(self):
self.sc.stop()
# To avoid Akka rebinding to the same port, since it doesn't unbind
# immediately on shutdown
self.sc.jvm.System.clearProperty("spark.master.port")
+
+
+class TestCheckpoint(PySparkTestCase):
+
+ def setUp(self):
+ PySparkTestCase.setUp(self)
+ self.checkpointDir = NamedTemporaryFile(delete=False)
+ os.unlink(self.checkpointDir.name)
+ self.sc.setCheckpointDir(self.checkpointDir.name)
+
+ def tearDown(self):
+ PySparkTestCase.tearDown(self)
shutil.rmtree(self.checkpointDir.name)
def test_basic_checkpointing(self):
@@ -57,5 +68,22 @@ class TestCheckpoint(unittest.TestCase):
self.assertEquals([1, 2, 3, 4], recovered.collect())
+class TestAddFile(PySparkTestCase):
+
+ def test_add_py_file(self):
+ # To ensure that we're actually testing addPyFile's effects, check that
+ # this job fails due to `userlibrary` not being on the Python path:
+ def func(x):
+ from userlibrary import UserClass
+ return UserClass().hello()
+ self.assertRaises(Exception,
+ self.sc.parallelize(range(2)).map(func).first)
+ # Add the file, so the job should now succeed:
+ path = os.path.join(SPARK_HOME, "python/test_support/userlibrary.py")
+ self.sc.addPyFile(path)
+ res = self.sc.parallelize(range(2)).map(func).first()
+ self.assertEqual("Hello World!", res)
+
+
if __name__ == "__main__":
unittest.main()