aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/sql/tests.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/sql/tests.py')
-rw-r--r--python/pyspark/sql/tests.py93
1 files changed, 93 insertions, 0 deletions
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index d4c221d712..1e864b4cd1 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -879,6 +879,99 @@ class SQLTests(ReusedPySparkTestCase):
shutil.rmtree(tmpPath)
+ def test_stream_trigger_takes_keyword_args(self):
+ df = self.sqlCtx.read.format('text').stream('python/test_support/sql/streaming')
+ try:
+ df.write.trigger('5 seconds')
+ self.fail("Should have thrown an exception")
+ except TypeError:
+ # should throw error
+ pass
+
+ def test_stream_read_options(self):
+ schema = StructType([StructField("data", StringType(), False)])
+ df = self.sqlCtx.read.format('text').option('path', 'python/test_support/sql/streaming')\
+ .schema(schema).stream()
+ self.assertTrue(df.isStreaming)
+ self.assertEqual(df.schema.simpleString(), "struct<data:string>")
+
+ def test_stream_read_options_overwrite(self):
+ bad_schema = StructType([StructField("test", IntegerType(), False)])
+ schema = StructType([StructField("data", StringType(), False)])
+ df = self.sqlCtx.read.format('csv').option('path', 'python/test_support/sql/fake') \
+ .schema(bad_schema).stream(path='python/test_support/sql/streaming',
+ schema=schema, format='text')
+ self.assertTrue(df.isStreaming)
+ self.assertEqual(df.schema.simpleString(), "struct<data:string>")
+
+ def test_stream_save_options(self):
+ df = self.sqlCtx.read.format('text').stream('python/test_support/sql/streaming')
+ tmpPath = tempfile.mkdtemp()
+ shutil.rmtree(tmpPath)
+ self.assertTrue(df.isStreaming)
+ out = os.path.join(tmpPath, 'out')
+ chk = os.path.join(tmpPath, 'chk')
+ cq = df.write.option('checkpointLocation', chk).queryName('this_query')\
+ .format('parquet').option('path', out).startStream()
+ self.assertEqual(cq.name, 'this_query')
+ self.assertTrue(cq.isActive)
+ cq.processAllAvailable()
+ output_files = []
+ for _, _, files in os.walk(out):
+ output_files.extend([f for f in files if 'parquet' in f and not f.startswith('.')])
+ self.assertTrue(len(output_files) > 0)
+ self.assertTrue(len(os.listdir(chk)) > 0)
+ cq.stop()
+ shutil.rmtree(tmpPath)
+
+ def test_stream_save_options_overwrite(self):
+ df = self.sqlCtx.read.format('text').stream('python/test_support/sql/streaming')
+ tmpPath = tempfile.mkdtemp()
+ shutil.rmtree(tmpPath)
+ self.assertTrue(df.isStreaming)
+ out = os.path.join(tmpPath, 'out')
+ chk = os.path.join(tmpPath, 'chk')
+ fake1 = os.path.join(tmpPath, 'fake1')
+ fake2 = os.path.join(tmpPath, 'fake2')
+ cq = df.write.option('checkpointLocation', fake1).format('memory').option('path', fake2) \
+ .queryName('fake_query').startStream(path=out, format='parquet', queryName='this_query',
+ checkpointLocation=chk)
+ self.assertEqual(cq.name, 'this_query')
+ self.assertTrue(cq.isActive)
+ cq.processAllAvailable()
+ output_files = []
+ for _, _, files in os.walk(out):
+ output_files.extend([f for f in files if 'parquet' in f and not f.startswith('.')])
+ self.assertTrue(len(output_files) > 0)
+ self.assertTrue(len(os.listdir(chk)) > 0)
+ self.assertFalse(os.path.isdir(fake1)) # should not have been created
+ self.assertFalse(os.path.isdir(fake2)) # should not have been created
+ cq.stop()
+ shutil.rmtree(tmpPath)
+
+ def test_stream_await_termination(self):
+ df = self.sqlCtx.read.format('text').stream('python/test_support/sql/streaming')
+ tmpPath = tempfile.mkdtemp()
+ shutil.rmtree(tmpPath)
+ self.assertTrue(df.isStreaming)
+ out = os.path.join(tmpPath, 'out')
+ chk = os.path.join(tmpPath, 'chk')
+ cq = df.write.startStream(path=out, format='parquet', queryName='this_query',
+ checkpointLocation=chk)
+ self.assertTrue(cq.isActive)
+ try:
+ cq.awaitTermination("hello")
+ self.fail("Expected a value exception")
+ except ValueError:
+ pass
+ now = time.time()
+ res = cq.awaitTermination(2600) # test should take at least 2 seconds
+ duration = time.time() - now
+ self.assertTrue(duration >= 2)
+ self.assertFalse(res)
+ cq.stop()
+ shutil.rmtree(tmpPath)
+
def test_help_command(self):
# Regression test for SPARK-5464
rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}'])