aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorBurak Yavuz <brkyvz@gmail.com>2016-04-28 15:22:28 -0700
committerTathagata Das <tathagata.das1565@gmail.com>2016-04-28 15:22:28 -0700
commit78c8aaf849aadbb065730959e7c1b70bb58d69c9 (patch)
tree8795a2ad1c070ec852809d5f6e81c5bbd1c3afd8
parentd584a2b8ac57eff3bf230c760e5bda205c6ea747 (diff)
downloadspark-78c8aaf849aadbb065730959e7c1b70bb58d69c9.tar.gz
spark-78c8aaf849aadbb065730959e7c1b70bb58d69c9.tar.bz2
spark-78c8aaf849aadbb065730959e7c1b70bb58d69c9.zip
[SPARK-14555] Second cut of Python API for Structured Streaming
## What changes were proposed in this pull request? This PR adds Python APIs for: - `ContinuousQueryManager` - `ContinuousQueryException` The `ContinuousQueryException` is a very basic wrapper, it doesn't provide the functionality that the Scala side provides, but it follows the same pattern for `AnalysisException`. For `ContinuousQueryManager`, all APIs are provided except for registering listeners. This PR also attempts to fix test flakiness by stopping all active streams just before tests. ## How was this patch tested? Python Doc tests and unit tests Author: Burak Yavuz <brkyvz@gmail.com> Closes #12673 from brkyvz/pyspark-cqm.
-rw-r--r--python/pyspark/sql/context.py9
-rw-r--r--python/pyspark/sql/readwriter.py2
-rw-r--r--python/pyspark/sql/streaming.py135
-rw-r--r--python/pyspark/sql/tests.py109
-rw-r--r--python/pyspark/sql/utils.py8
5 files changed, 217 insertions, 46 deletions
diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py
index 48ffb59668..a3ea192b28 100644
--- a/python/pyspark/sql/context.py
+++ b/python/pyspark/sql/context.py
@@ -423,6 +423,15 @@ class SQLContext(object):
"""
return DataFrameReader(self)
+ @property
+ @since(2.0)
+ def streams(self):
+ """Returns a :class:`ContinuousQueryManager` that allows managing all the
+ :class:`ContinuousQuery` ContinuousQueries active on `this` context.
+ """
+ from pyspark.sql.streaming import ContinuousQueryManager
+ return ContinuousQueryManager(self._ssql_ctx.streams())
+
# TODO(andrew): deprecate this
class HiveContext(SQLContext):
diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py
index 784609e4c5..ed9e716ab7 100644
--- a/python/pyspark/sql/readwriter.py
+++ b/python/pyspark/sql/readwriter.py
@@ -365,7 +365,7 @@ class DataFrameWriter(object):
def _cq(self, jcq):
from pyspark.sql.streaming import ContinuousQuery
- return ContinuousQuery(jcq, self._sqlContext)
+ return ContinuousQuery(jcq)
@since(1.4)
def mode(self, saveMode):
diff --git a/python/pyspark/sql/streaming.py b/python/pyspark/sql/streaming.py
index 549561669f..bf03fdca91 100644
--- a/python/pyspark/sql/streaming.py
+++ b/python/pyspark/sql/streaming.py
@@ -18,6 +18,7 @@
from abc import ABCMeta, abstractmethod
from pyspark import since
+from pyspark.rdd import ignore_unicode_prefix
__all__ = ["ContinuousQuery"]
@@ -32,9 +33,8 @@ class ContinuousQuery(object):
.. versionadded:: 2.0
"""
- def __init__(self, jcq, sqlContext):
+ def __init__(self, jcq):
self._jcq = jcq
- self._sqlContext = sqlContext
@property
@since(2.0)
@@ -51,22 +51,22 @@ class ContinuousQuery(object):
return self._jcq.isActive()
@since(2.0)
- def awaitTermination(self, timeoutMs=None):
+ def awaitTermination(self, timeout=None):
"""Waits for the termination of `this` query, either by :func:`query.stop()` or by an
exception. If the query has terminated with an exception, then the exception will be thrown.
- If `timeoutMs` is set, it returns whether the query has terminated or not within the
- `timeoutMs` milliseconds.
+ If `timeout` is set, it returns whether the query has terminated or not within the
+ `timeout` seconds.
If the query has terminated, then all subsequent calls to this method will either return
immediately (if the query was terminated by :func:`stop()`), or throw the exception
immediately (if the query has terminated with exception).
- throws ContinuousQueryException, if `this` query has terminated with an exception
+ throws :class:`ContinuousQueryException`, if `this` query has terminated with an exception
"""
- if timeoutMs is not None:
- if type(timeoutMs) != int or timeoutMs < 0:
- raise ValueError("timeoutMs must be a positive integer. Got %s" % timeoutMs)
- return self._jcq.awaitTermination(timeoutMs)
+ if timeout is not None:
+ if not isinstance(timeout, (int, float)) or timeout < 0:
+ raise ValueError("timeout must be a positive integer or float. Got %s" % timeout)
+ return self._jcq.awaitTermination(int(timeout * 1000))
else:
return self._jcq.awaitTermination()
@@ -87,6 +87,86 @@ class ContinuousQuery(object):
self._jcq.stop()
+class ContinuousQueryManager(object):
+ """A class to manage all the :class:`ContinuousQuery` ContinuousQueries active
+ on a :class:`SQLContext`.
+
+ .. note:: Experimental
+
+ .. versionadded:: 2.0
+ """
+
+ def __init__(self, jcqm):
+ self._jcqm = jcqm
+
+ @property
+ @ignore_unicode_prefix
+ @since(2.0)
+ def active(self):
+ """Returns a list of active queries associated with this SQLContext
+
+ >>> cq = df.write.format('memory').queryName('this_query').startStream()
+ >>> cqm = sqlContext.streams
+ >>> # get the list of active continuous queries
+ >>> [q.name for q in cqm.active]
+ [u'this_query']
+ >>> cq.stop()
+ """
+ return [ContinuousQuery(jcq) for jcq in self._jcqm.active()]
+
+ @since(2.0)
+ def get(self, name):
+ """Returns an active query from this SQLContext or throws exception if an active query
+ with this name doesn't exist.
+
+ >>> df.write.format('memory').queryName('this_query').startStream()
+ >>> cq = sqlContext.streams.get('this_query')
+ >>> cq.isActive
+ True
+ >>> cq.stop()
+ """
+ if type(name) != str or len(name.strip()) == 0:
+ raise ValueError("The name for the query must be a non-empty string. Got: %s" % name)
+ return ContinuousQuery(self._jcqm.get(name))
+
+ @since(2.0)
+ def awaitAnyTermination(self, timeout=None):
+ """Wait until any of the queries on the associated SQLContext has terminated since the
+ creation of the context, or since :func:`resetTerminated()` was called. If any query was
+ terminated with an exception, then the exception will be thrown.
+ If `timeout` is set, it returns whether the query has terminated or not within the
+ `timeout` seconds.
+
+ If a query has terminated, then subsequent calls to :func:`awaitAnyTermination()` will
+ either return immediately (if the query was terminated by :func:`query.stop()`),
+ or throw the exception immediately (if the query was terminated with exception). Use
+ :func:`resetTerminated()` to clear past terminations and wait for new terminations.
+
+ In the case where multiple queries have terminated since :func:`resetTermination()`
+ was called, if any query has terminated with exception, then :func:`awaitAnyTermination()`
+ will throw any of the exception. For correctly documenting exceptions across multiple
+ queries, users need to stop all of them after any of them terminates with exception, and
+ then check the `query.exception()` for each query.
+
+ throws :class:`ContinuousQueryException`, if `this` query has terminated with an exception
+ """
+ if timeout is not None:
+ if not isinstance(timeout, (int, float)) or timeout < 0:
+ raise ValueError("timeout must be a positive integer or float. Got %s" % timeout)
+ return self._jcqm.awaitAnyTermination(int(timeout * 1000))
+ else:
+ return self._jcqm.awaitAnyTermination()
+
+ @since(2.0)
+ def resetTerminated(self):
+ """Forget about past terminated queries so that :func:`awaitAnyTermination()` can be used
+ again to wait for new terminations.
+
+ >>> sqlContext.streams.resetTerminated()
+ """
+ self._jcqm.resetTerminated()
+
+
class Trigger(object):
"""Used to indicate how often results should be produced by a :class:`ContinuousQuery`.
@@ -116,9 +196,42 @@ class ProcessingTime(Trigger):
"""
def __init__(self, interval):
- if interval is None or type(interval) != str or len(interval.strip()) == 0:
+ if type(interval) != str or len(interval.strip()) == 0:
raise ValueError("interval should be a non empty interval string, e.g. '2 seconds'.")
self.interval = interval
def _to_java_trigger(self, sqlContext):
return sqlContext._sc._jvm.org.apache.spark.sql.ProcessingTime.create(self.interval)
+
+
+def _test():
+ import doctest
+ import os
+ import tempfile
+ from pyspark.context import SparkContext
+ from pyspark.sql import Row, SQLContext, HiveContext
+ import pyspark.sql.readwriter
+
+ os.chdir(os.environ["SPARK_HOME"])
+
+ globs = pyspark.sql.readwriter.__dict__.copy()
+ sc = SparkContext('local[4]', 'PythonTest')
+
+ globs['tempfile'] = tempfile
+ globs['os'] = os
+ globs['sc'] = sc
+ globs['sqlContext'] = SQLContext(sc)
+ globs['hiveContext'] = HiveContext(sc)
+ globs['df'] = \
+ globs['sqlContext'].read.format('text').stream('python/test_support/sql/streaming')
+
+ (failure_count, test_count) = doctest.testmod(
+ pyspark.sql.readwriter, globs=globs,
+ optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE | doctest.REPORT_NDIFF)
+ globs['sc'].stop()
+ if failure_count:
+ exit(-1)
+
+
+if __name__ == "__main__":
+ _test()
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 99a12d639a..1d3dc159da 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -924,26 +924,32 @@ class SQLTests(ReusedPySparkTestCase):
def test_stream_save_options(self):
df = self.sqlCtx.read.format('text').stream('python/test_support/sql/streaming')
+ for cq in self.sqlCtx.streams.active:
+ cq.stop()
tmpPath = tempfile.mkdtemp()
shutil.rmtree(tmpPath)
self.assertTrue(df.isStreaming)
out = os.path.join(tmpPath, 'out')
chk = os.path.join(tmpPath, 'chk')
- cq = df.write.option('checkpointLocation', chk).queryName('this_query')\
+ cq = df.write.option('checkpointLocation', chk).queryName('this_query') \
.format('parquet').option('path', out).startStream()
- self.assertEqual(cq.name, 'this_query')
- self.assertTrue(cq.isActive)
- cq.processAllAvailable()
- output_files = []
- for _, _, files in os.walk(out):
- output_files.extend([f for f in files if 'parquet' in f and not f.startswith('.')])
- self.assertTrue(len(output_files) > 0)
- self.assertTrue(len(os.listdir(chk)) > 0)
- cq.stop()
- shutil.rmtree(tmpPath)
+ try:
+ self.assertEqual(cq.name, 'this_query')
+ self.assertTrue(cq.isActive)
+ cq.processAllAvailable()
+ output_files = []
+ for _, _, files in os.walk(out):
+ output_files.extend([f for f in files if 'parquet' in f and not f.startswith('.')])
+ self.assertTrue(len(output_files) > 0)
+ self.assertTrue(len(os.listdir(chk)) > 0)
+ finally:
+ cq.stop()
+ shutil.rmtree(tmpPath)
def test_stream_save_options_overwrite(self):
df = self.sqlCtx.read.format('text').stream('python/test_support/sql/streaming')
+ for cq in self.sqlCtx.streams.active:
+ cq.stop()
tmpPath = tempfile.mkdtemp()
shutil.rmtree(tmpPath)
self.assertTrue(df.isStreaming)
@@ -954,21 +960,25 @@ class SQLTests(ReusedPySparkTestCase):
cq = df.write.option('checkpointLocation', fake1).format('memory').option('path', fake2) \
.queryName('fake_query').startStream(path=out, format='parquet', queryName='this_query',
checkpointLocation=chk)
- self.assertEqual(cq.name, 'this_query')
- self.assertTrue(cq.isActive)
- cq.processAllAvailable()
- output_files = []
- for _, _, files in os.walk(out):
- output_files.extend([f for f in files if 'parquet' in f and not f.startswith('.')])
- self.assertTrue(len(output_files) > 0)
- self.assertTrue(len(os.listdir(chk)) > 0)
- self.assertFalse(os.path.isdir(fake1)) # should not have been created
- self.assertFalse(os.path.isdir(fake2)) # should not have been created
- cq.stop()
- shutil.rmtree(tmpPath)
+ try:
+ self.assertEqual(cq.name, 'this_query')
+ self.assertTrue(cq.isActive)
+ cq.processAllAvailable()
+ output_files = []
+ for _, _, files in os.walk(out):
+ output_files.extend([f for f in files if 'parquet' in f and not f.startswith('.')])
+ self.assertTrue(len(output_files) > 0)
+ self.assertTrue(len(os.listdir(chk)) > 0)
+ self.assertFalse(os.path.isdir(fake1)) # should not have been created
+ self.assertFalse(os.path.isdir(fake2)) # should not have been created
+ finally:
+ cq.stop()
+ shutil.rmtree(tmpPath)
def test_stream_await_termination(self):
df = self.sqlCtx.read.format('text').stream('python/test_support/sql/streaming')
+ for cq in self.sqlCtx.streams.active:
+ cq.stop()
tmpPath = tempfile.mkdtemp()
shutil.rmtree(tmpPath)
self.assertTrue(df.isStreaming)
@@ -976,19 +986,50 @@ class SQLTests(ReusedPySparkTestCase):
chk = os.path.join(tmpPath, 'chk')
cq = df.write.startStream(path=out, format='parquet', queryName='this_query',
checkpointLocation=chk)
- self.assertTrue(cq.isActive)
try:
- cq.awaitTermination("hello")
- self.fail("Expected a value exception")
- except ValueError:
- pass
- now = time.time()
- res = cq.awaitTermination(2600) # test should take at least 2 seconds
- duration = time.time() - now
- self.assertTrue(duration >= 2)
- self.assertFalse(res)
- cq.stop()
+ self.assertTrue(cq.isActive)
+ try:
+ cq.awaitTermination("hello")
+ self.fail("Expected a value exception")
+ except ValueError:
+ pass
+ now = time.time()
+ # test should take at least 2 seconds
+ res = cq.awaitTermination(2.6)
+ duration = time.time() - now
+ self.assertTrue(duration >= 2)
+ self.assertFalse(res)
+ finally:
+ cq.stop()
+ shutil.rmtree(tmpPath)
+
+ def test_query_manager_await_termination(self):
+ df = self.sqlCtx.read.format('text').stream('python/test_support/sql/streaming')
+ for cq in self.sqlCtx.streams.active:
+ cq.stop()
+ tmpPath = tempfile.mkdtemp()
shutil.rmtree(tmpPath)
+ self.assertTrue(df.isStreaming)
+ out = os.path.join(tmpPath, 'out')
+ chk = os.path.join(tmpPath, 'chk')
+ cq = df.write.startStream(path=out, format='parquet', queryName='this_query',
+ checkpointLocation=chk)
+ try:
+ self.assertTrue(cq.isActive)
+ try:
+ self.sqlCtx.streams.awaitAnyTermination("hello")
+ self.fail("Expected a value exception")
+ except ValueError:
+ pass
+ now = time.time()
+ # test should take at least 2 seconds
+ res = self.sqlCtx.streams.awaitAnyTermination(2.6)
+ duration = time.time() - now
+ self.assertTrue(duration >= 2)
+ self.assertFalse(res)
+ finally:
+ cq.stop()
+ shutil.rmtree(tmpPath)
def test_help_command(self):
# Regression test for SPARK-5464
diff --git a/python/pyspark/sql/utils.py b/python/pyspark/sql/utils.py
index 7ea0e0d5c9..cb172d21f3 100644
--- a/python/pyspark/sql/utils.py
+++ b/python/pyspark/sql/utils.py
@@ -45,6 +45,12 @@ class IllegalArgumentException(CapturedException):
"""
+class ContinuousQueryException(CapturedException):
+ """
+ Exception that stopped a :class:`ContinuousQuery`.
+ """
+
+
def capture_sql_exception(f):
def deco(*a, **kw):
try:
@@ -57,6 +63,8 @@ def capture_sql_exception(f):
raise AnalysisException(s.split(': ', 1)[1], stackTrace)
if s.startswith('org.apache.spark.sql.catalyst.parser.ParseException: '):
raise ParseException(s.split(': ', 1)[1], stackTrace)
+ if s.startswith('org.apache.spark.sql.ContinuousQueryException: '):
+ raise ContinuousQueryException(s.split(': ', 1)[1], stackTrace)
if s.startswith('java.lang.IllegalArgumentException: '):
raise IllegalArgumentException(s.split(': ', 1)[1], stackTrace)
raise