aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/BufferedRowIterator.java9
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala24
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala35
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala30
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala2
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala11
6 files changed, 97 insertions, 14 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/BufferedRowIterator.java b/sql/core/src/main/scala/org/apache/spark/sql/execution/BufferedRowIterator.java
index 1d1d7edb24..dbea8521be 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/BufferedRowIterator.java
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/BufferedRowIterator.java
@@ -34,6 +34,7 @@ public abstract class BufferedRowIterator {
protected LinkedList<InternalRow> currentRows = new LinkedList<>();
// used when there is no column in output
protected UnsafeRow unsafeRow = new UnsafeRow(0);
+ private long startTimeNs = System.nanoTime();
public boolean hasNext() throws IOException {
if (currentRows.isEmpty()) {
@@ -47,6 +48,14 @@ public abstract class BufferedRowIterator {
}
/**
+ * Returns the elapsed time since this object is created. This object represents a pipeline so
+ * this is a measure of how long the pipeline has been running.
+ */
+ public long durationMs() {
+ return (System.nanoTime() - startTimeNs) / (1000 * 1000);
+ }
+
+ /**
* Initializes from array of iterators of InternalRow.
*/
public abstract void init(Iterator<InternalRow> iters[]);
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
index 67aef72ded..e3c7d7209a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
@@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.util.toCommentSafeString
import org.apache.spark.sql.execution.aggregate.TungstenAggregate
import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, SortMergeJoin}
-import org.apache.spark.sql.execution.metric.LongSQLMetricValue
+import org.apache.spark.sql.execution.metric.{LongSQLMetricValue, SQLMetrics}
import org.apache.spark.sql.internal.SQLConf
/**
@@ -264,6 +264,10 @@ case class InputAdapter(child: SparkPlan) extends UnaryNode with CodegenSupport
override def treeChildren: Seq[SparkPlan] = Nil
}
+object WholeStageCodegen {
+ val PIPELINE_DURATION_METRIC = "duration"
+}
+
/**
* WholeStageCodegen compile a subtree of plans that support codegen together into single Java
* function.
@@ -301,6 +305,10 @@ case class WholeStageCodegen(child: SparkPlan) extends UnaryNode with CodegenSup
override def outputPartitioning: Partitioning = child.outputPartitioning
override def outputOrdering: Seq[SortOrder] = child.outputOrdering
+ override private[sql] lazy val metrics = Map(
+ "pipelineTime" -> SQLMetrics.createTimingMetric(sparkContext,
+ WholeStageCodegen.PIPELINE_DURATION_METRIC))
+
override def doExecute(): RDD[InternalRow] = {
val ctx = new CodegenContext
val code = child.asInstanceOf[CodegenSupport].produce(ctx, this)
@@ -339,6 +347,8 @@ case class WholeStageCodegen(child: SparkPlan) extends UnaryNode with CodegenSup
logDebug(s"${CodeFormatter.format(cleanedSource)}")
CodeGenerator.compile(cleanedSource)
+ val durationMs = longMetric("pipelineTime")
+
val rdds = child.asInstanceOf[CodegenSupport].upstreams()
assert(rdds.size <= 2, "Up to two upstream RDDs can be supported")
if (rdds.length == 1) {
@@ -347,7 +357,11 @@ case class WholeStageCodegen(child: SparkPlan) extends UnaryNode with CodegenSup
val buffer = clazz.generate(references).asInstanceOf[BufferedRowIterator]
buffer.init(Array(iter))
new Iterator[InternalRow] {
- override def hasNext: Boolean = buffer.hasNext
+ override def hasNext: Boolean = {
+ val v = buffer.hasNext
+ if (!v) durationMs += buffer.durationMs()
+ v
+ }
override def next: InternalRow = buffer.next()
}
}
@@ -358,7 +372,11 @@ case class WholeStageCodegen(child: SparkPlan) extends UnaryNode with CodegenSup
val buffer = clazz.generate(references).asInstanceOf[BufferedRowIterator]
buffer.init(Array(leftIter, rightIter))
new Iterator[InternalRow] {
- override def hasNext: Boolean = buffer.hasNext
+ override def hasNext: Boolean = {
+ val v = buffer.hasNext
+ if (!v) durationMs += buffer.durationMs()
+ v
+ }
override def next: InternalRow = buffer.next()
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala
index 6b43d273fe..7fa1390729 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala
@@ -122,7 +122,7 @@ private class LongSQLMetricParam(val stringValue: Seq[Long] => String, initialVa
private object LongSQLMetricParam extends LongSQLMetricParam(_.sum.toString, 0L)
-private object StaticsLongSQLMetricParam extends LongSQLMetricParam(
+private object StatisticsBytesSQLMetricParam extends LongSQLMetricParam(
(values: Seq[Long]) => {
// This is a workaround for SPARK-11013.
// We use -1 as initial value of the accumulator, if the accumulator is valid, we will update
@@ -140,6 +140,24 @@ private object StaticsLongSQLMetricParam extends LongSQLMetricParam(
s"\n$sum ($min, $med, $max)"
}, -1L)
+private object StatisticsTimingSQLMetricParam extends LongSQLMetricParam(
+ (values: Seq[Long]) => {
+ // This is a workaround for SPARK-11013.
+ // We use -1 as initial value of the accumulator, if the accumulator is valid, we will update
+ // it at the end of task and the value will be at least 0.
+ val validValues = values.filter(_ >= 0)
+ val Seq(sum, min, med, max) = {
+ val metric = if (validValues.length == 0) {
+ Seq.fill(4)(0L)
+ } else {
+ val sorted = validValues.sorted
+ Seq(sorted.sum, sorted(0), sorted(validValues.length / 2), sorted(validValues.length - 1))
+ }
+ metric.map(Utils.msDurationToString)
+ }
+ s"\n$sum ($min, $med, $max)"
+ }, -1L)
+
private[sql] object SQLMetrics {
// Identifier for distinguishing SQL metrics from other accumulators
@@ -168,15 +186,24 @@ private[sql] object SQLMetrics {
// The final result of this metric in physical operator UI may looks like:
// data size total (min, med, max):
// 100GB (100MB, 1GB, 10GB)
- createLongMetric(sc, s"$name total (min, med, max)", StaticsLongSQLMetricParam)
+ createLongMetric(sc, s"$name total (min, med, max)", StatisticsBytesSQLMetricParam)
+ }
+
+ def createTimingMetric(sc: SparkContext, name: String): LongSQLMetric = {
+ // The final result of this metric in physical operator UI may looks like:
+ // duration(min, med, max):
+ // 5s (800ms, 1s, 2s)
+ createLongMetric(sc, s"$name total (min, med, max)", StatisticsTimingSQLMetricParam)
}
def getMetricParam(metricParamName: String): SQLMetricParam[SQLMetricValue[Any], Any] = {
val longSQLMetricParam = Utils.getFormattedClassName(LongSQLMetricParam)
- val staticsSQLMetricParam = Utils.getFormattedClassName(StaticsLongSQLMetricParam)
+ val bytesSQLMetricParam = Utils.getFormattedClassName(StatisticsBytesSQLMetricParam)
+ val timingsSQLMetricParam = Utils.getFormattedClassName(StatisticsTimingSQLMetricParam)
val metricParam = metricParamName match {
case `longSQLMetricParam` => LongSQLMetricParam
- case `staticsSQLMetricParam` => StaticsLongSQLMetricParam
+ case `bytesSQLMetricParam` => StatisticsBytesSQLMetricParam
+ case `timingsSQLMetricParam` => StatisticsTimingSQLMetricParam
}
metricParam.asInstanceOf[SQLMetricParam[SQLMetricValue[Any], Any]]
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala
index 8a36d32240..24a01f5be1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala
@@ -23,7 +23,7 @@ import scala.collection.mutable
import org.apache.commons.lang3.StringEscapeUtils
-import org.apache.spark.sql.execution.SparkPlanInfo
+import org.apache.spark.sql.execution.{SparkPlanInfo, WholeStageCodegen}
import org.apache.spark.sql.execution.metric.SQLMetrics
/**
@@ -79,12 +79,19 @@ private[sql] object SparkPlanGraph {
exchanges: mutable.HashMap[SparkPlanInfo, SparkPlanGraphNode]): Unit = {
planInfo.nodeName match {
case "WholeStageCodegen" =>
+ val metrics = planInfo.metrics.map { metric =>
+ SQLPlanMetric(metric.name, metric.accumulatorId,
+ SQLMetrics.getMetricParam(metric.metricParam))
+ }
+
val cluster = new SparkPlanGraphCluster(
nodeIdGenerator.getAndIncrement(),
planInfo.nodeName,
planInfo.simpleString,
- mutable.ArrayBuffer[SparkPlanGraphNode]())
+ mutable.ArrayBuffer[SparkPlanGraphNode](),
+ metrics)
nodes += cluster
+
buildSparkPlanGraphNode(
planInfo.children.head, nodeIdGenerator, nodes, edges, parent, cluster, exchanges)
case "InputAdapter" =>
@@ -166,13 +173,26 @@ private[ui] class SparkPlanGraphCluster(
id: Long,
name: String,
desc: String,
- val nodes: mutable.ArrayBuffer[SparkPlanGraphNode])
- extends SparkPlanGraphNode(id, name, desc, Map.empty, Nil) {
+ val nodes: mutable.ArrayBuffer[SparkPlanGraphNode],
+ metrics: Seq[SQLPlanMetric])
+ extends SparkPlanGraphNode(id, name, desc, Map.empty, metrics) {
override def makeDotNode(metricsValue: Map[Long, String]): String = {
+ val duration = metrics.filter(_.name.startsWith(WholeStageCodegen.PIPELINE_DURATION_METRIC))
+ val labelStr = if (duration.nonEmpty) {
+ require(duration.length == 1)
+ val id = duration(0).accumulatorId
+ if (metricsValue.contains(duration(0).accumulatorId)) {
+ name + "\n\n" + metricsValue.get(id).get
+ } else {
+ name
+ }
+ } else {
+ name
+ }
s"""
| subgraph cluster${id} {
- | label="${StringEscapeUtils.escapeJava(name)}";
+ | label="${StringEscapeUtils.escapeJava(labelStr)}";
| ${nodes.map(_.makeDotNode(metricsValue)).mkString(" \n")}
| }
""".stripMargin
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
index fa68c1a91d..695b1824e8 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
@@ -309,7 +309,7 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext {
val metricValues = sqlContext.listener.getExecutionMetrics(executionId)
// Because "save" will create a new DataFrame internally, we cannot get the real metric id.
// However, we still can check the value.
- assert(metricValues.values.toSeq === Seq("2"))
+ assert(metricValues.values.toSeq.exists(_ === "2"))
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala
index 4641a1ad78..09bd7f6e8f 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala
@@ -81,7 +81,16 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext {
test("basic") {
def checkAnswer(actual: Map[Long, String], expected: Map[Long, Long]): Unit = {
- assert(actual === expected.mapValues(_.toString))
+ assert(actual.size == expected.size)
+ expected.foreach { e =>
+ // The values in actual can be SQL metrics meaning that they contain additional formatting
+ // when converted to string. Verify that they start with the expected value.
+ // TODO: this is brittle. There is no requirement that the actual string needs to start
+ // with the accumulator value.
+ assert(actual.contains(e._1))
+ val v = actual.get(e._1).get.trim
+ assert(v.startsWith(e._2.toString))
+ }
}
val listener = new SQLListener(sqlContext.sparkContext.conf)