aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorWenchen Fan <wenchen@databricks.com>2016-04-28 00:26:39 -0700
committerReynold Xin <rxin@databricks.com>2016-04-28 00:26:39 -0700
commitbf5496dbdac75ea69081c95a92a29771e635ea98 (patch)
treeb6151d946b25171b5bbcd3252aa77f7f7a69f60c /sql
parentbe317d4a90b3ca906fefeb438f89a09b1c7da5a8 (diff)
downloadspark-bf5496dbdac75ea69081c95a92a29771e635ea98.tar.gz
spark-bf5496dbdac75ea69081c95a92a29771e635ea98.tar.bz2
spark-bf5496dbdac75ea69081c95a92a29771e635ea98.zip
[SPARK-14654][CORE] New accumulator API
## What changes were proposed in this pull request? This PR introduces a new accumulator API which is much simpler than before: 1. the type hierarchy is simplified, now we only have an `Accumulator` class 2. Combine `initialValue` and `zeroValue` concepts into just one concept: `zeroValue` 3. there in only one `register` method, the accumulator registration and cleanup registration are combined. 4. the `id`,`name` and `countFailedValues` are combined into an `AccumulatorMetadata`, and is provided during registration. `SQLMetric` is a good example to show the simplicity of this new API. What we break: 1. no `setValue` anymore. In the new API, the intermedia type can be different from the result type, it's very hard to implement a general `setValue` 2. accumulator can't be serialized before registered. Problems need to be addressed in follow-ups: 1. with this new API, `AccumulatorInfo` doesn't make a lot of sense, the partial output is not partial updates, we need to expose the intermediate value. 2. `ExceptionFailure` should not carry the accumulator updates. Why do users care about accumulator updates for failed cases? It looks like we only use this feature to update the internal metrics, how about we sending a heartbeat to update internal metrics after the failure event? 3. the public event `SparkListenerTaskEnd` carries a `TaskMetrics`. Ideally this `TaskMetrics` don't need to carry external accumulators, as the only method of `TaskMetrics` that can access external accumulators is `private[spark]`. However, `SQLListener` use it to retrieve sql metrics. ## How was this patch tested? existing tests Author: Wenchen Fan <wenchen@databricks.com> Closes #12612 from cloud-fan/acc.
Diffstat (limited to 'sql')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala6
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScanExec.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala7
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala3
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala13
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala8
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregateExec.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala8
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala8
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala6
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala6
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala8
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala12
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetricInfo.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala218
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala16
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala6
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala8
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala10
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala32
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala6
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala2
30 files changed, 154 insertions, 255 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala
index 520ceaaaea..d6516f26a7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala
@@ -106,7 +106,7 @@ private[sql] case class RDDScanExec(
override val nodeName: String) extends LeafExecNode {
private[sql] override lazy val metrics = Map(
- "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"))
+ "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))
protected override def doExecute(): RDD[InternalRow] = {
val numOutputRows = longMetric("numOutputRows")
@@ -147,7 +147,7 @@ private[sql] case class RowDataSourceScanExec(
extends DataSourceScanExec with CodegenSupport {
private[sql] override lazy val metrics =
- Map("numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"))
+ Map("numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))
val outputUnsafeRows = relation match {
case r: HadoopFsRelation if r.fileFormat.isInstanceOf[ParquetSource] =>
@@ -216,7 +216,7 @@ private[sql] case class BatchedDataSourceScanExec(
extends DataSourceScanExec with CodegenSupport {
private[sql] override lazy val metrics =
- Map("numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"),
+ Map("numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"),
"scanTime" -> SQLMetrics.createTimingMetric(sparkContext, "scan time"))
protected override def doExecute(): RDD[InternalRow] = {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala
index 7c4756663a..c201822d44 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala
@@ -40,7 +40,7 @@ case class ExpandExec(
extends UnaryExecNode with CodegenSupport {
private[sql] override lazy val metrics = Map(
- "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"))
+ "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))
// The GroupExpressions can output data with arbitrary partitioning, so set it
// as UNKNOWN partitioning
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala
index 10cfec3330..934bc38dc4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala
@@ -56,7 +56,7 @@ case class GenerateExec(
extends UnaryExecNode {
private[sql] override lazy val metrics = Map(
- "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"))
+ "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))
override def producedAttributes: AttributeSet = AttributeSet(output)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScanExec.scala
index 4ab447a47b..c5e78b0333 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScanExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScanExec.scala
@@ -31,7 +31,7 @@ private[sql] case class LocalTableScanExec(
rows: Seq[InternalRow]) extends LeafExecNode {
private[sql] override lazy val metrics = Map(
- "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"))
+ "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))
private val unsafeRows: Array[InternalRow] = {
val proj = UnsafeProjection.create(output, output)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
index 861ff3cd15..0bbe970420 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
@@ -33,7 +33,7 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.plans.physical._
-import org.apache.spark.sql.execution.metric.{LongSQLMetric, SQLMetric}
+import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.types.DataType
import org.apache.spark.util.ThreadUtils
@@ -77,7 +77,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
/**
* Return all metrics containing metrics of this SparkPlan.
*/
- private[sql] def metrics: Map[String, SQLMetric[_, _]] = Map.empty
+ private[sql] def metrics: Map[String, SQLMetric] = Map.empty
/**
* Reset all the metrics.
@@ -89,8 +89,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
/**
* Return a LongSQLMetric according to the name.
*/
- private[sql] def longMetric(name: String): LongSQLMetric =
- metrics(name).asInstanceOf[LongSQLMetric]
+ private[sql] def longMetric(name: String): SQLMetric = metrics(name)
// TODO: Move to `DistributedPlan`
/** Specifies how data is partitioned across different nodes in the cluster. */
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala
index cb4b1cfeb9..f84070a0c4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala
@@ -55,8 +55,7 @@ private[sql] object SparkPlanInfo {
case _ => plan.children ++ plan.subqueries
}
val metrics = plan.metrics.toSeq.map { case (key, metric) =>
- new SQLMetricInfo(metric.name.getOrElse(key), metric.id,
- Utils.getFormattedClassName(metric.param))
+ new SQLMetricInfo(metric.name.getOrElse(key), metric.id, metric.metricType)
}
new SparkPlanInfo(plan.nodeName, plan.simpleString, children.map(fromSparkPlan),
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala
index 362d0d7a72..484923428f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala
@@ -26,7 +26,7 @@ import com.google.common.io.ByteStreams
import org.apache.spark.serializer.{DeserializationStream, SerializationStream, Serializer, SerializerInstance}
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
-import org.apache.spark.sql.execution.metric.LongSQLMetric
+import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.unsafe.Platform
/**
@@ -42,7 +42,7 @@ import org.apache.spark.unsafe.Platform
*/
private[sql] class UnsafeRowSerializer(
numFields: Int,
- dataSize: LongSQLMetric = null) extends Serializer with Serializable {
+ dataSize: SQLMetric = null) extends Serializer with Serializable {
override def newInstance(): SerializerInstance =
new UnsafeRowSerializerInstance(numFields, dataSize)
override private[spark] def supportsRelocationOfSerializedObjects: Boolean = true
@@ -50,7 +50,7 @@ private[sql] class UnsafeRowSerializer(
private class UnsafeRowSerializerInstance(
numFields: Int,
- dataSize: LongSQLMetric) extends SerializerInstance {
+ dataSize: SQLMetric) extends SerializerInstance {
/**
* Serializes a stream of UnsafeRows. Within the stream, each record consists of a record
* length (stored as a 4-byte integer, written high byte first), followed by the record's bytes.
@@ -60,13 +60,10 @@ private class UnsafeRowSerializerInstance(
private[this] val dOut: DataOutputStream =
new DataOutputStream(new BufferedOutputStream(out))
- // LongSQLMetricParam.add() is faster than LongSQLMetric.+=
- val localDataSize = if (dataSize != null) dataSize.localValue else null
-
override def writeValue[T: ClassTag](value: T): SerializationStream = {
val row = value.asInstanceOf[UnsafeRow]
- if (localDataSize != null) {
- localDataSize.add(row.getSizeInBytes)
+ if (dataSize != null) {
+ dataSize.add(row.getSizeInBytes)
}
dOut.writeInt(row.getSizeInBytes)
row.writeToStream(dOut, writeBuffer)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala
index 6a03bd08c5..15b4abe806 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala
@@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.util.toCommentSafeString
import org.apache.spark.sql.execution.aggregate.TungstenAggregate
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, SortMergeJoinExec}
-import org.apache.spark.sql.execution.metric.{LongSQLMetricValue, SQLMetrics}
+import org.apache.spark.sql.execution.metric.SQLMetrics
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
@@ -52,11 +52,7 @@ trait CodegenSupport extends SparkPlan {
* @return name of the variable representing the metric
*/
def metricTerm(ctx: CodegenContext, name: String): String = {
- val metric = ctx.addReferenceObj(name, longMetric(name))
- val value = ctx.freshName("metricValue")
- val cls = classOf[LongSQLMetricValue].getName
- ctx.addMutableState(cls, value, s"$value = ($cls) $metric.localValue();")
- value
+ ctx.addReferenceObj(name, longMetric(name))
}
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregateExec.scala
index 3169e0a2fd..2e74d59c5f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregateExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregateExec.scala
@@ -46,7 +46,7 @@ case class SortBasedAggregateExec(
AttributeSet(aggregateBufferAttributes)
override private[sql] lazy val metrics = Map(
- "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"))
+ "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))
override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala
index c35d781d3e..f392b135ce 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.aggregate
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateFunction}
-import org.apache.spark.sql.execution.metric.LongSQLMetric
+import org.apache.spark.sql.execution.metric.SQLMetric
/**
* An iterator used to evaluate [[AggregateFunction]]. It assumes the input rows have been
@@ -35,7 +35,7 @@ class SortBasedAggregationIterator(
initialInputBufferOffset: Int,
resultExpressions: Seq[NamedExpression],
newMutableProjection: (Seq[Expression], Seq[Attribute]) => MutableProjection,
- numOutputRows: LongSQLMetric)
+ numOutputRows: SQLMetric)
extends AggregationIterator(
groupingExpressions,
valueAttributes,
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
index 16362f756f..d0ba37ee13 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
@@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.execution._
-import org.apache.spark.sql.execution.metric.{LongSQLMetricValue, SQLMetrics}
+import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
import org.apache.spark.sql.types.{DecimalType, StringType, StructType}
import org.apache.spark.unsafe.KVIterator
@@ -51,7 +51,7 @@ case class TungstenAggregate(
aggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes)
override private[sql] lazy val metrics = Map(
- "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"),
+ "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"),
"peakMemory" -> SQLMetrics.createSizeMetric(sparkContext, "peak memory"),
"spillSize" -> SQLMetrics.createSizeMetric(sparkContext, "spill size"),
"aggTime" -> SQLMetrics.createTimingMetric(sparkContext, "aggregate time"))
@@ -309,8 +309,8 @@ case class TungstenAggregate(
def finishAggregate(
hashMap: UnsafeFixedWidthAggregationMap,
sorter: UnsafeKVExternalSorter,
- peakMemory: LongSQLMetricValue,
- spillSize: LongSQLMetricValue): KVIterator[UnsafeRow, UnsafeRow] = {
+ peakMemory: SQLMetric,
+ spillSize: SQLMetric): KVIterator[UnsafeRow, UnsafeRow] = {
// update peak execution memory
val mapMemory = hashMap.getPeakMemoryUsedBytes
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala
index 9db5087fe0..243aa15deb 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala
@@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeRowJoiner
import org.apache.spark.sql.execution.{UnsafeFixedWidthAggregationMap, UnsafeKVExternalSorter}
-import org.apache.spark.sql.execution.metric.LongSQLMetric
+import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.types.StructType
import org.apache.spark.unsafe.KVIterator
@@ -86,9 +86,9 @@ class TungstenAggregationIterator(
originalInputAttributes: Seq[Attribute],
inputIter: Iterator[InternalRow],
testFallbackStartsAt: Option[(Int, Int)],
- numOutputRows: LongSQLMetric,
- peakMemory: LongSQLMetric,
- spillSize: LongSQLMetric)
+ numOutputRows: SQLMetric,
+ peakMemory: SQLMetric,
+ spillSize: SQLMetric)
extends AggregationIterator(
groupingExpressions,
originalInputAttributes,
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
index 83f527f555..77be613b83 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
@@ -103,7 +103,7 @@ case class FilterExec(condition: Expression, child: SparkPlan)
}
private[sql] override lazy val metrics = Map(
- "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"))
+ "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))
override def inputRDDs(): Seq[RDD[InternalRow]] = {
child.asInstanceOf[CodegenSupport].inputRDDs()
@@ -229,7 +229,7 @@ case class SampleExec(
override def output: Seq[Attribute] = child.output
private[sql] override lazy val metrics = Map(
- "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"))
+ "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))
protected override def doExecute(): RDD[InternalRow] = {
if (withReplacement) {
@@ -322,7 +322,7 @@ case class RangeExec(
extends LeafExecNode with CodegenSupport {
private[sql] override lazy val metrics = Map(
- "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"))
+ "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))
// output attributes should not affect the results
override lazy val cleanArgs: Seq[Any] = Seq(start, step, numSlices, numElements)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala
index cb957b9666..577c34ba61 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.columnar
import scala.collection.mutable.ArrayBuffer
-import org.apache.spark.{Accumulable, Accumulator, Accumulators}
+import org.apache.spark.{Accumulable, Accumulator, AccumulatorContext}
import org.apache.spark.network.util.JavaUtils
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
@@ -204,7 +204,7 @@ private[sql] case class InMemoryRelation(
Seq(_cachedColumnBuffers, statisticsToBePropagated, batchStats)
private[sql] def uncache(blocking: Boolean): Unit = {
- Accumulators.remove(batchStats.id)
+ AccumulatorContext.remove(batchStats.id)
cachedColumnBuffers.unpersist(blocking)
_cachedColumnBuffers = null
}
@@ -217,7 +217,7 @@ private[sql] case class InMemoryTableScanExec(
extends LeafExecNode {
private[sql] override lazy val metrics = Map(
- "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"))
+ "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))
override def output: Seq[Attribute] = attributes
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala
index 573ca195ac..b6ecd3cb06 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala
@@ -38,10 +38,10 @@ case class BroadcastExchangeExec(
child: SparkPlan) extends Exchange {
override private[sql] lazy val metrics = Map(
- "dataSize" -> SQLMetrics.createLongMetric(sparkContext, "data size (bytes)"),
- "collectTime" -> SQLMetrics.createLongMetric(sparkContext, "time to collect (ms)"),
- "buildTime" -> SQLMetrics.createLongMetric(sparkContext, "time to build (ms)"),
- "broadcastTime" -> SQLMetrics.createLongMetric(sparkContext, "time to broadcast (ms)"))
+ "dataSize" -> SQLMetrics.createMetric(sparkContext, "data size (bytes)"),
+ "collectTime" -> SQLMetrics.createMetric(sparkContext, "time to collect (ms)"),
+ "buildTime" -> SQLMetrics.createMetric(sparkContext, "time to build (ms)"),
+ "broadcastTime" -> SQLMetrics.createMetric(sparkContext, "time to broadcast (ms)"))
override def outputPartitioning: Partitioning = BroadcastPartitioning(mode)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala
index b0a6b8f28a..587c603192 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala
@@ -46,7 +46,7 @@ case class BroadcastHashJoinExec(
extends BinaryExecNode with HashJoin with CodegenSupport {
override private[sql] lazy val metrics = Map(
- "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"))
+ "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))
override def outputPartitioning: Partitioning = streamedPlan.outputPartitioning
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala
index 51afa0017d..a659bf26e3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala
@@ -35,7 +35,7 @@ case class BroadcastNestedLoopJoinExec(
condition: Option[Expression]) extends BinaryExecNode {
override private[sql] lazy val metrics = Map(
- "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"))
+ "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))
/** BuildRight means the right relation <=> the broadcast relation. */
private val (streamed, broadcast) = buildSide match {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala
index 67f59197ad..8d7ecc442a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala
@@ -86,7 +86,7 @@ case class CartesianProductExec(
override def output: Seq[Attribute] = left.output ++ right.output
override private[sql] lazy val metrics = Map(
- "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"))
+ "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))
protected override def doExecute(): RDD[InternalRow] = {
val numOutputRows = longMetric("numOutputRows")
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
index d6feedc272..9c173d7bf1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
@@ -21,7 +21,7 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.execution.{RowIterator, SparkPlan}
-import org.apache.spark.sql.execution.metric.LongSQLMetric
+import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.types.{IntegralType, LongType}
trait HashJoin {
@@ -201,7 +201,7 @@ trait HashJoin {
protected def join(
streamedIter: Iterator[InternalRow],
hashed: HashedRelation,
- numOutputRows: LongSQLMetric): Iterator[InternalRow] = {
+ numOutputRows: SQLMetric): Iterator[InternalRow] = {
val joinedIter = joinType match {
case Inner =>
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala
index a242a078f6..3ef2fec352 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala
@@ -40,7 +40,7 @@ case class ShuffledHashJoinExec(
extends BinaryExecNode with HashJoin {
override private[sql] lazy val metrics = Map(
- "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"),
+ "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"),
"buildDataSize" -> SQLMetrics.createSizeMetric(sparkContext, "data size of build side"),
"buildTime" -> SQLMetrics.createTimingMetric(sparkContext, "time to build hash map"))
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala
index a4c5491aff..775f8ac508 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala
@@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCo
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.execution.{BinaryExecNode, CodegenSupport, RowIterator, SparkPlan}
-import org.apache.spark.sql.execution.metric.{LongSQLMetric, SQLMetrics}
+import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
import org.apache.spark.util.collection.BitSet
/**
@@ -41,7 +41,7 @@ case class SortMergeJoinExec(
right: SparkPlan) extends BinaryExecNode with CodegenSupport {
override private[sql] lazy val metrics = Map(
- "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"))
+ "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))
override def output: Seq[Attribute] = {
joinType match {
@@ -734,7 +734,7 @@ private class LeftOuterIterator(
rightNullRow: InternalRow,
boundCondition: InternalRow => Boolean,
resultProj: InternalRow => InternalRow,
- numOutputRows: LongSQLMetric)
+ numOutputRows: SQLMetric)
extends OneSideOuterIterator(
smjScanner, rightNullRow, boundCondition, resultProj, numOutputRows) {
@@ -750,7 +750,7 @@ private class RightOuterIterator(
leftNullRow: InternalRow,
boundCondition: InternalRow => Boolean,
resultProj: InternalRow => InternalRow,
- numOutputRows: LongSQLMetric)
+ numOutputRows: SQLMetric)
extends OneSideOuterIterator(smjScanner, leftNullRow, boundCondition, resultProj, numOutputRows) {
protected override def setStreamSideOutput(row: InternalRow): Unit = joinedRow.withRight(row)
@@ -778,7 +778,7 @@ private abstract class OneSideOuterIterator(
bufferedSideNullRow: InternalRow,
boundCondition: InternalRow => Boolean,
resultProj: InternalRow => InternalRow,
- numOutputRows: LongSQLMetric) extends RowIterator {
+ numOutputRows: SQLMetric) extends RowIterator {
// A row to store the joined result, reused many times
protected[this] val joinedRow: JoinedRow = new JoinedRow()
@@ -1016,7 +1016,7 @@ private class SortMergeFullOuterJoinScanner(
private class FullOuterIterator(
smjScanner: SortMergeFullOuterJoinScanner,
resultProj: InternalRow => InternalRow,
- numRows: LongSQLMetric) extends RowIterator {
+ numRows: SQLMetric) extends RowIterator {
private[this] val joinedRow: JoinedRow = smjScanner.getJoinedRow()
override def advanceNext(): Boolean = {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetricInfo.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetricInfo.scala
index 2708219ad3..adb81519db 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetricInfo.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetricInfo.scala
@@ -27,4 +27,4 @@ import org.apache.spark.annotation.DeveloperApi
class SQLMetricInfo(
val name: String,
val accumulatorId: Long,
- val metricParam: String)
+ val metricType: String)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala
index 5755c00c1f..7bf9225272 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala
@@ -19,200 +19,106 @@ package org.apache.spark.sql.execution.metric
import java.text.NumberFormat
-import org.apache.spark.{Accumulable, AccumulableParam, Accumulators, SparkContext}
+import org.apache.spark.{NewAccumulator, SparkContext}
import org.apache.spark.scheduler.AccumulableInfo
import org.apache.spark.util.Utils
-/**
- * Create a layer for specialized metric. We cannot add `@specialized` to
- * `Accumulable/AccumulableParam` because it will break Java source compatibility.
- *
- * An implementation of SQLMetric should override `+=` and `add` to avoid boxing.
- */
-private[sql] abstract class SQLMetric[R <: SQLMetricValue[T], T](
- name: String,
- val param: SQLMetricParam[R, T]) extends Accumulable[R, T](param.zero, param, Some(name)) {
- // Provide special identifier as metadata so we can tell that this is a `SQLMetric` later
- override def toInfo(update: Option[Any], value: Option[Any]): AccumulableInfo = {
- new AccumulableInfo(id, Some(name), update, value, true, countFailedValues,
- Some(SQLMetrics.ACCUM_IDENTIFIER))
- }
-
- def reset(): Unit = {
- this.value = param.zero
- }
-}
-
-/**
- * Create a layer for specialized metric. We cannot add `@specialized` to
- * `Accumulable/AccumulableParam` because it will break Java source compatibility.
- */
-private[sql] trait SQLMetricParam[R <: SQLMetricValue[T], T] extends AccumulableParam[R, T] {
-
- /**
- * A function that defines how we aggregate the final accumulator results among all tasks,
- * and represent it in string for a SQL physical operator.
- */
- val stringValue: Seq[T] => String
-
- def zero: R
-}
+class SQLMetric(val metricType: String, initValue: Long = 0L) extends NewAccumulator[Long, Long] {
+ // This is a workaround for SPARK-11013.
+ // We may use -1 as initial value of the accumulator, if the accumulator is valid, we will
+ // update it at the end of task and the value will be at least 0. Then we can filter out the -1
+ // values before calculate max, min, etc.
+ private[this] var _value = initValue
-/**
- * Create a layer for specialized metric. We cannot add `@specialized` to
- * `Accumulable/AccumulableParam` because it will break Java source compatibility.
- */
-private[sql] trait SQLMetricValue[T] extends Serializable {
+ override def copyAndReset(): SQLMetric = new SQLMetric(metricType, initValue)
- def value: T
-
- override def toString: String = value.toString
-}
-
-/**
- * A wrapper of Long to avoid boxing and unboxing when using Accumulator
- */
-private[sql] class LongSQLMetricValue(private var _value : Long) extends SQLMetricValue[Long] {
-
- def add(incr: Long): LongSQLMetricValue = {
- _value += incr
- this
+ override def merge(other: NewAccumulator[Long, Long]): Unit = other match {
+ case o: SQLMetric => _value += o.localValue
+ case _ => throw new UnsupportedOperationException(
+ s"Cannot merge ${this.getClass.getName} with ${other.getClass.getName}")
}
- // Although there is a boxing here, it's fine because it's only called in SQLListener
- override def value: Long = _value
-
- // Needed for SQLListenerSuite
- override def equals(other: Any): Boolean = other match {
- case o: LongSQLMetricValue => value == o.value
- case _ => false
- }
+ override def isZero(): Boolean = _value == initValue
- override def hashCode(): Int = _value.hashCode()
-}
+ override def add(v: Long): Unit = _value += v
-/**
- * A specialized long Accumulable to avoid boxing and unboxing when using Accumulator's
- * `+=` and `add`.
- */
-private[sql] class LongSQLMetric private[metric](name: String, param: LongSQLMetricParam)
- extends SQLMetric[LongSQLMetricValue, Long](name, param) {
+ def +=(v: Long): Unit = _value += v
- override def +=(term: Long): Unit = {
- localValue.add(term)
- }
+ override def localValue: Long = _value
- override def add(term: Long): Unit = {
- localValue.add(term)
+ // Provide special identifier as metadata so we can tell that this is a `SQLMetric` later
+ private[spark] override def toInfo(update: Option[Any], value: Option[Any]): AccumulableInfo = {
+ new AccumulableInfo(id, name, update, value, true, true, Some(SQLMetrics.ACCUM_IDENTIFIER))
}
-}
-
-private class LongSQLMetricParam(val stringValue: Seq[Long] => String, initialValue: Long)
- extends SQLMetricParam[LongSQLMetricValue, Long] {
-
- override def addAccumulator(r: LongSQLMetricValue, t: Long): LongSQLMetricValue = r.add(t)
- override def addInPlace(r1: LongSQLMetricValue, r2: LongSQLMetricValue): LongSQLMetricValue =
- r1.add(r2.value)
-
- override def zero(initialValue: LongSQLMetricValue): LongSQLMetricValue = zero
-
- override def zero: LongSQLMetricValue = new LongSQLMetricValue(initialValue)
+ def reset(): Unit = _value = initValue
}
-private object LongSQLMetricParam
- extends LongSQLMetricParam(x => NumberFormat.getInstance().format(x.sum), 0L)
-
-private object StatisticsBytesSQLMetricParam extends LongSQLMetricParam(
- (values: Seq[Long]) => {
- // This is a workaround for SPARK-11013.
- // We use -1 as initial value of the accumulator, if the accumulator is valid, we will update
- // it at the end of task and the value will be at least 0.
- val validValues = values.filter(_ >= 0)
- val Seq(sum, min, med, max) = {
- val metric = if (validValues.length == 0) {
- Seq.fill(4)(0L)
- } else {
- val sorted = validValues.sorted
- Seq(sorted.sum, sorted(0), sorted(validValues.length / 2), sorted(validValues.length - 1))
- }
- metric.map(Utils.bytesToString)
- }
- s"\n$sum ($min, $med, $max)"
- }, -1L)
-
-private object StatisticsTimingSQLMetricParam extends LongSQLMetricParam(
- (values: Seq[Long]) => {
- // This is a workaround for SPARK-11013.
- // We use -1 as initial value of the accumulator, if the accumulator is valid, we will update
- // it at the end of task and the value will be at least 0.
- val validValues = values.filter(_ >= 0)
- val Seq(sum, min, med, max) = {
- val metric = if (validValues.length == 0) {
- Seq.fill(4)(0L)
- } else {
- val sorted = validValues.sorted
- Seq(sorted.sum, sorted(0), sorted(validValues.length / 2), sorted(validValues.length - 1))
- }
- metric.map(Utils.msDurationToString)
- }
- s"\n$sum ($min, $med, $max)"
- }, -1L)
private[sql] object SQLMetrics {
-
// Identifier for distinguishing SQL metrics from other accumulators
private[sql] val ACCUM_IDENTIFIER = "sql"
- private def createLongMetric(
- sc: SparkContext,
- name: String,
- param: LongSQLMetricParam): LongSQLMetric = {
- val acc = new LongSQLMetric(name, param)
- // This is an internal accumulator so we need to register it explicitly.
- Accumulators.register(acc)
- sc.cleaner.foreach(_.registerAccumulatorForCleanup(acc))
- acc
- }
+ private[sql] val SUM_METRIC = "sum"
+ private[sql] val SIZE_METRIC = "size"
+ private[sql] val TIMING_METRIC = "timing"
- def createLongMetric(sc: SparkContext, name: String): LongSQLMetric = {
- createLongMetric(sc, name, LongSQLMetricParam)
+ def createMetric(sc: SparkContext, name: String): SQLMetric = {
+ val acc = new SQLMetric(SUM_METRIC)
+ acc.register(sc, name = Some(name), countFailedValues = true)
+ acc
}
/**
* Create a metric to report the size information (including total, min, med, max) like data size,
* spill size, etc.
*/
- def createSizeMetric(sc: SparkContext, name: String): LongSQLMetric = {
+ def createSizeMetric(sc: SparkContext, name: String): SQLMetric = {
// The final result of this metric in physical operator UI may looks like:
// data size total (min, med, max):
// 100GB (100MB, 1GB, 10GB)
- createLongMetric(sc, s"$name total (min, med, max)", StatisticsBytesSQLMetricParam)
+ val acc = new SQLMetric(SIZE_METRIC, -1)
+ acc.register(sc, name = Some(s"$name total (min, med, max)"), countFailedValues = true)
+ acc
}
- def createTimingMetric(sc: SparkContext, name: String): LongSQLMetric = {
+ def createTimingMetric(sc: SparkContext, name: String): SQLMetric = {
// The final result of this metric in physical operator UI may looks like:
// duration(min, med, max):
// 5s (800ms, 1s, 2s)
- createLongMetric(sc, s"$name total (min, med, max)", StatisticsTimingSQLMetricParam)
- }
-
- def getMetricParam(metricParamName: String): SQLMetricParam[SQLMetricValue[Any], Any] = {
- val longSQLMetricParam = Utils.getFormattedClassName(LongSQLMetricParam)
- val bytesSQLMetricParam = Utils.getFormattedClassName(StatisticsBytesSQLMetricParam)
- val timingsSQLMetricParam = Utils.getFormattedClassName(StatisticsTimingSQLMetricParam)
- val metricParam = metricParamName match {
- case `longSQLMetricParam` => LongSQLMetricParam
- case `bytesSQLMetricParam` => StatisticsBytesSQLMetricParam
- case `timingsSQLMetricParam` => StatisticsTimingSQLMetricParam
- }
- metricParam.asInstanceOf[SQLMetricParam[SQLMetricValue[Any], Any]]
+ val acc = new SQLMetric(TIMING_METRIC, -1)
+ acc.register(sc, name = Some(s"$name total (min, med, max)"), countFailedValues = true)
+ acc
}
/**
- * A metric that its value will be ignored. Use this one when we need a metric parameter but don't
- * care about the value.
+ * A function that defines how we aggregate the final accumulator results among all tasks,
+ * and represent it in string for a SQL physical operator.
*/
- val nullLongMetric = new LongSQLMetric("null", LongSQLMetricParam)
+ def stringValue(metricsType: String, values: Seq[Long]): String = {
+ if (metricsType == SUM_METRIC) {
+ NumberFormat.getInstance().format(values.sum)
+ } else {
+ val strFormat: Long => String = if (metricsType == SIZE_METRIC) {
+ Utils.bytesToString
+ } else if (metricsType == TIMING_METRIC) {
+ Utils.msDurationToString
+ } else {
+ throw new IllegalStateException("unexpected metrics type: " + metricsType)
+ }
+
+ val validValues = values.filter(_ >= 0)
+ val Seq(sum, min, med, max) = {
+ val metric = if (validValues.length == 0) {
+ Seq.fill(4)(0L)
+ } else {
+ val sorted = validValues.sorted
+ Seq(sorted.sum, sorted(0), sorted(validValues.length / 2), sorted(validValues.length - 1))
+ }
+ metric.map(strFormat)
+ }
+ s"\n$sum ($min, $med, $max)"
+ }
+ }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala
index 5ae9e916ad..9118593c0e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala
@@ -164,7 +164,7 @@ private[sql] class SQLListener(conf: SparkConf) extends SparkListener with Loggi
taskEnd.taskInfo.taskId,
taskEnd.stageId,
taskEnd.stageAttemptId,
- taskEnd.taskMetrics.accumulatorUpdates(),
+ taskEnd.taskMetrics.accumulators().map(a => a.toInfo(Some(a.localValue), None)),
finishTask = true)
}
}
@@ -296,7 +296,7 @@ private[sql] class SQLListener(conf: SparkConf) extends SparkListener with Loggi
}
}.filter { case (id, _) => executionUIData.accumulatorMetrics.contains(id) }
mergeAccumulatorUpdates(accumulatorUpdates, accumulatorId =>
- executionUIData.accumulatorMetrics(accumulatorId).metricParam)
+ executionUIData.accumulatorMetrics(accumulatorId).metricType)
case None =>
// This execution has been dropped
Map.empty
@@ -305,11 +305,11 @@ private[sql] class SQLListener(conf: SparkConf) extends SparkListener with Loggi
private def mergeAccumulatorUpdates(
accumulatorUpdates: Seq[(Long, Any)],
- paramFunc: Long => SQLMetricParam[SQLMetricValue[Any], Any]): Map[Long, String] = {
+ metricTypeFunc: Long => String): Map[Long, String] = {
accumulatorUpdates.groupBy(_._1).map { case (accumulatorId, values) =>
- val param = paramFunc(accumulatorId)
- (accumulatorId,
- param.stringValue(values.map(_._2.asInstanceOf[SQLMetricValue[Any]].value)))
+ val metricType = metricTypeFunc(accumulatorId)
+ accumulatorId ->
+ SQLMetrics.stringValue(metricType, values.map(_._2.asInstanceOf[Long]))
}
}
@@ -337,7 +337,7 @@ private[spark] class SQLHistoryListener(conf: SparkConf, sparkUI: SparkUI)
// Filter out accumulators that are not SQL metrics
// For now we assume all SQL metrics are Long's that have been JSON serialized as String's
if (a.metadata == Some(SQLMetrics.ACCUM_IDENTIFIER)) {
- val newValue = new LongSQLMetricValue(a.update.map(_.toString.toLong).getOrElse(0L))
+ val newValue = a.update.map(_.toString.toLong).getOrElse(0L)
Some(a.copy(update = Some(newValue)))
} else {
None
@@ -403,7 +403,7 @@ private[ui] class SQLExecutionUIData(
private[ui] case class SQLPlanMetric(
name: String,
accumulatorId: Long,
- metricParam: SQLMetricParam[SQLMetricValue[Any], Any])
+ metricType: String)
/**
* Store all accumulatorUpdates for all tasks in a Spark stage.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala
index 1959f1e368..8f5681bfc7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala
@@ -80,8 +80,7 @@ private[sql] object SparkPlanGraph {
planInfo.nodeName match {
case "WholeStageCodegen" =>
val metrics = planInfo.metrics.map { metric =>
- SQLPlanMetric(metric.name, metric.accumulatorId,
- SQLMetrics.getMetricParam(metric.metricParam))
+ SQLPlanMetric(metric.name, metric.accumulatorId, metric.metricType)
}
val cluster = new SparkPlanGraphCluster(
@@ -106,8 +105,7 @@ private[sql] object SparkPlanGraph {
edges += SparkPlanGraphEdge(node.id, parent.id)
case name =>
val metrics = planInfo.metrics.map { metric =>
- SQLPlanMetric(metric.name, metric.accumulatorId,
- SQLMetrics.getMetricParam(metric.metricParam))
+ SQLPlanMetric(metric.name, metric.accumulatorId, metric.metricType)
}
val node = new SparkPlanGraphNode(
nodeIdGenerator.getAndIncrement(), planInfo.nodeName,
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
index 4aea21e52a..0e6356b578 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
@@ -22,7 +22,7 @@ import scala.language.postfixOps
import org.scalatest.concurrent.Eventually._
-import org.apache.spark.Accumulators
+import org.apache.spark.AccumulatorContext
import org.apache.spark.sql.execution.RDDScanExec
import org.apache.spark.sql.execution.columnar._
import org.apache.spark.sql.execution.exchange.ShuffleExchange
@@ -333,11 +333,11 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext
sql("SELECT * FROM t1").count()
sql("SELECT * FROM t2").count()
- Accumulators.synchronized {
- val accsSize = Accumulators.originals.size
+ AccumulatorContext.synchronized {
+ val accsSize = AccumulatorContext.originals.size
sqlContext.uncacheTable("t1")
sqlContext.uncacheTable("t2")
- assert((accsSize - 2) == Accumulators.originals.size)
+ assert((accsSize - 2) == AccumulatorContext.originals.size)
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
index 1859c6e7ad..8de4d8bbd4 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
@@ -37,8 +37,8 @@ import org.apache.spark.util.{JsonProtocol, Utils}
class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext {
import testImplicits._
- test("LongSQLMetric should not box Long") {
- val l = SQLMetrics.createLongMetric(sparkContext, "long")
+ test("SQLMetric should not box Long") {
+ val l = SQLMetrics.createMetric(sparkContext, "long")
val f = () => {
l += 1L
l.add(1L)
@@ -300,12 +300,12 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext {
}
test("metrics can be loaded by history server") {
- val metric = new LongSQLMetric("zanzibar", LongSQLMetricParam)
+ val metric = SQLMetrics.createMetric(sparkContext, "zanzibar")
metric += 10L
val metricInfo = metric.toInfo(Some(metric.localValue), None)
metricInfo.update match {
- case Some(v: LongSQLMetricValue) => assert(v.value === 10L)
- case Some(v) => fail(s"metric value was not a LongSQLMetricValue: ${v.getClass.getName}")
+ case Some(v: Long) => assert(v === 10L)
+ case Some(v) => fail(s"metric value was not a Long: ${v.getClass.getName}")
case _ => fail("metric update is missing")
}
assert(metricInfo.metadata === Some(SQLMetrics.ACCUM_IDENTIFIER))
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala
index 09bd7f6e8f..8572ed16aa 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala
@@ -21,18 +21,19 @@ import java.util.Properties
import org.mockito.Mockito.{mock, when}
-import org.apache.spark.{SparkConf, SparkContext, SparkException, SparkFunSuite}
+import org.apache.spark._
import org.apache.spark.executor.TaskMetrics
import org.apache.spark.scheduler._
import org.apache.spark.sql.{DataFrame, SQLContext}
import org.apache.spark.sql.catalyst.util.quietly
import org.apache.spark.sql.execution.{SparkPlanInfo, SQLExecution}
-import org.apache.spark.sql.execution.metric.{LongSQLMetricValue, SQLMetrics}
+import org.apache.spark.sql.execution.metric.SQLMetrics
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.ui.SparkUI
class SQLListenerSuite extends SparkFunSuite with SharedSQLContext {
import testImplicits._
+ import org.apache.spark.AccumulatorSuite.makeInfo
private def createTestDataFrame: DataFrame = {
Seq(
@@ -72,9 +73,11 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext {
private def createTaskMetrics(accumulatorUpdates: Map[Long, Long]): TaskMetrics = {
val metrics = mock(classOf[TaskMetrics])
- when(metrics.accumulatorUpdates()).thenReturn(accumulatorUpdates.map { case (id, update) =>
- new AccumulableInfo(id, Some(""), Some(new LongSQLMetricValue(update)),
- value = None, internal = true, countFailedValues = true)
+ when(metrics.accumulators()).thenReturn(accumulatorUpdates.map { case (id, update) =>
+ val acc = new LongAccumulator
+ acc.metadata = AccumulatorMetadata(id, Some(""), true)
+ acc.setValue(update)
+ acc
}.toSeq)
metrics
}
@@ -130,16 +133,17 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext {
listener.onExecutorMetricsUpdate(SparkListenerExecutorMetricsUpdate("", Seq(
// (task id, stage id, stage attempt, accum updates)
- (0L, 0, 0, createTaskMetrics(accumulatorUpdates).accumulatorUpdates()),
- (1L, 0, 0, createTaskMetrics(accumulatorUpdates).accumulatorUpdates())
+ (0L, 0, 0, createTaskMetrics(accumulatorUpdates).accumulators().map(makeInfo)),
+ (1L, 0, 0, createTaskMetrics(accumulatorUpdates).accumulators().map(makeInfo))
)))
checkAnswer(listener.getExecutionMetrics(0), accumulatorUpdates.mapValues(_ * 2))
listener.onExecutorMetricsUpdate(SparkListenerExecutorMetricsUpdate("", Seq(
// (task id, stage id, stage attempt, accum updates)
- (0L, 0, 0, createTaskMetrics(accumulatorUpdates).accumulatorUpdates()),
- (1L, 0, 0, createTaskMetrics(accumulatorUpdates.mapValues(_ * 2)).accumulatorUpdates())
+ (0L, 0, 0, createTaskMetrics(accumulatorUpdates).accumulators().map(makeInfo)),
+ (1L, 0, 0,
+ createTaskMetrics(accumulatorUpdates.mapValues(_ * 2)).accumulators().map(makeInfo))
)))
checkAnswer(listener.getExecutionMetrics(0), accumulatorUpdates.mapValues(_ * 3))
@@ -149,8 +153,8 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext {
listener.onExecutorMetricsUpdate(SparkListenerExecutorMetricsUpdate("", Seq(
// (task id, stage id, stage attempt, accum updates)
- (0L, 0, 1, createTaskMetrics(accumulatorUpdates).accumulatorUpdates()),
- (1L, 0, 1, createTaskMetrics(accumulatorUpdates).accumulatorUpdates())
+ (0L, 0, 1, createTaskMetrics(accumulatorUpdates).accumulators().map(makeInfo)),
+ (1L, 0, 1, createTaskMetrics(accumulatorUpdates).accumulators().map(makeInfo))
)))
checkAnswer(listener.getExecutionMetrics(0), accumulatorUpdates.mapValues(_ * 2))
@@ -189,8 +193,8 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext {
listener.onExecutorMetricsUpdate(SparkListenerExecutorMetricsUpdate("", Seq(
// (task id, stage id, stage attempt, accum updates)
- (0L, 1, 0, createTaskMetrics(accumulatorUpdates).accumulatorUpdates()),
- (1L, 1, 0, createTaskMetrics(accumulatorUpdates).accumulatorUpdates())
+ (0L, 1, 0, createTaskMetrics(accumulatorUpdates).accumulators().map(makeInfo)),
+ (1L, 1, 0, createTaskMetrics(accumulatorUpdates).accumulators().map(makeInfo))
)))
checkAnswer(listener.getExecutionMetrics(0), accumulatorUpdates.mapValues(_ * 7))
@@ -358,7 +362,7 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext {
val stageSubmitted = SparkListenerStageSubmitted(stageInfo)
// This task has both accumulators that are SQL metrics and accumulators that are not.
// The listener should only track the ones that are actually SQL metrics.
- val sqlMetric = SQLMetrics.createLongMetric(sparkContext, "beach umbrella")
+ val sqlMetric = SQLMetrics.createMetric(sparkContext, "beach umbrella")
val nonSqlMetric = sparkContext.accumulator[Int](0, "baseball")
val sqlMetricInfo = sqlMetric.toInfo(Some(sqlMetric.localValue), None)
val nonSqlMetricInfo = nonSqlMetric.toInfo(Some(nonSqlMetric.localValue), None)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala
index eb25ea0629..8a0578c1ff 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala
@@ -96,7 +96,7 @@ class DataFrameCallbackSuite extends QueryTest with SharedSQLContext {
case w: WholeStageCodegenExec => w.child.longMetric("numOutputRows")
case other => other.longMetric("numOutputRows")
}
- metrics += metric.value.value
+ metrics += metric.value
}
}
sqlContext.listenerManager.register(listener)
@@ -126,9 +126,9 @@ class DataFrameCallbackSuite extends QueryTest with SharedSQLContext {
override def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = {}
override def onSuccess(funcName: String, qe: QueryExecution, duration: Long): Unit = {
- metrics += qe.executedPlan.longMetric("dataSize").value.value
+ metrics += qe.executedPlan.longMetric("dataSize").value
val bottomAgg = qe.executedPlan.children(0).children(0)
- metrics += bottomAgg.longMetric("dataSize").value.value
+ metrics += bottomAgg.longMetric("dataSize").value
}
}
sqlContext.listenerManager.register(listener)
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala
index 007c3384e5..b52b96a804 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala
@@ -55,7 +55,7 @@ case class HiveTableScanExec(
"Partition pruning predicates only supported for partitioned tables.")
private[sql] override lazy val metrics = Map(
- "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"))
+ "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))
override def producedAttributes: AttributeSet = outputSet ++
AttributeSet(partitionPruningPred.flatMap(_.references))