diff options
Diffstat (limited to 'python/pyspark/sql/tests.py')
-rw-r--r-- | python/pyspark/sql/tests.py | 109 |
1 files changed, 75 insertions, 34 deletions
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 99a12d639a..1d3dc159da 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -924,26 +924,32 @@ class SQLTests(ReusedPySparkTestCase): def test_stream_save_options(self): df = self.sqlCtx.read.format('text').stream('python/test_support/sql/streaming') + for cq in self.sqlCtx.streams.active: + cq.stop() tmpPath = tempfile.mkdtemp() shutil.rmtree(tmpPath) self.assertTrue(df.isStreaming) out = os.path.join(tmpPath, 'out') chk = os.path.join(tmpPath, 'chk') - cq = df.write.option('checkpointLocation', chk).queryName('this_query')\ + cq = df.write.option('checkpointLocation', chk).queryName('this_query') \ .format('parquet').option('path', out).startStream() - self.assertEqual(cq.name, 'this_query') - self.assertTrue(cq.isActive) - cq.processAllAvailable() - output_files = [] - for _, _, files in os.walk(out): - output_files.extend([f for f in files if 'parquet' in f and not f.startswith('.')]) - self.assertTrue(len(output_files) > 0) - self.assertTrue(len(os.listdir(chk)) > 0) - cq.stop() - shutil.rmtree(tmpPath) + try: + self.assertEqual(cq.name, 'this_query') + self.assertTrue(cq.isActive) + cq.processAllAvailable() + output_files = [] + for _, _, files in os.walk(out): + output_files.extend([f for f in files if 'parquet' in f and not f.startswith('.')]) + self.assertTrue(len(output_files) > 0) + self.assertTrue(len(os.listdir(chk)) > 0) + finally: + cq.stop() + shutil.rmtree(tmpPath) def test_stream_save_options_overwrite(self): df = self.sqlCtx.read.format('text').stream('python/test_support/sql/streaming') + for cq in self.sqlCtx.streams.active: + cq.stop() tmpPath = tempfile.mkdtemp() shutil.rmtree(tmpPath) self.assertTrue(df.isStreaming) @@ -954,21 +960,25 @@ class SQLTests(ReusedPySparkTestCase): cq = df.write.option('checkpointLocation', fake1).format('memory').option('path', fake2) \ .queryName('fake_query').startStream(path=out, format='parquet', queryName='this_query', checkpointLocation=chk) - self.assertEqual(cq.name, 'this_query') - self.assertTrue(cq.isActive) - cq.processAllAvailable() - output_files = [] - for _, _, files in os.walk(out): - output_files.extend([f for f in files if 'parquet' in f and not f.startswith('.')]) - self.assertTrue(len(output_files) > 0) - self.assertTrue(len(os.listdir(chk)) > 0) - self.assertFalse(os.path.isdir(fake1)) # should not have been created - self.assertFalse(os.path.isdir(fake2)) # should not have been created - cq.stop() - shutil.rmtree(tmpPath) + try: + self.assertEqual(cq.name, 'this_query') + self.assertTrue(cq.isActive) + cq.processAllAvailable() + output_files = [] + for _, _, files in os.walk(out): + output_files.extend([f for f in files if 'parquet' in f and not f.startswith('.')]) + self.assertTrue(len(output_files) > 0) + self.assertTrue(len(os.listdir(chk)) > 0) + self.assertFalse(os.path.isdir(fake1)) # should not have been created + self.assertFalse(os.path.isdir(fake2)) # should not have been created + finally: + cq.stop() + shutil.rmtree(tmpPath) def test_stream_await_termination(self): df = self.sqlCtx.read.format('text').stream('python/test_support/sql/streaming') + for cq in self.sqlCtx.streams.active: + cq.stop() tmpPath = tempfile.mkdtemp() shutil.rmtree(tmpPath) self.assertTrue(df.isStreaming) @@ -976,19 +986,50 @@ class SQLTests(ReusedPySparkTestCase): chk = os.path.join(tmpPath, 'chk') cq = df.write.startStream(path=out, format='parquet', queryName='this_query', checkpointLocation=chk) - self.assertTrue(cq.isActive) try: - cq.awaitTermination("hello") - self.fail("Expected a value exception") - except ValueError: - pass - now = time.time() - res = cq.awaitTermination(2600) # test should take at least 2 seconds - duration = time.time() - now - self.assertTrue(duration >= 2) - self.assertFalse(res) - cq.stop() + self.assertTrue(cq.isActive) + try: + cq.awaitTermination("hello") + self.fail("Expected a value exception") + except ValueError: + pass + now = time.time() + # test should take at least 2 seconds + res = cq.awaitTermination(2.6) + duration = time.time() - now + self.assertTrue(duration >= 2) + self.assertFalse(res) + finally: + cq.stop() + shutil.rmtree(tmpPath) + + def test_query_manager_await_termination(self): + df = self.sqlCtx.read.format('text').stream('python/test_support/sql/streaming') + for cq in self.sqlCtx.streams.active: + cq.stop() + tmpPath = tempfile.mkdtemp() shutil.rmtree(tmpPath) + self.assertTrue(df.isStreaming) + out = os.path.join(tmpPath, 'out') + chk = os.path.join(tmpPath, 'chk') + cq = df.write.startStream(path=out, format='parquet', queryName='this_query', + checkpointLocation=chk) + try: + self.assertTrue(cq.isActive) + try: + self.sqlCtx.streams.awaitAnyTermination("hello") + self.fail("Expected a value exception") + except ValueError: + pass + now = time.time() + # test should take at least 2 seconds + res = self.sqlCtx.streams.awaitAnyTermination(2.6) + duration = time.time() - now + self.assertTrue(duration >= 2) + self.assertFalse(res) + finally: + cq.stop() + shutil.rmtree(tmpPath) def test_help_command(self): # Regression test for SPARK-5464 |