aboutsummaryrefslogtreecommitdiff
path: root/sql/core
diff options
context:
space:
mode:
authorEric Liang <ekl@databricks.com>2016-03-29 13:31:51 -0700
committerReynold Xin <rxin@databricks.com>2016-03-29 13:31:51 -0700
commite58c4cb3c5a95f44e357b99a2f0d0e1201d91e7a (patch)
tree4b25ec8b773e2864c2b09b5a4a47a775e236ddfd /sql/core
parent838cb4583dd68a9d7ef3bd74b9b4d33f32f177fc (diff)
downloadspark-e58c4cb3c5a95f44e357b99a2f0d0e1201d91e7a.tar.gz
spark-e58c4cb3c5a95f44e357b99a2f0d0e1201d91e7a.tar.bz2
spark-e58c4cb3c5a95f44e357b99a2f0d0e1201d91e7a.zip
[SPARK-14227][SQL] Add method for printing out generated code for debugging
## What changes were proposed in this pull request? This adds `debugCodegen` to the debug package for query execution. ## How was this patch tested? Unit and manual testing. Output example: ``` scala> import org.apache.spark.sql.execution.debug._ import org.apache.spark.sql.execution.debug._ scala> sqlContext.range(100).groupBy("id").count().orderBy("id").debugCodegen() Found 3 WholeStageCodegen subtrees. == Subtree 1 / 3 == WholeStageCodegen : +- TungstenAggregate(key=[id#0L], functions=[(count(1),mode=Partial,isDistinct=false)], output=[id#0L,count#9L]) : +- Range 0, 1, 1, 100, [id#0L] Generated code: /* 001 */ public Object generate(Object[] references) { /* 002 */ return new GeneratedIterator(references); /* 003 */ } /* 004 */ /* 005 */ /** Codegened pipeline for: /* 006 */ * TungstenAggregate(key=[id#0L], functions=[(count(1),mode=Partial,isDistinct=false)], output=[id#0L,count#9L]) /* 007 */ +- Range 0, 1, 1, 100, [id#0L] /* 008 */ */ /* 009 */ final class GeneratedIterator extends org.apache.spark.sql.execution.BufferedRowIterator { /* 010 */ private Object[] references; /* 011 */ private boolean agg_initAgg; /* 012 */ private org.apache.spark.sql.execution.aggregate.TungstenAggregate agg_plan; /* 013 */ private org.apache.spark.sql.execution.UnsafeFixedWidthAggregationMap agg_hashMap; /* 014 */ private org.apache.spark.sql.execution.UnsafeKVExternalSorter agg_sorter; /* 015 */ private org.apache.spark.unsafe.KVIterator agg_mapIter; /* 016 */ private org.apache.spark.sql.execution.metric.LongSQLMetric range_numOutputRows; /* 017 */ private org.apache.spark.sql.execution.metric.LongSQLMetricValue range_metricValue; /* 018 */ private boolean range_initRange; /* 019 */ private long range_partitionEnd; /* 020 */ private long range_number; /* 021 */ private boolean range_overflow; /* 022 */ private scala.collection.Iterator range_input; /* 023 */ private UnsafeRow range_result; /* 024 */ private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder range_holder; /* 025 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter range_rowWriter; /* 026 */ private UnsafeRow agg_result; /* 027 */ private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder agg_holder; /* 028 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter agg_rowWriter; /* 029 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowJoiner agg_unsafeRowJoiner; /* 030 */ private org.apache.spark.sql.execution.metric.LongSQLMetric wholestagecodegen_numOutputRows; /* 031 */ private org.apache.spark.sql.execution.metric.LongSQLMetricValue wholestagecodegen_metricValue; /* 032 */ /* 033 */ public GeneratedIterator(Object[] references) { /* 034 */ this.references = references; /* 035 */ } /* 036 */ /* 037 */ public void init(scala.collection.Iterator inputs[]) { /* 038 */ agg_initAgg = false; /* 039 */ this.agg_plan = (org.apache.spark.sql.execution.aggregate.TungstenAggregate) references[0]; /* 040 */ agg_hashMap = agg_plan.createHashMap(); /* 041 */ /* 042 */ this.range_numOutputRows = (org.apache.spark.sql.execution.metric.LongSQLMetric) references[1]; /* 043 */ range_metricValue = (org.apache.spark.sql.execution.metric.LongSQLMetricValue) range_numOutputRows.localValue(); /* 044 */ range_initRange = false; /* 045 */ range_partitionEnd = 0L; /* 046 */ range_number = 0L; /* 047 */ range_overflow = false; /* 048 */ range_input = inputs[0]; /* 049 */ range_result = new UnsafeRow(1); /* 050 */ this.range_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(range_result, 0); /* 051 */ this.range_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(range_holder, 1); /* 052 */ agg_result = new UnsafeRow(1); /* 053 */ this.agg_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(agg_result, 0); /* 054 */ this.agg_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(agg_holder, 1); /* 055 */ agg_unsafeRowJoiner = agg_plan.createUnsafeJoiner(); /* 056 */ this.wholestagecodegen_numOutputRows = (org.apache.spark.sql.execution.metric.LongSQLMetric) references[2]; /* 057 */ wholestagecodegen_metricValue = (org.apache.spark.sql.execution.metric.LongSQLMetricValue) wholestagecodegen_numOutputRows.localValue(); /* 058 */ } /* 059 */ /* 060 */ private void agg_doAggregateWithKeys() throws java.io.IOException { /* 061 */ /*** PRODUCE: Range 0, 1, 1, 100, [id#0L] */ /* 062 */ /* 063 */ // initialize Range /* 064 */ if (!range_initRange) { /* 065 */ range_initRange = true; /* 066 */ if (range_input.hasNext()) { /* 067 */ initRange(((InternalRow) range_input.next()).getInt(0)); /* 068 */ } else { /* 069 */ return; /* 070 */ } /* 071 */ } /* 072 */ /* 073 */ while (!range_overflow && range_number < range_partitionEnd) { /* 074 */ long range_value = range_number; /* 075 */ range_number += 1L; /* 076 */ if (range_number < range_value ^ 1L < 0) { /* 077 */ range_overflow = true; /* 078 */ } /* 079 */ /* 080 */ /*** CONSUME: TungstenAggregate(key=[id#0L], functions=[(count(1),mode=Partial,isDistinct=false)], output=[id#0L,count#9L]) */ /* 081 */ /* 082 */ // generate grouping key /* 083 */ agg_rowWriter.write(0, range_value); /* 084 */ /* hash(input[0, bigint], 42) */ /* 085 */ int agg_value1 = 42; /* 086 */ /* 087 */ agg_value1 = org.apache.spark.unsafe.hash.Murmur3_x86_32.hashLong(range_value, agg_value1); /* 088 */ UnsafeRow agg_aggBuffer = null; /* 089 */ if (true) { /* 090 */ // try to get the buffer from hash map /* 091 */ agg_aggBuffer = agg_hashMap.getAggregationBufferFromUnsafeRow(agg_result, agg_value1); /* 092 */ } /* 093 */ if (agg_aggBuffer == null) { /* 094 */ if (agg_sorter == null) { /* 095 */ agg_sorter = agg_hashMap.destructAndCreateExternalSorter(); /* 096 */ } else { /* 097 */ agg_sorter.merge(agg_hashMap.destructAndCreateExternalSorter()); /* 098 */ } /* 099 */ /* 100 */ // the hash map had be spilled, it should have enough memory now, /* 101 */ // try to allocate buffer again. /* 102 */ agg_aggBuffer = agg_hashMap.getAggregationBufferFromUnsafeRow(agg_result, agg_value1); /* 103 */ if (agg_aggBuffer == null) { /* 104 */ // failed to allocate the first page /* 105 */ throw new OutOfMemoryError("No enough memory for aggregation"); /* 106 */ } /* 107 */ } /* 108 */ /* 109 */ // evaluate aggregate function /* 110 */ /* (input[0, bigint] + 1) */ /* 111 */ /* input[0, bigint] */ /* 112 */ long agg_value4 = agg_aggBuffer.getLong(0); /* 113 */ /* 114 */ long agg_value3 = -1L; /* 115 */ agg_value3 = agg_value4 + 1L; /* 116 */ // update aggregate buffer /* 117 */ agg_aggBuffer.setLong(0, agg_value3); /* 118 */ /* 119 */ if (shouldStop()) return; /* 120 */ } /* 121 */ /* 122 */ agg_mapIter = agg_plan.finishAggregate(agg_hashMap, agg_sorter); /* 123 */ } /* 124 */ /* 125 */ private void initRange(int idx) { /* 126 */ java.math.BigInteger index = java.math.BigInteger.valueOf(idx); /* 127 */ java.math.BigInteger numSlice = java.math.BigInteger.valueOf(1L); /* 128 */ java.math.BigInteger numElement = java.math.BigInteger.valueOf(100L); /* 129 */ java.math.BigInteger step = java.math.BigInteger.valueOf(1L); /* 130 */ java.math.BigInteger start = java.math.BigInteger.valueOf(0L); /* 131 */ /* 132 */ java.math.BigInteger st = index.multiply(numElement).divide(numSlice).multiply(step).add(start); /* 133 */ if (st.compareTo(java.math.BigInteger.valueOf(Long.MAX_VALUE)) > 0) { /* 134 */ range_number = Long.MAX_VALUE; /* 135 */ } else if (st.compareTo(java.math.BigInteger.valueOf(Long.MIN_VALUE)) < 0) { /* 136 */ range_number = Long.MIN_VALUE; /* 137 */ } else { /* 138 */ range_number = st.longValue(); /* 139 */ } /* 140 */ /* 141 */ java.math.BigInteger end = index.add(java.math.BigInteger.ONE).multiply(numElement).divide(numSlice) /* 142 */ .multiply(step).add(start); /* 143 */ if (end.compareTo(java.math.BigInteger.valueOf(Long.MAX_VALUE)) > 0) { /* 144 */ range_partitionEnd = Long.MAX_VALUE; /* 145 */ } else if (end.compareTo(java.math.BigInteger.valueOf(Long.MIN_VALUE)) < 0) { /* 146 */ range_partitionEnd = Long.MIN_VALUE; /* 147 */ } else { /* 148 */ range_partitionEnd = end.longValue(); /* 149 */ } /* 150 */ /* 151 */ range_metricValue.add((range_partitionEnd - range_number) / 1L); /* 152 */ } /* 153 */ /* 154 */ protected void processNext() throws java.io.IOException { /* 155 */ /*** PRODUCE: TungstenAggregate(key=[id#0L], functions=[(count(1),mode=Partial,isDistinct=false)], output=[id#0L,count#9L]) */ /* 156 */ /* 157 */ if (!agg_initAgg) { /* 158 */ agg_initAgg = true; /* 159 */ agg_doAggregateWithKeys(); /* 160 */ } /* 161 */ /* 162 */ // output the result /* 163 */ while (agg_mapIter.next()) { /* 164 */ wholestagecodegen_metricValue.add(1); /* 165 */ UnsafeRow agg_aggKey = (UnsafeRow) agg_mapIter.getKey(); /* 166 */ UnsafeRow agg_aggBuffer1 = (UnsafeRow) agg_mapIter.getValue(); /* 167 */ /* 168 */ UnsafeRow agg_resultRow = agg_unsafeRowJoiner.join(agg_aggKey, agg_aggBuffer1); /* 169 */ /* 170 */ /*** CONSUME: WholeStageCodegen */ /* 171 */ /* 172 */ append(agg_resultRow); /* 173 */ /* 174 */ if (shouldStop()) return; /* 175 */ } /* 176 */ /* 177 */ agg_mapIter.close(); /* 178 */ if (agg_sorter == null) { /* 179 */ agg_hashMap.free(); /* 180 */ } /* 181 */ } /* 182 */ } == Subtree 2 / 3 == WholeStageCodegen : +- Sort [id#0L ASC], true, 0 : +- INPUT +- Exchange rangepartitioning(id#0L ASC, 200), None +- WholeStageCodegen : +- TungstenAggregate(key=[id#0L], functions=[(count(1),mode=Final,isDistinct=false)], output=[id#0L,count#4L]) : +- INPUT +- Exchange hashpartitioning(id#0L, 200), None +- WholeStageCodegen : +- TungstenAggregate(key=[id#0L], functions=[(count(1),mode=Partial,isDistinct=false)], output=[id#0L,count#9L]) : +- Range 0, 1, 1, 100, [id#0L] Generated code: /* 001 */ public Object generate(Object[] references) { /* 002 */ return new GeneratedIterator(references); /* 003 */ } /* 004 */ /* 005 */ /** Codegened pipeline for: /* 006 */ * Sort [id#0L ASC], true, 0 /* 007 */ +- INPUT /* 008 */ */ /* 009 */ final class GeneratedIterator extends org.apache.spark.sql.execution.BufferedRowIterator { /* 010 */ private Object[] references; /* 011 */ private boolean sort_needToSort; /* 012 */ private org.apache.spark.sql.execution.Sort sort_plan; /* 013 */ private org.apache.spark.sql.execution.UnsafeExternalRowSorter sort_sorter; /* 014 */ private org.apache.spark.executor.TaskMetrics sort_metrics; /* 015 */ private scala.collection.Iterator<UnsafeRow> sort_sortedIter; /* 016 */ private scala.collection.Iterator inputadapter_input; /* 017 */ private org.apache.spark.sql.execution.metric.LongSQLMetric sort_dataSize; /* 018 */ private org.apache.spark.sql.execution.metric.LongSQLMetricValue sort_metricValue; /* 019 */ private org.apache.spark.sql.execution.metric.LongSQLMetric sort_spillSize; /* 020 */ private org.apache.spark.sql.execution.metric.LongSQLMetricValue sort_metricValue1; /* 021 */ /* 022 */ public GeneratedIterator(Object[] references) { /* 023 */ this.references = references; /* 024 */ } /* 025 */ /* 026 */ public void init(scala.collection.Iterator inputs[]) { /* 027 */ sort_needToSort = true; /* 028 */ this.sort_plan = (org.apache.spark.sql.execution.Sort) references[0]; /* 029 */ sort_sorter = sort_plan.createSorter(); /* 030 */ sort_metrics = org.apache.spark.TaskContext.get().taskMetrics(); /* 031 */ /* 032 */ inputadapter_input = inputs[0]; /* 033 */ this.sort_dataSize = (org.apache.spark.sql.execution.metric.LongSQLMetric) references[1]; /* 034 */ sort_metricValue = (org.apache.spark.sql.execution.metric.LongSQLMetricValue) sort_dataSize.localValue(); /* 035 */ this.sort_spillSize = (org.apache.spark.sql.execution.metric.LongSQLMetric) references[2]; /* 036 */ sort_metricValue1 = (org.apache.spark.sql.execution.metric.LongSQLMetricValue) sort_spillSize.localValue(); /* 037 */ } /* 038 */ /* 039 */ private void sort_addToSorter() throws java.io.IOException { /* 040 */ /*** PRODUCE: INPUT */ /* 041 */ /* 042 */ while (inputadapter_input.hasNext()) { /* 043 */ InternalRow inputadapter_row = (InternalRow) inputadapter_input.next(); /* 044 */ /*** CONSUME: Sort [id#0L ASC], true, 0 */ /* 045 */ /* 046 */ sort_sorter.insertRow((UnsafeRow)inputadapter_row); /* 047 */ if (shouldStop()) return; /* 048 */ } /* 049 */ /* 050 */ } /* 051 */ /* 052 */ protected void processNext() throws java.io.IOException { /* 053 */ /*** PRODUCE: Sort [id#0L ASC], true, 0 */ /* 054 */ if (sort_needToSort) { /* 055 */ sort_addToSorter(); /* 056 */ Long sort_spillSizeBefore = sort_metrics.memoryBytesSpilled(); /* 057 */ sort_sortedIter = sort_sorter.sort(); /* 058 */ sort_metricValue.add(sort_sorter.getPeakMemoryUsage()); /* 059 */ sort_metricValue1.add(sort_metrics.memoryBytesSpilled() - sort_spillSizeBefore); /* 060 */ sort_metrics.incPeakExecutionMemory(sort_sorter.getPeakMemoryUsage()); /* 061 */ sort_needToSort = false; /* 062 */ } /* 063 */ /* 064 */ while (sort_sortedIter.hasNext()) { /* 065 */ UnsafeRow sort_outputRow = (UnsafeRow)sort_sortedIter.next(); /* 066 */ /* 067 */ /*** CONSUME: WholeStageCodegen */ /* 068 */ /* 069 */ append(sort_outputRow); /* 070 */ /* 071 */ if (shouldStop()) return; /* 072 */ } /* 073 */ } /* 074 */ } == Subtree 3 / 3 == WholeStageCodegen : +- TungstenAggregate(key=[id#0L], functions=[(count(1),mode=Final,isDistinct=false)], output=[id#0L,count#4L]) : +- INPUT +- Exchange hashpartitioning(id#0L, 200), None +- WholeStageCodegen : +- TungstenAggregate(key=[id#0L], functions=[(count(1),mode=Partial,isDistinct=false)], output=[id#0L,count#9L]) : +- Range 0, 1, 1, 100, [id#0L] Generated code: /* 001 */ public Object generate(Object[] references) { /* 002 */ return new GeneratedIterator(references); /* 003 */ } /* 004 */ /* 005 */ /** Codegened pipeline for: /* 006 */ * TungstenAggregate(key=[id#0L], functions=[(count(1),mode=Final,isDistinct=false)], output=[id#0L,count#4L]) /* 007 */ +- INPUT /* 008 */ */ /* 009 */ final class GeneratedIterator extends org.apache.spark.sql.execution.BufferedRowIterator { /* 010 */ private Object[] references; /* 011 */ private boolean agg_initAgg; /* 012 */ private org.apache.spark.sql.execution.aggregate.TungstenAggregate agg_plan; /* 013 */ private org.apache.spark.sql.execution.UnsafeFixedWidthAggregationMap agg_hashMap; /* 014 */ private org.apache.spark.sql.execution.UnsafeKVExternalSorter agg_sorter; /* 015 */ private org.apache.spark.unsafe.KVIterator agg_mapIter; /* 016 */ private scala.collection.Iterator inputadapter_input; /* 017 */ private UnsafeRow agg_result; /* 018 */ private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder agg_holder; /* 019 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter agg_rowWriter; /* 020 */ private UnsafeRow agg_result1; /* 021 */ private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder agg_holder1; /* 022 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter agg_rowWriter1; /* 023 */ private org.apache.spark.sql.execution.metric.LongSQLMetric wholestagecodegen_numOutputRows; /* 024 */ private org.apache.spark.sql.execution.metric.LongSQLMetricValue wholestagecodegen_metricValue; /* 025 */ /* 026 */ public GeneratedIterator(Object[] references) { /* 027 */ this.references = references; /* 028 */ } /* 029 */ /* 030 */ public void init(scala.collection.Iterator inputs[]) { /* 031 */ agg_initAgg = false; /* 032 */ this.agg_plan = (org.apache.spark.sql.execution.aggregate.TungstenAggregate) references[0]; /* 033 */ agg_hashMap = agg_plan.createHashMap(); /* 034 */ /* 035 */ inputadapter_input = inputs[0]; /* 036 */ agg_result = new UnsafeRow(1); /* 037 */ this.agg_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(agg_result, 0); /* 038 */ this.agg_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(agg_holder, 1); /* 039 */ agg_result1 = new UnsafeRow(2); /* 040 */ this.agg_holder1 = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(agg_result1, 0); /* 041 */ this.agg_rowWriter1 = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(agg_holder1, 2); /* 042 */ this.wholestagecodegen_numOutputRows = (org.apache.spark.sql.execution.metric.LongSQLMetric) references[1]; /* 043 */ wholestagecodegen_metricValue = (org.apache.spark.sql.execution.metric.LongSQLMetricValue) wholestagecodegen_numOutputRows.localValue(); /* 044 */ } /* 045 */ /* 046 */ private void agg_doAggregateWithKeys() throws java.io.IOException { /* 047 */ /*** PRODUCE: INPUT */ /* 048 */ /* 049 */ while (inputadapter_input.hasNext()) { /* 050 */ InternalRow inputadapter_row = (InternalRow) inputadapter_input.next(); /* 051 */ /*** CONSUME: TungstenAggregate(key=[id#0L], functions=[(count(1),mode=Final,isDistinct=false)], output=[id#0L,count#4L]) */ /* 052 */ /* input[0, bigint] */ /* 053 */ long inputadapter_value = inputadapter_row.getLong(0); /* 054 */ /* input[1, bigint] */ /* 055 */ long inputadapter_value1 = inputadapter_row.getLong(1); /* 056 */ /* 057 */ // generate grouping key /* 058 */ agg_rowWriter.write(0, inputadapter_value); /* 059 */ /* hash(input[0, bigint], 42) */ /* 060 */ int agg_value1 = 42; /* 061 */ /* 062 */ agg_value1 = org.apache.spark.unsafe.hash.Murmur3_x86_32.hashLong(inputadapter_value, agg_value1); /* 063 */ UnsafeRow agg_aggBuffer = null; /* 064 */ if (true) { /* 065 */ // try to get the buffer from hash map /* 066 */ agg_aggBuffer = agg_hashMap.getAggregationBufferFromUnsafeRow(agg_result, agg_value1); /* 067 */ } /* 068 */ if (agg_aggBuffer == null) { /* 069 */ if (agg_sorter == null) { /* 070 */ agg_sorter = agg_hashMap.destructAndCreateExternalSorter(); /* 071 */ } else { /* 072 */ agg_sorter.merge(agg_hashMap.destructAndCreateExternalSorter()); /* 073 */ } /* 074 */ /* 075 */ // the hash map had be spilled, it should have enough memory now, /* 076 */ // try to allocate buffer again. /* 077 */ agg_aggBuffer = agg_hashMap.getAggregationBufferFromUnsafeRow(agg_result, agg_value1); /* 078 */ if (agg_aggBuffer == null) { /* 079 */ // failed to allocate the first page /* 080 */ throw new OutOfMemoryError("No enough memory for aggregation"); /* 081 */ } /* 082 */ } /* 083 */ /* 084 */ // evaluate aggregate function /* 085 */ /* (input[0, bigint] + input[2, bigint]) */ /* 086 */ /* input[0, bigint] */ /* 087 */ long agg_value4 = agg_aggBuffer.getLong(0); /* 088 */ /* 089 */ long agg_value3 = -1L; /* 090 */ agg_value3 = agg_value4 + inputadapter_value1; /* 091 */ // update aggregate buffer /* 092 */ agg_aggBuffer.setLong(0, agg_value3); /* 093 */ if (shouldStop()) return; /* 094 */ } /* 095 */ /* 096 */ agg_mapIter = agg_plan.finishAggregate(agg_hashMap, agg_sorter); /* 097 */ } /* 098 */ /* 099 */ protected void processNext() throws java.io.IOException { /* 100 */ /*** PRODUCE: TungstenAggregate(key=[id#0L], functions=[(count(1),mode=Final,isDistinct=false)], output=[id#0L,count#4L]) */ /* 101 */ /* 102 */ if (!agg_initAgg) { /* 103 */ agg_initAgg = true; /* 104 */ agg_doAggregateWithKeys(); /* 105 */ } /* 106 */ /* 107 */ // output the result /* 108 */ while (agg_mapIter.next()) { /* 109 */ wholestagecodegen_metricValue.add(1); /* 110 */ UnsafeRow agg_aggKey = (UnsafeRow) agg_mapIter.getKey(); /* 111 */ UnsafeRow agg_aggBuffer1 = (UnsafeRow) agg_mapIter.getValue(); /* 112 */ /* 113 */ /* input[0, bigint] */ /* 114 */ long agg_value6 = agg_aggKey.getLong(0); /* 115 */ /* input[0, bigint] */ /* 116 */ long agg_value7 = agg_aggBuffer1.getLong(0); /* 117 */ /* 118 */ /*** CONSUME: WholeStageCodegen */ /* 119 */ /* 120 */ agg_rowWriter1.write(0, agg_value6); /* 121 */ /* 122 */ agg_rowWriter1.write(1, agg_value7); /* 123 */ append(agg_result1); /* 124 */ /* 125 */ if (shouldStop()) return; /* 126 */ } /* 127 */ /* 128 */ agg_mapIter.close(); /* 129 */ if (agg_sorter == null) { /* 130 */ agg_hashMap.free(); /* 131 */ } /* 132 */ } /* 133 */ } ``` rxin Author: Eric Liang <ekl@databricks.com> Closes #12025 from ericl/spark-14227.
Diffstat (limited to 'sql/core')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala13
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala46
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala7
3 files changed, 60 insertions, 6 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
index 1b13c8fd22..da3ee46b7d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
@@ -297,7 +297,12 @@ case class WholeStageCodegen(child: SparkPlan) extends UnaryNode with CodegenSup
"pipelineTime" -> SQLMetrics.createTimingMetric(sparkContext,
WholeStageCodegen.PIPELINE_DURATION_METRIC))
- override def doExecute(): RDD[InternalRow] = {
+ /**
+ * Generates code for this subtree.
+ *
+ * @return the tuple of the codegen context and the actual generated source.
+ */
+ def doCodeGen(): (CodegenContext, String) = {
val ctx = new CodegenContext
val code = child.asInstanceOf[CodegenSupport].produce(ctx, this)
val references = ctx.references.toArray
@@ -334,6 +339,12 @@ case class WholeStageCodegen(child: SparkPlan) extends UnaryNode with CodegenSup
val cleanedSource = CodeFormatter.stripExtraNewLines(source)
logDebug(s"${CodeFormatter.format(cleanedSource)}")
CodeGenerator.compile(cleanedSource)
+ (ctx, cleanedSource)
+ }
+
+ override def doExecute(): RDD[InternalRow] = {
+ val (ctx, cleanedSource) = doCodeGen()
+ val references = ctx.references.toArray
val durationMs = longMetric("pipelineTime")
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala
index 5e573b3159..9916482a68 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala
@@ -25,7 +25,7 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Attribute
-import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
+import org.apache.spark.sql.catalyst.expressions.codegen.{CodeFormatter, CodegenContext, ExprCode}
import org.apache.spark.sql.catalyst.trees.TreeNodeRef
import org.apache.spark.sql.internal.SQLConf
@@ -41,6 +41,13 @@ import org.apache.spark.sql.internal.SQLConf
*/
package object debug {
+ /** Helper function to evade the println() linter. */
+ private def debugPrint(msg: String): Unit = {
+ // scalastyle:off println
+ println(msg)
+ // scalastyle:on println
+ }
+
/**
* Augments [[SQLContext]] with debug methods.
*/
@@ -62,12 +69,41 @@ package object debug {
visited += new TreeNodeRef(s)
DebugNode(s)
}
- logDebug(s"Results returned: ${debugPlan.execute().count()}")
+ debugPrint(s"Results returned: ${debugPlan.execute().count()}")
debugPlan.foreach {
case d: DebugNode => d.dumpStats()
case _ =>
}
}
+
+ /**
+ * Prints to stdout all the generated code found in this plan (i.e. the output of each
+ * WholeStageCodegen subtree).
+ */
+ def debugCodegen(): Unit = {
+ debugPrint(debugCodegenString())
+ }
+
+ /** Visible for testing. */
+ def debugCodegenString(): String = {
+ val plan = query.queryExecution.executedPlan
+ val codegenSubtrees = new collection.mutable.HashSet[WholeStageCodegen]()
+ plan transform {
+ case s: WholeStageCodegen =>
+ codegenSubtrees += s
+ s
+ case s => s
+ }
+ var output = s"Found ${codegenSubtrees.size} WholeStageCodegen subtrees.\n"
+ for ((s, i) <- codegenSubtrees.toSeq.zipWithIndex) {
+ output += s"== Subtree ${i + 1} / ${codegenSubtrees.size} ==\n"
+ output += s
+ output += "\nGenerated code:\n"
+ val (_, source) = s.doCodeGen()
+ output += s"${CodeFormatter.format(source)}\n"
+ }
+ output
+ }
}
private[sql] case class DebugNode(child: SparkPlan) extends UnaryNode with CodegenSupport {
@@ -99,11 +135,11 @@ package object debug {
val columnStats: Array[ColumnMetrics] = Array.fill(child.output.size)(new ColumnMetrics())
def dumpStats(): Unit = {
- logDebug(s"== ${child.simpleString} ==")
- logDebug(s"Tuples output: ${tupleCount.value}")
+ debugPrint(s"== ${child.simpleString} ==")
+ debugPrint(s"Tuples output: ${tupleCount.value}")
child.output.zip(columnStats).foreach { case (attr, metric) =>
val actualDataTypes = metric.elementTypes.value.mkString("{", ",", "}")
- logDebug(s" ${attr.name} ${attr.dataType}: $actualDataTypes")
+ debugPrint(s" ${attr.name} ${attr.dataType}: $actualDataTypes")
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala
index 22189477d2..979265e274 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala
@@ -25,4 +25,11 @@ class DebuggingSuite extends SparkFunSuite with SharedSQLContext {
test("DataFrame.debug()") {
testData.debug()
}
+
+ test("debugCodegen") {
+ val res = sqlContext.range(10).groupBy("id").count().debugCodegenString()
+ assert(res.contains("Subtree 1 / 2"))
+ assert(res.contains("Subtree 2 / 2"))
+ assert(res.contains("Object[]"))
+ }
}