From 2dacc81ec31233e558855a26340ad4662d470387 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Mon, 11 Apr 2016 10:42:51 -0700 Subject: [SPARK-14494][SQL] Fix the race conditions in MemoryStream and MemorySink ## What changes were proposed in this pull request? Make sure accessing mutable variables in MemoryStream and MemorySink are protected by `synchronized`. This is probably why MemorySinkSuite failed here: https://amplab.cs.berkeley.edu/jenkins/job/spark-master-test-sbt-hadoop-2.2/650/testReport/junit/org.apache.spark.sql.streaming/MemorySinkSuite/registering_as_a_table/ ## How was this patch tested? Existing unit tests. Author: Shixiong Zhu Closes #12261 from zsxwing/memory-race-condition. --- .../spark/sql/execution/streaming/memory.scala | 25 ++++++++++++++-------- 1 file changed, 16 insertions(+), 9 deletions(-) (limited to 'sql') diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala index 351ef404a8..3820968324 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.streaming import java.util.concurrent.atomic.AtomicInteger +import javax.annotation.concurrent.GuardedBy import scala.collection.mutable.ArrayBuffer import scala.util.control.NonFatal @@ -47,8 +48,11 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) protected val encoder = encoderFor[A] protected val logicalPlan = StreamingExecutionRelation(this) protected val output = logicalPlan.output + + @GuardedBy("this") protected val batches = new ArrayBuffer[Dataset[A]] + @GuardedBy("this") protected var currentOffset: LongOffset = new LongOffset(-1) def schema: StructType = encoder.schema @@ -67,10 +71,10 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) def addData(data: TraversableOnce[A]): Offset = { import sqlContext.implicits._ + val ds = data.toVector.toDS() + logDebug(s"Adding ds: $ds") this.synchronized { currentOffset = currentOffset + 1 - val ds = data.toVector.toDS() - logDebug(s"Adding ds: $ds") batches.append(ds) currentOffset } @@ -78,10 +82,12 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) override def toString: String = s"MemoryStream[${output.mkString(",")}]" - override def getOffset: Option[Offset] = if (batches.isEmpty) { - None - } else { - Some(currentOffset) + override def getOffset: Option[Offset] = synchronized { + if (batches.isEmpty) { + None + } else { + Some(currentOffset) + } } /** @@ -91,7 +97,7 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) val startOrdinal = start.map(_.asInstanceOf[LongOffset]).getOrElse(LongOffset(-1)).offset.toInt + 1 val endOrdinal = end.asInstanceOf[LongOffset].offset.toInt + 1 - val newBlocks = batches.slice(startOrdinal, endOrdinal) + val newBlocks = synchronized { batches.slice(startOrdinal, endOrdinal) } logDebug( s"MemoryBatch [$startOrdinal, $endOrdinal]: ${newBlocks.flatMap(_.collect()).mkString(", ")}") @@ -110,6 +116,7 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) */ class MemorySink(val schema: StructType) extends Sink with Logging { /** An order list of batches that have been written to this [[Sink]]. */ + @GuardedBy("this") private val batches = new ArrayBuffer[Array[Row]]() /** Returns all rows that are stored in this [[Sink]]. */ @@ -117,7 +124,7 @@ class MemorySink(val schema: StructType) extends Sink with Logging { batches.flatten } - def lastBatch: Seq[Row] = batches.last + def lastBatch: Seq[Row] = synchronized { batches.last } def toDebugString: String = synchronized { batches.zipWithIndex.map { case (b, i) => @@ -128,7 +135,7 @@ class MemorySink(val schema: StructType) extends Sink with Logging { }.mkString("\n") } - override def addBatch(batchId: Long, data: DataFrame): Unit = { + override def addBatch(batchId: Long, data: DataFrame): Unit = synchronized { if (batchId == batches.size) { logDebug(s"Committing batch $batchId") batches.append(data.collect()) -- cgit v1.2.3