diff options
Diffstat (limited to 'python/pyspark/sql/tests.py')
-rw-r--r-- | python/pyspark/sql/tests.py | 93 |
1 files changed, 93 insertions, 0 deletions
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index d4c221d712..1e864b4cd1 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -879,6 +879,99 @@ class SQLTests(ReusedPySparkTestCase): shutil.rmtree(tmpPath) + def test_stream_trigger_takes_keyword_args(self): + df = self.sqlCtx.read.format('text').stream('python/test_support/sql/streaming') + try: + df.write.trigger('5 seconds') + self.fail("Should have thrown an exception") + except TypeError: + # should throw error + pass + + def test_stream_read_options(self): + schema = StructType([StructField("data", StringType(), False)]) + df = self.sqlCtx.read.format('text').option('path', 'python/test_support/sql/streaming')\ + .schema(schema).stream() + self.assertTrue(df.isStreaming) + self.assertEqual(df.schema.simpleString(), "struct<data:string>") + + def test_stream_read_options_overwrite(self): + bad_schema = StructType([StructField("test", IntegerType(), False)]) + schema = StructType([StructField("data", StringType(), False)]) + df = self.sqlCtx.read.format('csv').option('path', 'python/test_support/sql/fake') \ + .schema(bad_schema).stream(path='python/test_support/sql/streaming', + schema=schema, format='text') + self.assertTrue(df.isStreaming) + self.assertEqual(df.schema.simpleString(), "struct<data:string>") + + def test_stream_save_options(self): + df = self.sqlCtx.read.format('text').stream('python/test_support/sql/streaming') + tmpPath = tempfile.mkdtemp() + shutil.rmtree(tmpPath) + self.assertTrue(df.isStreaming) + out = os.path.join(tmpPath, 'out') + chk = os.path.join(tmpPath, 'chk') + cq = df.write.option('checkpointLocation', chk).queryName('this_query')\ + .format('parquet').option('path', out).startStream() + self.assertEqual(cq.name, 'this_query') + self.assertTrue(cq.isActive) + cq.processAllAvailable() + output_files = [] + for _, _, files in os.walk(out): + output_files.extend([f for f in files if 'parquet' in f and not f.startswith('.')]) + self.assertTrue(len(output_files) > 0) + self.assertTrue(len(os.listdir(chk)) > 0) + cq.stop() + shutil.rmtree(tmpPath) + + def test_stream_save_options_overwrite(self): + df = self.sqlCtx.read.format('text').stream('python/test_support/sql/streaming') + tmpPath = tempfile.mkdtemp() + shutil.rmtree(tmpPath) + self.assertTrue(df.isStreaming) + out = os.path.join(tmpPath, 'out') + chk = os.path.join(tmpPath, 'chk') + fake1 = os.path.join(tmpPath, 'fake1') + fake2 = os.path.join(tmpPath, 'fake2') + cq = df.write.option('checkpointLocation', fake1).format('memory').option('path', fake2) \ + .queryName('fake_query').startStream(path=out, format='parquet', queryName='this_query', + checkpointLocation=chk) + self.assertEqual(cq.name, 'this_query') + self.assertTrue(cq.isActive) + cq.processAllAvailable() + output_files = [] + for _, _, files in os.walk(out): + output_files.extend([f for f in files if 'parquet' in f and not f.startswith('.')]) + self.assertTrue(len(output_files) > 0) + self.assertTrue(len(os.listdir(chk)) > 0) + self.assertFalse(os.path.isdir(fake1)) # should not have been created + self.assertFalse(os.path.isdir(fake2)) # should not have been created + cq.stop() + shutil.rmtree(tmpPath) + + def test_stream_await_termination(self): + df = self.sqlCtx.read.format('text').stream('python/test_support/sql/streaming') + tmpPath = tempfile.mkdtemp() + shutil.rmtree(tmpPath) + self.assertTrue(df.isStreaming) + out = os.path.join(tmpPath, 'out') + chk = os.path.join(tmpPath, 'chk') + cq = df.write.startStream(path=out, format='parquet', queryName='this_query', + checkpointLocation=chk) + self.assertTrue(cq.isActive) + try: + cq.awaitTermination("hello") + self.fail("Expected a value exception") + except ValueError: + pass + now = time.time() + res = cq.awaitTermination(2600) # test should take at least 2 seconds + duration = time.time() - now + self.assertTrue(duration >= 2) + self.assertFalse(res) + cq.stop() + shutil.rmtree(tmpPath) + def test_help_command(self): # Regression test for SPARK-5464 rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}']) |