aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorShixiong Zhu <shixiong@databricks.com>2016-04-04 10:54:06 -0700
committerMichael Armbrust <michael@databricks.com>2016-04-04 10:54:06 -0700
commit855ed44ed31210d2001d7ce67c8fa99f8416edd3 (patch)
treefa495536694e76b3f4c9efa42e51eb24cee0476f
parent89f3befab6c150f87de2fb91b50ea8b414c69095 (diff)
downloadspark-855ed44ed31210d2001d7ce67c8fa99f8416edd3.tar.gz
spark-855ed44ed31210d2001d7ce67c8fa99f8416edd3.tar.bz2
spark-855ed44ed31210d2001d7ce67c8fa99f8416edd3.zip
[SPARK-14176][SQL] Add DataFrameWriter.trigger to set the stream batch period
## What changes were proposed in this pull request? Add a processing time trigger to control the batch processing speed ## How was this patch tested? Unit tests Author: Shixiong Zhu <shixiong@databricks.com> Closes #11976 from zsxwing/trigger.
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/ContinuousQueryManager.scala11
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala34
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/Trigger.scala133
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala24
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TriggerExecutor.scala72
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/ProcessingTimeSuite.scala40
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala6
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ProcessingTimeExecutorSuite.scala78
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/streaming/DataFrameReaderWriterSuite.scala28
9 files changed, 413 insertions, 13 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQueryManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQueryManager.scala
index 465feeb604..2306df09b8 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQueryManager.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQueryManager.scala
@@ -171,13 +171,20 @@ class ContinuousQueryManager(sqlContext: SQLContext) {
name: String,
checkpointLocation: String,
df: DataFrame,
- sink: Sink): ContinuousQuery = {
+ sink: Sink,
+ trigger: Trigger = ProcessingTime(0)): ContinuousQuery = {
activeQueriesLock.synchronized {
if (activeQueries.contains(name)) {
throw new IllegalArgumentException(
s"Cannot start query with name $name as a query with that name is already active")
}
- val query = new StreamExecution(sqlContext, name, checkpointLocation, df.logicalPlan, sink)
+ val query = new StreamExecution(
+ sqlContext,
+ name,
+ checkpointLocation,
+ df.logicalPlan,
+ sink,
+ trigger)
query.start()
activeQueries.put(name, query)
query
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
index c07bd0e7b7..3332a997cd 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
@@ -78,6 +78,35 @@ final class DataFrameWriter private[sql](df: DataFrame) {
}
/**
+ * :: Experimental ::
+ * Set the trigger for the stream query. The default value is `ProcessingTime(0)` and it will run
+ * the query as fast as possible.
+ *
+ * Scala Example:
+ * {{{
+ * def.writer.trigger(ProcessingTime("10 seconds"))
+ *
+ * import scala.concurrent.duration._
+ * def.writer.trigger(ProcessingTime(10.seconds))
+ * }}}
+ *
+ * Java Example:
+ * {{{
+ * def.writer.trigger(ProcessingTime.create("10 seconds"))
+ *
+ * import java.util.concurrent.TimeUnit
+ * def.writer.trigger(ProcessingTime.create(10, TimeUnit.SECONDS))
+ * }}}
+ *
+ * @since 2.0.0
+ */
+ @Experimental
+ def trigger(trigger: Trigger): DataFrameWriter = {
+ this.trigger = trigger
+ this
+ }
+
+ /**
* Specifies the underlying output data source. Built-in options include "parquet", "json", etc.
*
* @since 1.4.0
@@ -261,7 +290,8 @@ final class DataFrameWriter private[sql](df: DataFrame) {
queryName,
checkpointLocation,
df,
- dataSource.createSink())
+ dataSource.createSink(),
+ trigger)
}
/**
@@ -552,6 +582,8 @@ final class DataFrameWriter private[sql](df: DataFrame) {
private var mode: SaveMode = SaveMode.ErrorIfExists
+ private var trigger: Trigger = ProcessingTime(0L)
+
private var extraOptions = new scala.collection.mutable.HashMap[String, String]
private var partitioningColumns: Option[Seq[String]] = None
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Trigger.scala b/sql/core/src/main/scala/org/apache/spark/sql/Trigger.scala
new file mode 100644
index 0000000000..c4e54b3f90
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Trigger.scala
@@ -0,0 +1,133 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import java.util.concurrent.TimeUnit
+
+import scala.concurrent.duration.Duration
+
+import org.apache.commons.lang3.StringUtils
+
+import org.apache.spark.annotation.Experimental
+import org.apache.spark.unsafe.types.CalendarInterval
+
+/**
+ * :: Experimental ::
+ * Used to indicate how often results should be produced by a [[ContinuousQuery]].
+ */
+@Experimental
+sealed trait Trigger {}
+
+/**
+ * :: Experimental ::
+ * A trigger that runs a query periodically based on the processing time. If `intervalMs` is 0,
+ * the query will run as fast as possible.
+ *
+ * Scala Example:
+ * {{{
+ * def.writer.trigger(ProcessingTime("10 seconds"))
+ *
+ * import scala.concurrent.duration._
+ * def.writer.trigger(ProcessingTime(10.seconds))
+ * }}}
+ *
+ * Java Example:
+ * {{{
+ * def.writer.trigger(ProcessingTime.create("10 seconds"))
+ *
+ * import java.util.concurrent.TimeUnit
+ * def.writer.trigger(ProcessingTime.create(10, TimeUnit.SECONDS))
+ * }}}
+ */
+@Experimental
+case class ProcessingTime(intervalMs: Long) extends Trigger {
+ require(intervalMs >= 0, "the interval of trigger should not be negative")
+}
+
+/**
+ * :: Experimental ::
+ * Used to create [[ProcessingTime]] triggers for [[ContinuousQuery]]s.
+ */
+@Experimental
+object ProcessingTime {
+
+ /**
+ * Create a [[ProcessingTime]]. If `intervalMs` is 0, the query will run as fast as possible.
+ *
+ * Example:
+ * {{{
+ * def.writer.trigger(ProcessingTime("10 seconds"))
+ * }}}
+ */
+ def apply(interval: String): ProcessingTime = {
+ if (StringUtils.isBlank(interval)) {
+ throw new IllegalArgumentException(
+ "interval cannot be null or blank.")
+ }
+ val cal = if (interval.startsWith("interval")) {
+ CalendarInterval.fromString(interval)
+ } else {
+ CalendarInterval.fromString("interval " + interval)
+ }
+ if (cal == null) {
+ throw new IllegalArgumentException(s"Invalid interval: $interval")
+ }
+ if (cal.months > 0) {
+ throw new IllegalArgumentException(s"Doesn't support month or year interval: $interval")
+ }
+ new ProcessingTime(cal.microseconds / 1000)
+ }
+
+ /**
+ * Create a [[ProcessingTime]]. If `intervalMs` is 0, the query will run as fast as possible.
+ *
+ * Example:
+ * {{{
+ * import scala.concurrent.duration._
+ * def.writer.trigger(ProcessingTime(10.seconds))
+ * }}}
+ */
+ def apply(interval: Duration): ProcessingTime = {
+ new ProcessingTime(interval.toMillis)
+ }
+
+ /**
+ * Create a [[ProcessingTime]]. If `intervalMs` is 0, the query will run as fast as possible.
+ *
+ * Example:
+ * {{{
+ * def.writer.trigger(ProcessingTime.create("10 seconds"))
+ * }}}
+ */
+ def create(interval: String): ProcessingTime = {
+ apply(interval)
+ }
+
+ /**
+ * Create a [[ProcessingTime]]. If `intervalMs` is 0, the query will run as fast as possible.
+ *
+ * Example:
+ * {{{
+ * import java.util.concurrent.TimeUnit
+ * def.writer.trigger(ProcessingTime.create(10, TimeUnit.SECONDS))
+ * }}}
+ */
+ def create(interval: Long, unit: TimeUnit): ProcessingTime = {
+ new ProcessingTime(unit.toMillis(interval))
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
index 511e30c70c..64f80699ce 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
@@ -47,16 +47,14 @@ class StreamExecution(
override val name: String,
val checkpointRoot: String,
private[sql] val logicalPlan: LogicalPlan,
- val sink: Sink) extends ContinuousQuery with Logging {
+ val sink: Sink,
+ val trigger: Trigger) extends ContinuousQuery with Logging {
/** An monitor used to wait/notify when batches complete. */
private val awaitBatchLock = new Object
private val startLatch = new CountDownLatch(1)
private val terminationLatch = new CountDownLatch(1)
- /** Minimum amount of time in between the start of each batch. */
- private val minBatchTime = 10
-
/**
* Tracks how much data we have processed and committed to the sink or state store from each
* input source.
@@ -79,6 +77,10 @@ class StreamExecution(
/** A list of unique sources in the query plan. */
private val uniqueSources = sources.distinct
+ private val triggerExecutor = trigger match {
+ case t: ProcessingTime => ProcessingTimeExecutor(t)
+ }
+
/** Defines the internal state of execution */
@volatile
private var state: State = INITIALIZED
@@ -154,11 +156,15 @@ class StreamExecution(
SQLContext.setActive(sqlContext)
populateStartOffsets()
logDebug(s"Stream running from $committedOffsets to $availableOffsets")
- while (isActive) {
- if (dataAvailable) runBatch()
- commitAndConstructNextBatch()
- Thread.sleep(minBatchTime) // TODO: Could be tighter
- }
+ triggerExecutor.execute(() => {
+ if (isActive) {
+ if (dataAvailable) runBatch()
+ commitAndConstructNextBatch()
+ true
+ } else {
+ false
+ }
+ })
} catch {
case _: InterruptedException if state == TERMINATED => // interrupted by stop()
case NonFatal(e) =>
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TriggerExecutor.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TriggerExecutor.scala
new file mode 100644
index 0000000000..a1132d5106
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TriggerExecutor.scala
@@ -0,0 +1,72 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.ProcessingTime
+import org.apache.spark.util.{Clock, SystemClock}
+
+trait TriggerExecutor {
+
+ /**
+ * Execute batches using `batchRunner`. If `batchRunner` runs `false`, terminate the execution.
+ */
+ def execute(batchRunner: () => Boolean): Unit
+}
+
+/**
+ * A trigger executor that runs a batch every `intervalMs` milliseconds.
+ */
+case class ProcessingTimeExecutor(processingTime: ProcessingTime, clock: Clock = new SystemClock())
+ extends TriggerExecutor with Logging {
+
+ private val intervalMs = processingTime.intervalMs
+
+ override def execute(batchRunner: () => Boolean): Unit = {
+ while (true) {
+ val batchStartTimeMs = clock.getTimeMillis()
+ val terminated = !batchRunner()
+ if (intervalMs > 0) {
+ val batchEndTimeMs = clock.getTimeMillis()
+ val batchElapsedTimeMs = batchEndTimeMs - batchStartTimeMs
+ if (batchElapsedTimeMs > intervalMs) {
+ notifyBatchFallingBehind(batchElapsedTimeMs)
+ }
+ if (terminated) {
+ return
+ }
+ clock.waitTillTime(nextBatchTime(batchEndTimeMs))
+ } else {
+ if (terminated) {
+ return
+ }
+ }
+ }
+ }
+
+ /** Called when a batch falls behind. Expose for test only */
+ def notifyBatchFallingBehind(realElapsedTimeMs: Long): Unit = {
+ logWarning("Current batch is falling behind. The trigger interval is " +
+ s"${intervalMs} milliseconds, but spent ${realElapsedTimeMs} milliseconds")
+ }
+
+ /** Return the next multiple of intervalMs */
+ def nextBatchTime(now: Long): Long = {
+ (now - 1) / intervalMs * intervalMs + intervalMs
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ProcessingTimeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ProcessingTimeSuite.scala
new file mode 100644
index 0000000000..0d18a645f6
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/ProcessingTimeSuite.scala
@@ -0,0 +1,40 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import java.util.concurrent.TimeUnit
+
+import scala.concurrent.duration._
+
+import org.apache.spark.SparkFunSuite
+
+class ProcessingTimeSuite extends SparkFunSuite {
+
+ test("create") {
+ assert(ProcessingTime(10.seconds).intervalMs === 10 * 1000)
+ assert(ProcessingTime.create(10, TimeUnit.SECONDS).intervalMs === 10 * 1000)
+ assert(ProcessingTime("1 minute").intervalMs === 60 * 1000)
+ assert(ProcessingTime("interval 1 minute").intervalMs === 60 * 1000)
+
+ intercept[IllegalArgumentException] { ProcessingTime(null: String) }
+ intercept[IllegalArgumentException] { ProcessingTime("") }
+ intercept[IllegalArgumentException] { ProcessingTime("invalid") }
+ intercept[IllegalArgumentException] { ProcessingTime("1 month") }
+ intercept[IllegalArgumentException] { ProcessingTime("1 year") }
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala
index 550c3c6f9c..3444e56e9e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala
@@ -288,7 +288,11 @@ trait StreamTest extends QueryTest with Timeouts {
currentStream =
sqlContext
.streams
- .startQuery(StreamExecution.nextName, metadataRoot, stream, sink)
+ .startQuery(
+ StreamExecution.nextName,
+ metadataRoot,
+ stream,
+ sink)
.asInstanceOf[StreamExecution]
currentStream.microBatchThread.setUncaughtExceptionHandler(
new UncaughtExceptionHandler {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ProcessingTimeExecutorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ProcessingTimeExecutorSuite.scala
new file mode 100644
index 0000000000..dd5f92248b
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ProcessingTimeExecutorSuite.scala
@@ -0,0 +1,78 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import java.util.concurrent.{CountDownLatch, TimeUnit}
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.ProcessingTime
+import org.apache.spark.util.ManualClock
+
+class ProcessingTimeExecutorSuite extends SparkFunSuite {
+
+ test("nextBatchTime") {
+ val processingTimeExecutor = ProcessingTimeExecutor(ProcessingTime(100))
+ assert(processingTimeExecutor.nextBatchTime(1) === 100)
+ assert(processingTimeExecutor.nextBatchTime(99) === 100)
+ assert(processingTimeExecutor.nextBatchTime(100) === 100)
+ assert(processingTimeExecutor.nextBatchTime(101) === 200)
+ assert(processingTimeExecutor.nextBatchTime(150) === 200)
+ }
+
+ private def testBatchTermination(intervalMs: Long): Unit = {
+ var batchCounts = 0
+ val processingTimeExecutor = ProcessingTimeExecutor(ProcessingTime(intervalMs))
+ processingTimeExecutor.execute(() => {
+ batchCounts += 1
+ // If the batch termination works well, batchCounts should be 3 after `execute`
+ batchCounts < 3
+ })
+ assert(batchCounts === 3)
+ }
+
+ test("batch termination") {
+ testBatchTermination(0)
+ testBatchTermination(10)
+ }
+
+ test("notifyBatchFallingBehind") {
+ val clock = new ManualClock()
+ @volatile var batchFallingBehindCalled = false
+ val latch = new CountDownLatch(1)
+ val t = new Thread() {
+ override def run(): Unit = {
+ val processingTimeExecutor = new ProcessingTimeExecutor(ProcessingTime(100), clock) {
+ override def notifyBatchFallingBehind(realElapsedTimeMs: Long): Unit = {
+ batchFallingBehindCalled = true
+ }
+ }
+ processingTimeExecutor.execute(() => {
+ latch.countDown()
+ clock.waitTillTime(200)
+ false
+ })
+ }
+ }
+ t.start()
+ // Wait until the batch is running so that we don't call `advance` too early
+ assert(latch.await(10, TimeUnit.SECONDS), "the batch has not yet started in 10 seconds")
+ clock.advance(200)
+ t.join()
+ assert(batchFallingBehindCalled === true)
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/DataFrameReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/DataFrameReaderWriterSuite.scala
index 102473d7d0..28c558208f 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/DataFrameReaderWriterSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/DataFrameReaderWriterSuite.scala
@@ -17,6 +17,10 @@
package org.apache.spark.sql.streaming.test
+import java.util.concurrent.TimeUnit
+
+import scala.concurrent.duration._
+
import org.scalatest.BeforeAndAfter
import org.apache.spark.sql._
@@ -275,4 +279,28 @@ class DataFrameReaderWriterSuite extends StreamTest with SharedSQLContext with B
assert(activeStreamNames.contains("name"))
sqlContext.streams.active.foreach(_.stop())
}
+
+ test("trigger") {
+ val df = sqlContext.read
+ .format("org.apache.spark.sql.streaming.test")
+ .stream("/test")
+
+ var q = df.write
+ .format("org.apache.spark.sql.streaming.test")
+ .option("checkpointLocation", newMetadataDir)
+ .trigger(ProcessingTime(10.seconds))
+ .startStream()
+ q.stop()
+
+ assert(q.asInstanceOf[StreamExecution].trigger == ProcessingTime(10000))
+
+ q = df.write
+ .format("org.apache.spark.sql.streaming.test")
+ .option("checkpointLocation", newMetadataDir)
+ .trigger(ProcessingTime.create(100, TimeUnit.SECONDS))
+ .startStream()
+ q.stop()
+
+ assert(q.asInstanceOf[StreamExecution].trigger == ProcessingTime(100000))
+ }
}