From 78c8aaf849aadbb065730959e7c1b70bb58d69c9 Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Thu, 28 Apr 2016 15:22:28 -0700 Subject: [SPARK-14555] Second cut of Python API for Structured Streaming ## What changes were proposed in this pull request? This PR adds Python APIs for: - `ContinuousQueryManager` - `ContinuousQueryException` The `ContinuousQueryException` is a very basic wrapper, it doesn't provide the functionality that the Scala side provides, but it follows the same pattern for `AnalysisException`. For `ContinuousQueryManager`, all APIs are provided except for registering listeners. This PR also attempts to fix test flakiness by stopping all active streams just before tests. ## How was this patch tested? Python Doc tests and unit tests Author: Burak Yavuz Closes #12673 from brkyvz/pyspark-cqm. --- python/pyspark/sql/context.py | 9 +++ python/pyspark/sql/readwriter.py | 2 +- python/pyspark/sql/streaming.py | 135 +++++++++++++++++++++++++++++++++++---- python/pyspark/sql/tests.py | 109 +++++++++++++++++++++---------- python/pyspark/sql/utils.py | 8 +++ 5 files changed, 217 insertions(+), 46 deletions(-) diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index 48ffb59668..a3ea192b28 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -423,6 +423,15 @@ class SQLContext(object): """ return DataFrameReader(self) + @property + @since(2.0) + def streams(self): + """Returns a :class:`ContinuousQueryManager` that allows managing all the + :class:`ContinuousQuery` ContinuousQueries active on `this` context. + """ + from pyspark.sql.streaming import ContinuousQueryManager + return ContinuousQueryManager(self._ssql_ctx.streams()) + # TODO(andrew): deprecate this class HiveContext(SQLContext): diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index 784609e4c5..ed9e716ab7 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -365,7 +365,7 @@ class DataFrameWriter(object): def _cq(self, jcq): from pyspark.sql.streaming import ContinuousQuery - return ContinuousQuery(jcq, self._sqlContext) + return ContinuousQuery(jcq) @since(1.4) def mode(self, saveMode): diff --git a/python/pyspark/sql/streaming.py b/python/pyspark/sql/streaming.py index 549561669f..bf03fdca91 100644 --- a/python/pyspark/sql/streaming.py +++ b/python/pyspark/sql/streaming.py @@ -18,6 +18,7 @@ from abc import ABCMeta, abstractmethod from pyspark import since +from pyspark.rdd import ignore_unicode_prefix __all__ = ["ContinuousQuery"] @@ -32,9 +33,8 @@ class ContinuousQuery(object): .. versionadded:: 2.0 """ - def __init__(self, jcq, sqlContext): + def __init__(self, jcq): self._jcq = jcq - self._sqlContext = sqlContext @property @since(2.0) @@ -51,22 +51,22 @@ class ContinuousQuery(object): return self._jcq.isActive() @since(2.0) - def awaitTermination(self, timeoutMs=None): + def awaitTermination(self, timeout=None): """Waits for the termination of `this` query, either by :func:`query.stop()` or by an exception. If the query has terminated with an exception, then the exception will be thrown. - If `timeoutMs` is set, it returns whether the query has terminated or not within the - `timeoutMs` milliseconds. + If `timeout` is set, it returns whether the query has terminated or not within the + `timeout` seconds. If the query has terminated, then all subsequent calls to this method will either return immediately (if the query was terminated by :func:`stop()`), or throw the exception immediately (if the query has terminated with exception). - throws ContinuousQueryException, if `this` query has terminated with an exception + throws :class:`ContinuousQueryException`, if `this` query has terminated with an exception """ - if timeoutMs is not None: - if type(timeoutMs) != int or timeoutMs < 0: - raise ValueError("timeoutMs must be a positive integer. Got %s" % timeoutMs) - return self._jcq.awaitTermination(timeoutMs) + if timeout is not None: + if not isinstance(timeout, (int, float)) or timeout < 0: + raise ValueError("timeout must be a positive integer or float. Got %s" % timeout) + return self._jcq.awaitTermination(int(timeout * 1000)) else: return self._jcq.awaitTermination() @@ -87,6 +87,86 @@ class ContinuousQuery(object): self._jcq.stop() +class ContinuousQueryManager(object): + """A class to manage all the :class:`ContinuousQuery` ContinuousQueries active + on a :class:`SQLContext`. + + .. note:: Experimental + + .. versionadded:: 2.0 + """ + + def __init__(self, jcqm): + self._jcqm = jcqm + + @property + @ignore_unicode_prefix + @since(2.0) + def active(self): + """Returns a list of active queries associated with this SQLContext + + >>> cq = df.write.format('memory').queryName('this_query').startStream() + >>> cqm = sqlContext.streams + >>> # get the list of active continuous queries + >>> [q.name for q in cqm.active] + [u'this_query'] + >>> cq.stop() + """ + return [ContinuousQuery(jcq) for jcq in self._jcqm.active()] + + @since(2.0) + def get(self, name): + """Returns an active query from this SQLContext or throws exception if an active query + with this name doesn't exist. + + >>> df.write.format('memory').queryName('this_query').startStream() + >>> cq = sqlContext.streams.get('this_query') + >>> cq.isActive + True + >>> cq.stop() + """ + if type(name) != str or len(name.strip()) == 0: + raise ValueError("The name for the query must be a non-empty string. Got: %s" % name) + return ContinuousQuery(self._jcqm.get(name)) + + @since(2.0) + def awaitAnyTermination(self, timeout=None): + """Wait until any of the queries on the associated SQLContext has terminated since the + creation of the context, or since :func:`resetTerminated()` was called. If any query was + terminated with an exception, then the exception will be thrown. + If `timeout` is set, it returns whether the query has terminated or not within the + `timeout` seconds. + + If a query has terminated, then subsequent calls to :func:`awaitAnyTermination()` will + either return immediately (if the query was terminated by :func:`query.stop()`), + or throw the exception immediately (if the query was terminated with exception). Use + :func:`resetTerminated()` to clear past terminations and wait for new terminations. + + In the case where multiple queries have terminated since :func:`resetTermination()` + was called, if any query has terminated with exception, then :func:`awaitAnyTermination()` + will throw any of the exception. For correctly documenting exceptions across multiple + queries, users need to stop all of them after any of them terminates with exception, and + then check the `query.exception()` for each query. + + throws :class:`ContinuousQueryException`, if `this` query has terminated with an exception + """ + if timeout is not None: + if not isinstance(timeout, (int, float)) or timeout < 0: + raise ValueError("timeout must be a positive integer or float. Got %s" % timeout) + return self._jcqm.awaitAnyTermination(int(timeout * 1000)) + else: + return self._jcqm.awaitAnyTermination() + + @since(2.0) + def resetTerminated(self): + """Forget about past terminated queries so that :func:`awaitAnyTermination()` can be used + again to wait for new terminations. + + >>> sqlContext.streams.resetTerminated() + """ + self._jcqm.resetTerminated() + + class Trigger(object): """Used to indicate how often results should be produced by a :class:`ContinuousQuery`. @@ -116,9 +196,42 @@ class ProcessingTime(Trigger): """ def __init__(self, interval): - if interval is None or type(interval) != str or len(interval.strip()) == 0: + if type(interval) != str or len(interval.strip()) == 0: raise ValueError("interval should be a non empty interval string, e.g. '2 seconds'.") self.interval = interval def _to_java_trigger(self, sqlContext): return sqlContext._sc._jvm.org.apache.spark.sql.ProcessingTime.create(self.interval) + + +def _test(): + import doctest + import os + import tempfile + from pyspark.context import SparkContext + from pyspark.sql import Row, SQLContext, HiveContext + import pyspark.sql.readwriter + + os.chdir(os.environ["SPARK_HOME"]) + + globs = pyspark.sql.readwriter.__dict__.copy() + sc = SparkContext('local[4]', 'PythonTest') + + globs['tempfile'] = tempfile + globs['os'] = os + globs['sc'] = sc + globs['sqlContext'] = SQLContext(sc) + globs['hiveContext'] = HiveContext(sc) + globs['df'] = \ + globs['sqlContext'].read.format('text').stream('python/test_support/sql/streaming') + + (failure_count, test_count) = doctest.testmod( + pyspark.sql.readwriter, globs=globs, + optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE | doctest.REPORT_NDIFF) + globs['sc'].stop() + if failure_count: + exit(-1) + + +if __name__ == "__main__": + _test() diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 99a12d639a..1d3dc159da 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -924,26 +924,32 @@ class SQLTests(ReusedPySparkTestCase): def test_stream_save_options(self): df = self.sqlCtx.read.format('text').stream('python/test_support/sql/streaming') + for cq in self.sqlCtx.streams.active: + cq.stop() tmpPath = tempfile.mkdtemp() shutil.rmtree(tmpPath) self.assertTrue(df.isStreaming) out = os.path.join(tmpPath, 'out') chk = os.path.join(tmpPath, 'chk') - cq = df.write.option('checkpointLocation', chk).queryName('this_query')\ + cq = df.write.option('checkpointLocation', chk).queryName('this_query') \ .format('parquet').option('path', out).startStream() - self.assertEqual(cq.name, 'this_query') - self.assertTrue(cq.isActive) - cq.processAllAvailable() - output_files = [] - for _, _, files in os.walk(out): - output_files.extend([f for f in files if 'parquet' in f and not f.startswith('.')]) - self.assertTrue(len(output_files) > 0) - self.assertTrue(len(os.listdir(chk)) > 0) - cq.stop() - shutil.rmtree(tmpPath) + try: + self.assertEqual(cq.name, 'this_query') + self.assertTrue(cq.isActive) + cq.processAllAvailable() + output_files = [] + for _, _, files in os.walk(out): + output_files.extend([f for f in files if 'parquet' in f and not f.startswith('.')]) + self.assertTrue(len(output_files) > 0) + self.assertTrue(len(os.listdir(chk)) > 0) + finally: + cq.stop() + shutil.rmtree(tmpPath) def test_stream_save_options_overwrite(self): df = self.sqlCtx.read.format('text').stream('python/test_support/sql/streaming') + for cq in self.sqlCtx.streams.active: + cq.stop() tmpPath = tempfile.mkdtemp() shutil.rmtree(tmpPath) self.assertTrue(df.isStreaming) @@ -954,21 +960,25 @@ class SQLTests(ReusedPySparkTestCase): cq = df.write.option('checkpointLocation', fake1).format('memory').option('path', fake2) \ .queryName('fake_query').startStream(path=out, format='parquet', queryName='this_query', checkpointLocation=chk) - self.assertEqual(cq.name, 'this_query') - self.assertTrue(cq.isActive) - cq.processAllAvailable() - output_files = [] - for _, _, files in os.walk(out): - output_files.extend([f for f in files if 'parquet' in f and not f.startswith('.')]) - self.assertTrue(len(output_files) > 0) - self.assertTrue(len(os.listdir(chk)) > 0) - self.assertFalse(os.path.isdir(fake1)) # should not have been created - self.assertFalse(os.path.isdir(fake2)) # should not have been created - cq.stop() - shutil.rmtree(tmpPath) + try: + self.assertEqual(cq.name, 'this_query') + self.assertTrue(cq.isActive) + cq.processAllAvailable() + output_files = [] + for _, _, files in os.walk(out): + output_files.extend([f for f in files if 'parquet' in f and not f.startswith('.')]) + self.assertTrue(len(output_files) > 0) + self.assertTrue(len(os.listdir(chk)) > 0) + self.assertFalse(os.path.isdir(fake1)) # should not have been created + self.assertFalse(os.path.isdir(fake2)) # should not have been created + finally: + cq.stop() + shutil.rmtree(tmpPath) def test_stream_await_termination(self): df = self.sqlCtx.read.format('text').stream('python/test_support/sql/streaming') + for cq in self.sqlCtx.streams.active: + cq.stop() tmpPath = tempfile.mkdtemp() shutil.rmtree(tmpPath) self.assertTrue(df.isStreaming) @@ -976,19 +986,50 @@ class SQLTests(ReusedPySparkTestCase): chk = os.path.join(tmpPath, 'chk') cq = df.write.startStream(path=out, format='parquet', queryName='this_query', checkpointLocation=chk) - self.assertTrue(cq.isActive) try: - cq.awaitTermination("hello") - self.fail("Expected a value exception") - except ValueError: - pass - now = time.time() - res = cq.awaitTermination(2600) # test should take at least 2 seconds - duration = time.time() - now - self.assertTrue(duration >= 2) - self.assertFalse(res) - cq.stop() + self.assertTrue(cq.isActive) + try: + cq.awaitTermination("hello") + self.fail("Expected a value exception") + except ValueError: + pass + now = time.time() + # test should take at least 2 seconds + res = cq.awaitTermination(2.6) + duration = time.time() - now + self.assertTrue(duration >= 2) + self.assertFalse(res) + finally: + cq.stop() + shutil.rmtree(tmpPath) + + def test_query_manager_await_termination(self): + df = self.sqlCtx.read.format('text').stream('python/test_support/sql/streaming') + for cq in self.sqlCtx.streams.active: + cq.stop() + tmpPath = tempfile.mkdtemp() shutil.rmtree(tmpPath) + self.assertTrue(df.isStreaming) + out = os.path.join(tmpPath, 'out') + chk = os.path.join(tmpPath, 'chk') + cq = df.write.startStream(path=out, format='parquet', queryName='this_query', + checkpointLocation=chk) + try: + self.assertTrue(cq.isActive) + try: + self.sqlCtx.streams.awaitAnyTermination("hello") + self.fail("Expected a value exception") + except ValueError: + pass + now = time.time() + # test should take at least 2 seconds + res = self.sqlCtx.streams.awaitAnyTermination(2.6) + duration = time.time() - now + self.assertTrue(duration >= 2) + self.assertFalse(res) + finally: + cq.stop() + shutil.rmtree(tmpPath) def test_help_command(self): # Regression test for SPARK-5464 diff --git a/python/pyspark/sql/utils.py b/python/pyspark/sql/utils.py index 7ea0e0d5c9..cb172d21f3 100644 --- a/python/pyspark/sql/utils.py +++ b/python/pyspark/sql/utils.py @@ -45,6 +45,12 @@ class IllegalArgumentException(CapturedException): """ +class ContinuousQueryException(CapturedException): + """ + Exception that stopped a :class:`ContinuousQuery`. + """ + + def capture_sql_exception(f): def deco(*a, **kw): try: @@ -57,6 +63,8 @@ def capture_sql_exception(f): raise AnalysisException(s.split(': ', 1)[1], stackTrace) if s.startswith('org.apache.spark.sql.catalyst.parser.ParseException: '): raise ParseException(s.split(': ', 1)[1], stackTrace) + if s.startswith('org.apache.spark.sql.ContinuousQueryException: '): + raise ContinuousQueryException(s.split(': ', 1)[1], stackTrace) if s.startswith('java.lang.IllegalArgumentException: '): raise IllegalArgumentException(s.split(': ', 1)[1], stackTrace) raise -- cgit v1.2.3