From 855ed44ed31210d2001d7ce67c8fa99f8416edd3 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Mon, 4 Apr 2016 10:54:06 -0700 Subject: [SPARK-14176][SQL] Add DataFrameWriter.trigger to set the stream batch period ## What changes were proposed in this pull request? Add a processing time trigger to control the batch processing speed ## How was this patch tested? Unit tests Author: Shixiong Zhu Closes #11976 from zsxwing/trigger. --- .../apache/spark/sql/ContinuousQueryManager.scala | 11 +- .../org/apache/spark/sql/DataFrameWriter.scala | 34 +++++- .../main/scala/org/apache/spark/sql/Trigger.scala | 133 +++++++++++++++++++++ .../sql/execution/streaming/StreamExecution.scala | 24 ++-- .../sql/execution/streaming/TriggerExecutor.scala | 72 +++++++++++ .../org/apache/spark/sql/ProcessingTimeSuite.scala | 40 +++++++ .../scala/org/apache/spark/sql/StreamTest.scala | 6 +- .../streaming/ProcessingTimeExecutorSuite.scala | 78 ++++++++++++ .../sql/streaming/DataFrameReaderWriterSuite.scala | 28 +++++ 9 files changed, 413 insertions(+), 13 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/Trigger.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TriggerExecutor.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/ProcessingTimeSuite.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ProcessingTimeExecutorSuite.scala (limited to 'sql') diff --git a/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQueryManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQueryManager.scala index 465feeb604..2306df09b8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQueryManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQueryManager.scala @@ -171,13 +171,20 @@ class ContinuousQueryManager(sqlContext: SQLContext) { name: String, checkpointLocation: String, df: DataFrame, - sink: Sink): ContinuousQuery = { + sink: Sink, + trigger: Trigger = ProcessingTime(0)): ContinuousQuery = { activeQueriesLock.synchronized { if (activeQueries.contains(name)) { throw new IllegalArgumentException( s"Cannot start query with name $name as a query with that name is already active") } - val query = new StreamExecution(sqlContext, name, checkpointLocation, df.logicalPlan, sink) + val query = new StreamExecution( + sqlContext, + name, + checkpointLocation, + df.logicalPlan, + sink, + trigger) query.start() activeQueries.put(name, query) query diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index c07bd0e7b7..3332a997cd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -77,6 +77,35 @@ final class DataFrameWriter private[sql](df: DataFrame) { this } + /** + * :: Experimental :: + * Set the trigger for the stream query. The default value is `ProcessingTime(0)` and it will run + * the query as fast as possible. + * + * Scala Example: + * {{{ + * def.writer.trigger(ProcessingTime("10 seconds")) + * + * import scala.concurrent.duration._ + * def.writer.trigger(ProcessingTime(10.seconds)) + * }}} + * + * Java Example: + * {{{ + * def.writer.trigger(ProcessingTime.create("10 seconds")) + * + * import java.util.concurrent.TimeUnit + * def.writer.trigger(ProcessingTime.create(10, TimeUnit.SECONDS)) + * }}} + * + * @since 2.0.0 + */ + @Experimental + def trigger(trigger: Trigger): DataFrameWriter = { + this.trigger = trigger + this + } + /** * Specifies the underlying output data source. Built-in options include "parquet", "json", etc. * @@ -261,7 +290,8 @@ final class DataFrameWriter private[sql](df: DataFrame) { queryName, checkpointLocation, df, - dataSource.createSink()) + dataSource.createSink(), + trigger) } /** @@ -552,6 +582,8 @@ final class DataFrameWriter private[sql](df: DataFrame) { private var mode: SaveMode = SaveMode.ErrorIfExists + private var trigger: Trigger = ProcessingTime(0L) + private var extraOptions = new scala.collection.mutable.HashMap[String, String] private var partitioningColumns: Option[Seq[String]] = None diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Trigger.scala b/sql/core/src/main/scala/org/apache/spark/sql/Trigger.scala new file mode 100644 index 0000000000..c4e54b3f90 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/Trigger.scala @@ -0,0 +1,133 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import java.util.concurrent.TimeUnit + +import scala.concurrent.duration.Duration + +import org.apache.commons.lang3.StringUtils + +import org.apache.spark.annotation.Experimental +import org.apache.spark.unsafe.types.CalendarInterval + +/** + * :: Experimental :: + * Used to indicate how often results should be produced by a [[ContinuousQuery]]. + */ +@Experimental +sealed trait Trigger {} + +/** + * :: Experimental :: + * A trigger that runs a query periodically based on the processing time. If `intervalMs` is 0, + * the query will run as fast as possible. + * + * Scala Example: + * {{{ + * def.writer.trigger(ProcessingTime("10 seconds")) + * + * import scala.concurrent.duration._ + * def.writer.trigger(ProcessingTime(10.seconds)) + * }}} + * + * Java Example: + * {{{ + * def.writer.trigger(ProcessingTime.create("10 seconds")) + * + * import java.util.concurrent.TimeUnit + * def.writer.trigger(ProcessingTime.create(10, TimeUnit.SECONDS)) + * }}} + */ +@Experimental +case class ProcessingTime(intervalMs: Long) extends Trigger { + require(intervalMs >= 0, "the interval of trigger should not be negative") +} + +/** + * :: Experimental :: + * Used to create [[ProcessingTime]] triggers for [[ContinuousQuery]]s. + */ +@Experimental +object ProcessingTime { + + /** + * Create a [[ProcessingTime]]. If `intervalMs` is 0, the query will run as fast as possible. + * + * Example: + * {{{ + * def.writer.trigger(ProcessingTime("10 seconds")) + * }}} + */ + def apply(interval: String): ProcessingTime = { + if (StringUtils.isBlank(interval)) { + throw new IllegalArgumentException( + "interval cannot be null or blank.") + } + val cal = if (interval.startsWith("interval")) { + CalendarInterval.fromString(interval) + } else { + CalendarInterval.fromString("interval " + interval) + } + if (cal == null) { + throw new IllegalArgumentException(s"Invalid interval: $interval") + } + if (cal.months > 0) { + throw new IllegalArgumentException(s"Doesn't support month or year interval: $interval") + } + new ProcessingTime(cal.microseconds / 1000) + } + + /** + * Create a [[ProcessingTime]]. If `intervalMs` is 0, the query will run as fast as possible. + * + * Example: + * {{{ + * import scala.concurrent.duration._ + * def.writer.trigger(ProcessingTime(10.seconds)) + * }}} + */ + def apply(interval: Duration): ProcessingTime = { + new ProcessingTime(interval.toMillis) + } + + /** + * Create a [[ProcessingTime]]. If `intervalMs` is 0, the query will run as fast as possible. + * + * Example: + * {{{ + * def.writer.trigger(ProcessingTime.create("10 seconds")) + * }}} + */ + def create(interval: String): ProcessingTime = { + apply(interval) + } + + /** + * Create a [[ProcessingTime]]. If `intervalMs` is 0, the query will run as fast as possible. + * + * Example: + * {{{ + * import java.util.concurrent.TimeUnit + * def.writer.trigger(ProcessingTime.create(10, TimeUnit.SECONDS)) + * }}} + */ + def create(interval: Long, unit: TimeUnit): ProcessingTime = { + new ProcessingTime(unit.toMillis(interval)) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala index 511e30c70c..64f80699ce 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -47,16 +47,14 @@ class StreamExecution( override val name: String, val checkpointRoot: String, private[sql] val logicalPlan: LogicalPlan, - val sink: Sink) extends ContinuousQuery with Logging { + val sink: Sink, + val trigger: Trigger) extends ContinuousQuery with Logging { /** An monitor used to wait/notify when batches complete. */ private val awaitBatchLock = new Object private val startLatch = new CountDownLatch(1) private val terminationLatch = new CountDownLatch(1) - /** Minimum amount of time in between the start of each batch. */ - private val minBatchTime = 10 - /** * Tracks how much data we have processed and committed to the sink or state store from each * input source. @@ -79,6 +77,10 @@ class StreamExecution( /** A list of unique sources in the query plan. */ private val uniqueSources = sources.distinct + private val triggerExecutor = trigger match { + case t: ProcessingTime => ProcessingTimeExecutor(t) + } + /** Defines the internal state of execution */ @volatile private var state: State = INITIALIZED @@ -154,11 +156,15 @@ class StreamExecution( SQLContext.setActive(sqlContext) populateStartOffsets() logDebug(s"Stream running from $committedOffsets to $availableOffsets") - while (isActive) { - if (dataAvailable) runBatch() - commitAndConstructNextBatch() - Thread.sleep(minBatchTime) // TODO: Could be tighter - } + triggerExecutor.execute(() => { + if (isActive) { + if (dataAvailable) runBatch() + commitAndConstructNextBatch() + true + } else { + false + } + }) } catch { case _: InterruptedException if state == TERMINATED => // interrupted by stop() case NonFatal(e) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TriggerExecutor.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TriggerExecutor.scala new file mode 100644 index 0000000000..a1132d5106 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TriggerExecutor.scala @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.ProcessingTime +import org.apache.spark.util.{Clock, SystemClock} + +trait TriggerExecutor { + + /** + * Execute batches using `batchRunner`. If `batchRunner` runs `false`, terminate the execution. + */ + def execute(batchRunner: () => Boolean): Unit +} + +/** + * A trigger executor that runs a batch every `intervalMs` milliseconds. + */ +case class ProcessingTimeExecutor(processingTime: ProcessingTime, clock: Clock = new SystemClock()) + extends TriggerExecutor with Logging { + + private val intervalMs = processingTime.intervalMs + + override def execute(batchRunner: () => Boolean): Unit = { + while (true) { + val batchStartTimeMs = clock.getTimeMillis() + val terminated = !batchRunner() + if (intervalMs > 0) { + val batchEndTimeMs = clock.getTimeMillis() + val batchElapsedTimeMs = batchEndTimeMs - batchStartTimeMs + if (batchElapsedTimeMs > intervalMs) { + notifyBatchFallingBehind(batchElapsedTimeMs) + } + if (terminated) { + return + } + clock.waitTillTime(nextBatchTime(batchEndTimeMs)) + } else { + if (terminated) { + return + } + } + } + } + + /** Called when a batch falls behind. Expose for test only */ + def notifyBatchFallingBehind(realElapsedTimeMs: Long): Unit = { + logWarning("Current batch is falling behind. The trigger interval is " + + s"${intervalMs} milliseconds, but spent ${realElapsedTimeMs} milliseconds") + } + + /** Return the next multiple of intervalMs */ + def nextBatchTime(now: Long): Long = { + (now - 1) / intervalMs * intervalMs + intervalMs + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ProcessingTimeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ProcessingTimeSuite.scala new file mode 100644 index 0000000000..0d18a645f6 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/ProcessingTimeSuite.scala @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import java.util.concurrent.TimeUnit + +import scala.concurrent.duration._ + +import org.apache.spark.SparkFunSuite + +class ProcessingTimeSuite extends SparkFunSuite { + + test("create") { + assert(ProcessingTime(10.seconds).intervalMs === 10 * 1000) + assert(ProcessingTime.create(10, TimeUnit.SECONDS).intervalMs === 10 * 1000) + assert(ProcessingTime("1 minute").intervalMs === 60 * 1000) + assert(ProcessingTime("interval 1 minute").intervalMs === 60 * 1000) + + intercept[IllegalArgumentException] { ProcessingTime(null: String) } + intercept[IllegalArgumentException] { ProcessingTime("") } + intercept[IllegalArgumentException] { ProcessingTime("invalid") } + intercept[IllegalArgumentException] { ProcessingTime("1 month") } + intercept[IllegalArgumentException] { ProcessingTime("1 year") } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala index 550c3c6f9c..3444e56e9e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala @@ -288,7 +288,11 @@ trait StreamTest extends QueryTest with Timeouts { currentStream = sqlContext .streams - .startQuery(StreamExecution.nextName, metadataRoot, stream, sink) + .startQuery( + StreamExecution.nextName, + metadataRoot, + stream, + sink) .asInstanceOf[StreamExecution] currentStream.microBatchThread.setUncaughtExceptionHandler( new UncaughtExceptionHandler { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ProcessingTimeExecutorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ProcessingTimeExecutorSuite.scala new file mode 100644 index 0000000000..dd5f92248b --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ProcessingTimeExecutorSuite.scala @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +import java.util.concurrent.{CountDownLatch, TimeUnit} + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.ProcessingTime +import org.apache.spark.util.ManualClock + +class ProcessingTimeExecutorSuite extends SparkFunSuite { + + test("nextBatchTime") { + val processingTimeExecutor = ProcessingTimeExecutor(ProcessingTime(100)) + assert(processingTimeExecutor.nextBatchTime(1) === 100) + assert(processingTimeExecutor.nextBatchTime(99) === 100) + assert(processingTimeExecutor.nextBatchTime(100) === 100) + assert(processingTimeExecutor.nextBatchTime(101) === 200) + assert(processingTimeExecutor.nextBatchTime(150) === 200) + } + + private def testBatchTermination(intervalMs: Long): Unit = { + var batchCounts = 0 + val processingTimeExecutor = ProcessingTimeExecutor(ProcessingTime(intervalMs)) + processingTimeExecutor.execute(() => { + batchCounts += 1 + // If the batch termination works well, batchCounts should be 3 after `execute` + batchCounts < 3 + }) + assert(batchCounts === 3) + } + + test("batch termination") { + testBatchTermination(0) + testBatchTermination(10) + } + + test("notifyBatchFallingBehind") { + val clock = new ManualClock() + @volatile var batchFallingBehindCalled = false + val latch = new CountDownLatch(1) + val t = new Thread() { + override def run(): Unit = { + val processingTimeExecutor = new ProcessingTimeExecutor(ProcessingTime(100), clock) { + override def notifyBatchFallingBehind(realElapsedTimeMs: Long): Unit = { + batchFallingBehindCalled = true + } + } + processingTimeExecutor.execute(() => { + latch.countDown() + clock.waitTillTime(200) + false + }) + } + } + t.start() + // Wait until the batch is running so that we don't call `advance` too early + assert(latch.await(10, TimeUnit.SECONDS), "the batch has not yet started in 10 seconds") + clock.advance(200) + t.join() + assert(batchFallingBehindCalled === true) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/DataFrameReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/DataFrameReaderWriterSuite.scala index 102473d7d0..28c558208f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/DataFrameReaderWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/DataFrameReaderWriterSuite.scala @@ -17,6 +17,10 @@ package org.apache.spark.sql.streaming.test +import java.util.concurrent.TimeUnit + +import scala.concurrent.duration._ + import org.scalatest.BeforeAndAfter import org.apache.spark.sql._ @@ -275,4 +279,28 @@ class DataFrameReaderWriterSuite extends StreamTest with SharedSQLContext with B assert(activeStreamNames.contains("name")) sqlContext.streams.active.foreach(_.stop()) } + + test("trigger") { + val df = sqlContext.read + .format("org.apache.spark.sql.streaming.test") + .stream("/test") + + var q = df.write + .format("org.apache.spark.sql.streaming.test") + .option("checkpointLocation", newMetadataDir) + .trigger(ProcessingTime(10.seconds)) + .startStream() + q.stop() + + assert(q.asInstanceOf[StreamExecution].trigger == ProcessingTime(10000)) + + q = df.write + .format("org.apache.spark.sql.streaming.test") + .option("checkpointLocation", newMetadataDir) + .trigger(ProcessingTime.create(100, TimeUnit.SECONDS)) + .startStream() + q.stop() + + assert(q.asInstanceOf[StreamExecution].trigger == ProcessingTime(100000)) + } } -- cgit v1.2.3