diff options
Diffstat (limited to 'python/pyspark/sql')
-rw-r--r-- | python/pyspark/sql/dataframe.py | 12 | ||||
-rw-r--r-- | python/pyspark/sql/readwriter.py | 121 | ||||
-rw-r--r-- | python/pyspark/sql/streaming.py | 124 | ||||
-rw-r--r-- | python/pyspark/sql/tests.py | 93 |
4 files changed, 349 insertions, 1 deletions
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 328bda6601..bbe15f5f90 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -197,6 +197,18 @@ class DataFrame(object): """ return self._jdf.isLocal() + @property + @since(2.0) + def isStreaming(self): + """Returns true if this :class:`Dataset` contains one or more sources that continuously + return data as it arrives. A :class:`Dataset` that reads data from a streaming source + must be executed as a :class:`ContinuousQuery` using the :func:`startStream` method in + :class:`DataFrameWriter`. Methods that return a single answer, (e.g., :func:`count` or + :func:`collect`) will throw an :class:`AnalysisException` when there is a streaming + source present. + """ + return self._jdf.isStreaming() + @since(1.3) def show(self, n=20, truncate=True): """Prints the first ``n`` rows to the console. diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index 0cef37e57c..6c809d1139 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -22,7 +22,7 @@ if sys.version >= '3': from py4j.java_gateway import JavaClass -from pyspark import RDD, since +from pyspark import RDD, since, keyword_only from pyspark.rdd import ignore_unicode_prefix from pyspark.sql.column import _to_seq from pyspark.sql.types import * @@ -136,6 +136,32 @@ class DataFrameReader(object): else: return self._df(self._jreader.load()) + @since(2.0) + def stream(self, path=None, format=None, schema=None, **options): + """Loads a data stream from a data source and returns it as a :class`DataFrame`. + + :param path: optional string for file-system backed data sources. + :param format: optional string for format of the data source. Default to 'parquet'. + :param schema: optional :class:`StructType` for the input schema. + :param options: all other string options + + >>> df = sqlContext.read.format('text').stream('python/test_support/sql/streaming') + >>> df.isStreaming + True + """ + if format is not None: + self.format(format) + if schema is not None: + self.schema(schema) + self.options(**options) + if path is not None: + if type(path) != str or len(path.strip()) == 0: + raise ValueError("If the path is provided for stream, it needs to be a " + + "non-empty string. List of paths are not supported.") + return self._df(self._jreader.stream(path)) + else: + return self._df(self._jreader.stream()) + @since(1.4) def json(self, path, schema=None): """ @@ -334,6 +360,10 @@ class DataFrameWriter(object): self._sqlContext = df.sql_ctx self._jwrite = df._jdf.write() + def _cq(self, jcq): + from pyspark.sql.streaming import ContinuousQuery + return ContinuousQuery(jcq, self._sqlContext) + @since(1.4) def mode(self, saveMode): """Specifies the behavior when data or table already exists. @@ -395,6 +425,44 @@ class DataFrameWriter(object): self._jwrite = self._jwrite.partitionBy(_to_seq(self._sqlContext._sc, cols)) return self + @since(2.0) + def queryName(self, queryName): + """Specifies the name of the :class:`ContinuousQuery` that can be started with + :func:`startStream`. This name must be unique among all the currently active queries + in the associated SQLContext. + + :param queryName: unique name for the query + + >>> writer = sdf.write.queryName('streaming_query') + """ + if not queryName or type(queryName) != str or len(queryName.strip()) == 0: + raise ValueError('The queryName must be a non-empty string. Got: %s' % queryName) + self._jwrite = self._jwrite.queryName(queryName) + return self + + @keyword_only + @since(2.0) + def trigger(self, processingTime=None): + """Set the trigger for the stream query. If this is not set it will run the query as fast + as possible, which is equivalent to setting the trigger to ``processingTime='0 seconds'``. + + :param processingTime: a processing time interval as a string, e.g. '5 seconds', '1 minute'. + + >>> # trigger the query for execution every 5 seconds + >>> writer = sdf.write.trigger(processingTime='5 seconds') + """ + from pyspark.sql.streaming import ProcessingTime + trigger = None + if processingTime is not None: + if type(processingTime) != str or len(processingTime.strip()) == 0: + raise ValueError('The processing time must be a non empty string. Got: %s' % + processingTime) + trigger = ProcessingTime(processingTime) + if trigger is None: + raise ValueError('A trigger was not provided. Supported triggers: processingTime.') + self._jwrite = self._jwrite.trigger(trigger._to_java_trigger(self._sqlContext)) + return self + @since(1.4) def save(self, path=None, format=None, mode=None, partitionBy=None, **options): """Saves the contents of the :class:`DataFrame` to a data source. @@ -426,6 +494,55 @@ class DataFrameWriter(object): else: self._jwrite.save(path) + @ignore_unicode_prefix + @since(2.0) + def startStream(self, path=None, format=None, partitionBy=None, queryName=None, **options): + """Streams the contents of the :class:`DataFrame` to a data source. + + The data source is specified by the ``format`` and a set of ``options``. + If ``format`` is not specified, the default data source configured by + ``spark.sql.sources.default`` will be used. + + :param path: the path in a Hadoop supported file system + :param format: the format used to save + + * ``append``: Append contents of this :class:`DataFrame` to existing data. + * ``overwrite``: Overwrite existing data. + * ``ignore``: Silently ignore this operation if data already exists. + * ``error`` (default case): Throw an exception if data already exists. + :param partitionBy: names of partitioning columns + :param queryName: unique name for the query + :param options: All other string options. You may want to provide a `checkpointLocation` + for most streams, however it is not required for a `memory` stream. + + >>> cq = sdf.write.format('memory').queryName('this_query').startStream() + >>> cq.isActive + True + >>> cq.name + u'this_query' + >>> cq.stop() + >>> cq.isActive + False + >>> cq = sdf.write.trigger(processingTime='5 seconds').startStream( + ... queryName='that_query', format='memory') + >>> cq.name + u'that_query' + >>> cq.isActive + True + >>> cq.stop() + """ + self.options(**options) + if partitionBy is not None: + self.partitionBy(partitionBy) + if format is not None: + self.format(format) + if queryName is not None: + self.queryName(queryName) + if path is None: + return self._cq(self._jwrite.startStream()) + else: + return self._cq(self._jwrite.startStream(path)) + @since(1.4) def insertInto(self, tableName, overwrite=False): """Inserts the content of the :class:`DataFrame` to the specified table. @@ -625,6 +742,8 @@ def _test(): globs['sqlContext'] = SQLContext(sc) globs['hiveContext'] = HiveContext(sc) globs['df'] = globs['sqlContext'].read.parquet('python/test_support/sql/parquet_partitioned') + globs['sdf'] =\ + globs['sqlContext'].read.format('text').stream('python/test_support/sql/streaming') (failure_count, test_count) = doctest.testmod( pyspark.sql.readwriter, globs=globs, diff --git a/python/pyspark/sql/streaming.py b/python/pyspark/sql/streaming.py new file mode 100644 index 0000000000..549561669f --- /dev/null +++ b/python/pyspark/sql/streaming.py @@ -0,0 +1,124 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from abc import ABCMeta, abstractmethod + +from pyspark import since + +__all__ = ["ContinuousQuery"] + + +class ContinuousQuery(object): + """ + A handle to a query that is executing continuously in the background as new data arrives. + All these methods are thread-safe. + + .. note:: Experimental + + .. versionadded:: 2.0 + """ + + def __init__(self, jcq, sqlContext): + self._jcq = jcq + self._sqlContext = sqlContext + + @property + @since(2.0) + def name(self): + """The name of the continuous query. + """ + return self._jcq.name() + + @property + @since(2.0) + def isActive(self): + """Whether this continuous query is currently active or not. + """ + return self._jcq.isActive() + + @since(2.0) + def awaitTermination(self, timeoutMs=None): + """Waits for the termination of `this` query, either by :func:`query.stop()` or by an + exception. If the query has terminated with an exception, then the exception will be thrown. + If `timeoutMs` is set, it returns whether the query has terminated or not within the + `timeoutMs` milliseconds. + + If the query has terminated, then all subsequent calls to this method will either return + immediately (if the query was terminated by :func:`stop()`), or throw the exception + immediately (if the query has terminated with exception). + + throws ContinuousQueryException, if `this` query has terminated with an exception + """ + if timeoutMs is not None: + if type(timeoutMs) != int or timeoutMs < 0: + raise ValueError("timeoutMs must be a positive integer. Got %s" % timeoutMs) + return self._jcq.awaitTermination(timeoutMs) + else: + return self._jcq.awaitTermination() + + @since(2.0) + def processAllAvailable(self): + """Blocks until all available data in the source has been processed an committed to the + sink. This method is intended for testing. Note that in the case of continually arriving + data, this method may block forever. Additionally, this method is only guaranteed to block + until data that has been synchronously appended data to a stream source prior to invocation. + (i.e. `getOffset` must immediately reflect the addition). + """ + return self._jcq.processAllAvailable() + + @since(2.0) + def stop(self): + """Stop this continuous query. + """ + self._jcq.stop() + + +class Trigger(object): + """Used to indicate how often results should be produced by a :class:`ContinuousQuery`. + + .. note:: Experimental + + .. versionadded:: 2.0 + """ + + __metaclass__ = ABCMeta + + @abstractmethod + def _to_java_trigger(self, sqlContext): + """Internal method to construct the trigger on the jvm. + """ + pass + + +class ProcessingTime(Trigger): + """A trigger that runs a query periodically based on the processing time. If `interval` is 0, + the query will run as fast as possible. + + The interval should be given as a string, e.g. '2 seconds', '5 minutes', ... + + .. note:: Experimental + + .. versionadded:: 2.0 + """ + + def __init__(self, interval): + if interval is None or type(interval) != str or len(interval.strip()) == 0: + raise ValueError("interval should be a non empty interval string, e.g. '2 seconds'.") + self.interval = interval + + def _to_java_trigger(self, sqlContext): + return sqlContext._sc._jvm.org.apache.spark.sql.ProcessingTime.create(self.interval) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index d4c221d712..1e864b4cd1 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -879,6 +879,99 @@ class SQLTests(ReusedPySparkTestCase): shutil.rmtree(tmpPath) + def test_stream_trigger_takes_keyword_args(self): + df = self.sqlCtx.read.format('text').stream('python/test_support/sql/streaming') + try: + df.write.trigger('5 seconds') + self.fail("Should have thrown an exception") + except TypeError: + # should throw error + pass + + def test_stream_read_options(self): + schema = StructType([StructField("data", StringType(), False)]) + df = self.sqlCtx.read.format('text').option('path', 'python/test_support/sql/streaming')\ + .schema(schema).stream() + self.assertTrue(df.isStreaming) + self.assertEqual(df.schema.simpleString(), "struct<data:string>") + + def test_stream_read_options_overwrite(self): + bad_schema = StructType([StructField("test", IntegerType(), False)]) + schema = StructType([StructField("data", StringType(), False)]) + df = self.sqlCtx.read.format('csv').option('path', 'python/test_support/sql/fake') \ + .schema(bad_schema).stream(path='python/test_support/sql/streaming', + schema=schema, format='text') + self.assertTrue(df.isStreaming) + self.assertEqual(df.schema.simpleString(), "struct<data:string>") + + def test_stream_save_options(self): + df = self.sqlCtx.read.format('text').stream('python/test_support/sql/streaming') + tmpPath = tempfile.mkdtemp() + shutil.rmtree(tmpPath) + self.assertTrue(df.isStreaming) + out = os.path.join(tmpPath, 'out') + chk = os.path.join(tmpPath, 'chk') + cq = df.write.option('checkpointLocation', chk).queryName('this_query')\ + .format('parquet').option('path', out).startStream() + self.assertEqual(cq.name, 'this_query') + self.assertTrue(cq.isActive) + cq.processAllAvailable() + output_files = [] + for _, _, files in os.walk(out): + output_files.extend([f for f in files if 'parquet' in f and not f.startswith('.')]) + self.assertTrue(len(output_files) > 0) + self.assertTrue(len(os.listdir(chk)) > 0) + cq.stop() + shutil.rmtree(tmpPath) + + def test_stream_save_options_overwrite(self): + df = self.sqlCtx.read.format('text').stream('python/test_support/sql/streaming') + tmpPath = tempfile.mkdtemp() + shutil.rmtree(tmpPath) + self.assertTrue(df.isStreaming) + out = os.path.join(tmpPath, 'out') + chk = os.path.join(tmpPath, 'chk') + fake1 = os.path.join(tmpPath, 'fake1') + fake2 = os.path.join(tmpPath, 'fake2') + cq = df.write.option('checkpointLocation', fake1).format('memory').option('path', fake2) \ + .queryName('fake_query').startStream(path=out, format='parquet', queryName='this_query', + checkpointLocation=chk) + self.assertEqual(cq.name, 'this_query') + self.assertTrue(cq.isActive) + cq.processAllAvailable() + output_files = [] + for _, _, files in os.walk(out): + output_files.extend([f for f in files if 'parquet' in f and not f.startswith('.')]) + self.assertTrue(len(output_files) > 0) + self.assertTrue(len(os.listdir(chk)) > 0) + self.assertFalse(os.path.isdir(fake1)) # should not have been created + self.assertFalse(os.path.isdir(fake2)) # should not have been created + cq.stop() + shutil.rmtree(tmpPath) + + def test_stream_await_termination(self): + df = self.sqlCtx.read.format('text').stream('python/test_support/sql/streaming') + tmpPath = tempfile.mkdtemp() + shutil.rmtree(tmpPath) + self.assertTrue(df.isStreaming) + out = os.path.join(tmpPath, 'out') + chk = os.path.join(tmpPath, 'chk') + cq = df.write.startStream(path=out, format='parquet', queryName='this_query', + checkpointLocation=chk) + self.assertTrue(cq.isActive) + try: + cq.awaitTermination("hello") + self.fail("Expected a value exception") + except ValueError: + pass + now = time.time() + res = cq.awaitTermination(2600) # test should take at least 2 seconds + duration = time.time() - now + self.assertTrue(duration >= 2) + self.assertFalse(res) + cq.stop() + shutil.rmtree(tmpPath) + def test_help_command(self): # Regression test for SPARK-5464 rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}']) |