diff options
author | Kazuaki Ishizaki <ishizaki@jp.ibm.com> | 2017-03-10 18:04:37 +0100 |
---|---|---|
committer | Herman van Hovell <hvanhovell@databricks.com> | 2017-03-10 18:04:37 +0100 |
commit | fcb68e0f5d49234ac4527109887ff08cd4e1c29f (patch) | |
tree | a3277000d0386030805c2bf3cdd5e1a17ee28e7c /core/src | |
parent | 501b7111997bc74754663348967104181b43319b (diff) | |
download | spark-fcb68e0f5d49234ac4527109887ff08cd4e1c29f.tar.gz spark-fcb68e0f5d49234ac4527109887ff08cd4e1c29f.tar.bz2 spark-fcb68e0f5d49234ac4527109887ff08cd4e1c29f.zip |
[SPARK-19786][SQL] Facilitate loop optimizations in a JIT compiler regarding range()
## What changes were proposed in this pull request?
This PR improves performance of operations with `range()` by changing Java code generated by Catalyst. This PR is inspired by the [blog article](https://databricks.com/blog/2017/02/16/processing-trillion-rows-per-second-single-machine-can-nested-loop-joins-fast.html).
This PR changes generated code in the following two points.
1. Replace a while-loop with long instance variables a for-loop with int local varibles
2. Suppress generation of `shouldStop()` method if this method is unnecessary (e.g. `append()` is not generated).
These points facilitates compiler optimizations in a JIT compiler by feeding the simplified Java code into the JIT compiler. The performance is improved by 7.6x.
Benchmark program:
```java
val N = 1 << 29
val iters = 2
val benchmark = new Benchmark("range.count", N * iters)
benchmark.addCase(s"with this PR") { i =>
var n = 0
var len = 0
while (n < iters) {
len += sparkSession.range(N).selectExpr("count(id)").collect.length
n += 1
}
}
benchmark.run
```
Performance result without this PR
```
OpenJDK 64-Bit Server VM 1.8.0_111-8u111-b14-2ubuntu0.16.04.2-b14 on Linux 4.4.0-47-generic
Intel(R) Xeon(R) CPU E5-2667 v3 3.20GHz
range.count: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
------------------------------------------------------------------------------------------------
w/o this PR 1349 / 1356 796.2 1.3 1.0X
```
Performance result with this PR
```
OpenJDK 64-Bit Server VM 1.8.0_111-8u111-b14-2ubuntu0.16.04.2-b14 on Linux 4.4.0-47-generic
Intel(R) Xeon(R) CPU E5-2667 v3 3.20GHz
range.count: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
------------------------------------------------------------------------------------------------
with this PR 177 / 271 6065.3 0.2 1.0X
```
Here is a comparison between generated code w/o and with this PR. Only the method ```agg_doAggregateWithoutKey``` is changed.
Generated code without this PR
```java
/* 005 */ final class GeneratedIterator extends org.apache.spark.sql.execution.BufferedRowIterator {
/* 006 */ private Object[] references;
/* 007 */ private scala.collection.Iterator[] inputs;
/* 008 */ private boolean agg_initAgg;
/* 009 */ private boolean agg_bufIsNull;
/* 010 */ private long agg_bufValue;
/* 011 */ private org.apache.spark.sql.execution.metric.SQLMetric range_numOutputRows;
/* 012 */ private org.apache.spark.sql.execution.metric.SQLMetric range_numGeneratedRows;
/* 013 */ private boolean range_initRange;
/* 014 */ private long range_number;
/* 015 */ private TaskContext range_taskContext;
/* 016 */ private InputMetrics range_inputMetrics;
/* 017 */ private long range_batchEnd;
/* 018 */ private long range_numElementsTodo;
/* 019 */ private scala.collection.Iterator range_input;
/* 020 */ private UnsafeRow range_result;
/* 021 */ private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder range_holder;
/* 022 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter range_rowWriter;
/* 023 */ private org.apache.spark.sql.execution.metric.SQLMetric agg_numOutputRows;
/* 024 */ private org.apache.spark.sql.execution.metric.SQLMetric agg_aggTime;
/* 025 */ private UnsafeRow agg_result;
/* 026 */ private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder agg_holder;
/* 027 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter agg_rowWriter;
/* 028 */
/* 029 */ public GeneratedIterator(Object[] references) {
/* 030 */ this.references = references;
/* 031 */ }
/* 032 */
/* 033 */ public void init(int index, scala.collection.Iterator[] inputs) {
/* 034 */ partitionIndex = index;
/* 035 */ this.inputs = inputs;
/* 036 */ agg_initAgg = false;
/* 037 */
/* 038 */ this.range_numOutputRows = (org.apache.spark.sql.execution.metric.SQLMetric) references[0];
/* 039 */ this.range_numGeneratedRows = (org.apache.spark.sql.execution.metric.SQLMetric) references[1];
/* 040 */ range_initRange = false;
/* 041 */ range_number = 0L;
/* 042 */ range_taskContext = TaskContext.get();
/* 043 */ range_inputMetrics = range_taskContext.taskMetrics().inputMetrics();
/* 044 */ range_batchEnd = 0;
/* 045 */ range_numElementsTodo = 0L;
/* 046 */ range_input = inputs[0];
/* 047 */ range_result = new UnsafeRow(1);
/* 048 */ this.range_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(range_result, 0);
/* 049 */ this.range_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(range_holder, 1);
/* 050 */ this.agg_numOutputRows = (org.apache.spark.sql.execution.metric.SQLMetric) references[2];
/* 051 */ this.agg_aggTime = (org.apache.spark.sql.execution.metric.SQLMetric) references[3];
/* 052 */ agg_result = new UnsafeRow(1);
/* 053 */ this.agg_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(agg_result, 0);
/* 054 */ this.agg_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(agg_holder, 1);
/* 055 */
/* 056 */ }
/* 057 */
/* 058 */ private void agg_doAggregateWithoutKey() throws java.io.IOException {
/* 059 */ // initialize aggregation buffer
/* 060 */ agg_bufIsNull = false;
/* 061 */ agg_bufValue = 0L;
/* 062 */
/* 063 */ // initialize Range
/* 064 */ if (!range_initRange) {
/* 065 */ range_initRange = true;
/* 066 */ initRange(partitionIndex);
/* 067 */ }
/* 068 */
/* 069 */ while (true) {
/* 070 */ while (range_number != range_batchEnd) {
/* 071 */ long range_value = range_number;
/* 072 */ range_number += 1L;
/* 073 */
/* 074 */ // do aggregate
/* 075 */ // common sub-expressions
/* 076 */
/* 077 */ // evaluate aggregate function
/* 078 */ boolean agg_isNull1 = false;
/* 079 */
/* 080 */ long agg_value1 = -1L;
/* 081 */ agg_value1 = agg_bufValue + 1L;
/* 082 */ // update aggregation buffer
/* 083 */ agg_bufIsNull = false;
/* 084 */ agg_bufValue = agg_value1;
/* 085 */
/* 086 */ if (shouldStop()) return;
/* 087 */ }
/* 088 */
/* 089 */ if (range_taskContext.isInterrupted()) {
/* 090 */ throw new TaskKilledException();
/* 091 */ }
/* 092 */
/* 093 */ long range_nextBatchTodo;
/* 094 */ if (range_numElementsTodo > 1000L) {
/* 095 */ range_nextBatchTodo = 1000L;
/* 096 */ range_numElementsTodo -= 1000L;
/* 097 */ } else {
/* 098 */ range_nextBatchTodo = range_numElementsTodo;
/* 099 */ range_numElementsTodo = 0;
/* 100 */ if (range_nextBatchTodo == 0) break;
/* 101 */ }
/* 102 */ range_numOutputRows.add(range_nextBatchTodo);
/* 103 */ range_inputMetrics.incRecordsRead(range_nextBatchTodo);
/* 104 */
/* 105 */ range_batchEnd += range_nextBatchTodo * 1L;
/* 106 */ }
/* 107 */
/* 108 */ }
/* 109 */
/* 110 */ private void initRange(int idx) {
/* 111 */ java.math.BigInteger index = java.math.BigInteger.valueOf(idx);
/* 112 */ java.math.BigInteger numSlice = java.math.BigInteger.valueOf(2L);
/* 113 */ java.math.BigInteger numElement = java.math.BigInteger.valueOf(10000L);
/* 114 */ java.math.BigInteger step = java.math.BigInteger.valueOf(1L);
/* 115 */ java.math.BigInteger start = java.math.BigInteger.valueOf(0L);
/* 117 */
/* 118 */ java.math.BigInteger st = index.multiply(numElement).divide(numSlice).multiply(step).add(start);
/* 119 */ if (st.compareTo(java.math.BigInteger.valueOf(Long.MAX_VALUE)) > 0) {
/* 120 */ range_number = Long.MAX_VALUE;
/* 121 */ } else if (st.compareTo(java.math.BigInteger.valueOf(Long.MIN_VALUE)) < 0) {
/* 122 */ range_number = Long.MIN_VALUE;
/* 123 */ } else {
/* 124 */ range_number = st.longValue();
/* 125 */ }
/* 126 */ range_batchEnd = range_number;
/* 127 */
/* 128 */ java.math.BigInteger end = index.add(java.math.BigInteger.ONE).multiply(numElement).divide(numSlice)
/* 129 */ .multiply(step).add(start);
/* 130 */ if (end.compareTo(java.math.BigInteger.valueOf(Long.MAX_VALUE)) > 0) {
/* 131 */ partitionEnd = Long.MAX_VALUE;
/* 132 */ } else if (end.compareTo(java.math.BigInteger.valueOf(Long.MIN_VALUE)) < 0) {
/* 133 */ partitionEnd = Long.MIN_VALUE;
/* 134 */ } else {
/* 135 */ partitionEnd = end.longValue();
/* 136 */ }
/* 137 */
/* 138 */ java.math.BigInteger startToEnd = java.math.BigInteger.valueOf(partitionEnd).subtract(
/* 139 */ java.math.BigInteger.valueOf(range_number));
/* 140 */ range_numElementsTodo = startToEnd.divide(step).longValue();
/* 141 */ if (range_numElementsTodo < 0) {
/* 142 */ range_numElementsTodo = 0;
/* 143 */ } else if (startToEnd.remainder(step).compareTo(java.math.BigInteger.valueOf(0L)) != 0) {
/* 144 */ range_numElementsTodo++;
/* 145 */ }
/* 146 */ }
/* 147 */
/* 148 */ protected void processNext() throws java.io.IOException {
/* 149 */ while (!agg_initAgg) {
/* 150 */ agg_initAgg = true;
/* 151 */ long agg_beforeAgg = System.nanoTime();
/* 152 */ agg_doAggregateWithoutKey();
/* 153 */ agg_aggTime.add((System.nanoTime() - agg_beforeAgg) / 1000000);
/* 154 */
/* 155 */ // output the result
/* 156 */
/* 157 */ agg_numOutputRows.add(1);
/* 158 */ agg_rowWriter.zeroOutNullBytes();
/* 159 */
/* 160 */ if (agg_bufIsNull) {
/* 161 */ agg_rowWriter.setNullAt(0);
/* 162 */ } else {
/* 163 */ agg_rowWriter.write(0, agg_bufValue);
/* 164 */ }
/* 165 */ append(agg_result);
/* 166 */ }
/* 167 */ }
/* 168 */ }
```
Generated code with this PR
```java
/* 005 */ final class GeneratedIterator extends org.apache.spark.sql.execution.BufferedRowIterator {
/* 006 */ private Object[] references;
/* 007 */ private scala.collection.Iterator[] inputs;
/* 008 */ private boolean agg_initAgg;
/* 009 */ private boolean agg_bufIsNull;
/* 010 */ private long agg_bufValue;
/* 011 */ private org.apache.spark.sql.execution.metric.SQLMetric range_numOutputRows;
/* 012 */ private org.apache.spark.sql.execution.metric.SQLMetric range_numGeneratedRows;
/* 013 */ private boolean range_initRange;
/* 014 */ private long range_number;
/* 015 */ private TaskContext range_taskContext;
/* 016 */ private InputMetrics range_inputMetrics;
/* 017 */ private long range_batchEnd;
/* 018 */ private long range_numElementsTodo;
/* 019 */ private scala.collection.Iterator range_input;
/* 020 */ private UnsafeRow range_result;
/* 021 */ private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder range_holder;
/* 022 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter range_rowWriter;
/* 023 */ private org.apache.spark.sql.execution.metric.SQLMetric agg_numOutputRows;
/* 024 */ private org.apache.spark.sql.execution.metric.SQLMetric agg_aggTime;
/* 025 */ private UnsafeRow agg_result;
/* 026 */ private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder agg_holder;
/* 027 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter agg_rowWriter;
/* 028 */
/* 029 */ public GeneratedIterator(Object[] references) {
/* 030 */ this.references = references;
/* 031 */ }
/* 032 */
/* 033 */ public void init(int index, scala.collection.Iterator[] inputs) {
/* 034 */ partitionIndex = index;
/* 035 */ this.inputs = inputs;
/* 036 */ agg_initAgg = false;
/* 037 */
/* 038 */ this.range_numOutputRows = (org.apache.spark.sql.execution.metric.SQLMetric) references[0];
/* 039 */ this.range_numGeneratedRows = (org.apache.spark.sql.execution.metric.SQLMetric) references[1];
/* 040 */ range_initRange = false;
/* 041 */ range_number = 0L;
/* 042 */ range_taskContext = TaskContext.get();
/* 043 */ range_inputMetrics = range_taskContext.taskMetrics().inputMetrics();
/* 044 */ range_batchEnd = 0;
/* 045 */ range_numElementsTodo = 0L;
/* 046 */ range_input = inputs[0];
/* 047 */ range_result = new UnsafeRow(1);
/* 048 */ this.range_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(range_result, 0);
/* 049 */ this.range_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(range_holder, 1);
/* 050 */ this.agg_numOutputRows = (org.apache.spark.sql.execution.metric.SQLMetric) references[2];
/* 051 */ this.agg_aggTime = (org.apache.spark.sql.execution.metric.SQLMetric) references[3];
/* 052 */ agg_result = new UnsafeRow(1);
/* 053 */ this.agg_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(agg_result, 0);
/* 054 */ this.agg_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(agg_holder, 1);
/* 055 */
/* 056 */ }
/* 057 */
/* 058 */ private void agg_doAggregateWithoutKey() throws java.io.IOException {
/* 059 */ // initialize aggregation buffer
/* 060 */ agg_bufIsNull = false;
/* 061 */ agg_bufValue = 0L;
/* 062 */
/* 063 */ // initialize Range
/* 064 */ if (!range_initRange) {
/* 065 */ range_initRange = true;
/* 066 */ initRange(partitionIndex);
/* 067 */ }
/* 068 */
/* 069 */ while (true) {
/* 070 */ long range_range = range_batchEnd - range_number;
/* 071 */ if (range_range != 0L) {
/* 072 */ int range_localEnd = (int)(range_range / 1L);
/* 073 */ for (int range_localIdx = 0; range_localIdx < range_localEnd; range_localIdx++) {
/* 074 */ long range_value = ((long)range_localIdx * 1L) + range_number;
/* 075 */
/* 076 */ // do aggregate
/* 077 */ // common sub-expressions
/* 078 */
/* 079 */ // evaluate aggregate function
/* 080 */ boolean agg_isNull1 = false;
/* 081 */
/* 082 */ long agg_value1 = -1L;
/* 083 */ agg_value1 = agg_bufValue + 1L;
/* 084 */ // update aggregation buffer
/* 085 */ agg_bufIsNull = false;
/* 086 */ agg_bufValue = agg_value1;
/* 087 */
/* 088 */ // shouldStop check is eliminated
/* 089 */ }
/* 090 */ range_number = range_batchEnd;
/* 091 */ }
/* 092 */
/* 093 */ if (range_taskContext.isInterrupted()) {
/* 094 */ throw new TaskKilledException();
/* 095 */ }
/* 096 */
/* 097 */ long range_nextBatchTodo;
/* 098 */ if (range_numElementsTodo > 1000L) {
/* 099 */ range_nextBatchTodo = 1000L;
/* 100 */ range_numElementsTodo -= 1000L;
/* 101 */ } else {
/* 102 */ range_nextBatchTodo = range_numElementsTodo;
/* 103 */ range_numElementsTodo = 0;
/* 104 */ if (range_nextBatchTodo == 0) break;
/* 105 */ }
/* 106 */ range_numOutputRows.add(range_nextBatchTodo);
/* 107 */ range_inputMetrics.incRecordsRead(range_nextBatchTodo);
/* 108 */
/* 109 */ range_batchEnd += range_nextBatchTodo * 1L;
/* 110 */ }
/* 111 */
/* 112 */ }
/* 113 */
/* 114 */ private void initRange(int idx) {
/* 115 */ java.math.BigInteger index = java.math.BigInteger.valueOf(idx);
/* 116 */ java.math.BigInteger numSlice = java.math.BigInteger.valueOf(2L);
/* 117 */ java.math.BigInteger numElement = java.math.BigInteger.valueOf(10000L);
/* 118 */ java.math.BigInteger step = java.math.BigInteger.valueOf(1L);
/* 119 */ java.math.BigInteger start = java.math.BigInteger.valueOf(0L);
/* 120 */ long partitionEnd;
/* 121 */
/* 122 */ java.math.BigInteger st = index.multiply(numElement).divide(numSlice).multiply(step).add(start);
/* 123 */ if (st.compareTo(java.math.BigInteger.valueOf(Long.MAX_VALUE)) > 0) {
/* 124 */ range_number = Long.MAX_VALUE;
/* 125 */ } else if (st.compareTo(java.math.BigInteger.valueOf(Long.MIN_VALUE)) < 0) {
/* 126 */ range_number = Long.MIN_VALUE;
/* 127 */ } else {
/* 128 */ range_number = st.longValue();
/* 129 */ }
/* 130 */ range_batchEnd = range_number;
/* 131 */
/* 132 */ java.math.BigInteger end = index.add(java.math.BigInteger.ONE).multiply(numElement).divide(numSlice)
/* 133 */ .multiply(step).add(start);
/* 134 */ if (end.compareTo(java.math.BigInteger.valueOf(Long.MAX_VALUE)) > 0) {
/* 135 */ partitionEnd = Long.MAX_VALUE;
/* 136 */ } else if (end.compareTo(java.math.BigInteger.valueOf(Long.MIN_VALUE)) < 0) {
/* 137 */ partitionEnd = Long.MIN_VALUE;
/* 138 */ } else {
/* 139 */ partitionEnd = end.longValue();
/* 140 */ }
/* 141 */
/* 142 */ java.math.BigInteger startToEnd = java.math.BigInteger.valueOf(partitionEnd).subtract(
/* 143 */ java.math.BigInteger.valueOf(range_number));
/* 144 */ range_numElementsTodo = startToEnd.divide(step).longValue();
/* 145 */ if (range_numElementsTodo < 0) {
/* 146 */ range_numElementsTodo = 0;
/* 147 */ } else if (startToEnd.remainder(step).compareTo(java.math.BigInteger.valueOf(0L)) != 0) {
/* 148 */ range_numElementsTodo++;
/* 149 */ }
/* 150 */ }
/* 151 */
/* 152 */ protected void processNext() throws java.io.IOException {
/* 153 */ while (!agg_initAgg) {
/* 154 */ agg_initAgg = true;
/* 155 */ long agg_beforeAgg = System.nanoTime();
/* 156 */ agg_doAggregateWithoutKey();
/* 157 */ agg_aggTime.add((System.nanoTime() - agg_beforeAgg) / 1000000);
/* 158 */
/* 159 */ // output the result
/* 160 */
/* 161 */ agg_numOutputRows.add(1);
/* 162 */ agg_rowWriter.zeroOutNullBytes();
/* 163 */
/* 164 */ if (agg_bufIsNull) {
/* 165 */ agg_rowWriter.setNullAt(0);
/* 166 */ } else {
/* 167 */ agg_rowWriter.write(0, agg_bufValue);
/* 168 */ }
/* 169 */ append(agg_result);
/* 170 */ }
/* 171 */ }
/* 172 */ }
```
A part of suppressing `shouldStop()` was originally developed by inouehrs
## How was this patch tested?
Add new tests into `DataFrameRangeSuite`
Author: Kazuaki Ishizaki <ishizaki@jp.ibm.com>
Closes #17122 from kiszk/SPARK-19786.
Diffstat (limited to 'core/src')
0 files changed, 0 insertions, 0 deletions