aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorShixiong Zhu <shixiong@databricks.com>2016-04-11 10:42:51 -0700
committerMichael Armbrust <michael@databricks.com>2016-04-11 10:42:51 -0700
commit2dacc81ec31233e558855a26340ad4662d470387 (patch)
treed755aeaae1c3349ac8784c713d9f96cf016b05ad /sql
parent5de26194a3aaeab9b7a8323107f614126c90441f (diff)
downloadspark-2dacc81ec31233e558855a26340ad4662d470387.tar.gz
spark-2dacc81ec31233e558855a26340ad4662d470387.tar.bz2
spark-2dacc81ec31233e558855a26340ad4662d470387.zip
[SPARK-14494][SQL] Fix the race conditions in MemoryStream and MemorySink
## What changes were proposed in this pull request? Make sure accessing mutable variables in MemoryStream and MemorySink are protected by `synchronized`. This is probably why MemorySinkSuite failed here: https://amplab.cs.berkeley.edu/jenkins/job/spark-master-test-sbt-hadoop-2.2/650/testReport/junit/org.apache.spark.sql.streaming/MemorySinkSuite/registering_as_a_table/ ## How was this patch tested? Existing unit tests. Author: Shixiong Zhu <shixiong@databricks.com> Closes #12261 from zsxwing/memory-race-condition.
Diffstat (limited to 'sql')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala25
1 files changed, 16 insertions, 9 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala
index 351ef404a8..3820968324 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala
@@ -18,6 +18,7 @@
package org.apache.spark.sql.execution.streaming
import java.util.concurrent.atomic.AtomicInteger
+import javax.annotation.concurrent.GuardedBy
import scala.collection.mutable.ArrayBuffer
import scala.util.control.NonFatal
@@ -47,8 +48,11 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext)
protected val encoder = encoderFor[A]
protected val logicalPlan = StreamingExecutionRelation(this)
protected val output = logicalPlan.output
+
+ @GuardedBy("this")
protected val batches = new ArrayBuffer[Dataset[A]]
+ @GuardedBy("this")
protected var currentOffset: LongOffset = new LongOffset(-1)
def schema: StructType = encoder.schema
@@ -67,10 +71,10 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext)
def addData(data: TraversableOnce[A]): Offset = {
import sqlContext.implicits._
+ val ds = data.toVector.toDS()
+ logDebug(s"Adding ds: $ds")
this.synchronized {
currentOffset = currentOffset + 1
- val ds = data.toVector.toDS()
- logDebug(s"Adding ds: $ds")
batches.append(ds)
currentOffset
}
@@ -78,10 +82,12 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext)
override def toString: String = s"MemoryStream[${output.mkString(",")}]"
- override def getOffset: Option[Offset] = if (batches.isEmpty) {
- None
- } else {
- Some(currentOffset)
+ override def getOffset: Option[Offset] = synchronized {
+ if (batches.isEmpty) {
+ None
+ } else {
+ Some(currentOffset)
+ }
}
/**
@@ -91,7 +97,7 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext)
val startOrdinal =
start.map(_.asInstanceOf[LongOffset]).getOrElse(LongOffset(-1)).offset.toInt + 1
val endOrdinal = end.asInstanceOf[LongOffset].offset.toInt + 1
- val newBlocks = batches.slice(startOrdinal, endOrdinal)
+ val newBlocks = synchronized { batches.slice(startOrdinal, endOrdinal) }
logDebug(
s"MemoryBatch [$startOrdinal, $endOrdinal]: ${newBlocks.flatMap(_.collect()).mkString(", ")}")
@@ -110,6 +116,7 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext)
*/
class MemorySink(val schema: StructType) extends Sink with Logging {
/** An order list of batches that have been written to this [[Sink]]. */
+ @GuardedBy("this")
private val batches = new ArrayBuffer[Array[Row]]()
/** Returns all rows that are stored in this [[Sink]]. */
@@ -117,7 +124,7 @@ class MemorySink(val schema: StructType) extends Sink with Logging {
batches.flatten
}
- def lastBatch: Seq[Row] = batches.last
+ def lastBatch: Seq[Row] = synchronized { batches.last }
def toDebugString: String = synchronized {
batches.zipWithIndex.map { case (b, i) =>
@@ -128,7 +135,7 @@ class MemorySink(val schema: StructType) extends Sink with Logging {
}.mkString("\n")
}
- override def addBatch(batchId: Long, data: DataFrame): Unit = {
+ override def addBatch(batchId: Long, data: DataFrame): Unit = synchronized {
if (batchId == batches.size) {
logDebug(s"Committing batch $batchId")
batches.append(data.collect())