diff options
7 files changed, 1356 insertions, 126 deletions
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/FixedLengthRowBasedKeyValueBatch.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/FixedLengthRowBasedKeyValueBatch.java new file mode 100644 index 0000000000..b6130d1f33 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/FixedLengthRowBasedKeyValueBatch.java @@ -0,0 +1,174 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.expressions; + +import org.apache.spark.memory.TaskMemoryManager; +import org.apache.spark.sql.types.*; +import org.apache.spark.unsafe.Platform; + +/** + * An implementation of `RowBasedKeyValueBatch` in which all key-value records have same length. + * + * The format for each record looks like this: + * [UnsafeRow for key of length klen] [UnsafeRow for Value of length vlen] + * [8 bytes pointer to next] + * Thus, record length = klen + vlen + 8 + */ +public final class FixedLengthRowBasedKeyValueBatch extends RowBasedKeyValueBatch { + private final int klen; + private final int vlen; + private final int recordLength; + + private final long getKeyOffsetForFixedLengthRecords(int rowId) { + return recordStartOffset + rowId * (long) recordLength; + } + + /** + * Append a key value pair. + * It copies data into the backing MemoryBlock. + * Returns an UnsafeRow pointing to the value if succeeds, otherwise returns null. + */ + @Override + public final UnsafeRow appendRow(Object kbase, long koff, int klen, + Object vbase, long voff, int vlen) { + // if run out of max supported rows or page size, return null + if (numRows >= capacity || page == null || page.size() - pageCursor < recordLength) { + return null; + } + + long offset = page.getBaseOffset() + pageCursor; + final long recordOffset = offset; + Platform.copyMemory(kbase, koff, base, offset, klen); + offset += klen; + Platform.copyMemory(vbase, voff, base, offset, vlen); + offset += vlen; + Platform.putLong(base, offset, 0); + + pageCursor += recordLength; + + keyRowId = numRows; + keyRow.pointTo(base, recordOffset, klen); + valueRow.pointTo(base, recordOffset + klen, vlen + 4); + numRows++; + return valueRow; + } + + /** + * Returns the key row in this batch at `rowId`. Returned key row is reused across calls. + */ + @Override + public final UnsafeRow getKeyRow(int rowId) { + assert(rowId >= 0); + assert(rowId < numRows); + if (keyRowId != rowId) { // if keyRowId == rowId, desired keyRow is already cached + long offset = getKeyOffsetForFixedLengthRecords(rowId); + keyRow.pointTo(base, offset, klen); + // set keyRowId so we can check if desired row is cached + keyRowId = rowId; + } + return keyRow; + } + + /** + * Returns the value row by two steps: + * 1) looking up the key row with the same id (skipped if the key row is cached) + * 2) retrieve the value row by reusing the metadata from step 1) + * In most times, 1) is skipped because `getKeyRow(id)` is often called before `getValueRow(id)`. + */ + @Override + protected final UnsafeRow getValueFromKey(int rowId) { + if (keyRowId != rowId) { + getKeyRow(rowId); + } + assert(rowId >= 0); + valueRow.pointTo(base, keyRow.getBaseOffset() + klen, vlen + 4); + return valueRow; + } + + /** + * Returns an iterator to go through all rows + */ + @Override + public final org.apache.spark.unsafe.KVIterator<UnsafeRow, UnsafeRow> rowIterator() { + return new org.apache.spark.unsafe.KVIterator<UnsafeRow, UnsafeRow>() { + private final UnsafeRow key = new UnsafeRow(keySchema.length()); + private final UnsafeRow value = new UnsafeRow(valueSchema.length()); + + private long offsetInPage = 0; + private int recordsInPage = 0; + + private boolean initialized = false; + + private void init() { + if (page != null) { + offsetInPage = page.getBaseOffset(); + recordsInPage = numRows; + } + initialized = true; + } + + @Override + public boolean next() { + if (!initialized) init(); + //searching for the next non empty page is records is now zero + if (recordsInPage == 0) { + freeCurrentPage(); + return false; + } + + key.pointTo(base, offsetInPage, klen); + value.pointTo(base, offsetInPage + klen, vlen + 4); + + offsetInPage += recordLength; + recordsInPage -= 1; + return true; + } + + @Override + public UnsafeRow getKey() { + return key; + } + + @Override + public UnsafeRow getValue() { + return value; + } + + @Override + public void close() { + // do nothing + } + + private void freeCurrentPage() { + if (page != null) { + freePage(page); + page = null; + } + } + }; + } + + protected FixedLengthRowBasedKeyValueBatch(StructType keySchema, StructType valueSchema, + int maxRows, TaskMemoryManager manager) { + super(keySchema, valueSchema, maxRows, manager); + klen = keySchema.defaultSize() + + UnsafeRow.calculateBitSetWidthInBytes(keySchema.length()); + vlen = valueSchema.defaultSize() + + UnsafeRow.calculateBitSetWidthInBytes(valueSchema.length()); + recordLength = klen + vlen + 8; + } +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/RowBasedKeyValueBatch.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/RowBasedKeyValueBatch.java new file mode 100644 index 0000000000..cea9d5d5bc --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/RowBasedKeyValueBatch.java @@ -0,0 +1,182 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.expressions; + +import java.io.IOException; + +import org.apache.spark.memory.MemoryConsumer; +import org.apache.spark.memory.TaskMemoryManager; +import org.apache.spark.sql.types.*; +import org.apache.spark.unsafe.memory.MemoryBlock; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + + +/** + * RowBasedKeyValueBatch stores key value pairs in contiguous memory region. + * + * Each key or value is stored as a single UnsafeRow. Each record contains one key and one value + * and some auxiliary data, which differs based on implementation: + * i.e., `FixedLengthRowBasedKeyValueBatch` and `VariableLengthRowBasedKeyValueBatch`. + * + * We use `FixedLengthRowBasedKeyValueBatch` if all fields in the key and the value are fixed-length + * data types. Otherwise we use `VariableLengthRowBasedKeyValueBatch`. + * + * RowBasedKeyValueBatch is backed by a single page / MemoryBlock (defaults to 64MB). If the page + * is full, the aggregate logic should fallback to a second level, larger hash map. We intentionally + * use the single-page design because it simplifies memory address encoding & decoding for each + * key-value pair. Because the maximum capacity for RowBasedKeyValueBatch is only 2^16, it is + * unlikely we need a second page anyway. Filling the page requires an average size for key value + * pairs to be larger than 1024 bytes. + * + */ +public abstract class RowBasedKeyValueBatch extends MemoryConsumer { + protected final Logger logger = LoggerFactory.getLogger(RowBasedKeyValueBatch.class); + + private static final int DEFAULT_CAPACITY = 1 << 16; + private static final long DEFAULT_PAGE_SIZE = 64 * 1024 * 1024; + + protected final StructType keySchema; + protected final StructType valueSchema; + protected final int capacity; + protected int numRows = 0; + + // ids for current key row and value row being retrieved + protected int keyRowId = -1; + + // placeholder for key and value corresponding to keyRowId. + protected final UnsafeRow keyRow; + protected final UnsafeRow valueRow; + + protected MemoryBlock page = null; + protected Object base = null; + protected final long recordStartOffset; + protected long pageCursor = 0; + + public static RowBasedKeyValueBatch allocate(StructType keySchema, StructType valueSchema, + TaskMemoryManager manager) { + return allocate(keySchema, valueSchema, manager, DEFAULT_CAPACITY); + } + + public static RowBasedKeyValueBatch allocate(StructType keySchema, StructType valueSchema, + TaskMemoryManager manager, int maxRows) { + boolean allFixedLength = true; + // checking if there is any variable length fields + // there is probably a more succinct impl of this + for (String name : keySchema.fieldNames()) { + allFixedLength = allFixedLength + && UnsafeRow.isFixedLength(keySchema.apply(name).dataType()); + } + for (String name : valueSchema.fieldNames()) { + allFixedLength = allFixedLength + && UnsafeRow.isFixedLength(valueSchema.apply(name).dataType()); + } + + if (allFixedLength) { + return new FixedLengthRowBasedKeyValueBatch(keySchema, valueSchema, maxRows, manager); + } else { + return new VariableLengthRowBasedKeyValueBatch(keySchema, valueSchema, maxRows, manager); + } + } + + protected RowBasedKeyValueBatch(StructType keySchema, StructType valueSchema, int maxRows, + TaskMemoryManager manager) { + super(manager, manager.pageSizeBytes(), manager.getTungstenMemoryMode()); + + this.keySchema = keySchema; + this.valueSchema = valueSchema; + this.capacity = maxRows; + + this.keyRow = new UnsafeRow(keySchema.length()); + this.valueRow = new UnsafeRow(valueSchema.length()); + + if (!acquirePage(DEFAULT_PAGE_SIZE)) { + page = null; + recordStartOffset = 0; + } else { + base = page.getBaseObject(); + recordStartOffset = page.getBaseOffset(); + } + } + + public final int numRows() { return numRows; } + + public final void close() { + if (page != null) { + freePage(page); + page = null; + } + } + + private final boolean acquirePage(long requiredSize) { + try { + page = allocatePage(requiredSize); + } catch (OutOfMemoryError e) { + logger.warn("Failed to allocate page ({} bytes).", requiredSize); + return false; + } + base = page.getBaseObject(); + pageCursor = 0; + return true; + } + + /** + * Append a key value pair. + * It copies data into the backing MemoryBlock. + * Returns an UnsafeRow pointing to the value if succeeds, otherwise returns null. + */ + public abstract UnsafeRow appendRow(Object kbase, long koff, int klen, + Object vbase, long voff, int vlen); + + /** + * Returns the key row in this batch at `rowId`. Returned key row is reused across calls. + */ + public abstract UnsafeRow getKeyRow(int rowId); + + /** + * Returns the value row in this batch at `rowId`. Returned value row is reused across calls. + * Because `getValueRow(id)` is always called after `getKeyRow(id)` with the same id, we use + * `getValueFromKey(id) to retrieve value row, which reuses metadata from the cached key. + */ + public final UnsafeRow getValueRow(int rowId) { + return getValueFromKey(rowId); + } + + /** + * Returns the value row by two steps: + * 1) looking up the key row with the same id (skipped if the key row is cached) + * 2) retrieve the value row by reusing the metadata from step 1) + * In most times, 1) is skipped because `getKeyRow(id)` is often called before `getValueRow(id)`. + */ + protected abstract UnsafeRow getValueFromKey(int rowId); + + /** + * Sometimes the TaskMemoryManager may call spill() on its associated MemoryConsumers to make + * space for new consumers. For RowBasedKeyValueBatch, we do not actually spill and return 0. + * We should not throw OutOfMemory exception here because other associated consumers might spill + */ + public final long spill(long size, MemoryConsumer trigger) throws IOException { + logger.warn("Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0."); + return 0; + } + + /** + * Returns an iterator to go through all rows + */ + public abstract org.apache.spark.unsafe.KVIterator<UnsafeRow, UnsafeRow> rowIterator(); +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/VariableLengthRowBasedKeyValueBatch.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/VariableLengthRowBasedKeyValueBatch.java new file mode 100644 index 0000000000..f4002ee0d5 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/VariableLengthRowBasedKeyValueBatch.java @@ -0,0 +1,185 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.expressions; + +import org.apache.spark.memory.TaskMemoryManager; +import org.apache.spark.sql.types.*; +import org.apache.spark.unsafe.Platform; + +/** + * An implementation of `RowBasedKeyValueBatch` in which key-value records have variable lengths. + * + * The format for each record looks like this: + * [4 bytes total size = (klen + vlen + 4)] [4 bytes key size = klen] + * [UnsafeRow for key of length klen] [UnsafeRow for Value of length vlen] + * [8 bytes pointer to next] + * Thus, record length = 4 + 4 + klen + vlen + 8 + */ +public final class VariableLengthRowBasedKeyValueBatch extends RowBasedKeyValueBatch { + // full addresses for key rows and value rows + private final long[] keyOffsets; + + /** + * Append a key value pair. + * It copies data into the backing MemoryBlock. + * Returns an UnsafeRow pointing to the value if succeeds, otherwise returns null. + */ + @Override + public final UnsafeRow appendRow(Object kbase, long koff, int klen, + Object vbase, long voff, int vlen) { + final long recordLength = 8 + klen + vlen + 8; + // if run out of max supported rows or page size, return null + if (numRows >= capacity || page == null || page.size() - pageCursor < recordLength) { + return null; + } + + long offset = page.getBaseOffset() + pageCursor; + final long recordOffset = offset; + Platform.putInt(base, offset, klen + vlen + 4); + Platform.putInt(base, offset + 4, klen); + + offset += 8; + Platform.copyMemory(kbase, koff, base, offset, klen); + offset += klen; + Platform.copyMemory(vbase, voff, base, offset, vlen); + offset += vlen; + Platform.putLong(base, offset, 0); + + pageCursor += recordLength; + + keyOffsets[numRows] = recordOffset + 8; + + keyRowId = numRows; + keyRow.pointTo(base, recordOffset + 8, klen); + valueRow.pointTo(base, recordOffset + 8 + klen, vlen + 4); + numRows++; + return valueRow; + } + + /** + * Returns the key row in this batch at `rowId`. Returned key row is reused across calls. + */ + @Override + public UnsafeRow getKeyRow(int rowId) { + assert(rowId >= 0); + assert(rowId < numRows); + if (keyRowId != rowId) { // if keyRowId == rowId, desired keyRow is already cached + long offset = keyOffsets[rowId]; + int klen = Platform.getInt(base, offset - 4); + keyRow.pointTo(base, offset, klen); + // set keyRowId so we can check if desired row is cached + keyRowId = rowId; + } + return keyRow; + } + + /** + * Returns the value row by two steps: + * 1) looking up the key row with the same id (skipped if the key row is cached) + * 2) retrieve the value row by reusing the metadata from step 1) + * In most times, 1) is skipped because `getKeyRow(id)` is often called before `getValueRow(id)`. + */ + @Override + public final UnsafeRow getValueFromKey(int rowId) { + if (keyRowId != rowId) { + getKeyRow(rowId); + } + assert(rowId >= 0); + long offset = keyRow.getBaseOffset(); + int klen = keyRow.getSizeInBytes(); + int vlen = Platform.getInt(base, offset - 8) - klen - 4; + valueRow.pointTo(base, offset + klen, vlen + 4); + return valueRow; + } + + /** + * Returns an iterator to go through all rows + */ + @Override + public final org.apache.spark.unsafe.KVIterator<UnsafeRow, UnsafeRow> rowIterator() { + return new org.apache.spark.unsafe.KVIterator<UnsafeRow, UnsafeRow>() { + private final UnsafeRow key = new UnsafeRow(keySchema.length()); + private final UnsafeRow value = new UnsafeRow(valueSchema.length()); + + private long offsetInPage = 0; + private int recordsInPage = 0; + + private int currentklen; + private int currentvlen; + private int totalLength; + + private boolean initialized = false; + + private void init() { + if (page != null) { + offsetInPage = page.getBaseOffset(); + recordsInPage = numRows; + } + initialized = true; + } + + @Override + public boolean next() { + if (!initialized) init(); + //searching for the next non empty page is records is now zero + if (recordsInPage == 0) { + freeCurrentPage(); + return false; + } + + totalLength = Platform.getInt(base, offsetInPage) - 4; + currentklen = Platform.getInt(base, offsetInPage + 4); + currentvlen = totalLength - currentklen; + + key.pointTo(base, offsetInPage + 8, currentklen); + value.pointTo(base, offsetInPage + 8 + currentklen, currentvlen + 4); + + offsetInPage += 8 + totalLength + 8; + recordsInPage -= 1; + return true; + } + + @Override + public UnsafeRow getKey() { + return key; + } + + @Override + public UnsafeRow getValue() { + return value; + } + + @Override + public void close() { + // do nothing + } + + private void freeCurrentPage() { + if (page != null) { + freePage(page); + page = null; + } + } + }; + } + + protected VariableLengthRowBasedKeyValueBatch(StructType keySchema, StructType valueSchema, + int maxRows, TaskMemoryManager manager) { + super(keySchema, valueSchema, maxRows, manager); + this.keyOffsets = new long[maxRows]; + } +} diff --git a/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/RowBasedKeyValueBatchSuite.java b/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/RowBasedKeyValueBatchSuite.java new file mode 100644 index 0000000000..0dd129cea7 --- /dev/null +++ b/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/RowBasedKeyValueBatchSuite.java @@ -0,0 +1,425 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions; + +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.SparkConf; +import org.apache.spark.memory.TaskMemoryManager; +import org.apache.spark.memory.TestMemoryManager; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder; +import org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter; +import org.apache.spark.unsafe.types.UTF8String; + +import java.util.Random; + +public class RowBasedKeyValueBatchSuite { + + private final Random rand = new Random(42); + + private TestMemoryManager memoryManager; + private TaskMemoryManager taskMemoryManager; + private StructType keySchema = new StructType().add("k1", DataTypes.LongType) + .add("k2", DataTypes.StringType); + private StructType fixedKeySchema = new StructType().add("k1", DataTypes.LongType) + .add("k2", DataTypes.LongType); + private StructType valueSchema = new StructType().add("count", DataTypes.LongType) + .add("sum", DataTypes.LongType); + private int DEFAULT_CAPACITY = 1 << 16; + + private String getRandomString(int length) { + Assert.assertTrue(length >= 0); + final byte[] bytes = new byte[length]; + rand.nextBytes(bytes); + return new String(bytes); + } + + private UnsafeRow makeKeyRow(long k1, String k2) { + UnsafeRow row = new UnsafeRow(2); + BufferHolder holder = new BufferHolder(row, 32); + UnsafeRowWriter writer = new UnsafeRowWriter(holder, 2); + holder.reset(); + writer.write(0, k1); + writer.write(1, UTF8String.fromString(k2)); + row.setTotalSize(holder.totalSize()); + return row; + } + + private UnsafeRow makeKeyRow(long k1, long k2) { + UnsafeRow row = new UnsafeRow(2); + BufferHolder holder = new BufferHolder(row, 0); + UnsafeRowWriter writer = new UnsafeRowWriter(holder, 2); + holder.reset(); + writer.write(0, k1); + writer.write(1, k2); + row.setTotalSize(holder.totalSize()); + return row; + } + + private UnsafeRow makeValueRow(long v1, long v2) { + UnsafeRow row = new UnsafeRow(2); + BufferHolder holder = new BufferHolder(row, 0); + UnsafeRowWriter writer = new UnsafeRowWriter(holder, 2); + holder.reset(); + writer.write(0, v1); + writer.write(1, v2); + row.setTotalSize(holder.totalSize()); + return row; + } + + private UnsafeRow appendRow(RowBasedKeyValueBatch batch, UnsafeRow key, UnsafeRow value) { + return batch.appendRow(key.getBaseObject(), key.getBaseOffset(), key.getSizeInBytes(), + value.getBaseObject(), value.getBaseOffset(), value.getSizeInBytes()); + } + + private void updateValueRow(UnsafeRow row, long v1, long v2) { + row.setLong(0, v1); + row.setLong(1, v2); + } + + private boolean checkKey(UnsafeRow row, long k1, String k2) { + return (row.getLong(0) == k1) + && (row.getUTF8String(1).equals(UTF8String.fromString(k2))); + } + + private boolean checkKey(UnsafeRow row, long k1, long k2) { + return (row.getLong(0) == k1) + && (row.getLong(1) == k2); + } + + private boolean checkValue(UnsafeRow row, long v1, long v2) { + return (row.getLong(0) == v1) && (row.getLong(1) == v2); + } + + @Before + public void setup() { + memoryManager = new TestMemoryManager(new SparkConf() + .set("spark.memory.offHeap.enabled", "false") + .set("spark.shuffle.spill.compress", "false") + .set("spark.shuffle.compress", "false")); + taskMemoryManager = new TaskMemoryManager(memoryManager, 0); + } + + @After + public void tearDown() { + if (taskMemoryManager != null) { + Assert.assertEquals(0L, taskMemoryManager.cleanUpAllAllocatedMemory()); + long leakedMemory = taskMemoryManager.getMemoryConsumptionForThisTask(); + taskMemoryManager = null; + Assert.assertEquals(0L, leakedMemory); + } + } + + + @Test + public void emptyBatch() throws Exception { + RowBasedKeyValueBatch batch = RowBasedKeyValueBatch.allocate(keySchema, + valueSchema, taskMemoryManager, DEFAULT_CAPACITY); + try { + Assert.assertEquals(0, batch.numRows()); + try { + batch.getKeyRow(-1); + Assert.fail("Should not be able to get row -1"); + } catch (AssertionError e) { + // Expected exception; do nothing. + } + try { + batch.getValueRow(-1); + Assert.fail("Should not be able to get row -1"); + } catch (AssertionError e) { + // Expected exception; do nothing. + } + try { + batch.getKeyRow(0); + Assert.fail("Should not be able to get row 0 when batch is empty"); + } catch (AssertionError e) { + // Expected exception; do nothing. + } + try { + batch.getValueRow(0); + Assert.fail("Should not be able to get row 0 when batch is empty"); + } catch (AssertionError e) { + // Expected exception; do nothing. + } + Assert.assertFalse(batch.rowIterator().next()); + } finally { + batch.close(); + } + } + + @Test + public void batchType() throws Exception { + RowBasedKeyValueBatch batch1 = RowBasedKeyValueBatch.allocate(keySchema, + valueSchema, taskMemoryManager, DEFAULT_CAPACITY); + RowBasedKeyValueBatch batch2 = RowBasedKeyValueBatch.allocate(fixedKeySchema, + valueSchema, taskMemoryManager, DEFAULT_CAPACITY); + try { + Assert.assertEquals(batch1.getClass(), VariableLengthRowBasedKeyValueBatch.class); + Assert.assertEquals(batch2.getClass(), FixedLengthRowBasedKeyValueBatch.class); + } finally { + batch1.close(); + batch2.close(); + } + } + + @Test + public void setAndRetrieve() { + RowBasedKeyValueBatch batch = RowBasedKeyValueBatch.allocate(keySchema, + valueSchema, taskMemoryManager, DEFAULT_CAPACITY); + try { + UnsafeRow ret1 = appendRow(batch, makeKeyRow(1, "A"), makeValueRow(1, 1)); + Assert.assertTrue(checkValue(ret1, 1, 1)); + UnsafeRow ret2 = appendRow(batch, makeKeyRow(2, "B"), makeValueRow(2, 2)); + Assert.assertTrue(checkValue(ret2, 2, 2)); + UnsafeRow ret3 = appendRow(batch, makeKeyRow(3, "C"), makeValueRow(3, 3)); + Assert.assertTrue(checkValue(ret3, 3, 3)); + Assert.assertEquals(3, batch.numRows()); + UnsafeRow retrievedKey1 = batch.getKeyRow(0); + Assert.assertTrue(checkKey(retrievedKey1, 1, "A")); + UnsafeRow retrievedKey2 = batch.getKeyRow(1); + Assert.assertTrue(checkKey(retrievedKey2, 2, "B")); + UnsafeRow retrievedValue1 = batch.getValueRow(1); + Assert.assertTrue(checkValue(retrievedValue1, 2, 2)); + UnsafeRow retrievedValue2 = batch.getValueRow(2); + Assert.assertTrue(checkValue(retrievedValue2, 3, 3)); + try { + batch.getKeyRow(3); + Assert.fail("Should not be able to get row 3"); + } catch (AssertionError e) { + // Expected exception; do nothing. + } + try { + batch.getValueRow(3); + Assert.fail("Should not be able to get row 3"); + } catch (AssertionError e) { + // Expected exception; do nothing. + } + } finally { + batch.close(); + } + } + + @Test + public void setUpdateAndRetrieve() { + RowBasedKeyValueBatch batch = RowBasedKeyValueBatch.allocate(keySchema, + valueSchema, taskMemoryManager, DEFAULT_CAPACITY); + try { + appendRow(batch, makeKeyRow(1, "A"), makeValueRow(1, 1)); + Assert.assertEquals(1, batch.numRows()); + UnsafeRow retrievedValue = batch.getValueRow(0); + updateValueRow(retrievedValue, 2, 2); + UnsafeRow retrievedValue2 = batch.getValueRow(0); + Assert.assertTrue(checkValue(retrievedValue2, 2, 2)); + } finally { + batch.close(); + } + } + + + @Test + public void iteratorTest() throws Exception { + RowBasedKeyValueBatch batch = RowBasedKeyValueBatch.allocate(keySchema, + valueSchema, taskMemoryManager, DEFAULT_CAPACITY); + try { + appendRow(batch, makeKeyRow(1, "A"), makeValueRow(1, 1)); + appendRow(batch, makeKeyRow(2, "B"), makeValueRow(2, 2)); + appendRow(batch, makeKeyRow(3, "C"), makeValueRow(3, 3)); + Assert.assertEquals(3, batch.numRows()); + org.apache.spark.unsafe.KVIterator<UnsafeRow, UnsafeRow> iterator + = batch.rowIterator(); + Assert.assertTrue(iterator.next()); + UnsafeRow key1 = iterator.getKey(); + UnsafeRow value1 = iterator.getValue(); + Assert.assertTrue(checkKey(key1, 1, "A")); + Assert.assertTrue(checkValue(value1, 1, 1)); + Assert.assertTrue(iterator.next()); + UnsafeRow key2 = iterator.getKey(); + UnsafeRow value2 = iterator.getValue(); + Assert.assertTrue(checkKey(key2, 2, "B")); + Assert.assertTrue(checkValue(value2, 2, 2)); + Assert.assertTrue(iterator.next()); + UnsafeRow key3 = iterator.getKey(); + UnsafeRow value3 = iterator.getValue(); + Assert.assertTrue(checkKey(key3, 3, "C")); + Assert.assertTrue(checkValue(value3, 3, 3)); + Assert.assertFalse(iterator.next()); + } finally { + batch.close(); + } + } + + @Test + public void fixedLengthTest() throws Exception { + RowBasedKeyValueBatch batch = RowBasedKeyValueBatch.allocate(fixedKeySchema, + valueSchema, taskMemoryManager, DEFAULT_CAPACITY); + try { + appendRow(batch, makeKeyRow(11, 11), makeValueRow(1, 1)); + appendRow(batch, makeKeyRow(22, 22), makeValueRow(2, 2)); + appendRow(batch, makeKeyRow(33, 33), makeValueRow(3, 3)); + UnsafeRow retrievedKey1 = batch.getKeyRow(0); + Assert.assertTrue(checkKey(retrievedKey1, 11, 11)); + UnsafeRow retrievedKey2 = batch.getKeyRow(1); + Assert.assertTrue(checkKey(retrievedKey2, 22, 22)); + UnsafeRow retrievedValue1 = batch.getValueRow(1); + Assert.assertTrue(checkValue(retrievedValue1, 2, 2)); + UnsafeRow retrievedValue2 = batch.getValueRow(2); + Assert.assertTrue(checkValue(retrievedValue2, 3, 3)); + Assert.assertEquals(3, batch.numRows()); + org.apache.spark.unsafe.KVIterator<UnsafeRow, UnsafeRow> iterator + = batch.rowIterator(); + Assert.assertTrue(iterator.next()); + UnsafeRow key1 = iterator.getKey(); + UnsafeRow value1 = iterator.getValue(); + Assert.assertTrue(checkKey(key1, 11, 11)); + Assert.assertTrue(checkValue(value1, 1, 1)); + Assert.assertTrue(iterator.next()); + UnsafeRow key2 = iterator.getKey(); + UnsafeRow value2 = iterator.getValue(); + Assert.assertTrue(checkKey(key2, 22, 22)); + Assert.assertTrue(checkValue(value2, 2, 2)); + Assert.assertTrue(iterator.next()); + UnsafeRow key3 = iterator.getKey(); + UnsafeRow value3 = iterator.getValue(); + Assert.assertTrue(checkKey(key3, 33, 33)); + Assert.assertTrue(checkValue(value3, 3, 3)); + Assert.assertFalse(iterator.next()); + } finally { + batch.close(); + } + } + + @Test + public void appendRowUntilExceedingCapacity() throws Exception { + RowBasedKeyValueBatch batch = RowBasedKeyValueBatch.allocate(keySchema, + valueSchema, taskMemoryManager, 10); + try { + UnsafeRow key = makeKeyRow(1, "A"); + UnsafeRow value = makeValueRow(1, 1); + for (int i = 0; i < 10; i++) { + appendRow(batch, key, value); + } + UnsafeRow ret = appendRow(batch, key, value); + Assert.assertEquals(batch.numRows(), 10); + Assert.assertNull(ret); + org.apache.spark.unsafe.KVIterator<UnsafeRow, UnsafeRow> iterator + = batch.rowIterator(); + for (int i = 0; i < 10; i++) { + Assert.assertTrue(iterator.next()); + UnsafeRow key1 = iterator.getKey(); + UnsafeRow value1 = iterator.getValue(); + Assert.assertTrue(checkKey(key1, 1, "A")); + Assert.assertTrue(checkValue(value1, 1, 1)); + } + Assert.assertFalse(iterator.next()); + } finally { + batch.close(); + } + } + + @Test + public void appendRowUntilExceedingPageSize() throws Exception { + RowBasedKeyValueBatch batch = RowBasedKeyValueBatch.allocate(keySchema, + valueSchema, taskMemoryManager, 64 * 1024 * 1024); //enough capacity + try { + UnsafeRow key = makeKeyRow(1, "A"); + UnsafeRow value = makeValueRow(1, 1); + int recordLength = 8 + key.getSizeInBytes() + value.getSizeInBytes() + 8; + int totalSize = 4; + int numRows = 0; + while (totalSize + recordLength < 64 * 1024 * 1024) { // default page size + appendRow(batch, key, value); + totalSize += recordLength; + numRows++; + } + UnsafeRow ret = appendRow(batch, key, value); + Assert.assertEquals(batch.numRows(), numRows); + Assert.assertNull(ret); + org.apache.spark.unsafe.KVIterator<UnsafeRow, UnsafeRow> iterator + = batch.rowIterator(); + for (int i = 0; i < numRows; i++) { + Assert.assertTrue(iterator.next()); + UnsafeRow key1 = iterator.getKey(); + UnsafeRow value1 = iterator.getValue(); + Assert.assertTrue(checkKey(key1, 1, "A")); + Assert.assertTrue(checkValue(value1, 1, 1)); + } + Assert.assertFalse(iterator.next()); + } finally { + batch.close(); + } + } + + @Test + public void failureToAllocateFirstPage() throws Exception { + memoryManager.limit(1024); + RowBasedKeyValueBatch batch = RowBasedKeyValueBatch.allocate(keySchema, + valueSchema, taskMemoryManager, DEFAULT_CAPACITY); + try { + UnsafeRow key = makeKeyRow(1, "A"); + UnsafeRow value = makeValueRow(11, 11); + UnsafeRow ret = appendRow(batch, key, value); + Assert.assertNull(ret); + Assert.assertFalse(batch.rowIterator().next()); + } finally { + batch.close(); + } + } + + @Test + public void randomizedTest() { + RowBasedKeyValueBatch batch = RowBasedKeyValueBatch.allocate(keySchema, + valueSchema, taskMemoryManager, DEFAULT_CAPACITY); + int numEntry = 100; + long[] expectedK1 = new long[numEntry]; + String[] expectedK2 = new String[numEntry]; + long[] expectedV1 = new long[numEntry]; + long[] expectedV2 = new long[numEntry]; + + for (int i = 0; i < numEntry; i++) { + long k1 = rand.nextLong(); + String k2 = getRandomString(rand.nextInt(256)); + long v1 = rand.nextLong(); + long v2 = rand.nextLong(); + appendRow(batch, makeKeyRow(k1, k2), makeValueRow(v1, v2)); + expectedK1[i] = k1; + expectedK2[i] = k2; + expectedV1[i] = v1; + expectedV2[i] = v2; + } + try { + for (int j = 0; j < 10000; j++) { + int rowId = rand.nextInt(numEntry); + if (rand.nextBoolean()) { + UnsafeRow key = batch.getKeyRow(rowId); + Assert.assertTrue(checkKey(key, expectedK1[rowId], expectedK2[rowId])); + } + if (rand.nextBoolean()) { + UnsafeRow value = batch.getValueRow(rowId); + Assert.assertTrue(checkValue(value, expectedV1[rowId], expectedV2[rowId])); + } + } + } finally { + batch.close(); + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala new file mode 100644 index 0000000000..90deb20e97 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala @@ -0,0 +1,176 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.aggregate + +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, DeclarativeAggregate} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.types._ + +/** + * This is a helper class to generate an append-only row-based hash map that can act as a 'cache' + * for extremely fast key-value lookups while evaluating aggregates (and fall back to the + * `BytesToBytesMap` if a given key isn't found). This is 'codegened' in HashAggregate to speed + * up aggregates w/ key. + * + * NOTE: the generated hash map currently doesn't support nullable keys and falls back to the + * `BytesToBytesMap` to store them. + */ +abstract class HashMapGenerator( + ctx: CodegenContext, + aggregateExpressions: Seq[AggregateExpression], + generatedClassName: String, + groupingKeySchema: StructType, + bufferSchema: StructType) { + case class Buffer(dataType: DataType, name: String) + + val groupingKeys = groupingKeySchema.map(k => Buffer(k.dataType, ctx.freshName("key"))) + val bufferValues = bufferSchema.map(k => Buffer(k.dataType, ctx.freshName("value"))) + val groupingKeySignature = + groupingKeys.map(key => s"${ctx.javaType(key.dataType)} ${key.name}").mkString(", ") + val buffVars: Seq[ExprCode] = { + val functions = aggregateExpressions.map(_.aggregateFunction.asInstanceOf[DeclarativeAggregate]) + val initExpr = functions.flatMap(f => f.initialValues) + initExpr.map { e => + val isNull = ctx.freshName("bufIsNull") + val value = ctx.freshName("bufValue") + ctx.addMutableState("boolean", isNull, "") + ctx.addMutableState(ctx.javaType(e.dataType), value, "") + val ev = e.genCode(ctx) + val initVars = + s""" + | $isNull = ${ev.isNull}; + | $value = ${ev.value}; + """.stripMargin + ExprCode(ev.code + initVars, isNull, value) + } + } + + def generate(): String = { + s""" + |public class $generatedClassName { + |${initializeAggregateHashMap()} + | + |${generateFindOrInsert()} + | + |${generateEquals()} + | + |${generateHashFunction()} + | + |${generateRowIterator()} + | + |${generateClose()} + |} + """.stripMargin + } + + protected def initializeAggregateHashMap(): String + + /** + * Generates a method that computes a hash by currently xor-ing all individual group-by keys. For + * instance, if we have 2 long group-by keys, the generated function would be of the form: + * + * {{{ + * private long hash(long agg_key, long agg_key1) { + * return agg_key ^ agg_key1; + * } + * }}} + */ + protected final def generateHashFunction(): String = { + val hash = ctx.freshName("hash") + + def genHashForKeys(groupingKeys: Seq[Buffer]): String = { + groupingKeys.map { key => + val result = ctx.freshName("result") + s""" + |${genComputeHash(ctx, key.name, key.dataType, result)} + |$hash = ($hash ^ (0x9e3779b9)) + $result + ($hash << 6) + ($hash >>> 2); + """.stripMargin + }.mkString("\n") + } + + s""" + |private long hash($groupingKeySignature) { + | long $hash = 0; + | ${genHashForKeys(groupingKeys)} + | return $hash; + |} + """.stripMargin + } + + /** + * Generates a method that returns true if the group-by keys exist at a given index. + */ + protected def generateEquals(): String + + /** + * Generates a method that returns a row which keeps track of the + * aggregate value(s) for a given set of keys. If the corresponding row doesn't exist, the + * generated method adds the corresponding row in the associated key value batch. + */ + protected def generateFindOrInsert(): String + + protected def generateRowIterator(): String + + protected final def generateClose(): String = { + s""" + |public void close() { + | batch.close(); + |} + """.stripMargin + } + + protected final def genComputeHash( + ctx: CodegenContext, + input: String, + dataType: DataType, + result: String): String = { + def hashInt(i: String): String = s"int $result = $i;" + def hashLong(l: String): String = s"long $result = $l;" + def hashBytes(b: String): String = { + val hash = ctx.freshName("hash") + val bytes = ctx.freshName("bytes") + s""" + |int $result = 0; + |byte[] $bytes = $b; + |for (int i = 0; i < $bytes.length; i++) { + | ${genComputeHash(ctx, s"$bytes[i]", ByteType, hash)} + | $result = ($result ^ (0x9e3779b9)) + $hash + ($result << 6) + ($result >>> 2); + |} + """.stripMargin + } + + dataType match { + case BooleanType => hashInt(s"$input ? 1 : 0") + case ByteType | ShortType | IntegerType | DateType => hashInt(input) + case LongType | TimestampType => hashLong(input) + case FloatType => hashInt(s"Float.floatToIntBits($input)") + case DoubleType => hashLong(s"Double.doubleToLongBits($input)") + case d: DecimalType => + if (d.precision <= Decimal.MAX_LONG_DIGITS) { + hashLong(s"$input.toUnscaledLong()") + } else { + val bytes = ctx.freshName("bytes") + s""" + final byte[] $bytes = $input.toJavaBigDecimal().unscaledValue().toByteArray(); + ${hashBytes(bytes)} + """ + } + case StringType => hashBytes(s"$input.getBytes()") + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala new file mode 100644 index 0000000000..1dea33037c --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala @@ -0,0 +1,205 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.aggregate + +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext} +import org.apache.spark.sql.types._ + +/** + * This is a helper class to generate an append-only row-based hash map that can act as a 'cache' + * for extremely fast key-value lookups while evaluating aggregates (and fall back to the + * `BytesToBytesMap` if a given key isn't found). This is 'codegened' in HashAggregate to speed + * up aggregates w/ key. + * + * We also have VectorizedHashMapGenerator, which generates a append-only vectorized hash map. + * We choose one of the two as the 1st level, fast hash map during aggregation. + * + * NOTE: This row-based hash map currently doesn't support nullable keys and falls back to the + * `BytesToBytesMap` to store them. + */ +class RowBasedHashMapGenerator( + ctx: CodegenContext, + aggregateExpressions: Seq[AggregateExpression], + generatedClassName: String, + groupingKeySchema: StructType, + bufferSchema: StructType) + extends HashMapGenerator (ctx, aggregateExpressions, generatedClassName, + groupingKeySchema, bufferSchema) { + + protected def initializeAggregateHashMap(): String = { + val generatedKeySchema: String = + s"new org.apache.spark.sql.types.StructType()" + + groupingKeySchema.map { key => + key.dataType match { + case d: DecimalType => + s""".add("${key.name}", org.apache.spark.sql.types.DataTypes.createDecimalType( + |${d.precision}, ${d.scale}))""".stripMargin + case _ => + s""".add("${key.name}", org.apache.spark.sql.types.DataTypes.${key.dataType})""" + } + }.mkString("\n").concat(";") + + val generatedValueSchema: String = + s"new org.apache.spark.sql.types.StructType()" + + bufferSchema.map { key => + key.dataType match { + case d: DecimalType => + s""".add("${key.name}", org.apache.spark.sql.types.DataTypes.createDecimalType( + |${d.precision}, ${d.scale}))""".stripMargin + case _ => + s""".add("${key.name}", org.apache.spark.sql.types.DataTypes.${key.dataType})""" + } + }.mkString("\n").concat(";") + + s""" + | private org.apache.spark.sql.catalyst.expressions.RowBasedKeyValueBatch batch; + | private int[] buckets; + | private int capacity = 1 << 16; + | private double loadFactor = 0.5; + | private int numBuckets = (int) (capacity / loadFactor); + | private int maxSteps = 2; + | private int numRows = 0; + | private org.apache.spark.sql.types.StructType keySchema = $generatedKeySchema + | private org.apache.spark.sql.types.StructType valueSchema = $generatedValueSchema + | private Object emptyVBase; + | private long emptyVOff; + | private int emptyVLen; + | private boolean isBatchFull = false; + | + | + | public $generatedClassName( + | org.apache.spark.memory.TaskMemoryManager taskMemoryManager, + | InternalRow emptyAggregationBuffer) { + | batch = org.apache.spark.sql.catalyst.expressions.RowBasedKeyValueBatch + | .allocate(keySchema, valueSchema, taskMemoryManager, capacity); + | + | final UnsafeProjection valueProjection = UnsafeProjection.create(valueSchema); + | final byte[] emptyBuffer = valueProjection.apply(emptyAggregationBuffer).getBytes(); + | + | emptyVBase = emptyBuffer; + | emptyVOff = Platform.BYTE_ARRAY_OFFSET; + | emptyVLen = emptyBuffer.length; + | + | buckets = new int[numBuckets]; + | java.util.Arrays.fill(buckets, -1); + | } + """.stripMargin + } + + /** + * Generates a method that returns true if the group-by keys exist at a given index in the + * associated [[org.apache.spark.sql.catalyst.expressions.RowBasedKeyValueBatch]]. + * + */ + protected def generateEquals(): String = { + + def genEqualsForKeys(groupingKeys: Seq[Buffer]): String = { + groupingKeys.zipWithIndex.map { case (key: Buffer, ordinal: Int) => + s"""(${ctx.genEqual(key.dataType, ctx.getValue("row", + key.dataType, ordinal.toString()), key.name)})""" + }.mkString(" && ") + } + + s""" + |private boolean equals(int idx, $groupingKeySignature) { + | UnsafeRow row = batch.getKeyRow(buckets[idx]); + | return ${genEqualsForKeys(groupingKeys)}; + |} + """.stripMargin + } + + /** + * Generates a method that returns a + * [[org.apache.spark.sql.catalyst.expressions.UnsafeRow]] which keeps track of the + * aggregate value(s) for a given set of keys. If the corresponding row doesn't exist, the + * generated method adds the corresponding row in the associated + * [[org.apache.spark.sql.catalyst.expressions.RowBasedKeyValueBatch]]. + * + */ + protected def generateFindOrInsert(): String = { + val numVarLenFields = groupingKeys.map(_.dataType).count { + case dt if UnsafeRow.isFixedLength(dt) => false + // TODO: consider large decimal and interval type + case _ => true + } + + val createUnsafeRowForKey = groupingKeys.zipWithIndex.map { case (key: Buffer, ordinal: Int) => + s"agg_rowWriter.write(${ordinal}, ${key.name})"} + .mkString(";\n") + + s""" + |public org.apache.spark.sql.catalyst.expressions.UnsafeRow findOrInsert(${ + groupingKeySignature}) { + | long h = hash(${groupingKeys.map(_.name).mkString(", ")}); + | int step = 0; + | int idx = (int) h & (numBuckets - 1); + | while (step < maxSteps) { + | // Return bucket index if it's either an empty slot or already contains the key + | if (buckets[idx] == -1) { + | if (numRows < capacity && !isBatchFull) { + | // creating the unsafe for new entry + | UnsafeRow agg_result = new UnsafeRow(${groupingKeySchema.length}); + | org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder agg_holder + | = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(agg_result, + | ${numVarLenFields * 32}); + | org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter agg_rowWriter + | = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter( + | agg_holder, + | ${groupingKeySchema.length}); + | agg_holder.reset(); //TODO: investigate if reset or zeroout are actually needed + | agg_rowWriter.zeroOutNullBytes(); + | ${createUnsafeRowForKey}; + | agg_result.setTotalSize(agg_holder.totalSize()); + | Object kbase = agg_result.getBaseObject(); + | long koff = agg_result.getBaseOffset(); + | int klen = agg_result.getSizeInBytes(); + | + | UnsafeRow vRow + | = batch.appendRow(kbase, koff, klen, emptyVBase, emptyVOff, emptyVLen); + | if (vRow == null) { + | isBatchFull = true; + | } else { + | buckets[idx] = numRows++; + | } + | return vRow; + | } else { + | // No more space + | return null; + | } + | } else if (equals(idx, ${groupingKeys.map(_.name).mkString(", ")})) { + | return batch.getValueRow(buckets[idx]); + | } + | idx = (idx + 1) & (numBuckets - 1); + | step++; + | } + | // Didn't find it + | return null; + |} + """.stripMargin + } + + protected def generateRowIterator(): String = { + s""" + |public org.apache.spark.unsafe.KVIterator<UnsafeRow, UnsafeRow> rowIterator() { + | return batch.rowIterator(); + |} + """.stripMargin + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala index b4a9059299..7418df90b8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala @@ -17,8 +17,8 @@ package org.apache.spark.sql.execution.aggregate -import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, DeclarativeAggregate} -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext} import org.apache.spark.sql.types._ /** @@ -44,49 +44,11 @@ class VectorizedHashMapGenerator( aggregateExpressions: Seq[AggregateExpression], generatedClassName: String, groupingKeySchema: StructType, - bufferSchema: StructType) { - case class Buffer(dataType: DataType, name: String) - val groupingKeys = groupingKeySchema.map(k => Buffer(k.dataType, ctx.freshName("key"))) - val bufferValues = bufferSchema.map(k => Buffer(k.dataType, ctx.freshName("value"))) - val groupingKeySignature = - groupingKeys.map(key => s"${ctx.javaType(key.dataType)} ${key.name}").mkString(", ") - val buffVars: Seq[ExprCode] = { - val functions = aggregateExpressions.map(_.aggregateFunction.asInstanceOf[DeclarativeAggregate]) - val initExpr = functions.flatMap(f => f.initialValues) - initExpr.map { e => - val isNull = ctx.freshName("bufIsNull") - val value = ctx.freshName("bufValue") - ctx.addMutableState("boolean", isNull, "") - ctx.addMutableState(ctx.javaType(e.dataType), value, "") - val ev = e.genCode(ctx) - val initVars = - s""" - | $isNull = ${ev.isNull}; - | $value = ${ev.value}; - """.stripMargin - ExprCode(ev.code + initVars, isNull, value) - } - } - - def generate(): String = { - s""" - |public class $generatedClassName { - |${initializeAggregateHashMap()} - | - |${generateFindOrInsert()} - | - |${generateEquals()} - | - |${generateHashFunction()} - | - |${generateRowIterator()} - | - |${generateClose()} - |} - """.stripMargin - } + bufferSchema: StructType) + extends HashMapGenerator (ctx, aggregateExpressions, generatedClassName, + groupingKeySchema, bufferSchema) { - private def initializeAggregateHashMap(): String = { + protected def initializeAggregateHashMap(): String = { val generatedSchema: String = s"new org.apache.spark.sql.types.StructType()" + (groupingKeySchema ++ bufferSchema).map { key => @@ -140,37 +102,6 @@ class VectorizedHashMapGenerator( """.stripMargin } - /** - * Generates a method that computes a hash by currently xor-ing all individual group-by keys. For - * instance, if we have 2 long group-by keys, the generated function would be of the form: - * - * {{{ - * private long hash(long agg_key, long agg_key1) { - * return agg_key ^ agg_key1; - * } - * }}} - */ - private def generateHashFunction(): String = { - val hash = ctx.freshName("hash") - - def genHashForKeys(groupingKeys: Seq[Buffer]): String = { - groupingKeys.map { key => - val result = ctx.freshName("result") - s""" - |${genComputeHash(ctx, key.name, key.dataType, result)} - |$hash = ($hash ^ (0x9e3779b9)) + $result + ($hash << 6) + ($hash >>> 2); - """.stripMargin - }.mkString("\n") - } - - s""" - |private long hash($groupingKeySignature) { - | long $hash = 0; - | ${genHashForKeys(groupingKeys)} - | return $hash; - |} - """.stripMargin - } /** * Generates a method that returns true if the group-by keys exist at a given index in the @@ -184,7 +115,7 @@ class VectorizedHashMapGenerator( * } * }}} */ - private def generateEquals(): String = { + protected def generateEquals(): String = { def genEqualsForKeys(groupingKeys: Seq[Buffer]): String = { groupingKeys.zipWithIndex.map { case (key: Buffer, ordinal: Int) => @@ -233,7 +164,7 @@ class VectorizedHashMapGenerator( * } * }}} */ - private def generateFindOrInsert(): String = { + protected def generateFindOrInsert(): String = { def genCodeToSetKeys(groupingKeys: Seq[Buffer]): Seq[String] = { groupingKeys.zipWithIndex.map { case (key: Buffer, ordinal: Int) => @@ -287,7 +218,7 @@ class VectorizedHashMapGenerator( """.stripMargin } - private def generateRowIterator(): String = { + protected def generateRowIterator(): String = { s""" |public java.util.Iterator<org.apache.spark.sql.execution.vectorized.ColumnarBatch.Row> | rowIterator() { @@ -295,52 +226,4 @@ class VectorizedHashMapGenerator( |} """.stripMargin } - - private def generateClose(): String = { - s""" - |public void close() { - | batch.close(); - |} - """.stripMargin - } - - private def genComputeHash( - ctx: CodegenContext, - input: String, - dataType: DataType, - result: String): String = { - def hashInt(i: String): String = s"int $result = $i;" - def hashLong(l: String): String = s"long $result = $l;" - def hashBytes(b: String): String = { - val hash = ctx.freshName("hash") - val bytes = ctx.freshName("bytes") - s""" - |int $result = 0; - |byte[] $bytes = $b; - |for (int i = 0; i < $bytes.length; i++) { - | ${genComputeHash(ctx, s"$bytes[i]", ByteType, hash)} - | $result = ($result ^ (0x9e3779b9)) + $hash + ($result << 6) + ($result >>> 2); - |} - """.stripMargin - } - - dataType match { - case BooleanType => hashInt(s"$input ? 1 : 0") - case ByteType | ShortType | IntegerType | DateType => hashInt(input) - case LongType | TimestampType => hashLong(input) - case FloatType => hashInt(s"Float.floatToIntBits($input)") - case DoubleType => hashLong(s"Double.doubleToLongBits($input)") - case d: DecimalType => - if (d.precision <= Decimal.MAX_LONG_DIGITS) { - hashLong(s"$input.toUnscaledLong()") - } else { - val bytes = ctx.freshName("bytes") - s""" - final byte[] $bytes = $input.toJavaBigDecimal().unscaledValue().toByteArray(); - ${hashBytes(bytes)} - """ - } - case StringType => hashBytes(s"$input.getBytes()") - } - } } |