aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorShixiong Zhu <shixiong@databricks.com>2016-09-14 13:33:51 -0700
committerJosh Rosen <joshrosen@databricks.com>2016-09-14 13:33:51 -0700
commite33bfaed3b160fbc617c878067af17477a0044f5 (patch)
tree369215b3c286b9d9f2694ebec7f75bbcf7fdfba4
parentff6e4cbdc80e2ad84c5d70ee07f323fad9374e3e (diff)
downloadspark-e33bfaed3b160fbc617c878067af17477a0044f5.tar.gz
spark-e33bfaed3b160fbc617c878067af17477a0044f5.tar.bz2
spark-e33bfaed3b160fbc617c878067af17477a0044f5.zip
[SPARK-17463][CORE] Make CollectionAccumulator and SetAccumulator's value can be read thread-safely
## What changes were proposed in this pull request? Make CollectionAccumulator and SetAccumulator's value can be read thread-safely to fix the ConcurrentModificationException reported in [JIRA](https://issues.apache.org/jira/browse/SPARK-17463). ## How was this patch tested? Existing tests. Author: Shixiong Zhu <shixiong@databricks.com> Closes #15063 from zsxwing/SPARK-17463.
-rw-r--r--core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala41
-rw-r--r--core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala7
-rw-r--r--core/src/main/scala/org/apache/spark/util/JsonProtocol.scala11
-rw-r--r--core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala3
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala24
5 files changed, 54 insertions, 32 deletions
diff --git a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala
index dd149a919f..52a349919e 100644
--- a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala
+++ b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala
@@ -17,6 +17,9 @@
package org.apache.spark.executor
+import java.util.{ArrayList, Collections}
+
+import scala.collection.JavaConverters._
import scala.collection.mutable.{ArrayBuffer, LinkedHashMap}
import org.apache.spark._
@@ -99,7 +102,11 @@ class TaskMetrics private[spark] () extends Serializable {
/**
* Storage statuses of any blocks that have been updated as a result of this task.
*/
- def updatedBlockStatuses: Seq[(BlockId, BlockStatus)] = _updatedBlockStatuses.value
+ def updatedBlockStatuses: Seq[(BlockId, BlockStatus)] = {
+ // This is called on driver. All accumulator updates have a fixed value. So it's safe to use
+ // `asScala` which accesses the internal values using `java.util.Iterator`.
+ _updatedBlockStatuses.value.asScala
+ }
// Setters and increment-ers
private[spark] def setExecutorDeserializeTime(v: Long): Unit =
@@ -114,8 +121,10 @@ class TaskMetrics private[spark] () extends Serializable {
private[spark] def incPeakExecutionMemory(v: Long): Unit = _peakExecutionMemory.add(v)
private[spark] def incUpdatedBlockStatuses(v: (BlockId, BlockStatus)): Unit =
_updatedBlockStatuses.add(v)
- private[spark] def setUpdatedBlockStatuses(v: Seq[(BlockId, BlockStatus)]): Unit =
+ private[spark] def setUpdatedBlockStatuses(v: java.util.List[(BlockId, BlockStatus)]): Unit =
_updatedBlockStatuses.setValue(v)
+ private[spark] def setUpdatedBlockStatuses(v: Seq[(BlockId, BlockStatus)]): Unit =
+ _updatedBlockStatuses.setValue(v.asJava)
/**
* Metrics related to reading data from a [[org.apache.spark.rdd.HadoopRDD]] or from persisted
@@ -268,7 +277,7 @@ private[spark] object TaskMetrics extends Logging {
val name = info.name.get
val value = info.update.get
if (name == UPDATED_BLOCK_STATUSES) {
- tm.setUpdatedBlockStatuses(value.asInstanceOf[Seq[(BlockId, BlockStatus)]])
+ tm.setUpdatedBlockStatuses(value.asInstanceOf[java.util.List[(BlockId, BlockStatus)]])
} else {
tm.nameToAccums.get(name).foreach(
_.asInstanceOf[LongAccumulator].setValue(value.asInstanceOf[Long])
@@ -299,8 +308,8 @@ private[spark] object TaskMetrics extends Logging {
private[spark] class BlockStatusesAccumulator
- extends AccumulatorV2[(BlockId, BlockStatus), Seq[(BlockId, BlockStatus)]] {
- private var _seq = ArrayBuffer.empty[(BlockId, BlockStatus)]
+ extends AccumulatorV2[(BlockId, BlockStatus), java.util.List[(BlockId, BlockStatus)]] {
+ private val _seq = Collections.synchronizedList(new ArrayList[(BlockId, BlockStatus)]())
override def isZero(): Boolean = _seq.isEmpty
@@ -308,25 +317,27 @@ private[spark] class BlockStatusesAccumulator
override def copy(): BlockStatusesAccumulator = {
val newAcc = new BlockStatusesAccumulator
- newAcc._seq = _seq.clone()
+ newAcc._seq.addAll(_seq)
newAcc
}
override def reset(): Unit = _seq.clear()
- override def add(v: (BlockId, BlockStatus)): Unit = _seq += v
+ override def add(v: (BlockId, BlockStatus)): Unit = _seq.add(v)
- override def merge(other: AccumulatorV2[(BlockId, BlockStatus), Seq[(BlockId, BlockStatus)]])
- : Unit = other match {
- case o: BlockStatusesAccumulator => _seq ++= o.value
- case _ => throw new UnsupportedOperationException(
- s"Cannot merge ${this.getClass.getName} with ${other.getClass.getName}")
+ override def merge(
+ other: AccumulatorV2[(BlockId, BlockStatus), java.util.List[(BlockId, BlockStatus)]]): Unit = {
+ other match {
+ case o: BlockStatusesAccumulator => _seq.addAll(o.value)
+ case _ => throw new UnsupportedOperationException(
+ s"Cannot merge ${this.getClass.getName} with ${other.getClass.getName}")
+ }
}
- override def value: Seq[(BlockId, BlockStatus)] = _seq
+ override def value: java.util.List[(BlockId, BlockStatus)] = _seq
- def setValue(newValue: Seq[(BlockId, BlockStatus)]): Unit = {
+ def setValue(newValue: java.util.List[(BlockId, BlockStatus)]): Unit = {
_seq.clear()
- _seq ++= newValue
+ _seq.addAll(newValue)
}
}
diff --git a/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala b/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala
index d130a37db5..470d912ecf 100644
--- a/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala
+++ b/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala
@@ -19,7 +19,7 @@ package org.apache.spark.util
import java.{lang => jl}
import java.io.ObjectInputStream
-import java.util.ArrayList
+import java.util.{ArrayList, Collections}
import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.atomic.AtomicLong
@@ -38,6 +38,9 @@ private[spark] case class AccumulatorMetadata(
/**
* The base class for accumulators, that can accumulate inputs of type `IN`, and produce output of
* type `OUT`.
+ *
+ * `OUT` should be a type that can be read atomically (e.g., Int, Long), or thread-safely
+ * (e.g., synchronized collections) because it will be read from other threads.
*/
abstract class AccumulatorV2[IN, OUT] extends Serializable {
private[spark] var metadata: AccumulatorMetadata = _
@@ -433,7 +436,7 @@ class DoubleAccumulator extends AccumulatorV2[jl.Double, jl.Double] {
* @since 2.0.0
*/
class CollectionAccumulator[T] extends AccumulatorV2[T, java.util.List[T]] {
- private val _list: java.util.List[T] = new ArrayList[T]
+ private val _list: java.util.List[T] = Collections.synchronizedList(new ArrayList[T]())
override def isZero: Boolean = _list.isEmpty
diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala
index 022b226894..41d947c442 100644
--- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala
+++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala
@@ -310,11 +310,12 @@ private[spark] object JsonProtocol {
case v: Int => JInt(v)
case v: Long => JInt(v)
// We only have 3 kind of internal accumulator types, so if it's not int or long, it must be
- // the blocks accumulator, whose type is `Seq[(BlockId, BlockStatus)]`
+ // the blocks accumulator, whose type is `java.util.List[(BlockId, BlockStatus)]`
case v =>
- JArray(v.asInstanceOf[Seq[(BlockId, BlockStatus)]].toList.map { case (id, status) =>
- ("Block ID" -> id.toString) ~
- ("Status" -> blockStatusToJson(status))
+ JArray(v.asInstanceOf[java.util.List[(BlockId, BlockStatus)]].asScala.toList.map {
+ case (id, status) =>
+ ("Block ID" -> id.toString) ~
+ ("Status" -> blockStatusToJson(status))
})
}
} else {
@@ -743,7 +744,7 @@ private[spark] object JsonProtocol {
val id = BlockId((blockJson \ "Block ID").extract[String])
val status = blockStatusFromJson(blockJson \ "Status")
(id, status)
- }
+ }.asJava
case _ => throw new IllegalArgumentException(s"unexpected json value $value for " +
"accumulator " + name.get)
}
diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala
index 85ca9d39d4..c89be22a34 100644
--- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala
@@ -19,6 +19,7 @@ package org.apache.spark.util
import java.util.Properties
+import scala.collection.JavaConverters._
import scala.collection.Map
import org.json4s.jackson.JsonMethods._
@@ -415,7 +416,7 @@ class JsonProtocolSuite extends SparkFunSuite {
})
testAccumValue(Some(RESULT_SIZE), 3L, JInt(3))
testAccumValue(Some(shuffleRead.REMOTE_BLOCKS_FETCHED), 2, JInt(2))
- testAccumValue(Some(UPDATED_BLOCK_STATUSES), blocks, blocksJson)
+ testAccumValue(Some(UPDATED_BLOCK_STATUSES), blocks.asJava, blocksJson)
// For anything else, we just cast the value to a string
testAccumValue(Some("anything"), blocks, JString(blocks.toString))
testAccumValue(Some("anything"), 123, JString("123"))
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala
index 082f97a880..d321f4cd76 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala
@@ -17,7 +17,9 @@
package org.apache.spark.sql.execution
-import scala.collection.mutable.HashSet
+import java.util.Collections
+
+import scala.collection.JavaConverters._
import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
@@ -107,18 +109,20 @@ package object debug {
case class DebugExec(child: SparkPlan) extends UnaryExecNode with CodegenSupport {
def output: Seq[Attribute] = child.output
- class SetAccumulator[T] extends AccumulatorV2[T, HashSet[T]] {
- private val _set = new HashSet[T]()
+ class SetAccumulator[T] extends AccumulatorV2[T, java.util.Set[T]] {
+ private val _set = Collections.synchronizedSet(new java.util.HashSet[T]())
override def isZero: Boolean = _set.isEmpty
- override def copy(): AccumulatorV2[T, HashSet[T]] = {
+ override def copy(): AccumulatorV2[T, java.util.Set[T]] = {
val newAcc = new SetAccumulator[T]()
- newAcc._set ++= _set
+ newAcc._set.addAll(_set)
newAcc
}
override def reset(): Unit = _set.clear()
- override def add(v: T): Unit = _set += v
- override def merge(other: AccumulatorV2[T, HashSet[T]]): Unit = _set ++= other.value
- override def value: HashSet[T] = _set
+ override def add(v: T): Unit = _set.add(v)
+ override def merge(other: AccumulatorV2[T, java.util.Set[T]]): Unit = {
+ _set.addAll(other.value)
+ }
+ override def value: java.util.Set[T] = _set
}
/**
@@ -138,7 +142,9 @@ package object debug {
debugPrint(s"== ${child.simpleString} ==")
debugPrint(s"Tuples output: ${tupleCount.value}")
child.output.zip(columnStats).foreach { case (attr, metric) =>
- val actualDataTypes = metric.elementTypes.value.mkString("{", ",", "}")
+ // This is called on driver. All accumulator updates have a fixed value. So it's safe to use
+ // `asScala` which accesses the internal values using `java.util.Iterator`.
+ val actualDataTypes = metric.elementTypes.value.asScala.mkString("{", ",", "}")
debugPrint(s" ${attr.name} ${attr.dataType}: $actualDataTypes")
}
}