diff options
5 files changed, 54 insertions, 32 deletions
diff --git a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala index dd149a919f..52a349919e 100644 --- a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala +++ b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala @@ -17,6 +17,9 @@ package org.apache.spark.executor +import java.util.{ArrayList, Collections} + +import scala.collection.JavaConverters._ import scala.collection.mutable.{ArrayBuffer, LinkedHashMap} import org.apache.spark._ @@ -99,7 +102,11 @@ class TaskMetrics private[spark] () extends Serializable { /** * Storage statuses of any blocks that have been updated as a result of this task. */ - def updatedBlockStatuses: Seq[(BlockId, BlockStatus)] = _updatedBlockStatuses.value + def updatedBlockStatuses: Seq[(BlockId, BlockStatus)] = { + // This is called on driver. All accumulator updates have a fixed value. So it's safe to use + // `asScala` which accesses the internal values using `java.util.Iterator`. + _updatedBlockStatuses.value.asScala + } // Setters and increment-ers private[spark] def setExecutorDeserializeTime(v: Long): Unit = @@ -114,8 +121,10 @@ class TaskMetrics private[spark] () extends Serializable { private[spark] def incPeakExecutionMemory(v: Long): Unit = _peakExecutionMemory.add(v) private[spark] def incUpdatedBlockStatuses(v: (BlockId, BlockStatus)): Unit = _updatedBlockStatuses.add(v) - private[spark] def setUpdatedBlockStatuses(v: Seq[(BlockId, BlockStatus)]): Unit = + private[spark] def setUpdatedBlockStatuses(v: java.util.List[(BlockId, BlockStatus)]): Unit = _updatedBlockStatuses.setValue(v) + private[spark] def setUpdatedBlockStatuses(v: Seq[(BlockId, BlockStatus)]): Unit = + _updatedBlockStatuses.setValue(v.asJava) /** * Metrics related to reading data from a [[org.apache.spark.rdd.HadoopRDD]] or from persisted @@ -268,7 +277,7 @@ private[spark] object TaskMetrics extends Logging { val name = info.name.get val value = info.update.get if (name == UPDATED_BLOCK_STATUSES) { - tm.setUpdatedBlockStatuses(value.asInstanceOf[Seq[(BlockId, BlockStatus)]]) + tm.setUpdatedBlockStatuses(value.asInstanceOf[java.util.List[(BlockId, BlockStatus)]]) } else { tm.nameToAccums.get(name).foreach( _.asInstanceOf[LongAccumulator].setValue(value.asInstanceOf[Long]) @@ -299,8 +308,8 @@ private[spark] object TaskMetrics extends Logging { private[spark] class BlockStatusesAccumulator - extends AccumulatorV2[(BlockId, BlockStatus), Seq[(BlockId, BlockStatus)]] { - private var _seq = ArrayBuffer.empty[(BlockId, BlockStatus)] + extends AccumulatorV2[(BlockId, BlockStatus), java.util.List[(BlockId, BlockStatus)]] { + private val _seq = Collections.synchronizedList(new ArrayList[(BlockId, BlockStatus)]()) override def isZero(): Boolean = _seq.isEmpty @@ -308,25 +317,27 @@ private[spark] class BlockStatusesAccumulator override def copy(): BlockStatusesAccumulator = { val newAcc = new BlockStatusesAccumulator - newAcc._seq = _seq.clone() + newAcc._seq.addAll(_seq) newAcc } override def reset(): Unit = _seq.clear() - override def add(v: (BlockId, BlockStatus)): Unit = _seq += v + override def add(v: (BlockId, BlockStatus)): Unit = _seq.add(v) - override def merge(other: AccumulatorV2[(BlockId, BlockStatus), Seq[(BlockId, BlockStatus)]]) - : Unit = other match { - case o: BlockStatusesAccumulator => _seq ++= o.value - case _ => throw new UnsupportedOperationException( - s"Cannot merge ${this.getClass.getName} with ${other.getClass.getName}") + override def merge( + other: AccumulatorV2[(BlockId, BlockStatus), java.util.List[(BlockId, BlockStatus)]]): Unit = { + other match { + case o: BlockStatusesAccumulator => _seq.addAll(o.value) + case _ => throw new UnsupportedOperationException( + s"Cannot merge ${this.getClass.getName} with ${other.getClass.getName}") + } } - override def value: Seq[(BlockId, BlockStatus)] = _seq + override def value: java.util.List[(BlockId, BlockStatus)] = _seq - def setValue(newValue: Seq[(BlockId, BlockStatus)]): Unit = { + def setValue(newValue: java.util.List[(BlockId, BlockStatus)]): Unit = { _seq.clear() - _seq ++= newValue + _seq.addAll(newValue) } } diff --git a/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala b/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala index d130a37db5..470d912ecf 100644 --- a/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala +++ b/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala @@ -19,7 +19,7 @@ package org.apache.spark.util import java.{lang => jl} import java.io.ObjectInputStream -import java.util.ArrayList +import java.util.{ArrayList, Collections} import java.util.concurrent.ConcurrentHashMap import java.util.concurrent.atomic.AtomicLong @@ -38,6 +38,9 @@ private[spark] case class AccumulatorMetadata( /** * The base class for accumulators, that can accumulate inputs of type `IN`, and produce output of * type `OUT`. + * + * `OUT` should be a type that can be read atomically (e.g., Int, Long), or thread-safely + * (e.g., synchronized collections) because it will be read from other threads. */ abstract class AccumulatorV2[IN, OUT] extends Serializable { private[spark] var metadata: AccumulatorMetadata = _ @@ -433,7 +436,7 @@ class DoubleAccumulator extends AccumulatorV2[jl.Double, jl.Double] { * @since 2.0.0 */ class CollectionAccumulator[T] extends AccumulatorV2[T, java.util.List[T]] { - private val _list: java.util.List[T] = new ArrayList[T] + private val _list: java.util.List[T] = Collections.synchronizedList(new ArrayList[T]()) override def isZero: Boolean = _list.isEmpty diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala index 022b226894..41d947c442 100644 --- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala @@ -310,11 +310,12 @@ private[spark] object JsonProtocol { case v: Int => JInt(v) case v: Long => JInt(v) // We only have 3 kind of internal accumulator types, so if it's not int or long, it must be - // the blocks accumulator, whose type is `Seq[(BlockId, BlockStatus)]` + // the blocks accumulator, whose type is `java.util.List[(BlockId, BlockStatus)]` case v => - JArray(v.asInstanceOf[Seq[(BlockId, BlockStatus)]].toList.map { case (id, status) => - ("Block ID" -> id.toString) ~ - ("Status" -> blockStatusToJson(status)) + JArray(v.asInstanceOf[java.util.List[(BlockId, BlockStatus)]].asScala.toList.map { + case (id, status) => + ("Block ID" -> id.toString) ~ + ("Status" -> blockStatusToJson(status)) }) } } else { @@ -743,7 +744,7 @@ private[spark] object JsonProtocol { val id = BlockId((blockJson \ "Block ID").extract[String]) val status = blockStatusFromJson(blockJson \ "Status") (id, status) - } + }.asJava case _ => throw new IllegalArgumentException(s"unexpected json value $value for " + "accumulator " + name.get) } diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala index 85ca9d39d4..c89be22a34 100644 --- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.util import java.util.Properties +import scala.collection.JavaConverters._ import scala.collection.Map import org.json4s.jackson.JsonMethods._ @@ -415,7 +416,7 @@ class JsonProtocolSuite extends SparkFunSuite { }) testAccumValue(Some(RESULT_SIZE), 3L, JInt(3)) testAccumValue(Some(shuffleRead.REMOTE_BLOCKS_FETCHED), 2, JInt(2)) - testAccumValue(Some(UPDATED_BLOCK_STATUSES), blocks, blocksJson) + testAccumValue(Some(UPDATED_BLOCK_STATUSES), blocks.asJava, blocksJson) // For anything else, we just cast the value to a string testAccumValue(Some("anything"), blocks, JString(blocks.toString)) testAccumValue(Some("anything"), 123, JString("123")) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala index 082f97a880..d321f4cd76 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala @@ -17,7 +17,9 @@ package org.apache.spark.sql.execution -import scala.collection.mutable.HashSet +import java.util.Collections + +import scala.collection.JavaConverters._ import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD @@ -107,18 +109,20 @@ package object debug { case class DebugExec(child: SparkPlan) extends UnaryExecNode with CodegenSupport { def output: Seq[Attribute] = child.output - class SetAccumulator[T] extends AccumulatorV2[T, HashSet[T]] { - private val _set = new HashSet[T]() + class SetAccumulator[T] extends AccumulatorV2[T, java.util.Set[T]] { + private val _set = Collections.synchronizedSet(new java.util.HashSet[T]()) override def isZero: Boolean = _set.isEmpty - override def copy(): AccumulatorV2[T, HashSet[T]] = { + override def copy(): AccumulatorV2[T, java.util.Set[T]] = { val newAcc = new SetAccumulator[T]() - newAcc._set ++= _set + newAcc._set.addAll(_set) newAcc } override def reset(): Unit = _set.clear() - override def add(v: T): Unit = _set += v - override def merge(other: AccumulatorV2[T, HashSet[T]]): Unit = _set ++= other.value - override def value: HashSet[T] = _set + override def add(v: T): Unit = _set.add(v) + override def merge(other: AccumulatorV2[T, java.util.Set[T]]): Unit = { + _set.addAll(other.value) + } + override def value: java.util.Set[T] = _set } /** @@ -138,7 +142,9 @@ package object debug { debugPrint(s"== ${child.simpleString} ==") debugPrint(s"Tuples output: ${tupleCount.value}") child.output.zip(columnStats).foreach { case (attr, metric) => - val actualDataTypes = metric.elementTypes.value.mkString("{", ",", "}") + // This is called on driver. All accumulator updates have a fixed value. So it's safe to use + // `asScala` which accesses the internal values using `java.util.Iterator`. + val actualDataTypes = metric.elementTypes.value.asScala.mkString("{", ",", "}") debugPrint(s" ${attr.name} ${attr.dataType}: $actualDataTypes") } } |