aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/FixedLengthRowBasedKeyValueBatch.java174
-rw-r--r--sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/RowBasedKeyValueBatch.java182
-rw-r--r--sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/VariableLengthRowBasedKeyValueBatch.java185
-rw-r--r--sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/RowBasedKeyValueBatchSuite.java425
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala176
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala205
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala135
7 files changed, 1356 insertions, 126 deletions
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/FixedLengthRowBasedKeyValueBatch.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/FixedLengthRowBasedKeyValueBatch.java
new file mode 100644
index 0000000000..b6130d1f33
--- /dev/null
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/FixedLengthRowBasedKeyValueBatch.java
@@ -0,0 +1,174 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.catalyst.expressions;
+
+import org.apache.spark.memory.TaskMemoryManager;
+import org.apache.spark.sql.types.*;
+import org.apache.spark.unsafe.Platform;
+
+/**
+ * An implementation of `RowBasedKeyValueBatch` in which all key-value records have same length.
+ *
+ * The format for each record looks like this:
+ * [UnsafeRow for key of length klen] [UnsafeRow for Value of length vlen]
+ * [8 bytes pointer to next]
+ * Thus, record length = klen + vlen + 8
+ */
+public final class FixedLengthRowBasedKeyValueBatch extends RowBasedKeyValueBatch {
+ private final int klen;
+ private final int vlen;
+ private final int recordLength;
+
+ private final long getKeyOffsetForFixedLengthRecords(int rowId) {
+ return recordStartOffset + rowId * (long) recordLength;
+ }
+
+ /**
+ * Append a key value pair.
+ * It copies data into the backing MemoryBlock.
+ * Returns an UnsafeRow pointing to the value if succeeds, otherwise returns null.
+ */
+ @Override
+ public final UnsafeRow appendRow(Object kbase, long koff, int klen,
+ Object vbase, long voff, int vlen) {
+ // if run out of max supported rows or page size, return null
+ if (numRows >= capacity || page == null || page.size() - pageCursor < recordLength) {
+ return null;
+ }
+
+ long offset = page.getBaseOffset() + pageCursor;
+ final long recordOffset = offset;
+ Platform.copyMemory(kbase, koff, base, offset, klen);
+ offset += klen;
+ Platform.copyMemory(vbase, voff, base, offset, vlen);
+ offset += vlen;
+ Platform.putLong(base, offset, 0);
+
+ pageCursor += recordLength;
+
+ keyRowId = numRows;
+ keyRow.pointTo(base, recordOffset, klen);
+ valueRow.pointTo(base, recordOffset + klen, vlen + 4);
+ numRows++;
+ return valueRow;
+ }
+
+ /**
+ * Returns the key row in this batch at `rowId`. Returned key row is reused across calls.
+ */
+ @Override
+ public final UnsafeRow getKeyRow(int rowId) {
+ assert(rowId >= 0);
+ assert(rowId < numRows);
+ if (keyRowId != rowId) { // if keyRowId == rowId, desired keyRow is already cached
+ long offset = getKeyOffsetForFixedLengthRecords(rowId);
+ keyRow.pointTo(base, offset, klen);
+ // set keyRowId so we can check if desired row is cached
+ keyRowId = rowId;
+ }
+ return keyRow;
+ }
+
+ /**
+ * Returns the value row by two steps:
+ * 1) looking up the key row with the same id (skipped if the key row is cached)
+ * 2) retrieve the value row by reusing the metadata from step 1)
+ * In most times, 1) is skipped because `getKeyRow(id)` is often called before `getValueRow(id)`.
+ */
+ @Override
+ protected final UnsafeRow getValueFromKey(int rowId) {
+ if (keyRowId != rowId) {
+ getKeyRow(rowId);
+ }
+ assert(rowId >= 0);
+ valueRow.pointTo(base, keyRow.getBaseOffset() + klen, vlen + 4);
+ return valueRow;
+ }
+
+ /**
+ * Returns an iterator to go through all rows
+ */
+ @Override
+ public final org.apache.spark.unsafe.KVIterator<UnsafeRow, UnsafeRow> rowIterator() {
+ return new org.apache.spark.unsafe.KVIterator<UnsafeRow, UnsafeRow>() {
+ private final UnsafeRow key = new UnsafeRow(keySchema.length());
+ private final UnsafeRow value = new UnsafeRow(valueSchema.length());
+
+ private long offsetInPage = 0;
+ private int recordsInPage = 0;
+
+ private boolean initialized = false;
+
+ private void init() {
+ if (page != null) {
+ offsetInPage = page.getBaseOffset();
+ recordsInPage = numRows;
+ }
+ initialized = true;
+ }
+
+ @Override
+ public boolean next() {
+ if (!initialized) init();
+ //searching for the next non empty page is records is now zero
+ if (recordsInPage == 0) {
+ freeCurrentPage();
+ return false;
+ }
+
+ key.pointTo(base, offsetInPage, klen);
+ value.pointTo(base, offsetInPage + klen, vlen + 4);
+
+ offsetInPage += recordLength;
+ recordsInPage -= 1;
+ return true;
+ }
+
+ @Override
+ public UnsafeRow getKey() {
+ return key;
+ }
+
+ @Override
+ public UnsafeRow getValue() {
+ return value;
+ }
+
+ @Override
+ public void close() {
+ // do nothing
+ }
+
+ private void freeCurrentPage() {
+ if (page != null) {
+ freePage(page);
+ page = null;
+ }
+ }
+ };
+ }
+
+ protected FixedLengthRowBasedKeyValueBatch(StructType keySchema, StructType valueSchema,
+ int maxRows, TaskMemoryManager manager) {
+ super(keySchema, valueSchema, maxRows, manager);
+ klen = keySchema.defaultSize()
+ + UnsafeRow.calculateBitSetWidthInBytes(keySchema.length());
+ vlen = valueSchema.defaultSize()
+ + UnsafeRow.calculateBitSetWidthInBytes(valueSchema.length());
+ recordLength = klen + vlen + 8;
+ }
+}
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/RowBasedKeyValueBatch.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/RowBasedKeyValueBatch.java
new file mode 100644
index 0000000000..cea9d5d5bc
--- /dev/null
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/RowBasedKeyValueBatch.java
@@ -0,0 +1,182 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.catalyst.expressions;
+
+import java.io.IOException;
+
+import org.apache.spark.memory.MemoryConsumer;
+import org.apache.spark.memory.TaskMemoryManager;
+import org.apache.spark.sql.types.*;
+import org.apache.spark.unsafe.memory.MemoryBlock;
+
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+
+/**
+ * RowBasedKeyValueBatch stores key value pairs in contiguous memory region.
+ *
+ * Each key or value is stored as a single UnsafeRow. Each record contains one key and one value
+ * and some auxiliary data, which differs based on implementation:
+ * i.e., `FixedLengthRowBasedKeyValueBatch` and `VariableLengthRowBasedKeyValueBatch`.
+ *
+ * We use `FixedLengthRowBasedKeyValueBatch` if all fields in the key and the value are fixed-length
+ * data types. Otherwise we use `VariableLengthRowBasedKeyValueBatch`.
+ *
+ * RowBasedKeyValueBatch is backed by a single page / MemoryBlock (defaults to 64MB). If the page
+ * is full, the aggregate logic should fallback to a second level, larger hash map. We intentionally
+ * use the single-page design because it simplifies memory address encoding & decoding for each
+ * key-value pair. Because the maximum capacity for RowBasedKeyValueBatch is only 2^16, it is
+ * unlikely we need a second page anyway. Filling the page requires an average size for key value
+ * pairs to be larger than 1024 bytes.
+ *
+ */
+public abstract class RowBasedKeyValueBatch extends MemoryConsumer {
+ protected final Logger logger = LoggerFactory.getLogger(RowBasedKeyValueBatch.class);
+
+ private static final int DEFAULT_CAPACITY = 1 << 16;
+ private static final long DEFAULT_PAGE_SIZE = 64 * 1024 * 1024;
+
+ protected final StructType keySchema;
+ protected final StructType valueSchema;
+ protected final int capacity;
+ protected int numRows = 0;
+
+ // ids for current key row and value row being retrieved
+ protected int keyRowId = -1;
+
+ // placeholder for key and value corresponding to keyRowId.
+ protected final UnsafeRow keyRow;
+ protected final UnsafeRow valueRow;
+
+ protected MemoryBlock page = null;
+ protected Object base = null;
+ protected final long recordStartOffset;
+ protected long pageCursor = 0;
+
+ public static RowBasedKeyValueBatch allocate(StructType keySchema, StructType valueSchema,
+ TaskMemoryManager manager) {
+ return allocate(keySchema, valueSchema, manager, DEFAULT_CAPACITY);
+ }
+
+ public static RowBasedKeyValueBatch allocate(StructType keySchema, StructType valueSchema,
+ TaskMemoryManager manager, int maxRows) {
+ boolean allFixedLength = true;
+ // checking if there is any variable length fields
+ // there is probably a more succinct impl of this
+ for (String name : keySchema.fieldNames()) {
+ allFixedLength = allFixedLength
+ && UnsafeRow.isFixedLength(keySchema.apply(name).dataType());
+ }
+ for (String name : valueSchema.fieldNames()) {
+ allFixedLength = allFixedLength
+ && UnsafeRow.isFixedLength(valueSchema.apply(name).dataType());
+ }
+
+ if (allFixedLength) {
+ return new FixedLengthRowBasedKeyValueBatch(keySchema, valueSchema, maxRows, manager);
+ } else {
+ return new VariableLengthRowBasedKeyValueBatch(keySchema, valueSchema, maxRows, manager);
+ }
+ }
+
+ protected RowBasedKeyValueBatch(StructType keySchema, StructType valueSchema, int maxRows,
+ TaskMemoryManager manager) {
+ super(manager, manager.pageSizeBytes(), manager.getTungstenMemoryMode());
+
+ this.keySchema = keySchema;
+ this.valueSchema = valueSchema;
+ this.capacity = maxRows;
+
+ this.keyRow = new UnsafeRow(keySchema.length());
+ this.valueRow = new UnsafeRow(valueSchema.length());
+
+ if (!acquirePage(DEFAULT_PAGE_SIZE)) {
+ page = null;
+ recordStartOffset = 0;
+ } else {
+ base = page.getBaseObject();
+ recordStartOffset = page.getBaseOffset();
+ }
+ }
+
+ public final int numRows() { return numRows; }
+
+ public final void close() {
+ if (page != null) {
+ freePage(page);
+ page = null;
+ }
+ }
+
+ private final boolean acquirePage(long requiredSize) {
+ try {
+ page = allocatePage(requiredSize);
+ } catch (OutOfMemoryError e) {
+ logger.warn("Failed to allocate page ({} bytes).", requiredSize);
+ return false;
+ }
+ base = page.getBaseObject();
+ pageCursor = 0;
+ return true;
+ }
+
+ /**
+ * Append a key value pair.
+ * It copies data into the backing MemoryBlock.
+ * Returns an UnsafeRow pointing to the value if succeeds, otherwise returns null.
+ */
+ public abstract UnsafeRow appendRow(Object kbase, long koff, int klen,
+ Object vbase, long voff, int vlen);
+
+ /**
+ * Returns the key row in this batch at `rowId`. Returned key row is reused across calls.
+ */
+ public abstract UnsafeRow getKeyRow(int rowId);
+
+ /**
+ * Returns the value row in this batch at `rowId`. Returned value row is reused across calls.
+ * Because `getValueRow(id)` is always called after `getKeyRow(id)` with the same id, we use
+ * `getValueFromKey(id) to retrieve value row, which reuses metadata from the cached key.
+ */
+ public final UnsafeRow getValueRow(int rowId) {
+ return getValueFromKey(rowId);
+ }
+
+ /**
+ * Returns the value row by two steps:
+ * 1) looking up the key row with the same id (skipped if the key row is cached)
+ * 2) retrieve the value row by reusing the metadata from step 1)
+ * In most times, 1) is skipped because `getKeyRow(id)` is often called before `getValueRow(id)`.
+ */
+ protected abstract UnsafeRow getValueFromKey(int rowId);
+
+ /**
+ * Sometimes the TaskMemoryManager may call spill() on its associated MemoryConsumers to make
+ * space for new consumers. For RowBasedKeyValueBatch, we do not actually spill and return 0.
+ * We should not throw OutOfMemory exception here because other associated consumers might spill
+ */
+ public final long spill(long size, MemoryConsumer trigger) throws IOException {
+ logger.warn("Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.");
+ return 0;
+ }
+
+ /**
+ * Returns an iterator to go through all rows
+ */
+ public abstract org.apache.spark.unsafe.KVIterator<UnsafeRow, UnsafeRow> rowIterator();
+}
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/VariableLengthRowBasedKeyValueBatch.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/VariableLengthRowBasedKeyValueBatch.java
new file mode 100644
index 0000000000..f4002ee0d5
--- /dev/null
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/VariableLengthRowBasedKeyValueBatch.java
@@ -0,0 +1,185 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.catalyst.expressions;
+
+import org.apache.spark.memory.TaskMemoryManager;
+import org.apache.spark.sql.types.*;
+import org.apache.spark.unsafe.Platform;
+
+/**
+ * An implementation of `RowBasedKeyValueBatch` in which key-value records have variable lengths.
+ *
+ * The format for each record looks like this:
+ * [4 bytes total size = (klen + vlen + 4)] [4 bytes key size = klen]
+ * [UnsafeRow for key of length klen] [UnsafeRow for Value of length vlen]
+ * [8 bytes pointer to next]
+ * Thus, record length = 4 + 4 + klen + vlen + 8
+ */
+public final class VariableLengthRowBasedKeyValueBatch extends RowBasedKeyValueBatch {
+ // full addresses for key rows and value rows
+ private final long[] keyOffsets;
+
+ /**
+ * Append a key value pair.
+ * It copies data into the backing MemoryBlock.
+ * Returns an UnsafeRow pointing to the value if succeeds, otherwise returns null.
+ */
+ @Override
+ public final UnsafeRow appendRow(Object kbase, long koff, int klen,
+ Object vbase, long voff, int vlen) {
+ final long recordLength = 8 + klen + vlen + 8;
+ // if run out of max supported rows or page size, return null
+ if (numRows >= capacity || page == null || page.size() - pageCursor < recordLength) {
+ return null;
+ }
+
+ long offset = page.getBaseOffset() + pageCursor;
+ final long recordOffset = offset;
+ Platform.putInt(base, offset, klen + vlen + 4);
+ Platform.putInt(base, offset + 4, klen);
+
+ offset += 8;
+ Platform.copyMemory(kbase, koff, base, offset, klen);
+ offset += klen;
+ Platform.copyMemory(vbase, voff, base, offset, vlen);
+ offset += vlen;
+ Platform.putLong(base, offset, 0);
+
+ pageCursor += recordLength;
+
+ keyOffsets[numRows] = recordOffset + 8;
+
+ keyRowId = numRows;
+ keyRow.pointTo(base, recordOffset + 8, klen);
+ valueRow.pointTo(base, recordOffset + 8 + klen, vlen + 4);
+ numRows++;
+ return valueRow;
+ }
+
+ /**
+ * Returns the key row in this batch at `rowId`. Returned key row is reused across calls.
+ */
+ @Override
+ public UnsafeRow getKeyRow(int rowId) {
+ assert(rowId >= 0);
+ assert(rowId < numRows);
+ if (keyRowId != rowId) { // if keyRowId == rowId, desired keyRow is already cached
+ long offset = keyOffsets[rowId];
+ int klen = Platform.getInt(base, offset - 4);
+ keyRow.pointTo(base, offset, klen);
+ // set keyRowId so we can check if desired row is cached
+ keyRowId = rowId;
+ }
+ return keyRow;
+ }
+
+ /**
+ * Returns the value row by two steps:
+ * 1) looking up the key row with the same id (skipped if the key row is cached)
+ * 2) retrieve the value row by reusing the metadata from step 1)
+ * In most times, 1) is skipped because `getKeyRow(id)` is often called before `getValueRow(id)`.
+ */
+ @Override
+ public final UnsafeRow getValueFromKey(int rowId) {
+ if (keyRowId != rowId) {
+ getKeyRow(rowId);
+ }
+ assert(rowId >= 0);
+ long offset = keyRow.getBaseOffset();
+ int klen = keyRow.getSizeInBytes();
+ int vlen = Platform.getInt(base, offset - 8) - klen - 4;
+ valueRow.pointTo(base, offset + klen, vlen + 4);
+ return valueRow;
+ }
+
+ /**
+ * Returns an iterator to go through all rows
+ */
+ @Override
+ public final org.apache.spark.unsafe.KVIterator<UnsafeRow, UnsafeRow> rowIterator() {
+ return new org.apache.spark.unsafe.KVIterator<UnsafeRow, UnsafeRow>() {
+ private final UnsafeRow key = new UnsafeRow(keySchema.length());
+ private final UnsafeRow value = new UnsafeRow(valueSchema.length());
+
+ private long offsetInPage = 0;
+ private int recordsInPage = 0;
+
+ private int currentklen;
+ private int currentvlen;
+ private int totalLength;
+
+ private boolean initialized = false;
+
+ private void init() {
+ if (page != null) {
+ offsetInPage = page.getBaseOffset();
+ recordsInPage = numRows;
+ }
+ initialized = true;
+ }
+
+ @Override
+ public boolean next() {
+ if (!initialized) init();
+ //searching for the next non empty page is records is now zero
+ if (recordsInPage == 0) {
+ freeCurrentPage();
+ return false;
+ }
+
+ totalLength = Platform.getInt(base, offsetInPage) - 4;
+ currentklen = Platform.getInt(base, offsetInPage + 4);
+ currentvlen = totalLength - currentklen;
+
+ key.pointTo(base, offsetInPage + 8, currentklen);
+ value.pointTo(base, offsetInPage + 8 + currentklen, currentvlen + 4);
+
+ offsetInPage += 8 + totalLength + 8;
+ recordsInPage -= 1;
+ return true;
+ }
+
+ @Override
+ public UnsafeRow getKey() {
+ return key;
+ }
+
+ @Override
+ public UnsafeRow getValue() {
+ return value;
+ }
+
+ @Override
+ public void close() {
+ // do nothing
+ }
+
+ private void freeCurrentPage() {
+ if (page != null) {
+ freePage(page);
+ page = null;
+ }
+ }
+ };
+ }
+
+ protected VariableLengthRowBasedKeyValueBatch(StructType keySchema, StructType valueSchema,
+ int maxRows, TaskMemoryManager manager) {
+ super(keySchema, valueSchema, maxRows, manager);
+ this.keyOffsets = new long[maxRows];
+ }
+}
diff --git a/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/RowBasedKeyValueBatchSuite.java b/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/RowBasedKeyValueBatchSuite.java
new file mode 100644
index 0000000000..0dd129cea7
--- /dev/null
+++ b/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/RowBasedKeyValueBatchSuite.java
@@ -0,0 +1,425 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions;
+
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Test;
+
+import org.apache.spark.SparkConf;
+import org.apache.spark.memory.TaskMemoryManager;
+import org.apache.spark.memory.TestMemoryManager;
+import org.apache.spark.sql.types.StructType;
+import org.apache.spark.sql.types.DataTypes;
+import org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder;
+import org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter;
+import org.apache.spark.unsafe.types.UTF8String;
+
+import java.util.Random;
+
+public class RowBasedKeyValueBatchSuite {
+
+ private final Random rand = new Random(42);
+
+ private TestMemoryManager memoryManager;
+ private TaskMemoryManager taskMemoryManager;
+ private StructType keySchema = new StructType().add("k1", DataTypes.LongType)
+ .add("k2", DataTypes.StringType);
+ private StructType fixedKeySchema = new StructType().add("k1", DataTypes.LongType)
+ .add("k2", DataTypes.LongType);
+ private StructType valueSchema = new StructType().add("count", DataTypes.LongType)
+ .add("sum", DataTypes.LongType);
+ private int DEFAULT_CAPACITY = 1 << 16;
+
+ private String getRandomString(int length) {
+ Assert.assertTrue(length >= 0);
+ final byte[] bytes = new byte[length];
+ rand.nextBytes(bytes);
+ return new String(bytes);
+ }
+
+ private UnsafeRow makeKeyRow(long k1, String k2) {
+ UnsafeRow row = new UnsafeRow(2);
+ BufferHolder holder = new BufferHolder(row, 32);
+ UnsafeRowWriter writer = new UnsafeRowWriter(holder, 2);
+ holder.reset();
+ writer.write(0, k1);
+ writer.write(1, UTF8String.fromString(k2));
+ row.setTotalSize(holder.totalSize());
+ return row;
+ }
+
+ private UnsafeRow makeKeyRow(long k1, long k2) {
+ UnsafeRow row = new UnsafeRow(2);
+ BufferHolder holder = new BufferHolder(row, 0);
+ UnsafeRowWriter writer = new UnsafeRowWriter(holder, 2);
+ holder.reset();
+ writer.write(0, k1);
+ writer.write(1, k2);
+ row.setTotalSize(holder.totalSize());
+ return row;
+ }
+
+ private UnsafeRow makeValueRow(long v1, long v2) {
+ UnsafeRow row = new UnsafeRow(2);
+ BufferHolder holder = new BufferHolder(row, 0);
+ UnsafeRowWriter writer = new UnsafeRowWriter(holder, 2);
+ holder.reset();
+ writer.write(0, v1);
+ writer.write(1, v2);
+ row.setTotalSize(holder.totalSize());
+ return row;
+ }
+
+ private UnsafeRow appendRow(RowBasedKeyValueBatch batch, UnsafeRow key, UnsafeRow value) {
+ return batch.appendRow(key.getBaseObject(), key.getBaseOffset(), key.getSizeInBytes(),
+ value.getBaseObject(), value.getBaseOffset(), value.getSizeInBytes());
+ }
+
+ private void updateValueRow(UnsafeRow row, long v1, long v2) {
+ row.setLong(0, v1);
+ row.setLong(1, v2);
+ }
+
+ private boolean checkKey(UnsafeRow row, long k1, String k2) {
+ return (row.getLong(0) == k1)
+ && (row.getUTF8String(1).equals(UTF8String.fromString(k2)));
+ }
+
+ private boolean checkKey(UnsafeRow row, long k1, long k2) {
+ return (row.getLong(0) == k1)
+ && (row.getLong(1) == k2);
+ }
+
+ private boolean checkValue(UnsafeRow row, long v1, long v2) {
+ return (row.getLong(0) == v1) && (row.getLong(1) == v2);
+ }
+
+ @Before
+ public void setup() {
+ memoryManager = new TestMemoryManager(new SparkConf()
+ .set("spark.memory.offHeap.enabled", "false")
+ .set("spark.shuffle.spill.compress", "false")
+ .set("spark.shuffle.compress", "false"));
+ taskMemoryManager = new TaskMemoryManager(memoryManager, 0);
+ }
+
+ @After
+ public void tearDown() {
+ if (taskMemoryManager != null) {
+ Assert.assertEquals(0L, taskMemoryManager.cleanUpAllAllocatedMemory());
+ long leakedMemory = taskMemoryManager.getMemoryConsumptionForThisTask();
+ taskMemoryManager = null;
+ Assert.assertEquals(0L, leakedMemory);
+ }
+ }
+
+
+ @Test
+ public void emptyBatch() throws Exception {
+ RowBasedKeyValueBatch batch = RowBasedKeyValueBatch.allocate(keySchema,
+ valueSchema, taskMemoryManager, DEFAULT_CAPACITY);
+ try {
+ Assert.assertEquals(0, batch.numRows());
+ try {
+ batch.getKeyRow(-1);
+ Assert.fail("Should not be able to get row -1");
+ } catch (AssertionError e) {
+ // Expected exception; do nothing.
+ }
+ try {
+ batch.getValueRow(-1);
+ Assert.fail("Should not be able to get row -1");
+ } catch (AssertionError e) {
+ // Expected exception; do nothing.
+ }
+ try {
+ batch.getKeyRow(0);
+ Assert.fail("Should not be able to get row 0 when batch is empty");
+ } catch (AssertionError e) {
+ // Expected exception; do nothing.
+ }
+ try {
+ batch.getValueRow(0);
+ Assert.fail("Should not be able to get row 0 when batch is empty");
+ } catch (AssertionError e) {
+ // Expected exception; do nothing.
+ }
+ Assert.assertFalse(batch.rowIterator().next());
+ } finally {
+ batch.close();
+ }
+ }
+
+ @Test
+ public void batchType() throws Exception {
+ RowBasedKeyValueBatch batch1 = RowBasedKeyValueBatch.allocate(keySchema,
+ valueSchema, taskMemoryManager, DEFAULT_CAPACITY);
+ RowBasedKeyValueBatch batch2 = RowBasedKeyValueBatch.allocate(fixedKeySchema,
+ valueSchema, taskMemoryManager, DEFAULT_CAPACITY);
+ try {
+ Assert.assertEquals(batch1.getClass(), VariableLengthRowBasedKeyValueBatch.class);
+ Assert.assertEquals(batch2.getClass(), FixedLengthRowBasedKeyValueBatch.class);
+ } finally {
+ batch1.close();
+ batch2.close();
+ }
+ }
+
+ @Test
+ public void setAndRetrieve() {
+ RowBasedKeyValueBatch batch = RowBasedKeyValueBatch.allocate(keySchema,
+ valueSchema, taskMemoryManager, DEFAULT_CAPACITY);
+ try {
+ UnsafeRow ret1 = appendRow(batch, makeKeyRow(1, "A"), makeValueRow(1, 1));
+ Assert.assertTrue(checkValue(ret1, 1, 1));
+ UnsafeRow ret2 = appendRow(batch, makeKeyRow(2, "B"), makeValueRow(2, 2));
+ Assert.assertTrue(checkValue(ret2, 2, 2));
+ UnsafeRow ret3 = appendRow(batch, makeKeyRow(3, "C"), makeValueRow(3, 3));
+ Assert.assertTrue(checkValue(ret3, 3, 3));
+ Assert.assertEquals(3, batch.numRows());
+ UnsafeRow retrievedKey1 = batch.getKeyRow(0);
+ Assert.assertTrue(checkKey(retrievedKey1, 1, "A"));
+ UnsafeRow retrievedKey2 = batch.getKeyRow(1);
+ Assert.assertTrue(checkKey(retrievedKey2, 2, "B"));
+ UnsafeRow retrievedValue1 = batch.getValueRow(1);
+ Assert.assertTrue(checkValue(retrievedValue1, 2, 2));
+ UnsafeRow retrievedValue2 = batch.getValueRow(2);
+ Assert.assertTrue(checkValue(retrievedValue2, 3, 3));
+ try {
+ batch.getKeyRow(3);
+ Assert.fail("Should not be able to get row 3");
+ } catch (AssertionError e) {
+ // Expected exception; do nothing.
+ }
+ try {
+ batch.getValueRow(3);
+ Assert.fail("Should not be able to get row 3");
+ } catch (AssertionError e) {
+ // Expected exception; do nothing.
+ }
+ } finally {
+ batch.close();
+ }
+ }
+
+ @Test
+ public void setUpdateAndRetrieve() {
+ RowBasedKeyValueBatch batch = RowBasedKeyValueBatch.allocate(keySchema,
+ valueSchema, taskMemoryManager, DEFAULT_CAPACITY);
+ try {
+ appendRow(batch, makeKeyRow(1, "A"), makeValueRow(1, 1));
+ Assert.assertEquals(1, batch.numRows());
+ UnsafeRow retrievedValue = batch.getValueRow(0);
+ updateValueRow(retrievedValue, 2, 2);
+ UnsafeRow retrievedValue2 = batch.getValueRow(0);
+ Assert.assertTrue(checkValue(retrievedValue2, 2, 2));
+ } finally {
+ batch.close();
+ }
+ }
+
+
+ @Test
+ public void iteratorTest() throws Exception {
+ RowBasedKeyValueBatch batch = RowBasedKeyValueBatch.allocate(keySchema,
+ valueSchema, taskMemoryManager, DEFAULT_CAPACITY);
+ try {
+ appendRow(batch, makeKeyRow(1, "A"), makeValueRow(1, 1));
+ appendRow(batch, makeKeyRow(2, "B"), makeValueRow(2, 2));
+ appendRow(batch, makeKeyRow(3, "C"), makeValueRow(3, 3));
+ Assert.assertEquals(3, batch.numRows());
+ org.apache.spark.unsafe.KVIterator<UnsafeRow, UnsafeRow> iterator
+ = batch.rowIterator();
+ Assert.assertTrue(iterator.next());
+ UnsafeRow key1 = iterator.getKey();
+ UnsafeRow value1 = iterator.getValue();
+ Assert.assertTrue(checkKey(key1, 1, "A"));
+ Assert.assertTrue(checkValue(value1, 1, 1));
+ Assert.assertTrue(iterator.next());
+ UnsafeRow key2 = iterator.getKey();
+ UnsafeRow value2 = iterator.getValue();
+ Assert.assertTrue(checkKey(key2, 2, "B"));
+ Assert.assertTrue(checkValue(value2, 2, 2));
+ Assert.assertTrue(iterator.next());
+ UnsafeRow key3 = iterator.getKey();
+ UnsafeRow value3 = iterator.getValue();
+ Assert.assertTrue(checkKey(key3, 3, "C"));
+ Assert.assertTrue(checkValue(value3, 3, 3));
+ Assert.assertFalse(iterator.next());
+ } finally {
+ batch.close();
+ }
+ }
+
+ @Test
+ public void fixedLengthTest() throws Exception {
+ RowBasedKeyValueBatch batch = RowBasedKeyValueBatch.allocate(fixedKeySchema,
+ valueSchema, taskMemoryManager, DEFAULT_CAPACITY);
+ try {
+ appendRow(batch, makeKeyRow(11, 11), makeValueRow(1, 1));
+ appendRow(batch, makeKeyRow(22, 22), makeValueRow(2, 2));
+ appendRow(batch, makeKeyRow(33, 33), makeValueRow(3, 3));
+ UnsafeRow retrievedKey1 = batch.getKeyRow(0);
+ Assert.assertTrue(checkKey(retrievedKey1, 11, 11));
+ UnsafeRow retrievedKey2 = batch.getKeyRow(1);
+ Assert.assertTrue(checkKey(retrievedKey2, 22, 22));
+ UnsafeRow retrievedValue1 = batch.getValueRow(1);
+ Assert.assertTrue(checkValue(retrievedValue1, 2, 2));
+ UnsafeRow retrievedValue2 = batch.getValueRow(2);
+ Assert.assertTrue(checkValue(retrievedValue2, 3, 3));
+ Assert.assertEquals(3, batch.numRows());
+ org.apache.spark.unsafe.KVIterator<UnsafeRow, UnsafeRow> iterator
+ = batch.rowIterator();
+ Assert.assertTrue(iterator.next());
+ UnsafeRow key1 = iterator.getKey();
+ UnsafeRow value1 = iterator.getValue();
+ Assert.assertTrue(checkKey(key1, 11, 11));
+ Assert.assertTrue(checkValue(value1, 1, 1));
+ Assert.assertTrue(iterator.next());
+ UnsafeRow key2 = iterator.getKey();
+ UnsafeRow value2 = iterator.getValue();
+ Assert.assertTrue(checkKey(key2, 22, 22));
+ Assert.assertTrue(checkValue(value2, 2, 2));
+ Assert.assertTrue(iterator.next());
+ UnsafeRow key3 = iterator.getKey();
+ UnsafeRow value3 = iterator.getValue();
+ Assert.assertTrue(checkKey(key3, 33, 33));
+ Assert.assertTrue(checkValue(value3, 3, 3));
+ Assert.assertFalse(iterator.next());
+ } finally {
+ batch.close();
+ }
+ }
+
+ @Test
+ public void appendRowUntilExceedingCapacity() throws Exception {
+ RowBasedKeyValueBatch batch = RowBasedKeyValueBatch.allocate(keySchema,
+ valueSchema, taskMemoryManager, 10);
+ try {
+ UnsafeRow key = makeKeyRow(1, "A");
+ UnsafeRow value = makeValueRow(1, 1);
+ for (int i = 0; i < 10; i++) {
+ appendRow(batch, key, value);
+ }
+ UnsafeRow ret = appendRow(batch, key, value);
+ Assert.assertEquals(batch.numRows(), 10);
+ Assert.assertNull(ret);
+ org.apache.spark.unsafe.KVIterator<UnsafeRow, UnsafeRow> iterator
+ = batch.rowIterator();
+ for (int i = 0; i < 10; i++) {
+ Assert.assertTrue(iterator.next());
+ UnsafeRow key1 = iterator.getKey();
+ UnsafeRow value1 = iterator.getValue();
+ Assert.assertTrue(checkKey(key1, 1, "A"));
+ Assert.assertTrue(checkValue(value1, 1, 1));
+ }
+ Assert.assertFalse(iterator.next());
+ } finally {
+ batch.close();
+ }
+ }
+
+ @Test
+ public void appendRowUntilExceedingPageSize() throws Exception {
+ RowBasedKeyValueBatch batch = RowBasedKeyValueBatch.allocate(keySchema,
+ valueSchema, taskMemoryManager, 64 * 1024 * 1024); //enough capacity
+ try {
+ UnsafeRow key = makeKeyRow(1, "A");
+ UnsafeRow value = makeValueRow(1, 1);
+ int recordLength = 8 + key.getSizeInBytes() + value.getSizeInBytes() + 8;
+ int totalSize = 4;
+ int numRows = 0;
+ while (totalSize + recordLength < 64 * 1024 * 1024) { // default page size
+ appendRow(batch, key, value);
+ totalSize += recordLength;
+ numRows++;
+ }
+ UnsafeRow ret = appendRow(batch, key, value);
+ Assert.assertEquals(batch.numRows(), numRows);
+ Assert.assertNull(ret);
+ org.apache.spark.unsafe.KVIterator<UnsafeRow, UnsafeRow> iterator
+ = batch.rowIterator();
+ for (int i = 0; i < numRows; i++) {
+ Assert.assertTrue(iterator.next());
+ UnsafeRow key1 = iterator.getKey();
+ UnsafeRow value1 = iterator.getValue();
+ Assert.assertTrue(checkKey(key1, 1, "A"));
+ Assert.assertTrue(checkValue(value1, 1, 1));
+ }
+ Assert.assertFalse(iterator.next());
+ } finally {
+ batch.close();
+ }
+ }
+
+ @Test
+ public void failureToAllocateFirstPage() throws Exception {
+ memoryManager.limit(1024);
+ RowBasedKeyValueBatch batch = RowBasedKeyValueBatch.allocate(keySchema,
+ valueSchema, taskMemoryManager, DEFAULT_CAPACITY);
+ try {
+ UnsafeRow key = makeKeyRow(1, "A");
+ UnsafeRow value = makeValueRow(11, 11);
+ UnsafeRow ret = appendRow(batch, key, value);
+ Assert.assertNull(ret);
+ Assert.assertFalse(batch.rowIterator().next());
+ } finally {
+ batch.close();
+ }
+ }
+
+ @Test
+ public void randomizedTest() {
+ RowBasedKeyValueBatch batch = RowBasedKeyValueBatch.allocate(keySchema,
+ valueSchema, taskMemoryManager, DEFAULT_CAPACITY);
+ int numEntry = 100;
+ long[] expectedK1 = new long[numEntry];
+ String[] expectedK2 = new String[numEntry];
+ long[] expectedV1 = new long[numEntry];
+ long[] expectedV2 = new long[numEntry];
+
+ for (int i = 0; i < numEntry; i++) {
+ long k1 = rand.nextLong();
+ String k2 = getRandomString(rand.nextInt(256));
+ long v1 = rand.nextLong();
+ long v2 = rand.nextLong();
+ appendRow(batch, makeKeyRow(k1, k2), makeValueRow(v1, v2));
+ expectedK1[i] = k1;
+ expectedK2[i] = k2;
+ expectedV1[i] = v1;
+ expectedV2[i] = v2;
+ }
+ try {
+ for (int j = 0; j < 10000; j++) {
+ int rowId = rand.nextInt(numEntry);
+ if (rand.nextBoolean()) {
+ UnsafeRow key = batch.getKeyRow(rowId);
+ Assert.assertTrue(checkKey(key, expectedK1[rowId], expectedK2[rowId]));
+ }
+ if (rand.nextBoolean()) {
+ UnsafeRow value = batch.getValueRow(rowId);
+ Assert.assertTrue(checkValue(value, expectedV1[rowId], expectedV2[rowId]));
+ }
+ }
+ } finally {
+ batch.close();
+ }
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala
new file mode 100644
index 0000000000..90deb20e97
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala
@@ -0,0 +1,176 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.aggregate
+
+import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, DeclarativeAggregate}
+import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
+import org.apache.spark.sql.types._
+
+/**
+ * This is a helper class to generate an append-only row-based hash map that can act as a 'cache'
+ * for extremely fast key-value lookups while evaluating aggregates (and fall back to the
+ * `BytesToBytesMap` if a given key isn't found). This is 'codegened' in HashAggregate to speed
+ * up aggregates w/ key.
+ *
+ * NOTE: the generated hash map currently doesn't support nullable keys and falls back to the
+ * `BytesToBytesMap` to store them.
+ */
+abstract class HashMapGenerator(
+ ctx: CodegenContext,
+ aggregateExpressions: Seq[AggregateExpression],
+ generatedClassName: String,
+ groupingKeySchema: StructType,
+ bufferSchema: StructType) {
+ case class Buffer(dataType: DataType, name: String)
+
+ val groupingKeys = groupingKeySchema.map(k => Buffer(k.dataType, ctx.freshName("key")))
+ val bufferValues = bufferSchema.map(k => Buffer(k.dataType, ctx.freshName("value")))
+ val groupingKeySignature =
+ groupingKeys.map(key => s"${ctx.javaType(key.dataType)} ${key.name}").mkString(", ")
+ val buffVars: Seq[ExprCode] = {
+ val functions = aggregateExpressions.map(_.aggregateFunction.asInstanceOf[DeclarativeAggregate])
+ val initExpr = functions.flatMap(f => f.initialValues)
+ initExpr.map { e =>
+ val isNull = ctx.freshName("bufIsNull")
+ val value = ctx.freshName("bufValue")
+ ctx.addMutableState("boolean", isNull, "")
+ ctx.addMutableState(ctx.javaType(e.dataType), value, "")
+ val ev = e.genCode(ctx)
+ val initVars =
+ s"""
+ | $isNull = ${ev.isNull};
+ | $value = ${ev.value};
+ """.stripMargin
+ ExprCode(ev.code + initVars, isNull, value)
+ }
+ }
+
+ def generate(): String = {
+ s"""
+ |public class $generatedClassName {
+ |${initializeAggregateHashMap()}
+ |
+ |${generateFindOrInsert()}
+ |
+ |${generateEquals()}
+ |
+ |${generateHashFunction()}
+ |
+ |${generateRowIterator()}
+ |
+ |${generateClose()}
+ |}
+ """.stripMargin
+ }
+
+ protected def initializeAggregateHashMap(): String
+
+ /**
+ * Generates a method that computes a hash by currently xor-ing all individual group-by keys. For
+ * instance, if we have 2 long group-by keys, the generated function would be of the form:
+ *
+ * {{{
+ * private long hash(long agg_key, long agg_key1) {
+ * return agg_key ^ agg_key1;
+ * }
+ * }}}
+ */
+ protected final def generateHashFunction(): String = {
+ val hash = ctx.freshName("hash")
+
+ def genHashForKeys(groupingKeys: Seq[Buffer]): String = {
+ groupingKeys.map { key =>
+ val result = ctx.freshName("result")
+ s"""
+ |${genComputeHash(ctx, key.name, key.dataType, result)}
+ |$hash = ($hash ^ (0x9e3779b9)) + $result + ($hash << 6) + ($hash >>> 2);
+ """.stripMargin
+ }.mkString("\n")
+ }
+
+ s"""
+ |private long hash($groupingKeySignature) {
+ | long $hash = 0;
+ | ${genHashForKeys(groupingKeys)}
+ | return $hash;
+ |}
+ """.stripMargin
+ }
+
+ /**
+ * Generates a method that returns true if the group-by keys exist at a given index.
+ */
+ protected def generateEquals(): String
+
+ /**
+ * Generates a method that returns a row which keeps track of the
+ * aggregate value(s) for a given set of keys. If the corresponding row doesn't exist, the
+ * generated method adds the corresponding row in the associated key value batch.
+ */
+ protected def generateFindOrInsert(): String
+
+ protected def generateRowIterator(): String
+
+ protected final def generateClose(): String = {
+ s"""
+ |public void close() {
+ | batch.close();
+ |}
+ """.stripMargin
+ }
+
+ protected final def genComputeHash(
+ ctx: CodegenContext,
+ input: String,
+ dataType: DataType,
+ result: String): String = {
+ def hashInt(i: String): String = s"int $result = $i;"
+ def hashLong(l: String): String = s"long $result = $l;"
+ def hashBytes(b: String): String = {
+ val hash = ctx.freshName("hash")
+ val bytes = ctx.freshName("bytes")
+ s"""
+ |int $result = 0;
+ |byte[] $bytes = $b;
+ |for (int i = 0; i < $bytes.length; i++) {
+ | ${genComputeHash(ctx, s"$bytes[i]", ByteType, hash)}
+ | $result = ($result ^ (0x9e3779b9)) + $hash + ($result << 6) + ($result >>> 2);
+ |}
+ """.stripMargin
+ }
+
+ dataType match {
+ case BooleanType => hashInt(s"$input ? 1 : 0")
+ case ByteType | ShortType | IntegerType | DateType => hashInt(input)
+ case LongType | TimestampType => hashLong(input)
+ case FloatType => hashInt(s"Float.floatToIntBits($input)")
+ case DoubleType => hashLong(s"Double.doubleToLongBits($input)")
+ case d: DecimalType =>
+ if (d.precision <= Decimal.MAX_LONG_DIGITS) {
+ hashLong(s"$input.toUnscaledLong()")
+ } else {
+ val bytes = ctx.freshName("bytes")
+ s"""
+ final byte[] $bytes = $input.toJavaBigDecimal().unscaledValue().toByteArray();
+ ${hashBytes(bytes)}
+ """
+ }
+ case StringType => hashBytes(s"$input.getBytes()")
+ }
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala
new file mode 100644
index 0000000000..1dea33037c
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala
@@ -0,0 +1,205 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.aggregate
+
+import org.apache.spark.sql.catalyst.expressions.UnsafeRow
+import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression}
+import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext}
+import org.apache.spark.sql.types._
+
+/**
+ * This is a helper class to generate an append-only row-based hash map that can act as a 'cache'
+ * for extremely fast key-value lookups while evaluating aggregates (and fall back to the
+ * `BytesToBytesMap` if a given key isn't found). This is 'codegened' in HashAggregate to speed
+ * up aggregates w/ key.
+ *
+ * We also have VectorizedHashMapGenerator, which generates a append-only vectorized hash map.
+ * We choose one of the two as the 1st level, fast hash map during aggregation.
+ *
+ * NOTE: This row-based hash map currently doesn't support nullable keys and falls back to the
+ * `BytesToBytesMap` to store them.
+ */
+class RowBasedHashMapGenerator(
+ ctx: CodegenContext,
+ aggregateExpressions: Seq[AggregateExpression],
+ generatedClassName: String,
+ groupingKeySchema: StructType,
+ bufferSchema: StructType)
+ extends HashMapGenerator (ctx, aggregateExpressions, generatedClassName,
+ groupingKeySchema, bufferSchema) {
+
+ protected def initializeAggregateHashMap(): String = {
+ val generatedKeySchema: String =
+ s"new org.apache.spark.sql.types.StructType()" +
+ groupingKeySchema.map { key =>
+ key.dataType match {
+ case d: DecimalType =>
+ s""".add("${key.name}", org.apache.spark.sql.types.DataTypes.createDecimalType(
+ |${d.precision}, ${d.scale}))""".stripMargin
+ case _ =>
+ s""".add("${key.name}", org.apache.spark.sql.types.DataTypes.${key.dataType})"""
+ }
+ }.mkString("\n").concat(";")
+
+ val generatedValueSchema: String =
+ s"new org.apache.spark.sql.types.StructType()" +
+ bufferSchema.map { key =>
+ key.dataType match {
+ case d: DecimalType =>
+ s""".add("${key.name}", org.apache.spark.sql.types.DataTypes.createDecimalType(
+ |${d.precision}, ${d.scale}))""".stripMargin
+ case _ =>
+ s""".add("${key.name}", org.apache.spark.sql.types.DataTypes.${key.dataType})"""
+ }
+ }.mkString("\n").concat(";")
+
+ s"""
+ | private org.apache.spark.sql.catalyst.expressions.RowBasedKeyValueBatch batch;
+ | private int[] buckets;
+ | private int capacity = 1 << 16;
+ | private double loadFactor = 0.5;
+ | private int numBuckets = (int) (capacity / loadFactor);
+ | private int maxSteps = 2;
+ | private int numRows = 0;
+ | private org.apache.spark.sql.types.StructType keySchema = $generatedKeySchema
+ | private org.apache.spark.sql.types.StructType valueSchema = $generatedValueSchema
+ | private Object emptyVBase;
+ | private long emptyVOff;
+ | private int emptyVLen;
+ | private boolean isBatchFull = false;
+ |
+ |
+ | public $generatedClassName(
+ | org.apache.spark.memory.TaskMemoryManager taskMemoryManager,
+ | InternalRow emptyAggregationBuffer) {
+ | batch = org.apache.spark.sql.catalyst.expressions.RowBasedKeyValueBatch
+ | .allocate(keySchema, valueSchema, taskMemoryManager, capacity);
+ |
+ | final UnsafeProjection valueProjection = UnsafeProjection.create(valueSchema);
+ | final byte[] emptyBuffer = valueProjection.apply(emptyAggregationBuffer).getBytes();
+ |
+ | emptyVBase = emptyBuffer;
+ | emptyVOff = Platform.BYTE_ARRAY_OFFSET;
+ | emptyVLen = emptyBuffer.length;
+ |
+ | buckets = new int[numBuckets];
+ | java.util.Arrays.fill(buckets, -1);
+ | }
+ """.stripMargin
+ }
+
+ /**
+ * Generates a method that returns true if the group-by keys exist at a given index in the
+ * associated [[org.apache.spark.sql.catalyst.expressions.RowBasedKeyValueBatch]].
+ *
+ */
+ protected def generateEquals(): String = {
+
+ def genEqualsForKeys(groupingKeys: Seq[Buffer]): String = {
+ groupingKeys.zipWithIndex.map { case (key: Buffer, ordinal: Int) =>
+ s"""(${ctx.genEqual(key.dataType, ctx.getValue("row",
+ key.dataType, ordinal.toString()), key.name)})"""
+ }.mkString(" && ")
+ }
+
+ s"""
+ |private boolean equals(int idx, $groupingKeySignature) {
+ | UnsafeRow row = batch.getKeyRow(buckets[idx]);
+ | return ${genEqualsForKeys(groupingKeys)};
+ |}
+ """.stripMargin
+ }
+
+ /**
+ * Generates a method that returns a
+ * [[org.apache.spark.sql.catalyst.expressions.UnsafeRow]] which keeps track of the
+ * aggregate value(s) for a given set of keys. If the corresponding row doesn't exist, the
+ * generated method adds the corresponding row in the associated
+ * [[org.apache.spark.sql.catalyst.expressions.RowBasedKeyValueBatch]].
+ *
+ */
+ protected def generateFindOrInsert(): String = {
+ val numVarLenFields = groupingKeys.map(_.dataType).count {
+ case dt if UnsafeRow.isFixedLength(dt) => false
+ // TODO: consider large decimal and interval type
+ case _ => true
+ }
+
+ val createUnsafeRowForKey = groupingKeys.zipWithIndex.map { case (key: Buffer, ordinal: Int) =>
+ s"agg_rowWriter.write(${ordinal}, ${key.name})"}
+ .mkString(";\n")
+
+ s"""
+ |public org.apache.spark.sql.catalyst.expressions.UnsafeRow findOrInsert(${
+ groupingKeySignature}) {
+ | long h = hash(${groupingKeys.map(_.name).mkString(", ")});
+ | int step = 0;
+ | int idx = (int) h & (numBuckets - 1);
+ | while (step < maxSteps) {
+ | // Return bucket index if it's either an empty slot or already contains the key
+ | if (buckets[idx] == -1) {
+ | if (numRows < capacity && !isBatchFull) {
+ | // creating the unsafe for new entry
+ | UnsafeRow agg_result = new UnsafeRow(${groupingKeySchema.length});
+ | org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder agg_holder
+ | = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(agg_result,
+ | ${numVarLenFields * 32});
+ | org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter agg_rowWriter
+ | = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(
+ | agg_holder,
+ | ${groupingKeySchema.length});
+ | agg_holder.reset(); //TODO: investigate if reset or zeroout are actually needed
+ | agg_rowWriter.zeroOutNullBytes();
+ | ${createUnsafeRowForKey};
+ | agg_result.setTotalSize(agg_holder.totalSize());
+ | Object kbase = agg_result.getBaseObject();
+ | long koff = agg_result.getBaseOffset();
+ | int klen = agg_result.getSizeInBytes();
+ |
+ | UnsafeRow vRow
+ | = batch.appendRow(kbase, koff, klen, emptyVBase, emptyVOff, emptyVLen);
+ | if (vRow == null) {
+ | isBatchFull = true;
+ | } else {
+ | buckets[idx] = numRows++;
+ | }
+ | return vRow;
+ | } else {
+ | // No more space
+ | return null;
+ | }
+ | } else if (equals(idx, ${groupingKeys.map(_.name).mkString(", ")})) {
+ | return batch.getValueRow(buckets[idx]);
+ | }
+ | idx = (idx + 1) & (numBuckets - 1);
+ | step++;
+ | }
+ | // Didn't find it
+ | return null;
+ |}
+ """.stripMargin
+ }
+
+ protected def generateRowIterator(): String = {
+ s"""
+ |public org.apache.spark.unsafe.KVIterator<UnsafeRow, UnsafeRow> rowIterator() {
+ | return batch.rowIterator();
+ |}
+ """.stripMargin
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala
index b4a9059299..7418df90b8 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala
@@ -17,8 +17,8 @@
package org.apache.spark.sql.execution.aggregate
-import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, DeclarativeAggregate}
-import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
+import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression}
+import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext}
import org.apache.spark.sql.types._
/**
@@ -44,49 +44,11 @@ class VectorizedHashMapGenerator(
aggregateExpressions: Seq[AggregateExpression],
generatedClassName: String,
groupingKeySchema: StructType,
- bufferSchema: StructType) {
- case class Buffer(dataType: DataType, name: String)
- val groupingKeys = groupingKeySchema.map(k => Buffer(k.dataType, ctx.freshName("key")))
- val bufferValues = bufferSchema.map(k => Buffer(k.dataType, ctx.freshName("value")))
- val groupingKeySignature =
- groupingKeys.map(key => s"${ctx.javaType(key.dataType)} ${key.name}").mkString(", ")
- val buffVars: Seq[ExprCode] = {
- val functions = aggregateExpressions.map(_.aggregateFunction.asInstanceOf[DeclarativeAggregate])
- val initExpr = functions.flatMap(f => f.initialValues)
- initExpr.map { e =>
- val isNull = ctx.freshName("bufIsNull")
- val value = ctx.freshName("bufValue")
- ctx.addMutableState("boolean", isNull, "")
- ctx.addMutableState(ctx.javaType(e.dataType), value, "")
- val ev = e.genCode(ctx)
- val initVars =
- s"""
- | $isNull = ${ev.isNull};
- | $value = ${ev.value};
- """.stripMargin
- ExprCode(ev.code + initVars, isNull, value)
- }
- }
-
- def generate(): String = {
- s"""
- |public class $generatedClassName {
- |${initializeAggregateHashMap()}
- |
- |${generateFindOrInsert()}
- |
- |${generateEquals()}
- |
- |${generateHashFunction()}
- |
- |${generateRowIterator()}
- |
- |${generateClose()}
- |}
- """.stripMargin
- }
+ bufferSchema: StructType)
+ extends HashMapGenerator (ctx, aggregateExpressions, generatedClassName,
+ groupingKeySchema, bufferSchema) {
- private def initializeAggregateHashMap(): String = {
+ protected def initializeAggregateHashMap(): String = {
val generatedSchema: String =
s"new org.apache.spark.sql.types.StructType()" +
(groupingKeySchema ++ bufferSchema).map { key =>
@@ -140,37 +102,6 @@ class VectorizedHashMapGenerator(
""".stripMargin
}
- /**
- * Generates a method that computes a hash by currently xor-ing all individual group-by keys. For
- * instance, if we have 2 long group-by keys, the generated function would be of the form:
- *
- * {{{
- * private long hash(long agg_key, long agg_key1) {
- * return agg_key ^ agg_key1;
- * }
- * }}}
- */
- private def generateHashFunction(): String = {
- val hash = ctx.freshName("hash")
-
- def genHashForKeys(groupingKeys: Seq[Buffer]): String = {
- groupingKeys.map { key =>
- val result = ctx.freshName("result")
- s"""
- |${genComputeHash(ctx, key.name, key.dataType, result)}
- |$hash = ($hash ^ (0x9e3779b9)) + $result + ($hash << 6) + ($hash >>> 2);
- """.stripMargin
- }.mkString("\n")
- }
-
- s"""
- |private long hash($groupingKeySignature) {
- | long $hash = 0;
- | ${genHashForKeys(groupingKeys)}
- | return $hash;
- |}
- """.stripMargin
- }
/**
* Generates a method that returns true if the group-by keys exist at a given index in the
@@ -184,7 +115,7 @@ class VectorizedHashMapGenerator(
* }
* }}}
*/
- private def generateEquals(): String = {
+ protected def generateEquals(): String = {
def genEqualsForKeys(groupingKeys: Seq[Buffer]): String = {
groupingKeys.zipWithIndex.map { case (key: Buffer, ordinal: Int) =>
@@ -233,7 +164,7 @@ class VectorizedHashMapGenerator(
* }
* }}}
*/
- private def generateFindOrInsert(): String = {
+ protected def generateFindOrInsert(): String = {
def genCodeToSetKeys(groupingKeys: Seq[Buffer]): Seq[String] = {
groupingKeys.zipWithIndex.map { case (key: Buffer, ordinal: Int) =>
@@ -287,7 +218,7 @@ class VectorizedHashMapGenerator(
""".stripMargin
}
- private def generateRowIterator(): String = {
+ protected def generateRowIterator(): String = {
s"""
|public java.util.Iterator<org.apache.spark.sql.execution.vectorized.ColumnarBatch.Row>
| rowIterator() {
@@ -295,52 +226,4 @@ class VectorizedHashMapGenerator(
|}
""".stripMargin
}
-
- private def generateClose(): String = {
- s"""
- |public void close() {
- | batch.close();
- |}
- """.stripMargin
- }
-
- private def genComputeHash(
- ctx: CodegenContext,
- input: String,
- dataType: DataType,
- result: String): String = {
- def hashInt(i: String): String = s"int $result = $i;"
- def hashLong(l: String): String = s"long $result = $l;"
- def hashBytes(b: String): String = {
- val hash = ctx.freshName("hash")
- val bytes = ctx.freshName("bytes")
- s"""
- |int $result = 0;
- |byte[] $bytes = $b;
- |for (int i = 0; i < $bytes.length; i++) {
- | ${genComputeHash(ctx, s"$bytes[i]", ByteType, hash)}
- | $result = ($result ^ (0x9e3779b9)) + $hash + ($result << 6) + ($result >>> 2);
- |}
- """.stripMargin
- }
-
- dataType match {
- case BooleanType => hashInt(s"$input ? 1 : 0")
- case ByteType | ShortType | IntegerType | DateType => hashInt(input)
- case LongType | TimestampType => hashLong(input)
- case FloatType => hashInt(s"Float.floatToIntBits($input)")
- case DoubleType => hashLong(s"Double.doubleToLongBits($input)")
- case d: DecimalType =>
- if (d.precision <= Decimal.MAX_LONG_DIGITS) {
- hashLong(s"$input.toUnscaledLong()")
- } else {
- val bytes = ctx.freshName("bytes")
- s"""
- final byte[] $bytes = $input.toJavaBigDecimal().unscaledValue().toByteArray();
- ${hashBytes(bytes)}
- """
- }
- case StringType => hashBytes(s"$input.getBytes()")
- }
- }
}