aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/sql
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/sql')
-rw-r--r--python/pyspark/sql/dataframe.py12
-rw-r--r--python/pyspark/sql/readwriter.py121
-rw-r--r--python/pyspark/sql/streaming.py124
-rw-r--r--python/pyspark/sql/tests.py93
4 files changed, 349 insertions, 1 deletions
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index 328bda6601..bbe15f5f90 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -197,6 +197,18 @@ class DataFrame(object):
"""
return self._jdf.isLocal()
+ @property
+ @since(2.0)
+ def isStreaming(self):
+ """Returns true if this :class:`Dataset` contains one or more sources that continuously
+ return data as it arrives. A :class:`Dataset` that reads data from a streaming source
+ must be executed as a :class:`ContinuousQuery` using the :func:`startStream` method in
+ :class:`DataFrameWriter`. Methods that return a single answer, (e.g., :func:`count` or
+ :func:`collect`) will throw an :class:`AnalysisException` when there is a streaming
+ source present.
+ """
+ return self._jdf.isStreaming()
+
@since(1.3)
def show(self, n=20, truncate=True):
"""Prints the first ``n`` rows to the console.
diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py
index 0cef37e57c..6c809d1139 100644
--- a/python/pyspark/sql/readwriter.py
+++ b/python/pyspark/sql/readwriter.py
@@ -22,7 +22,7 @@ if sys.version >= '3':
from py4j.java_gateway import JavaClass
-from pyspark import RDD, since
+from pyspark import RDD, since, keyword_only
from pyspark.rdd import ignore_unicode_prefix
from pyspark.sql.column import _to_seq
from pyspark.sql.types import *
@@ -136,6 +136,32 @@ class DataFrameReader(object):
else:
return self._df(self._jreader.load())
+ @since(2.0)
+ def stream(self, path=None, format=None, schema=None, **options):
+ """Loads a data stream from a data source and returns it as a :class`DataFrame`.
+
+ :param path: optional string for file-system backed data sources.
+ :param format: optional string for format of the data source. Default to 'parquet'.
+ :param schema: optional :class:`StructType` for the input schema.
+ :param options: all other string options
+
+ >>> df = sqlContext.read.format('text').stream('python/test_support/sql/streaming')
+ >>> df.isStreaming
+ True
+ """
+ if format is not None:
+ self.format(format)
+ if schema is not None:
+ self.schema(schema)
+ self.options(**options)
+ if path is not None:
+ if type(path) != str or len(path.strip()) == 0:
+ raise ValueError("If the path is provided for stream, it needs to be a " +
+ "non-empty string. List of paths are not supported.")
+ return self._df(self._jreader.stream(path))
+ else:
+ return self._df(self._jreader.stream())
+
@since(1.4)
def json(self, path, schema=None):
"""
@@ -334,6 +360,10 @@ class DataFrameWriter(object):
self._sqlContext = df.sql_ctx
self._jwrite = df._jdf.write()
+ def _cq(self, jcq):
+ from pyspark.sql.streaming import ContinuousQuery
+ return ContinuousQuery(jcq, self._sqlContext)
+
@since(1.4)
def mode(self, saveMode):
"""Specifies the behavior when data or table already exists.
@@ -395,6 +425,44 @@ class DataFrameWriter(object):
self._jwrite = self._jwrite.partitionBy(_to_seq(self._sqlContext._sc, cols))
return self
+ @since(2.0)
+ def queryName(self, queryName):
+ """Specifies the name of the :class:`ContinuousQuery` that can be started with
+ :func:`startStream`. This name must be unique among all the currently active queries
+ in the associated SQLContext.
+
+ :param queryName: unique name for the query
+
+ >>> writer = sdf.write.queryName('streaming_query')
+ """
+ if not queryName or type(queryName) != str or len(queryName.strip()) == 0:
+ raise ValueError('The queryName must be a non-empty string. Got: %s' % queryName)
+ self._jwrite = self._jwrite.queryName(queryName)
+ return self
+
+ @keyword_only
+ @since(2.0)
+ def trigger(self, processingTime=None):
+ """Set the trigger for the stream query. If this is not set it will run the query as fast
+ as possible, which is equivalent to setting the trigger to ``processingTime='0 seconds'``.
+
+ :param processingTime: a processing time interval as a string, e.g. '5 seconds', '1 minute'.
+
+ >>> # trigger the query for execution every 5 seconds
+ >>> writer = sdf.write.trigger(processingTime='5 seconds')
+ """
+ from pyspark.sql.streaming import ProcessingTime
+ trigger = None
+ if processingTime is not None:
+ if type(processingTime) != str or len(processingTime.strip()) == 0:
+ raise ValueError('The processing time must be a non empty string. Got: %s' %
+ processingTime)
+ trigger = ProcessingTime(processingTime)
+ if trigger is None:
+ raise ValueError('A trigger was not provided. Supported triggers: processingTime.')
+ self._jwrite = self._jwrite.trigger(trigger._to_java_trigger(self._sqlContext))
+ return self
+
@since(1.4)
def save(self, path=None, format=None, mode=None, partitionBy=None, **options):
"""Saves the contents of the :class:`DataFrame` to a data source.
@@ -426,6 +494,55 @@ class DataFrameWriter(object):
else:
self._jwrite.save(path)
+ @ignore_unicode_prefix
+ @since(2.0)
+ def startStream(self, path=None, format=None, partitionBy=None, queryName=None, **options):
+ """Streams the contents of the :class:`DataFrame` to a data source.
+
+ The data source is specified by the ``format`` and a set of ``options``.
+ If ``format`` is not specified, the default data source configured by
+ ``spark.sql.sources.default`` will be used.
+
+ :param path: the path in a Hadoop supported file system
+ :param format: the format used to save
+
+ * ``append``: Append contents of this :class:`DataFrame` to existing data.
+ * ``overwrite``: Overwrite existing data.
+ * ``ignore``: Silently ignore this operation if data already exists.
+ * ``error`` (default case): Throw an exception if data already exists.
+ :param partitionBy: names of partitioning columns
+ :param queryName: unique name for the query
+ :param options: All other string options. You may want to provide a `checkpointLocation`
+ for most streams, however it is not required for a `memory` stream.
+
+ >>> cq = sdf.write.format('memory').queryName('this_query').startStream()
+ >>> cq.isActive
+ True
+ >>> cq.name
+ u'this_query'
+ >>> cq.stop()
+ >>> cq.isActive
+ False
+ >>> cq = sdf.write.trigger(processingTime='5 seconds').startStream(
+ ... queryName='that_query', format='memory')
+ >>> cq.name
+ u'that_query'
+ >>> cq.isActive
+ True
+ >>> cq.stop()
+ """
+ self.options(**options)
+ if partitionBy is not None:
+ self.partitionBy(partitionBy)
+ if format is not None:
+ self.format(format)
+ if queryName is not None:
+ self.queryName(queryName)
+ if path is None:
+ return self._cq(self._jwrite.startStream())
+ else:
+ return self._cq(self._jwrite.startStream(path))
+
@since(1.4)
def insertInto(self, tableName, overwrite=False):
"""Inserts the content of the :class:`DataFrame` to the specified table.
@@ -625,6 +742,8 @@ def _test():
globs['sqlContext'] = SQLContext(sc)
globs['hiveContext'] = HiveContext(sc)
globs['df'] = globs['sqlContext'].read.parquet('python/test_support/sql/parquet_partitioned')
+ globs['sdf'] =\
+ globs['sqlContext'].read.format('text').stream('python/test_support/sql/streaming')
(failure_count, test_count) = doctest.testmod(
pyspark.sql.readwriter, globs=globs,
diff --git a/python/pyspark/sql/streaming.py b/python/pyspark/sql/streaming.py
new file mode 100644
index 0000000000..549561669f
--- /dev/null
+++ b/python/pyspark/sql/streaming.py
@@ -0,0 +1,124 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from abc import ABCMeta, abstractmethod
+
+from pyspark import since
+
+__all__ = ["ContinuousQuery"]
+
+
+class ContinuousQuery(object):
+ """
+ A handle to a query that is executing continuously in the background as new data arrives.
+ All these methods are thread-safe.
+
+ .. note:: Experimental
+
+ .. versionadded:: 2.0
+ """
+
+ def __init__(self, jcq, sqlContext):
+ self._jcq = jcq
+ self._sqlContext = sqlContext
+
+ @property
+ @since(2.0)
+ def name(self):
+ """The name of the continuous query.
+ """
+ return self._jcq.name()
+
+ @property
+ @since(2.0)
+ def isActive(self):
+ """Whether this continuous query is currently active or not.
+ """
+ return self._jcq.isActive()
+
+ @since(2.0)
+ def awaitTermination(self, timeoutMs=None):
+ """Waits for the termination of `this` query, either by :func:`query.stop()` or by an
+ exception. If the query has terminated with an exception, then the exception will be thrown.
+ If `timeoutMs` is set, it returns whether the query has terminated or not within the
+ `timeoutMs` milliseconds.
+
+ If the query has terminated, then all subsequent calls to this method will either return
+ immediately (if the query was terminated by :func:`stop()`), or throw the exception
+ immediately (if the query has terminated with exception).
+
+ throws ContinuousQueryException, if `this` query has terminated with an exception
+ """
+ if timeoutMs is not None:
+ if type(timeoutMs) != int or timeoutMs < 0:
+ raise ValueError("timeoutMs must be a positive integer. Got %s" % timeoutMs)
+ return self._jcq.awaitTermination(timeoutMs)
+ else:
+ return self._jcq.awaitTermination()
+
+ @since(2.0)
+ def processAllAvailable(self):
+ """Blocks until all available data in the source has been processed an committed to the
+ sink. This method is intended for testing. Note that in the case of continually arriving
+ data, this method may block forever. Additionally, this method is only guaranteed to block
+ until data that has been synchronously appended data to a stream source prior to invocation.
+ (i.e. `getOffset` must immediately reflect the addition).
+ """
+ return self._jcq.processAllAvailable()
+
+ @since(2.0)
+ def stop(self):
+ """Stop this continuous query.
+ """
+ self._jcq.stop()
+
+
+class Trigger(object):
+ """Used to indicate how often results should be produced by a :class:`ContinuousQuery`.
+
+ .. note:: Experimental
+
+ .. versionadded:: 2.0
+ """
+
+ __metaclass__ = ABCMeta
+
+ @abstractmethod
+ def _to_java_trigger(self, sqlContext):
+ """Internal method to construct the trigger on the jvm.
+ """
+ pass
+
+
+class ProcessingTime(Trigger):
+ """A trigger that runs a query periodically based on the processing time. If `interval` is 0,
+ the query will run as fast as possible.
+
+ The interval should be given as a string, e.g. '2 seconds', '5 minutes', ...
+
+ .. note:: Experimental
+
+ .. versionadded:: 2.0
+ """
+
+ def __init__(self, interval):
+ if interval is None or type(interval) != str or len(interval.strip()) == 0:
+ raise ValueError("interval should be a non empty interval string, e.g. '2 seconds'.")
+ self.interval = interval
+
+ def _to_java_trigger(self, sqlContext):
+ return sqlContext._sc._jvm.org.apache.spark.sql.ProcessingTime.create(self.interval)
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index d4c221d712..1e864b4cd1 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -879,6 +879,99 @@ class SQLTests(ReusedPySparkTestCase):
shutil.rmtree(tmpPath)
+ def test_stream_trigger_takes_keyword_args(self):
+ df = self.sqlCtx.read.format('text').stream('python/test_support/sql/streaming')
+ try:
+ df.write.trigger('5 seconds')
+ self.fail("Should have thrown an exception")
+ except TypeError:
+ # should throw error
+ pass
+
+ def test_stream_read_options(self):
+ schema = StructType([StructField("data", StringType(), False)])
+ df = self.sqlCtx.read.format('text').option('path', 'python/test_support/sql/streaming')\
+ .schema(schema).stream()
+ self.assertTrue(df.isStreaming)
+ self.assertEqual(df.schema.simpleString(), "struct<data:string>")
+
+ def test_stream_read_options_overwrite(self):
+ bad_schema = StructType([StructField("test", IntegerType(), False)])
+ schema = StructType([StructField("data", StringType(), False)])
+ df = self.sqlCtx.read.format('csv').option('path', 'python/test_support/sql/fake') \
+ .schema(bad_schema).stream(path='python/test_support/sql/streaming',
+ schema=schema, format='text')
+ self.assertTrue(df.isStreaming)
+ self.assertEqual(df.schema.simpleString(), "struct<data:string>")
+
+ def test_stream_save_options(self):
+ df = self.sqlCtx.read.format('text').stream('python/test_support/sql/streaming')
+ tmpPath = tempfile.mkdtemp()
+ shutil.rmtree(tmpPath)
+ self.assertTrue(df.isStreaming)
+ out = os.path.join(tmpPath, 'out')
+ chk = os.path.join(tmpPath, 'chk')
+ cq = df.write.option('checkpointLocation', chk).queryName('this_query')\
+ .format('parquet').option('path', out).startStream()
+ self.assertEqual(cq.name, 'this_query')
+ self.assertTrue(cq.isActive)
+ cq.processAllAvailable()
+ output_files = []
+ for _, _, files in os.walk(out):
+ output_files.extend([f for f in files if 'parquet' in f and not f.startswith('.')])
+ self.assertTrue(len(output_files) > 0)
+ self.assertTrue(len(os.listdir(chk)) > 0)
+ cq.stop()
+ shutil.rmtree(tmpPath)
+
+ def test_stream_save_options_overwrite(self):
+ df = self.sqlCtx.read.format('text').stream('python/test_support/sql/streaming')
+ tmpPath = tempfile.mkdtemp()
+ shutil.rmtree(tmpPath)
+ self.assertTrue(df.isStreaming)
+ out = os.path.join(tmpPath, 'out')
+ chk = os.path.join(tmpPath, 'chk')
+ fake1 = os.path.join(tmpPath, 'fake1')
+ fake2 = os.path.join(tmpPath, 'fake2')
+ cq = df.write.option('checkpointLocation', fake1).format('memory').option('path', fake2) \
+ .queryName('fake_query').startStream(path=out, format='parquet', queryName='this_query',
+ checkpointLocation=chk)
+ self.assertEqual(cq.name, 'this_query')
+ self.assertTrue(cq.isActive)
+ cq.processAllAvailable()
+ output_files = []
+ for _, _, files in os.walk(out):
+ output_files.extend([f for f in files if 'parquet' in f and not f.startswith('.')])
+ self.assertTrue(len(output_files) > 0)
+ self.assertTrue(len(os.listdir(chk)) > 0)
+ self.assertFalse(os.path.isdir(fake1)) # should not have been created
+ self.assertFalse(os.path.isdir(fake2)) # should not have been created
+ cq.stop()
+ shutil.rmtree(tmpPath)
+
+ def test_stream_await_termination(self):
+ df = self.sqlCtx.read.format('text').stream('python/test_support/sql/streaming')
+ tmpPath = tempfile.mkdtemp()
+ shutil.rmtree(tmpPath)
+ self.assertTrue(df.isStreaming)
+ out = os.path.join(tmpPath, 'out')
+ chk = os.path.join(tmpPath, 'chk')
+ cq = df.write.startStream(path=out, format='parquet', queryName='this_query',
+ checkpointLocation=chk)
+ self.assertTrue(cq.isActive)
+ try:
+ cq.awaitTermination("hello")
+ self.fail("Expected a value exception")
+ except ValueError:
+ pass
+ now = time.time()
+ res = cq.awaitTermination(2600) # test should take at least 2 seconds
+ duration = time.time() - now
+ self.assertTrue(duration >= 2)
+ self.assertFalse(res)
+ cq.stop()
+ shutil.rmtree(tmpPath)
+
def test_help_command(self):
# Regression test for SPARK-5464
rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}'])