From 0e5ebac3c1f1ff58f938be59c7c9e604977d269c Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Tue, 9 Feb 2016 16:41:21 -0800 Subject: [SPARK-12950] [SQL] Improve lookup of BytesToBytesMap in aggregate This PR improve the lookup of BytesToBytesMap by: 1. Generate code for calculate the hash code of grouping keys. 2. Do not use MemoryLocation, fetch the baseObject and offset for key and value directly (remove the indirection). Author: Davies Liu Closes #11010 from davies/gen_map. --- .../apache/spark/unsafe/map/BytesToBytesMap.java | 108 ++++++++++++--------- .../unsafe/map/AbstractBytesToBytesMapSuite.java | 64 ++++++------ project/MimaExcludes.scala | 1 + .../spark/sql/catalyst/expressions/misc.scala | 1 - .../execution/UnsafeFixedWidthAggregationMap.java | 34 ++++--- .../sql/execution/UnsafeKVExternalSorter.java | 4 +- .../spark/sql/execution/WholeStageCodegen.scala | 6 +- .../execution/aggregate/TungstenAggregate.scala | 10 +- .../spark/sql/execution/joins/HashedRelation.scala | 17 ++-- .../sql/execution/BenchmarkWholeStageCodegen.scala | 64 +++++++++--- 10 files changed, 182 insertions(+), 127 deletions(-) diff --git a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java index 3387f9a417..b55a322a1b 100644 --- a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java +++ b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java @@ -38,7 +38,6 @@ import org.apache.spark.unsafe.array.ByteArrayMethods; import org.apache.spark.unsafe.array.LongArray; import org.apache.spark.unsafe.hash.Murmur3_x86_32; import org.apache.spark.unsafe.memory.MemoryBlock; -import org.apache.spark.unsafe.memory.MemoryLocation; import org.apache.spark.util.collection.unsafe.sort.UnsafeSorterSpillReader; import org.apache.spark.util.collection.unsafe.sort.UnsafeSorterSpillWriter; @@ -65,8 +64,6 @@ public final class BytesToBytesMap extends MemoryConsumer { private final Logger logger = LoggerFactory.getLogger(BytesToBytesMap.class); - private static final Murmur3_x86_32 HASHER = new Murmur3_x86_32(0); - private static final HashMapGrowthStrategy growthStrategy = HashMapGrowthStrategy.DOUBLING; private final TaskMemoryManager taskMemoryManager; @@ -417,7 +414,19 @@ public final class BytesToBytesMap extends MemoryConsumer { * This function always return the same {@link Location} instance to avoid object allocation. */ public Location lookup(Object keyBase, long keyOffset, int keyLength) { - safeLookup(keyBase, keyOffset, keyLength, loc); + safeLookup(keyBase, keyOffset, keyLength, loc, + Murmur3_x86_32.hashUnsafeWords(keyBase, keyOffset, keyLength, 42)); + return loc; + } + + /** + * Looks up a key, and return a {@link Location} handle that can be used to test existence + * and read/write values. + * + * This function always return the same {@link Location} instance to avoid object allocation. + */ + public Location lookup(Object keyBase, long keyOffset, int keyLength, int hash) { + safeLookup(keyBase, keyOffset, keyLength, loc, hash); return loc; } @@ -426,14 +435,13 @@ public final class BytesToBytesMap extends MemoryConsumer { * * This is a thread-safe version of `lookup`, could be used by multiple threads. */ - public void safeLookup(Object keyBase, long keyOffset, int keyLength, Location loc) { + public void safeLookup(Object keyBase, long keyOffset, int keyLength, Location loc, int hash) { assert(longArray != null); if (enablePerfMetrics) { numKeyLookups++; } - final int hashcode = HASHER.hashUnsafeWords(keyBase, keyOffset, keyLength); - int pos = hashcode & mask; + int pos = hash & mask; int step = 1; while (true) { if (enablePerfMetrics) { @@ -441,22 +449,19 @@ public final class BytesToBytesMap extends MemoryConsumer { } if (longArray.get(pos * 2) == 0) { // This is a new key. - loc.with(pos, hashcode, false); + loc.with(pos, hash, false); return; } else { long stored = longArray.get(pos * 2 + 1); - if ((int) (stored) == hashcode) { + if ((int) (stored) == hash) { // Full hash code matches. Let's compare the keys for equality. - loc.with(pos, hashcode, true); + loc.with(pos, hash, true); if (loc.getKeyLength() == keyLength) { - final MemoryLocation keyAddress = loc.getKeyAddress(); - final Object storedkeyBase = keyAddress.getBaseObject(); - final long storedkeyOffset = keyAddress.getBaseOffset(); final boolean areEqual = ByteArrayMethods.arrayEquals( keyBase, keyOffset, - storedkeyBase, - storedkeyOffset, + loc.getKeyBase(), + loc.getKeyOffset(), keyLength ); if (areEqual) { @@ -484,13 +489,14 @@ public final class BytesToBytesMap extends MemoryConsumer { private boolean isDefined; /** * The hashcode of the most recent key passed to - * {@link BytesToBytesMap#lookup(Object, long, int)}. Caching this hashcode here allows us to - * avoid re-hashing the key when storing a value for that key. + * {@link BytesToBytesMap#lookup(Object, long, int, int)}. Caching this hashcode here allows us + * to avoid re-hashing the key when storing a value for that key. */ private int keyHashcode; - private final MemoryLocation keyMemoryLocation = new MemoryLocation(); - private final MemoryLocation valueMemoryLocation = new MemoryLocation(); + private Object baseObject; // the base object for key and value + private long keyOffset; private int keyLength; + private long valueOffset; private int valueLength; /** @@ -504,18 +510,15 @@ public final class BytesToBytesMap extends MemoryConsumer { taskMemoryManager.getOffsetInPage(fullKeyAddress)); } - private void updateAddressesAndSizes(final Object base, final long offset) { - long position = offset; - final int totalLength = Platform.getInt(base, position); - position += 4; - keyLength = Platform.getInt(base, position); - position += 4; + private void updateAddressesAndSizes(final Object base, long offset) { + baseObject = base; + final int totalLength = Platform.getInt(base, offset); + offset += 4; + keyLength = Platform.getInt(base, offset); + offset += 4; + keyOffset = offset; + valueOffset = offset + keyLength; valueLength = totalLength - keyLength - 4; - - keyMemoryLocation.setObjAndOffset(base, position); - - position += keyLength; - valueMemoryLocation.setObjAndOffset(base, position); } private Location with(int pos, int keyHashcode, boolean isDefined) { @@ -543,10 +546,11 @@ public final class BytesToBytesMap extends MemoryConsumer { private Location with(Object base, long offset, int length) { this.isDefined = true; this.memoryPage = null; + baseObject = base; + keyOffset = offset + 4; keyLength = Platform.getInt(base, offset); + valueOffset = offset + 4 + keyLength; valueLength = length - 4 - keyLength; - keyMemoryLocation.setObjAndOffset(base, offset + 4); - valueMemoryLocation.setObjAndOffset(base, offset + 4 + keyLength); return this; } @@ -566,34 +570,44 @@ public final class BytesToBytesMap extends MemoryConsumer { } /** - * Returns the address of the key defined at this position. - * This points to the first byte of the key data. - * Unspecified behavior if the key is not defined. - * For efficiency reasons, calls to this method always returns the same MemoryLocation object. + * Returns the base object for key. */ - public MemoryLocation getKeyAddress() { + public Object getKeyBase() { assert (isDefined); - return keyMemoryLocation; + return baseObject; } /** - * Returns the length of the key defined at this position. - * Unspecified behavior if the key is not defined. + * Returns the offset for key. */ - public int getKeyLength() { + public long getKeyOffset() { assert (isDefined); - return keyLength; + return keyOffset; + } + + /** + * Returns the base object for value. + */ + public Object getValueBase() { + assert (isDefined); + return baseObject; } /** - * Returns the address of the value defined at this position. - * This points to the first byte of the value data. + * Returns the offset for value. + */ + public long getValueOffset() { + assert (isDefined); + return valueOffset; + } + + /** + * Returns the length of the key defined at this position. * Unspecified behavior if the key is not defined. - * For efficiency reasons, calls to this method always returns the same MemoryLocation object. */ - public MemoryLocation getValueAddress() { + public int getKeyLength() { assert (isDefined); - return valueMemoryLocation; + return keyLength; } /** diff --git a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java index 702ba5469b..d8af2b336d 100644 --- a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java +++ b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java @@ -39,14 +39,13 @@ import org.mockito.stubbing.Answer; import org.apache.spark.SparkConf; import org.apache.spark.executor.ShuffleWriteMetrics; -import org.apache.spark.memory.TestMemoryManager; import org.apache.spark.memory.TaskMemoryManager; +import org.apache.spark.memory.TestMemoryManager; import org.apache.spark.network.util.JavaUtils; import org.apache.spark.serializer.SerializerInstance; import org.apache.spark.storage.*; import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.array.ByteArrayMethods; -import org.apache.spark.unsafe.memory.MemoryLocation; import org.apache.spark.util.Utils; import static org.hamcrest.Matchers.greaterThan; @@ -142,10 +141,9 @@ public abstract class AbstractBytesToBytesMapSuite { protected abstract boolean useOffHeapMemoryAllocator(); - private static byte[] getByteArray(MemoryLocation loc, int size) { + private static byte[] getByteArray(Object base, long offset, int size) { final byte[] arr = new byte[size]; - Platform.copyMemory( - loc.getBaseObject(), loc.getBaseOffset(), arr, Platform.BYTE_ARRAY_OFFSET, size); + Platform.copyMemory(base, offset, arr, Platform.BYTE_ARRAY_OFFSET, size); return arr; } @@ -163,13 +161,14 @@ public abstract class AbstractBytesToBytesMapSuite { */ private static boolean arrayEquals( byte[] expected, - MemoryLocation actualAddr, + Object base, + long offset, long actualLengthBytes) { return (actualLengthBytes == expected.length) && ByteArrayMethods.arrayEquals( expected, Platform.BYTE_ARRAY_OFFSET, - actualAddr.getBaseObject(), - actualAddr.getBaseOffset(), + base, + offset, expected.length ); } @@ -212,16 +211,20 @@ public abstract class AbstractBytesToBytesMapSuite { // reflect the result of this store without us having to call lookup() again on the same key. Assert.assertEquals(recordLengthBytes, loc.getKeyLength()); Assert.assertEquals(recordLengthBytes, loc.getValueLength()); - Assert.assertArrayEquals(keyData, getByteArray(loc.getKeyAddress(), recordLengthBytes)); - Assert.assertArrayEquals(valueData, getByteArray(loc.getValueAddress(), recordLengthBytes)); + Assert.assertArrayEquals(keyData, + getByteArray(loc.getKeyBase(), loc.getKeyOffset(), recordLengthBytes)); + Assert.assertArrayEquals(valueData, + getByteArray(loc.getValueBase(), loc.getValueOffset(), recordLengthBytes)); // After calling lookup() the location should still point to the correct data. Assert.assertTrue( map.lookup(keyData, Platform.BYTE_ARRAY_OFFSET, recordLengthBytes).isDefined()); Assert.assertEquals(recordLengthBytes, loc.getKeyLength()); Assert.assertEquals(recordLengthBytes, loc.getValueLength()); - Assert.assertArrayEquals(keyData, getByteArray(loc.getKeyAddress(), recordLengthBytes)); - Assert.assertArrayEquals(valueData, getByteArray(loc.getValueAddress(), recordLengthBytes)); + Assert.assertArrayEquals(keyData, + getByteArray(loc.getKeyBase(), loc.getKeyOffset(), recordLengthBytes)); + Assert.assertArrayEquals(valueData, + getByteArray(loc.getValueBase(), loc.getValueOffset(), recordLengthBytes)); try { Assert.assertTrue(loc.putNewKey( @@ -283,15 +286,12 @@ public abstract class AbstractBytesToBytesMapSuite { while (iter.hasNext()) { final BytesToBytesMap.Location loc = iter.next(); Assert.assertTrue(loc.isDefined()); - final MemoryLocation keyAddress = loc.getKeyAddress(); - final MemoryLocation valueAddress = loc.getValueAddress(); - final long value = Platform.getLong( - valueAddress.getBaseObject(), valueAddress.getBaseOffset()); + final long value = Platform.getLong(loc.getValueBase(), loc.getValueOffset()); final long keyLength = loc.getKeyLength(); if (keyLength == 0) { Assert.assertTrue("value " + value + " was not divisible by 5", value % 5 == 0); } else { - final long key = Platform.getLong(keyAddress.getBaseObject(), keyAddress.getBaseOffset()); + final long key = Platform.getLong(loc.getKeyBase(), loc.getKeyOffset()); Assert.assertEquals(value, key); } valuesSeen.set((int) value); @@ -365,15 +365,15 @@ public abstract class AbstractBytesToBytesMapSuite { Assert.assertEquals(KEY_LENGTH, loc.getKeyLength()); Assert.assertEquals(VALUE_LENGTH, loc.getValueLength()); Platform.copyMemory( - loc.getKeyAddress().getBaseObject(), - loc.getKeyAddress().getBaseOffset(), + loc.getKeyBase(), + loc.getKeyOffset(), key, Platform.LONG_ARRAY_OFFSET, KEY_LENGTH ); Platform.copyMemory( - loc.getValueAddress().getBaseObject(), - loc.getValueAddress().getBaseOffset(), + loc.getValueBase(), + loc.getValueOffset(), value, Platform.LONG_ARRAY_OFFSET, VALUE_LENGTH @@ -425,8 +425,9 @@ public abstract class AbstractBytesToBytesMapSuite { Assert.assertTrue(loc.isDefined()); Assert.assertEquals(key.length, loc.getKeyLength()); Assert.assertEquals(value.length, loc.getValueLength()); - Assert.assertTrue(arrayEquals(key, loc.getKeyAddress(), key.length)); - Assert.assertTrue(arrayEquals(value, loc.getValueAddress(), value.length)); + Assert.assertTrue(arrayEquals(key, loc.getKeyBase(), loc.getKeyOffset(), key.length)); + Assert.assertTrue( + arrayEquals(value, loc.getValueBase(), loc.getValueOffset(), value.length)); } } @@ -436,8 +437,10 @@ public abstract class AbstractBytesToBytesMapSuite { final BytesToBytesMap.Location loc = map.lookup(key, Platform.BYTE_ARRAY_OFFSET, key.length); Assert.assertTrue(loc.isDefined()); - Assert.assertTrue(arrayEquals(key, loc.getKeyAddress(), loc.getKeyLength())); - Assert.assertTrue(arrayEquals(value, loc.getValueAddress(), loc.getValueLength())); + Assert.assertTrue( + arrayEquals(key, loc.getKeyBase(), loc.getKeyOffset(), loc.getKeyLength())); + Assert.assertTrue( + arrayEquals(value, loc.getValueBase(), loc.getValueOffset(), loc.getValueLength())); } } finally { map.free(); @@ -476,8 +479,9 @@ public abstract class AbstractBytesToBytesMapSuite { Assert.assertTrue(loc.isDefined()); Assert.assertEquals(key.length, loc.getKeyLength()); Assert.assertEquals(value.length, loc.getValueLength()); - Assert.assertTrue(arrayEquals(key, loc.getKeyAddress(), key.length)); - Assert.assertTrue(arrayEquals(value, loc.getValueAddress(), value.length)); + Assert.assertTrue(arrayEquals(key, loc.getKeyBase(), loc.getKeyOffset(), key.length)); + Assert.assertTrue( + arrayEquals(value, loc.getValueBase(), loc.getValueOffset(), value.length)); } } for (Map.Entry entry : expected.entrySet()) { @@ -486,8 +490,10 @@ public abstract class AbstractBytesToBytesMapSuite { final BytesToBytesMap.Location loc = map.lookup(key, Platform.BYTE_ARRAY_OFFSET, key.length); Assert.assertTrue(loc.isDefined()); - Assert.assertTrue(arrayEquals(key, loc.getKeyAddress(), loc.getKeyLength())); - Assert.assertTrue(arrayEquals(value, loc.getValueAddress(), loc.getValueLength())); + Assert.assertTrue( + arrayEquals(key, loc.getKeyBase(), loc.getKeyOffset(), loc.getKeyLength())); + Assert.assertTrue( + arrayEquals(value, loc.getValueBase(), loc.getValueOffset(), loc.getValueLength())); } } finally { map.free(); diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 9209094385..133894704b 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -40,6 +40,7 @@ object MimaExcludes { excludePackage("org.apache.spark.rpc"), excludePackage("org.spark-project.jetty"), excludePackage("org.apache.spark.unused"), + excludePackage("org.apache.spark.unsafe"), excludePackage("org.apache.spark.util.collection.unsafe"), excludePackage("org.apache.spark.sql.catalyst"), excludePackage("org.apache.spark.sql.execution"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index f4ccadd9c5..28e4f50eee 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -322,7 +322,6 @@ case class Murmur3Hash(children: Seq[Expression], seed: Int) extends Expression } } - override def genCode(ctx: CodegenContext, ev: ExprCode): String = { ev.isNull = "false" val childrenHash = children.map { child => diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java index 6bf9d7bd03..2e84178d69 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java @@ -121,19 +121,24 @@ public final class UnsafeFixedWidthAggregationMap { return getAggregationBufferFromUnsafeRow(unsafeGroupingKeyRow); } - public UnsafeRow getAggregationBufferFromUnsafeRow(UnsafeRow unsafeGroupingKeyRow) { + public UnsafeRow getAggregationBufferFromUnsafeRow(UnsafeRow key) { + return getAggregationBufferFromUnsafeRow(key, key.hashCode()); + } + + public UnsafeRow getAggregationBufferFromUnsafeRow(UnsafeRow key, int hash) { // Probe our map using the serialized key final BytesToBytesMap.Location loc = map.lookup( - unsafeGroupingKeyRow.getBaseObject(), - unsafeGroupingKeyRow.getBaseOffset(), - unsafeGroupingKeyRow.getSizeInBytes()); + key.getBaseObject(), + key.getBaseOffset(), + key.getSizeInBytes(), + hash); if (!loc.isDefined()) { // This is the first time that we've seen this grouping key, so we'll insert a copy of the // empty aggregation buffer into the map: boolean putSucceeded = loc.putNewKey( - unsafeGroupingKeyRow.getBaseObject(), - unsafeGroupingKeyRow.getBaseOffset(), - unsafeGroupingKeyRow.getSizeInBytes(), + key.getBaseObject(), + key.getBaseOffset(), + key.getSizeInBytes(), emptyAggregationBuffer, Platform.BYTE_ARRAY_OFFSET, emptyAggregationBuffer.length @@ -144,10 +149,9 @@ public final class UnsafeFixedWidthAggregationMap { } // Reset the pointer to point to the value that we just stored or looked up: - final MemoryLocation address = loc.getValueAddress(); currentAggregationBuffer.pointTo( - address.getBaseObject(), - address.getBaseOffset(), + loc.getValueBase(), + loc.getValueOffset(), loc.getValueLength() ); return currentAggregationBuffer; @@ -172,16 +176,14 @@ public final class UnsafeFixedWidthAggregationMap { public boolean next() { if (mapLocationIterator.hasNext()) { final BytesToBytesMap.Location loc = mapLocationIterator.next(); - final MemoryLocation keyAddress = loc.getKeyAddress(); - final MemoryLocation valueAddress = loc.getValueAddress(); key.pointTo( - keyAddress.getBaseObject(), - keyAddress.getBaseOffset(), + loc.getKeyBase(), + loc.getKeyOffset(), loc.getKeyLength() ); value.pointTo( - valueAddress.getBaseObject(), - valueAddress.getBaseOffset(), + loc.getValueBase(), + loc.getValueOffset(), loc.getValueLength() ); return true; diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java index 0da26bf376..51e10b0e93 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java @@ -97,8 +97,8 @@ public final class UnsafeKVExternalSorter { UnsafeRow row = new UnsafeRow(numKeyFields); while (iter.hasNext()) { final BytesToBytesMap.Location loc = iter.next(); - final Object baseObject = loc.getKeyAddress().getBaseObject(); - final long baseOffset = loc.getKeyAddress().getBaseOffset(); + final Object baseObject = loc.getKeyBase(); + final long baseOffset = loc.getKeyOffset(); // Get encoded memory address // baseObject + baseOffset point to the beginning of the key data in the map, but that diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala index 4ca2d85406..b200239c94 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala @@ -366,11 +366,7 @@ private[sql] case class CollapseCodegenStages(sqlContext: SQLContext) extends Ru def apply(plan: SparkPlan): SparkPlan = { if (sqlContext.conf.wholeStageEnabled) { plan.transform { - case plan: CodegenSupport if supportCodegen(plan) && - // Whole stage codegen is only useful when there are at least two levels of operators that - // support it (save at least one projection/iterator). - (Utils.isTesting || plan.children.exists(supportCodegen)) => - + case plan: CodegenSupport if supportCodegen(plan) => var inputs = ArrayBuffer[SparkPlan]() val combined = plan.transform { // The build side can't be compiled together diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala index 9d9f14f2dd..340b8f78e5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala @@ -501,6 +501,11 @@ case class TungstenAggregate( } } + // generate hash code for key + val hashExpr = Murmur3Hash(groupingExpressions, 42) + ctx.currentVars = input + val hashEval = BindReferences.bindReference(hashExpr, child.output).gen(ctx) + val inputAttr = bufferAttributes ++ child.output ctx.currentVars = new Array[ExprCode](bufferAttributes.length) ++ input ctx.INPUT_ROW = buffer @@ -526,10 +531,11 @@ case class TungstenAggregate( s""" // generate grouping key ${keyCode.code.trim} + ${hashEval.code.trim} UnsafeRow $buffer = null; if ($checkFallback) { // try to get the buffer from hash map - $buffer = $hashMapTerm.getAggregationBufferFromUnsafeRow($key); + $buffer = $hashMapTerm.getAggregationBufferFromUnsafeRow($key, ${hashEval.value}); } if ($buffer == null) { if ($sorterTerm == null) { @@ -540,7 +546,7 @@ case class TungstenAggregate( $resetCoulter // the hash map had be spilled, it should have enough memory now, // try to allocate buffer again. - $buffer = $hashMapTerm.getAggregationBufferFromUnsafeRow($key); + $buffer = $hashMapTerm.getAggregationBufferFromUnsafeRow($key, ${hashEval.value}); if ($buffer == null) { // failed to allocate the first page throw new OutOfMemoryError("No enough memory for aggregation"); diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index c94d6c195b..eb6930a14f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -277,13 +277,13 @@ private[joins] final class UnsafeHashedRelation( val map = binaryMap // avoid the compiler error val loc = new map.Location // this could be allocated in stack binaryMap.safeLookup(unsafeKey.getBaseObject, unsafeKey.getBaseOffset, - unsafeKey.getSizeInBytes, loc) + unsafeKey.getSizeInBytes, loc, unsafeKey.hashCode()) if (loc.isDefined) { val buffer = CompactBuffer[UnsafeRow]() - val base = loc.getValueAddress.getBaseObject - var offset = loc.getValueAddress.getBaseOffset - val last = loc.getValueAddress.getBaseOffset + loc.getValueLength + val base = loc.getValueBase + var offset = loc.getValueOffset + val last = offset + loc.getValueLength while (offset < last) { val numFields = Platform.getInt(base, offset) val sizeInBytes = Platform.getInt(base, offset + 4) @@ -311,12 +311,11 @@ private[joins] final class UnsafeHashedRelation( out.writeInt(binaryMap.numElements()) var buffer = new Array[Byte](64) - def write(addr: MemoryLocation, length: Int): Unit = { + def write(base: Object, offset: Long, length: Int): Unit = { if (buffer.length < length) { buffer = new Array[Byte](length) } - Platform.copyMemory(addr.getBaseObject, addr.getBaseOffset, - buffer, Platform.BYTE_ARRAY_OFFSET, length) + Platform.copyMemory(base, offset, buffer, Platform.BYTE_ARRAY_OFFSET, length) out.write(buffer, 0, length) } @@ -326,8 +325,8 @@ private[joins] final class UnsafeHashedRelation( // [key size] [values size] [key bytes] [values bytes] out.writeInt(loc.getKeyLength) out.writeInt(loc.getValueLength) - write(loc.getKeyAddress, loc.getKeyLength) - write(loc.getValueAddress, loc.getValueLength) + write(loc.getKeyBase, loc.getKeyOffset, loc.getKeyLength) + write(loc.getValueBase, loc.getValueOffset, loc.getValueLength) } } else { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala index f015d29704..dc6c647a4a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala @@ -114,11 +114,11 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { } /* - Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz - Aggregate w keys: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------- - Aggregate w keys codegen=false 2402 / 2551 8.0 125.0 1.0X - Aggregate w keys codegen=true 1620 / 1670 12.0 83.3 1.5X + Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz + Aggregate w keys: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------- + Aggregate w keys codegen=false 2429 / 2644 8.6 115.8 1.0X + Aggregate w keys codegen=true 1535 / 1571 13.7 73.2 1.6X */ } @@ -165,21 +165,51 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { benchmark.addCase("hash") { iter => var i = 0 val keyBytes = new Array[Byte](16) - val valueBytes = new Array[Byte](16) val key = new UnsafeRow(1) key.pointTo(keyBytes, Platform.BYTE_ARRAY_OFFSET, 16) - val value = new UnsafeRow(2) - value.pointTo(valueBytes, Platform.BYTE_ARRAY_OFFSET, 16) var s = 0 while (i < N) { key.setInt(0, i % 1000) val h = Murmur3_x86_32.hashUnsafeWords( - key.getBaseObject, key.getBaseOffset, key.getSizeInBytes, 0) + key.getBaseObject, key.getBaseOffset, key.getSizeInBytes, 42) + s += h + i += 1 + } + } + + benchmark.addCase("fast hash") { iter => + var i = 0 + val keyBytes = new Array[Byte](16) + val key = new UnsafeRow(1) + key.pointTo(keyBytes, Platform.BYTE_ARRAY_OFFSET, 16) + var s = 0 + while (i < N) { + key.setInt(0, i % 1000) + val h = Murmur3_x86_32.hashLong(i % 1000, 42) s += h i += 1 } } + benchmark.addCase("arrayEqual") { iter => + var i = 0 + val keyBytes = new Array[Byte](16) + val valueBytes = new Array[Byte](16) + val key = new UnsafeRow(1) + key.pointTo(keyBytes, Platform.BYTE_ARRAY_OFFSET, 16) + val value = new UnsafeRow(1) + value.pointTo(valueBytes, Platform.BYTE_ARRAY_OFFSET, 16) + value.setInt(0, 555) + var s = 0 + while (i < N) { + key.setInt(0, i % 1000) + if (key.equals(value)) { + s += 1 + } + i += 1 + } + } + Seq("off", "on").foreach { heap => benchmark.addCase(s"BytesToBytesMap ($heap Heap)") { iter => val taskMemoryManager = new TaskMemoryManager( @@ -195,15 +225,15 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { val valueBytes = new Array[Byte](16) val key = new UnsafeRow(1) key.pointTo(keyBytes, Platform.BYTE_ARRAY_OFFSET, 16) - val value = new UnsafeRow(2) + val value = new UnsafeRow(1) value.pointTo(valueBytes, Platform.BYTE_ARRAY_OFFSET, 16) var i = 0 while (i < N) { key.setInt(0, i % 65536) - val loc = map.lookup(key.getBaseObject, key.getBaseOffset, key.getSizeInBytes) + val loc = map.lookup(key.getBaseObject, key.getBaseOffset, key.getSizeInBytes, + Murmur3_x86_32.hashLong(i % 65536, 42)) if (loc.isDefined) { - value.pointTo(loc.getValueAddress.getBaseObject, loc.getValueAddress.getBaseOffset, - loc.getValueLength) + value.pointTo(loc.getValueBase, loc.getValueOffset, loc.getValueLength) value.setInt(0, value.getInt(0) + 1) i += 1 } else { @@ -218,9 +248,11 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz BytesToBytesMap: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------- - hash 628 / 661 83.0 12.0 1.0X - BytesToBytesMap (off Heap) 3292 / 3408 15.0 66.7 0.2X - BytesToBytesMap (on Heap) 3349 / 4267 15.0 66.7 0.2X + hash 651 / 678 80.0 12.5 1.0X + fast hash 336 / 343 155.9 6.4 1.9X + arrayEqual 417 / 428 125.0 8.0 1.6X + BytesToBytesMap (off Heap) 2594 / 2664 20.2 49.5 0.2X + BytesToBytesMap (on Heap) 2693 / 2989 19.5 51.4 0.2X */ benchmark.run() } -- cgit v1.2.3