aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2016-02-09 16:41:21 -0800
committerDavies Liu <davies.liu@gmail.com>2016-02-09 16:41:21 -0800
commit0e5ebac3c1f1ff58f938be59c7c9e604977d269c (patch)
tree6a572ee7aa4f79f285a78a083dc69eef87627be0
parentfae830d15846f7ffdfe49eeb45e175a3cdd2c670 (diff)
downloadspark-0e5ebac3c1f1ff58f938be59c7c9e604977d269c.tar.gz
spark-0e5ebac3c1f1ff58f938be59c7c9e604977d269c.tar.bz2
spark-0e5ebac3c1f1ff58f938be59c7c9e604977d269c.zip
[SPARK-12950] [SQL] Improve lookup of BytesToBytesMap in aggregate
This PR improve the lookup of BytesToBytesMap by: 1. Generate code for calculate the hash code of grouping keys. 2. Do not use MemoryLocation, fetch the baseObject and offset for key and value directly (remove the indirection). Author: Davies Liu <davies@databricks.com> Closes #11010 from davies/gen_map.
-rw-r--r--core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java108
-rw-r--r--core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java64
-rw-r--r--project/MimaExcludes.scala1
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala1
-rw-r--r--sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java34
-rw-r--r--sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala6
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala10
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala17
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala64
10 files changed, 182 insertions, 127 deletions
diff --git a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java
index 3387f9a417..b55a322a1b 100644
--- a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java
+++ b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java
@@ -38,7 +38,6 @@ import org.apache.spark.unsafe.array.ByteArrayMethods;
import org.apache.spark.unsafe.array.LongArray;
import org.apache.spark.unsafe.hash.Murmur3_x86_32;
import org.apache.spark.unsafe.memory.MemoryBlock;
-import org.apache.spark.unsafe.memory.MemoryLocation;
import org.apache.spark.util.collection.unsafe.sort.UnsafeSorterSpillReader;
import org.apache.spark.util.collection.unsafe.sort.UnsafeSorterSpillWriter;
@@ -65,8 +64,6 @@ public final class BytesToBytesMap extends MemoryConsumer {
private final Logger logger = LoggerFactory.getLogger(BytesToBytesMap.class);
- private static final Murmur3_x86_32 HASHER = new Murmur3_x86_32(0);
-
private static final HashMapGrowthStrategy growthStrategy = HashMapGrowthStrategy.DOUBLING;
private final TaskMemoryManager taskMemoryManager;
@@ -417,7 +414,19 @@ public final class BytesToBytesMap extends MemoryConsumer {
* This function always return the same {@link Location} instance to avoid object allocation.
*/
public Location lookup(Object keyBase, long keyOffset, int keyLength) {
- safeLookup(keyBase, keyOffset, keyLength, loc);
+ safeLookup(keyBase, keyOffset, keyLength, loc,
+ Murmur3_x86_32.hashUnsafeWords(keyBase, keyOffset, keyLength, 42));
+ return loc;
+ }
+
+ /**
+ * Looks up a key, and return a {@link Location} handle that can be used to test existence
+ * and read/write values.
+ *
+ * This function always return the same {@link Location} instance to avoid object allocation.
+ */
+ public Location lookup(Object keyBase, long keyOffset, int keyLength, int hash) {
+ safeLookup(keyBase, keyOffset, keyLength, loc, hash);
return loc;
}
@@ -426,14 +435,13 @@ public final class BytesToBytesMap extends MemoryConsumer {
*
* This is a thread-safe version of `lookup`, could be used by multiple threads.
*/
- public void safeLookup(Object keyBase, long keyOffset, int keyLength, Location loc) {
+ public void safeLookup(Object keyBase, long keyOffset, int keyLength, Location loc, int hash) {
assert(longArray != null);
if (enablePerfMetrics) {
numKeyLookups++;
}
- final int hashcode = HASHER.hashUnsafeWords(keyBase, keyOffset, keyLength);
- int pos = hashcode & mask;
+ int pos = hash & mask;
int step = 1;
while (true) {
if (enablePerfMetrics) {
@@ -441,22 +449,19 @@ public final class BytesToBytesMap extends MemoryConsumer {
}
if (longArray.get(pos * 2) == 0) {
// This is a new key.
- loc.with(pos, hashcode, false);
+ loc.with(pos, hash, false);
return;
} else {
long stored = longArray.get(pos * 2 + 1);
- if ((int) (stored) == hashcode) {
+ if ((int) (stored) == hash) {
// Full hash code matches. Let's compare the keys for equality.
- loc.with(pos, hashcode, true);
+ loc.with(pos, hash, true);
if (loc.getKeyLength() == keyLength) {
- final MemoryLocation keyAddress = loc.getKeyAddress();
- final Object storedkeyBase = keyAddress.getBaseObject();
- final long storedkeyOffset = keyAddress.getBaseOffset();
final boolean areEqual = ByteArrayMethods.arrayEquals(
keyBase,
keyOffset,
- storedkeyBase,
- storedkeyOffset,
+ loc.getKeyBase(),
+ loc.getKeyOffset(),
keyLength
);
if (areEqual) {
@@ -484,13 +489,14 @@ public final class BytesToBytesMap extends MemoryConsumer {
private boolean isDefined;
/**
* The hashcode of the most recent key passed to
- * {@link BytesToBytesMap#lookup(Object, long, int)}. Caching this hashcode here allows us to
- * avoid re-hashing the key when storing a value for that key.
+ * {@link BytesToBytesMap#lookup(Object, long, int, int)}. Caching this hashcode here allows us
+ * to avoid re-hashing the key when storing a value for that key.
*/
private int keyHashcode;
- private final MemoryLocation keyMemoryLocation = new MemoryLocation();
- private final MemoryLocation valueMemoryLocation = new MemoryLocation();
+ private Object baseObject; // the base object for key and value
+ private long keyOffset;
private int keyLength;
+ private long valueOffset;
private int valueLength;
/**
@@ -504,18 +510,15 @@ public final class BytesToBytesMap extends MemoryConsumer {
taskMemoryManager.getOffsetInPage(fullKeyAddress));
}
- private void updateAddressesAndSizes(final Object base, final long offset) {
- long position = offset;
- final int totalLength = Platform.getInt(base, position);
- position += 4;
- keyLength = Platform.getInt(base, position);
- position += 4;
+ private void updateAddressesAndSizes(final Object base, long offset) {
+ baseObject = base;
+ final int totalLength = Platform.getInt(base, offset);
+ offset += 4;
+ keyLength = Platform.getInt(base, offset);
+ offset += 4;
+ keyOffset = offset;
+ valueOffset = offset + keyLength;
valueLength = totalLength - keyLength - 4;
-
- keyMemoryLocation.setObjAndOffset(base, position);
-
- position += keyLength;
- valueMemoryLocation.setObjAndOffset(base, position);
}
private Location with(int pos, int keyHashcode, boolean isDefined) {
@@ -543,10 +546,11 @@ public final class BytesToBytesMap extends MemoryConsumer {
private Location with(Object base, long offset, int length) {
this.isDefined = true;
this.memoryPage = null;
+ baseObject = base;
+ keyOffset = offset + 4;
keyLength = Platform.getInt(base, offset);
+ valueOffset = offset + 4 + keyLength;
valueLength = length - 4 - keyLength;
- keyMemoryLocation.setObjAndOffset(base, offset + 4);
- valueMemoryLocation.setObjAndOffset(base, offset + 4 + keyLength);
return this;
}
@@ -566,34 +570,44 @@ public final class BytesToBytesMap extends MemoryConsumer {
}
/**
- * Returns the address of the key defined at this position.
- * This points to the first byte of the key data.
- * Unspecified behavior if the key is not defined.
- * For efficiency reasons, calls to this method always returns the same MemoryLocation object.
+ * Returns the base object for key.
*/
- public MemoryLocation getKeyAddress() {
+ public Object getKeyBase() {
assert (isDefined);
- return keyMemoryLocation;
+ return baseObject;
}
/**
- * Returns the length of the key defined at this position.
- * Unspecified behavior if the key is not defined.
+ * Returns the offset for key.
*/
- public int getKeyLength() {
+ public long getKeyOffset() {
assert (isDefined);
- return keyLength;
+ return keyOffset;
+ }
+
+ /**
+ * Returns the base object for value.
+ */
+ public Object getValueBase() {
+ assert (isDefined);
+ return baseObject;
}
/**
- * Returns the address of the value defined at this position.
- * This points to the first byte of the value data.
+ * Returns the offset for value.
+ */
+ public long getValueOffset() {
+ assert (isDefined);
+ return valueOffset;
+ }
+
+ /**
+ * Returns the length of the key defined at this position.
* Unspecified behavior if the key is not defined.
- * For efficiency reasons, calls to this method always returns the same MemoryLocation object.
*/
- public MemoryLocation getValueAddress() {
+ public int getKeyLength() {
assert (isDefined);
- return valueMemoryLocation;
+ return keyLength;
}
/**
diff --git a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java
index 702ba5469b..d8af2b336d 100644
--- a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java
+++ b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java
@@ -39,14 +39,13 @@ import org.mockito.stubbing.Answer;
import org.apache.spark.SparkConf;
import org.apache.spark.executor.ShuffleWriteMetrics;
-import org.apache.spark.memory.TestMemoryManager;
import org.apache.spark.memory.TaskMemoryManager;
+import org.apache.spark.memory.TestMemoryManager;
import org.apache.spark.network.util.JavaUtils;
import org.apache.spark.serializer.SerializerInstance;
import org.apache.spark.storage.*;
import org.apache.spark.unsafe.Platform;
import org.apache.spark.unsafe.array.ByteArrayMethods;
-import org.apache.spark.unsafe.memory.MemoryLocation;
import org.apache.spark.util.Utils;
import static org.hamcrest.Matchers.greaterThan;
@@ -142,10 +141,9 @@ public abstract class AbstractBytesToBytesMapSuite {
protected abstract boolean useOffHeapMemoryAllocator();
- private static byte[] getByteArray(MemoryLocation loc, int size) {
+ private static byte[] getByteArray(Object base, long offset, int size) {
final byte[] arr = new byte[size];
- Platform.copyMemory(
- loc.getBaseObject(), loc.getBaseOffset(), arr, Platform.BYTE_ARRAY_OFFSET, size);
+ Platform.copyMemory(base, offset, arr, Platform.BYTE_ARRAY_OFFSET, size);
return arr;
}
@@ -163,13 +161,14 @@ public abstract class AbstractBytesToBytesMapSuite {
*/
private static boolean arrayEquals(
byte[] expected,
- MemoryLocation actualAddr,
+ Object base,
+ long offset,
long actualLengthBytes) {
return (actualLengthBytes == expected.length) && ByteArrayMethods.arrayEquals(
expected,
Platform.BYTE_ARRAY_OFFSET,
- actualAddr.getBaseObject(),
- actualAddr.getBaseOffset(),
+ base,
+ offset,
expected.length
);
}
@@ -212,16 +211,20 @@ public abstract class AbstractBytesToBytesMapSuite {
// reflect the result of this store without us having to call lookup() again on the same key.
Assert.assertEquals(recordLengthBytes, loc.getKeyLength());
Assert.assertEquals(recordLengthBytes, loc.getValueLength());
- Assert.assertArrayEquals(keyData, getByteArray(loc.getKeyAddress(), recordLengthBytes));
- Assert.assertArrayEquals(valueData, getByteArray(loc.getValueAddress(), recordLengthBytes));
+ Assert.assertArrayEquals(keyData,
+ getByteArray(loc.getKeyBase(), loc.getKeyOffset(), recordLengthBytes));
+ Assert.assertArrayEquals(valueData,
+ getByteArray(loc.getValueBase(), loc.getValueOffset(), recordLengthBytes));
// After calling lookup() the location should still point to the correct data.
Assert.assertTrue(
map.lookup(keyData, Platform.BYTE_ARRAY_OFFSET, recordLengthBytes).isDefined());
Assert.assertEquals(recordLengthBytes, loc.getKeyLength());
Assert.assertEquals(recordLengthBytes, loc.getValueLength());
- Assert.assertArrayEquals(keyData, getByteArray(loc.getKeyAddress(), recordLengthBytes));
- Assert.assertArrayEquals(valueData, getByteArray(loc.getValueAddress(), recordLengthBytes));
+ Assert.assertArrayEquals(keyData,
+ getByteArray(loc.getKeyBase(), loc.getKeyOffset(), recordLengthBytes));
+ Assert.assertArrayEquals(valueData,
+ getByteArray(loc.getValueBase(), loc.getValueOffset(), recordLengthBytes));
try {
Assert.assertTrue(loc.putNewKey(
@@ -283,15 +286,12 @@ public abstract class AbstractBytesToBytesMapSuite {
while (iter.hasNext()) {
final BytesToBytesMap.Location loc = iter.next();
Assert.assertTrue(loc.isDefined());
- final MemoryLocation keyAddress = loc.getKeyAddress();
- final MemoryLocation valueAddress = loc.getValueAddress();
- final long value = Platform.getLong(
- valueAddress.getBaseObject(), valueAddress.getBaseOffset());
+ final long value = Platform.getLong(loc.getValueBase(), loc.getValueOffset());
final long keyLength = loc.getKeyLength();
if (keyLength == 0) {
Assert.assertTrue("value " + value + " was not divisible by 5", value % 5 == 0);
} else {
- final long key = Platform.getLong(keyAddress.getBaseObject(), keyAddress.getBaseOffset());
+ final long key = Platform.getLong(loc.getKeyBase(), loc.getKeyOffset());
Assert.assertEquals(value, key);
}
valuesSeen.set((int) value);
@@ -365,15 +365,15 @@ public abstract class AbstractBytesToBytesMapSuite {
Assert.assertEquals(KEY_LENGTH, loc.getKeyLength());
Assert.assertEquals(VALUE_LENGTH, loc.getValueLength());
Platform.copyMemory(
- loc.getKeyAddress().getBaseObject(),
- loc.getKeyAddress().getBaseOffset(),
+ loc.getKeyBase(),
+ loc.getKeyOffset(),
key,
Platform.LONG_ARRAY_OFFSET,
KEY_LENGTH
);
Platform.copyMemory(
- loc.getValueAddress().getBaseObject(),
- loc.getValueAddress().getBaseOffset(),
+ loc.getValueBase(),
+ loc.getValueOffset(),
value,
Platform.LONG_ARRAY_OFFSET,
VALUE_LENGTH
@@ -425,8 +425,9 @@ public abstract class AbstractBytesToBytesMapSuite {
Assert.assertTrue(loc.isDefined());
Assert.assertEquals(key.length, loc.getKeyLength());
Assert.assertEquals(value.length, loc.getValueLength());
- Assert.assertTrue(arrayEquals(key, loc.getKeyAddress(), key.length));
- Assert.assertTrue(arrayEquals(value, loc.getValueAddress(), value.length));
+ Assert.assertTrue(arrayEquals(key, loc.getKeyBase(), loc.getKeyOffset(), key.length));
+ Assert.assertTrue(
+ arrayEquals(value, loc.getValueBase(), loc.getValueOffset(), value.length));
}
}
@@ -436,8 +437,10 @@ public abstract class AbstractBytesToBytesMapSuite {
final BytesToBytesMap.Location loc =
map.lookup(key, Platform.BYTE_ARRAY_OFFSET, key.length);
Assert.assertTrue(loc.isDefined());
- Assert.assertTrue(arrayEquals(key, loc.getKeyAddress(), loc.getKeyLength()));
- Assert.assertTrue(arrayEquals(value, loc.getValueAddress(), loc.getValueLength()));
+ Assert.assertTrue(
+ arrayEquals(key, loc.getKeyBase(), loc.getKeyOffset(), loc.getKeyLength()));
+ Assert.assertTrue(
+ arrayEquals(value, loc.getValueBase(), loc.getValueOffset(), loc.getValueLength()));
}
} finally {
map.free();
@@ -476,8 +479,9 @@ public abstract class AbstractBytesToBytesMapSuite {
Assert.assertTrue(loc.isDefined());
Assert.assertEquals(key.length, loc.getKeyLength());
Assert.assertEquals(value.length, loc.getValueLength());
- Assert.assertTrue(arrayEquals(key, loc.getKeyAddress(), key.length));
- Assert.assertTrue(arrayEquals(value, loc.getValueAddress(), value.length));
+ Assert.assertTrue(arrayEquals(key, loc.getKeyBase(), loc.getKeyOffset(), key.length));
+ Assert.assertTrue(
+ arrayEquals(value, loc.getValueBase(), loc.getValueOffset(), value.length));
}
}
for (Map.Entry<ByteBuffer, byte[]> entry : expected.entrySet()) {
@@ -486,8 +490,10 @@ public abstract class AbstractBytesToBytesMapSuite {
final BytesToBytesMap.Location loc =
map.lookup(key, Platform.BYTE_ARRAY_OFFSET, key.length);
Assert.assertTrue(loc.isDefined());
- Assert.assertTrue(arrayEquals(key, loc.getKeyAddress(), loc.getKeyLength()));
- Assert.assertTrue(arrayEquals(value, loc.getValueAddress(), loc.getValueLength()));
+ Assert.assertTrue(
+ arrayEquals(key, loc.getKeyBase(), loc.getKeyOffset(), loc.getKeyLength()));
+ Assert.assertTrue(
+ arrayEquals(value, loc.getValueBase(), loc.getValueOffset(), loc.getValueLength()));
}
} finally {
map.free();
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index 9209094385..133894704b 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -40,6 +40,7 @@ object MimaExcludes {
excludePackage("org.apache.spark.rpc"),
excludePackage("org.spark-project.jetty"),
excludePackage("org.apache.spark.unused"),
+ excludePackage("org.apache.spark.unsafe"),
excludePackage("org.apache.spark.util.collection.unsafe"),
excludePackage("org.apache.spark.sql.catalyst"),
excludePackage("org.apache.spark.sql.execution"),
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
index f4ccadd9c5..28e4f50eee 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
@@ -322,7 +322,6 @@ case class Murmur3Hash(children: Seq[Expression], seed: Int) extends Expression
}
}
-
override def genCode(ctx: CodegenContext, ev: ExprCode): String = {
ev.isNull = "false"
val childrenHash = children.map { child =>
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java
index 6bf9d7bd03..2e84178d69 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java
@@ -121,19 +121,24 @@ public final class UnsafeFixedWidthAggregationMap {
return getAggregationBufferFromUnsafeRow(unsafeGroupingKeyRow);
}
- public UnsafeRow getAggregationBufferFromUnsafeRow(UnsafeRow unsafeGroupingKeyRow) {
+ public UnsafeRow getAggregationBufferFromUnsafeRow(UnsafeRow key) {
+ return getAggregationBufferFromUnsafeRow(key, key.hashCode());
+ }
+
+ public UnsafeRow getAggregationBufferFromUnsafeRow(UnsafeRow key, int hash) {
// Probe our map using the serialized key
final BytesToBytesMap.Location loc = map.lookup(
- unsafeGroupingKeyRow.getBaseObject(),
- unsafeGroupingKeyRow.getBaseOffset(),
- unsafeGroupingKeyRow.getSizeInBytes());
+ key.getBaseObject(),
+ key.getBaseOffset(),
+ key.getSizeInBytes(),
+ hash);
if (!loc.isDefined()) {
// This is the first time that we've seen this grouping key, so we'll insert a copy of the
// empty aggregation buffer into the map:
boolean putSucceeded = loc.putNewKey(
- unsafeGroupingKeyRow.getBaseObject(),
- unsafeGroupingKeyRow.getBaseOffset(),
- unsafeGroupingKeyRow.getSizeInBytes(),
+ key.getBaseObject(),
+ key.getBaseOffset(),
+ key.getSizeInBytes(),
emptyAggregationBuffer,
Platform.BYTE_ARRAY_OFFSET,
emptyAggregationBuffer.length
@@ -144,10 +149,9 @@ public final class UnsafeFixedWidthAggregationMap {
}
// Reset the pointer to point to the value that we just stored or looked up:
- final MemoryLocation address = loc.getValueAddress();
currentAggregationBuffer.pointTo(
- address.getBaseObject(),
- address.getBaseOffset(),
+ loc.getValueBase(),
+ loc.getValueOffset(),
loc.getValueLength()
);
return currentAggregationBuffer;
@@ -172,16 +176,14 @@ public final class UnsafeFixedWidthAggregationMap {
public boolean next() {
if (mapLocationIterator.hasNext()) {
final BytesToBytesMap.Location loc = mapLocationIterator.next();
- final MemoryLocation keyAddress = loc.getKeyAddress();
- final MemoryLocation valueAddress = loc.getValueAddress();
key.pointTo(
- keyAddress.getBaseObject(),
- keyAddress.getBaseOffset(),
+ loc.getKeyBase(),
+ loc.getKeyOffset(),
loc.getKeyLength()
);
value.pointTo(
- valueAddress.getBaseObject(),
- valueAddress.getBaseOffset(),
+ loc.getValueBase(),
+ loc.getValueOffset(),
loc.getValueLength()
);
return true;
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java
index 0da26bf376..51e10b0e93 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java
@@ -97,8 +97,8 @@ public final class UnsafeKVExternalSorter {
UnsafeRow row = new UnsafeRow(numKeyFields);
while (iter.hasNext()) {
final BytesToBytesMap.Location loc = iter.next();
- final Object baseObject = loc.getKeyAddress().getBaseObject();
- final long baseOffset = loc.getKeyAddress().getBaseOffset();
+ final Object baseObject = loc.getKeyBase();
+ final long baseOffset = loc.getKeyOffset();
// Get encoded memory address
// baseObject + baseOffset point to the beginning of the key data in the map, but that
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
index 4ca2d85406..b200239c94 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
@@ -366,11 +366,7 @@ private[sql] case class CollapseCodegenStages(sqlContext: SQLContext) extends Ru
def apply(plan: SparkPlan): SparkPlan = {
if (sqlContext.conf.wholeStageEnabled) {
plan.transform {
- case plan: CodegenSupport if supportCodegen(plan) &&
- // Whole stage codegen is only useful when there are at least two levels of operators that
- // support it (save at least one projection/iterator).
- (Utils.isTesting || plan.children.exists(supportCodegen)) =>
-
+ case plan: CodegenSupport if supportCodegen(plan) =>
var inputs = ArrayBuffer[SparkPlan]()
val combined = plan.transform {
// The build side can't be compiled together
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
index 9d9f14f2dd..340b8f78e5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
@@ -501,6 +501,11 @@ case class TungstenAggregate(
}
}
+ // generate hash code for key
+ val hashExpr = Murmur3Hash(groupingExpressions, 42)
+ ctx.currentVars = input
+ val hashEval = BindReferences.bindReference(hashExpr, child.output).gen(ctx)
+
val inputAttr = bufferAttributes ++ child.output
ctx.currentVars = new Array[ExprCode](bufferAttributes.length) ++ input
ctx.INPUT_ROW = buffer
@@ -526,10 +531,11 @@ case class TungstenAggregate(
s"""
// generate grouping key
${keyCode.code.trim}
+ ${hashEval.code.trim}
UnsafeRow $buffer = null;
if ($checkFallback) {
// try to get the buffer from hash map
- $buffer = $hashMapTerm.getAggregationBufferFromUnsafeRow($key);
+ $buffer = $hashMapTerm.getAggregationBufferFromUnsafeRow($key, ${hashEval.value});
}
if ($buffer == null) {
if ($sorterTerm == null) {
@@ -540,7 +546,7 @@ case class TungstenAggregate(
$resetCoulter
// the hash map had be spilled, it should have enough memory now,
// try to allocate buffer again.
- $buffer = $hashMapTerm.getAggregationBufferFromUnsafeRow($key);
+ $buffer = $hashMapTerm.getAggregationBufferFromUnsafeRow($key, ${hashEval.value});
if ($buffer == null) {
// failed to allocate the first page
throw new OutOfMemoryError("No enough memory for aggregation");
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
index c94d6c195b..eb6930a14f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
@@ -277,13 +277,13 @@ private[joins] final class UnsafeHashedRelation(
val map = binaryMap // avoid the compiler error
val loc = new map.Location // this could be allocated in stack
binaryMap.safeLookup(unsafeKey.getBaseObject, unsafeKey.getBaseOffset,
- unsafeKey.getSizeInBytes, loc)
+ unsafeKey.getSizeInBytes, loc, unsafeKey.hashCode())
if (loc.isDefined) {
val buffer = CompactBuffer[UnsafeRow]()
- val base = loc.getValueAddress.getBaseObject
- var offset = loc.getValueAddress.getBaseOffset
- val last = loc.getValueAddress.getBaseOffset + loc.getValueLength
+ val base = loc.getValueBase
+ var offset = loc.getValueOffset
+ val last = offset + loc.getValueLength
while (offset < last) {
val numFields = Platform.getInt(base, offset)
val sizeInBytes = Platform.getInt(base, offset + 4)
@@ -311,12 +311,11 @@ private[joins] final class UnsafeHashedRelation(
out.writeInt(binaryMap.numElements())
var buffer = new Array[Byte](64)
- def write(addr: MemoryLocation, length: Int): Unit = {
+ def write(base: Object, offset: Long, length: Int): Unit = {
if (buffer.length < length) {
buffer = new Array[Byte](length)
}
- Platform.copyMemory(addr.getBaseObject, addr.getBaseOffset,
- buffer, Platform.BYTE_ARRAY_OFFSET, length)
+ Platform.copyMemory(base, offset, buffer, Platform.BYTE_ARRAY_OFFSET, length)
out.write(buffer, 0, length)
}
@@ -326,8 +325,8 @@ private[joins] final class UnsafeHashedRelation(
// [key size] [values size] [key bytes] [values bytes]
out.writeInt(loc.getKeyLength)
out.writeInt(loc.getValueLength)
- write(loc.getKeyAddress, loc.getKeyLength)
- write(loc.getValueAddress, loc.getValueLength)
+ write(loc.getKeyBase, loc.getKeyOffset, loc.getKeyLength)
+ write(loc.getValueBase, loc.getValueOffset, loc.getValueLength)
}
} else {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala
index f015d29704..dc6c647a4a 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala
@@ -114,11 +114,11 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite {
}
/*
- Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz
- Aggregate w keys: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
- -------------------------------------------------------------------------------------------
- Aggregate w keys codegen=false 2402 / 2551 8.0 125.0 1.0X
- Aggregate w keys codegen=true 1620 / 1670 12.0 83.3 1.5X
+ Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz
+ Aggregate w keys: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
+ -------------------------------------------------------------------------------------------
+ Aggregate w keys codegen=false 2429 / 2644 8.6 115.8 1.0X
+ Aggregate w keys codegen=true 1535 / 1571 13.7 73.2 1.6X
*/
}
@@ -165,21 +165,51 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite {
benchmark.addCase("hash") { iter =>
var i = 0
val keyBytes = new Array[Byte](16)
- val valueBytes = new Array[Byte](16)
val key = new UnsafeRow(1)
key.pointTo(keyBytes, Platform.BYTE_ARRAY_OFFSET, 16)
- val value = new UnsafeRow(2)
- value.pointTo(valueBytes, Platform.BYTE_ARRAY_OFFSET, 16)
var s = 0
while (i < N) {
key.setInt(0, i % 1000)
val h = Murmur3_x86_32.hashUnsafeWords(
- key.getBaseObject, key.getBaseOffset, key.getSizeInBytes, 0)
+ key.getBaseObject, key.getBaseOffset, key.getSizeInBytes, 42)
+ s += h
+ i += 1
+ }
+ }
+
+ benchmark.addCase("fast hash") { iter =>
+ var i = 0
+ val keyBytes = new Array[Byte](16)
+ val key = new UnsafeRow(1)
+ key.pointTo(keyBytes, Platform.BYTE_ARRAY_OFFSET, 16)
+ var s = 0
+ while (i < N) {
+ key.setInt(0, i % 1000)
+ val h = Murmur3_x86_32.hashLong(i % 1000, 42)
s += h
i += 1
}
}
+ benchmark.addCase("arrayEqual") { iter =>
+ var i = 0
+ val keyBytes = new Array[Byte](16)
+ val valueBytes = new Array[Byte](16)
+ val key = new UnsafeRow(1)
+ key.pointTo(keyBytes, Platform.BYTE_ARRAY_OFFSET, 16)
+ val value = new UnsafeRow(1)
+ value.pointTo(valueBytes, Platform.BYTE_ARRAY_OFFSET, 16)
+ value.setInt(0, 555)
+ var s = 0
+ while (i < N) {
+ key.setInt(0, i % 1000)
+ if (key.equals(value)) {
+ s += 1
+ }
+ i += 1
+ }
+ }
+
Seq("off", "on").foreach { heap =>
benchmark.addCase(s"BytesToBytesMap ($heap Heap)") { iter =>
val taskMemoryManager = new TaskMemoryManager(
@@ -195,15 +225,15 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite {
val valueBytes = new Array[Byte](16)
val key = new UnsafeRow(1)
key.pointTo(keyBytes, Platform.BYTE_ARRAY_OFFSET, 16)
- val value = new UnsafeRow(2)
+ val value = new UnsafeRow(1)
value.pointTo(valueBytes, Platform.BYTE_ARRAY_OFFSET, 16)
var i = 0
while (i < N) {
key.setInt(0, i % 65536)
- val loc = map.lookup(key.getBaseObject, key.getBaseOffset, key.getSizeInBytes)
+ val loc = map.lookup(key.getBaseObject, key.getBaseOffset, key.getSizeInBytes,
+ Murmur3_x86_32.hashLong(i % 65536, 42))
if (loc.isDefined) {
- value.pointTo(loc.getValueAddress.getBaseObject, loc.getValueAddress.getBaseOffset,
- loc.getValueLength)
+ value.pointTo(loc.getValueBase, loc.getValueOffset, loc.getValueLength)
value.setInt(0, value.getInt(0) + 1)
i += 1
} else {
@@ -218,9 +248,11 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite {
Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz
BytesToBytesMap: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
-------------------------------------------------------------------------------------------
- hash 628 / 661 83.0 12.0 1.0X
- BytesToBytesMap (off Heap) 3292 / 3408 15.0 66.7 0.2X
- BytesToBytesMap (on Heap) 3349 / 4267 15.0 66.7 0.2X
+ hash 651 / 678 80.0 12.5 1.0X
+ fast hash 336 / 343 155.9 6.4 1.9X
+ arrayEqual 417 / 428 125.0 8.0 1.6X
+ BytesToBytesMap (off Heap) 2594 / 2664 20.2 49.5 0.2X
+ BytesToBytesMap (on Heap) 2693 / 2989 19.5 51.4 0.2X
*/
benchmark.run()
}