diff options
Diffstat (limited to 'sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala')
-rw-r--r-- | sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala | 53 |
1 files changed, 44 insertions, 9 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala index 5e573b3159..17eae88b49 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala @@ -25,7 +25,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Attribute -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodeFormatter, CodegenContext, ExprCode} import org.apache.spark.sql.catalyst.trees.TreeNodeRef import org.apache.spark.sql.internal.SQLConf @@ -35,12 +35,38 @@ import org.apache.spark.sql.internal.SQLConf * Usage: * {{{ * import org.apache.spark.sql.execution.debug._ - * sql("SELECT key FROM src").debug() - * dataFrame.typeCheck() + * sql("SELECT 1").debug() + * sql("SELECT 1").debugCodegen() * }}} */ package object debug { + /** Helper function to evade the println() linter. */ + private def debugPrint(msg: String): Unit = { + // scalastyle:off println + println(msg) + // scalastyle:on println + } + + def codegenString(plan: SparkPlan): String = { + val codegenSubtrees = new collection.mutable.HashSet[WholeStageCodegen]() + plan transform { + case s: WholeStageCodegen => + codegenSubtrees += s + s + case s => s + } + var output = s"Found ${codegenSubtrees.size} WholeStageCodegen subtrees.\n" + for ((s, i) <- codegenSubtrees.toSeq.zipWithIndex) { + output += s"== Subtree ${i + 1} / ${codegenSubtrees.size} ==\n" + output += s + output += "\nGenerated code:\n" + val (_, source) = s.doCodeGen() + output += s"${CodeFormatter.format(source)}\n" + } + output + } + /** * Augments [[SQLContext]] with debug methods. */ @@ -51,9 +77,9 @@ package object debug { } /** - * Augments [[DataFrame]]s with debug methods. + * Augments [[Dataset]]s with debug methods. */ - implicit class DebugQuery(query: DataFrame) extends Logging { + implicit class DebugQuery(query: Dataset[_]) extends Logging { def debug(): Unit = { val plan = query.queryExecution.executedPlan val visited = new collection.mutable.HashSet[TreeNodeRef]() @@ -62,12 +88,20 @@ package object debug { visited += new TreeNodeRef(s) DebugNode(s) } - logDebug(s"Results returned: ${debugPlan.execute().count()}") + debugPrint(s"Results returned: ${debugPlan.execute().count()}") debugPlan.foreach { case d: DebugNode => d.dumpStats() case _ => } } + + /** + * Prints to stdout all the generated code found in this plan (i.e. the output of each + * WholeStageCodegen subtree). + */ + def debugCodegen(): Unit = { + debugPrint(codegenString(query.queryExecution.executedPlan)) + } } private[sql] case class DebugNode(child: SparkPlan) extends UnaryNode with CodegenSupport { @@ -87,6 +121,7 @@ package object debug { /** * A collection of metrics for each column of output. + * * @param elementTypes the actual runtime types for the output. Useful when there are bugs * causing the wrong data to be projected. */ @@ -99,11 +134,11 @@ package object debug { val columnStats: Array[ColumnMetrics] = Array.fill(child.output.size)(new ColumnMetrics()) def dumpStats(): Unit = { - logDebug(s"== ${child.simpleString} ==") - logDebug(s"Tuples output: ${tupleCount.value}") + debugPrint(s"== ${child.simpleString} ==") + debugPrint(s"Tuples output: ${tupleCount.value}") child.output.zip(columnStats).foreach { case (attr, metric) => val actualDataTypes = metric.elementTypes.value.mkString("{", ",", "}") - logDebug(s" ${attr.name} ${attr.dataType}: $actualDataTypes") + debugPrint(s" ${attr.name} ${attr.dataType}: $actualDataTypes") } } |