diff options
author | Shixiong Zhu <shixiong@databricks.com> | 2016-04-11 10:42:51 -0700 |
---|---|---|
committer | Michael Armbrust <michael@databricks.com> | 2016-04-11 10:42:51 -0700 |
commit | 2dacc81ec31233e558855a26340ad4662d470387 (patch) | |
tree | d755aeaae1c3349ac8784c713d9f96cf016b05ad /sql | |
parent | 5de26194a3aaeab9b7a8323107f614126c90441f (diff) | |
download | spark-2dacc81ec31233e558855a26340ad4662d470387.tar.gz spark-2dacc81ec31233e558855a26340ad4662d470387.tar.bz2 spark-2dacc81ec31233e558855a26340ad4662d470387.zip |
[SPARK-14494][SQL] Fix the race conditions in MemoryStream and MemorySink
## What changes were proposed in this pull request?
Make sure accessing mutable variables in MemoryStream and MemorySink are protected by `synchronized`.
This is probably why MemorySinkSuite failed here: https://amplab.cs.berkeley.edu/jenkins/job/spark-master-test-sbt-hadoop-2.2/650/testReport/junit/org.apache.spark.sql.streaming/MemorySinkSuite/registering_as_a_table/
## How was this patch tested?
Existing unit tests.
Author: Shixiong Zhu <shixiong@databricks.com>
Closes #12261 from zsxwing/memory-race-condition.
Diffstat (limited to 'sql')
-rw-r--r-- | sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala | 25 |
1 files changed, 16 insertions, 9 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala index 351ef404a8..3820968324 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.streaming import java.util.concurrent.atomic.AtomicInteger +import javax.annotation.concurrent.GuardedBy import scala.collection.mutable.ArrayBuffer import scala.util.control.NonFatal @@ -47,8 +48,11 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) protected val encoder = encoderFor[A] protected val logicalPlan = StreamingExecutionRelation(this) protected val output = logicalPlan.output + + @GuardedBy("this") protected val batches = new ArrayBuffer[Dataset[A]] + @GuardedBy("this") protected var currentOffset: LongOffset = new LongOffset(-1) def schema: StructType = encoder.schema @@ -67,10 +71,10 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) def addData(data: TraversableOnce[A]): Offset = { import sqlContext.implicits._ + val ds = data.toVector.toDS() + logDebug(s"Adding ds: $ds") this.synchronized { currentOffset = currentOffset + 1 - val ds = data.toVector.toDS() - logDebug(s"Adding ds: $ds") batches.append(ds) currentOffset } @@ -78,10 +82,12 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) override def toString: String = s"MemoryStream[${output.mkString(",")}]" - override def getOffset: Option[Offset] = if (batches.isEmpty) { - None - } else { - Some(currentOffset) + override def getOffset: Option[Offset] = synchronized { + if (batches.isEmpty) { + None + } else { + Some(currentOffset) + } } /** @@ -91,7 +97,7 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) val startOrdinal = start.map(_.asInstanceOf[LongOffset]).getOrElse(LongOffset(-1)).offset.toInt + 1 val endOrdinal = end.asInstanceOf[LongOffset].offset.toInt + 1 - val newBlocks = batches.slice(startOrdinal, endOrdinal) + val newBlocks = synchronized { batches.slice(startOrdinal, endOrdinal) } logDebug( s"MemoryBatch [$startOrdinal, $endOrdinal]: ${newBlocks.flatMap(_.collect()).mkString(", ")}") @@ -110,6 +116,7 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) */ class MemorySink(val schema: StructType) extends Sink with Logging { /** An order list of batches that have been written to this [[Sink]]. */ + @GuardedBy("this") private val batches = new ArrayBuffer[Array[Row]]() /** Returns all rows that are stored in this [[Sink]]. */ @@ -117,7 +124,7 @@ class MemorySink(val schema: StructType) extends Sink with Logging { batches.flatten } - def lastBatch: Seq[Row] = batches.last + def lastBatch: Seq[Row] = synchronized { batches.last } def toDebugString: String = synchronized { batches.zipWithIndex.map { case (b, i) => @@ -128,7 +135,7 @@ class MemorySink(val schema: StructType) extends Sink with Logging { }.mkString("\n") } - override def addBatch(batchId: Long, data: DataFrame): Unit = { + override def addBatch(batchId: Long, data: DataFrame): Unit = synchronized { if (batchId == batches.size) { logDebug(s"Committing batch $batchId") batches.append(data.collect()) |