aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2016-02-09 16:41:21 -0800
committerDavies Liu <davies.liu@gmail.com>2016-02-09 16:41:21 -0800
commit0e5ebac3c1f1ff58f938be59c7c9e604977d269c (patch)
tree6a572ee7aa4f79f285a78a083dc69eef87627be0 /sql
parentfae830d15846f7ffdfe49eeb45e175a3cdd2c670 (diff)
downloadspark-0e5ebac3c1f1ff58f938be59c7c9e604977d269c.tar.gz
spark-0e5ebac3c1f1ff58f938be59c7c9e604977d269c.tar.bz2
spark-0e5ebac3c1f1ff58f938be59c7c9e604977d269c.zip
[SPARK-12950] [SQL] Improve lookup of BytesToBytesMap in aggregate
This PR improve the lookup of BytesToBytesMap by: 1. Generate code for calculate the hash code of grouping keys. 2. Do not use MemoryLocation, fetch the baseObject and offset for key and value directly (remove the indirection). Author: Davies Liu <davies@databricks.com> Closes #11010 from davies/gen_map.
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala1
-rw-r--r--sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java34
-rw-r--r--sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala6
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala10
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala17
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala64
7 files changed, 85 insertions, 51 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
index f4ccadd9c5..28e4f50eee 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
@@ -322,7 +322,6 @@ case class Murmur3Hash(children: Seq[Expression], seed: Int) extends Expression
}
}
-
override def genCode(ctx: CodegenContext, ev: ExprCode): String = {
ev.isNull = "false"
val childrenHash = children.map { child =>
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java
index 6bf9d7bd03..2e84178d69 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java
@@ -121,19 +121,24 @@ public final class UnsafeFixedWidthAggregationMap {
return getAggregationBufferFromUnsafeRow(unsafeGroupingKeyRow);
}
- public UnsafeRow getAggregationBufferFromUnsafeRow(UnsafeRow unsafeGroupingKeyRow) {
+ public UnsafeRow getAggregationBufferFromUnsafeRow(UnsafeRow key) {
+ return getAggregationBufferFromUnsafeRow(key, key.hashCode());
+ }
+
+ public UnsafeRow getAggregationBufferFromUnsafeRow(UnsafeRow key, int hash) {
// Probe our map using the serialized key
final BytesToBytesMap.Location loc = map.lookup(
- unsafeGroupingKeyRow.getBaseObject(),
- unsafeGroupingKeyRow.getBaseOffset(),
- unsafeGroupingKeyRow.getSizeInBytes());
+ key.getBaseObject(),
+ key.getBaseOffset(),
+ key.getSizeInBytes(),
+ hash);
if (!loc.isDefined()) {
// This is the first time that we've seen this grouping key, so we'll insert a copy of the
// empty aggregation buffer into the map:
boolean putSucceeded = loc.putNewKey(
- unsafeGroupingKeyRow.getBaseObject(),
- unsafeGroupingKeyRow.getBaseOffset(),
- unsafeGroupingKeyRow.getSizeInBytes(),
+ key.getBaseObject(),
+ key.getBaseOffset(),
+ key.getSizeInBytes(),
emptyAggregationBuffer,
Platform.BYTE_ARRAY_OFFSET,
emptyAggregationBuffer.length
@@ -144,10 +149,9 @@ public final class UnsafeFixedWidthAggregationMap {
}
// Reset the pointer to point to the value that we just stored or looked up:
- final MemoryLocation address = loc.getValueAddress();
currentAggregationBuffer.pointTo(
- address.getBaseObject(),
- address.getBaseOffset(),
+ loc.getValueBase(),
+ loc.getValueOffset(),
loc.getValueLength()
);
return currentAggregationBuffer;
@@ -172,16 +176,14 @@ public final class UnsafeFixedWidthAggregationMap {
public boolean next() {
if (mapLocationIterator.hasNext()) {
final BytesToBytesMap.Location loc = mapLocationIterator.next();
- final MemoryLocation keyAddress = loc.getKeyAddress();
- final MemoryLocation valueAddress = loc.getValueAddress();
key.pointTo(
- keyAddress.getBaseObject(),
- keyAddress.getBaseOffset(),
+ loc.getKeyBase(),
+ loc.getKeyOffset(),
loc.getKeyLength()
);
value.pointTo(
- valueAddress.getBaseObject(),
- valueAddress.getBaseOffset(),
+ loc.getValueBase(),
+ loc.getValueOffset(),
loc.getValueLength()
);
return true;
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java
index 0da26bf376..51e10b0e93 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java
@@ -97,8 +97,8 @@ public final class UnsafeKVExternalSorter {
UnsafeRow row = new UnsafeRow(numKeyFields);
while (iter.hasNext()) {
final BytesToBytesMap.Location loc = iter.next();
- final Object baseObject = loc.getKeyAddress().getBaseObject();
- final long baseOffset = loc.getKeyAddress().getBaseOffset();
+ final Object baseObject = loc.getKeyBase();
+ final long baseOffset = loc.getKeyOffset();
// Get encoded memory address
// baseObject + baseOffset point to the beginning of the key data in the map, but that
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
index 4ca2d85406..b200239c94 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
@@ -366,11 +366,7 @@ private[sql] case class CollapseCodegenStages(sqlContext: SQLContext) extends Ru
def apply(plan: SparkPlan): SparkPlan = {
if (sqlContext.conf.wholeStageEnabled) {
plan.transform {
- case plan: CodegenSupport if supportCodegen(plan) &&
- // Whole stage codegen is only useful when there are at least two levels of operators that
- // support it (save at least one projection/iterator).
- (Utils.isTesting || plan.children.exists(supportCodegen)) =>
-
+ case plan: CodegenSupport if supportCodegen(plan) =>
var inputs = ArrayBuffer[SparkPlan]()
val combined = plan.transform {
// The build side can't be compiled together
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
index 9d9f14f2dd..340b8f78e5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
@@ -501,6 +501,11 @@ case class TungstenAggregate(
}
}
+ // generate hash code for key
+ val hashExpr = Murmur3Hash(groupingExpressions, 42)
+ ctx.currentVars = input
+ val hashEval = BindReferences.bindReference(hashExpr, child.output).gen(ctx)
+
val inputAttr = bufferAttributes ++ child.output
ctx.currentVars = new Array[ExprCode](bufferAttributes.length) ++ input
ctx.INPUT_ROW = buffer
@@ -526,10 +531,11 @@ case class TungstenAggregate(
s"""
// generate grouping key
${keyCode.code.trim}
+ ${hashEval.code.trim}
UnsafeRow $buffer = null;
if ($checkFallback) {
// try to get the buffer from hash map
- $buffer = $hashMapTerm.getAggregationBufferFromUnsafeRow($key);
+ $buffer = $hashMapTerm.getAggregationBufferFromUnsafeRow($key, ${hashEval.value});
}
if ($buffer == null) {
if ($sorterTerm == null) {
@@ -540,7 +546,7 @@ case class TungstenAggregate(
$resetCoulter
// the hash map had be spilled, it should have enough memory now,
// try to allocate buffer again.
- $buffer = $hashMapTerm.getAggregationBufferFromUnsafeRow($key);
+ $buffer = $hashMapTerm.getAggregationBufferFromUnsafeRow($key, ${hashEval.value});
if ($buffer == null) {
// failed to allocate the first page
throw new OutOfMemoryError("No enough memory for aggregation");
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
index c94d6c195b..eb6930a14f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
@@ -277,13 +277,13 @@ private[joins] final class UnsafeHashedRelation(
val map = binaryMap // avoid the compiler error
val loc = new map.Location // this could be allocated in stack
binaryMap.safeLookup(unsafeKey.getBaseObject, unsafeKey.getBaseOffset,
- unsafeKey.getSizeInBytes, loc)
+ unsafeKey.getSizeInBytes, loc, unsafeKey.hashCode())
if (loc.isDefined) {
val buffer = CompactBuffer[UnsafeRow]()
- val base = loc.getValueAddress.getBaseObject
- var offset = loc.getValueAddress.getBaseOffset
- val last = loc.getValueAddress.getBaseOffset + loc.getValueLength
+ val base = loc.getValueBase
+ var offset = loc.getValueOffset
+ val last = offset + loc.getValueLength
while (offset < last) {
val numFields = Platform.getInt(base, offset)
val sizeInBytes = Platform.getInt(base, offset + 4)
@@ -311,12 +311,11 @@ private[joins] final class UnsafeHashedRelation(
out.writeInt(binaryMap.numElements())
var buffer = new Array[Byte](64)
- def write(addr: MemoryLocation, length: Int): Unit = {
+ def write(base: Object, offset: Long, length: Int): Unit = {
if (buffer.length < length) {
buffer = new Array[Byte](length)
}
- Platform.copyMemory(addr.getBaseObject, addr.getBaseOffset,
- buffer, Platform.BYTE_ARRAY_OFFSET, length)
+ Platform.copyMemory(base, offset, buffer, Platform.BYTE_ARRAY_OFFSET, length)
out.write(buffer, 0, length)
}
@@ -326,8 +325,8 @@ private[joins] final class UnsafeHashedRelation(
// [key size] [values size] [key bytes] [values bytes]
out.writeInt(loc.getKeyLength)
out.writeInt(loc.getValueLength)
- write(loc.getKeyAddress, loc.getKeyLength)
- write(loc.getValueAddress, loc.getValueLength)
+ write(loc.getKeyBase, loc.getKeyOffset, loc.getKeyLength)
+ write(loc.getValueBase, loc.getValueOffset, loc.getValueLength)
}
} else {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala
index f015d29704..dc6c647a4a 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala
@@ -114,11 +114,11 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite {
}
/*
- Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz
- Aggregate w keys: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
- -------------------------------------------------------------------------------------------
- Aggregate w keys codegen=false 2402 / 2551 8.0 125.0 1.0X
- Aggregate w keys codegen=true 1620 / 1670 12.0 83.3 1.5X
+ Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz
+ Aggregate w keys: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
+ -------------------------------------------------------------------------------------------
+ Aggregate w keys codegen=false 2429 / 2644 8.6 115.8 1.0X
+ Aggregate w keys codegen=true 1535 / 1571 13.7 73.2 1.6X
*/
}
@@ -165,21 +165,51 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite {
benchmark.addCase("hash") { iter =>
var i = 0
val keyBytes = new Array[Byte](16)
- val valueBytes = new Array[Byte](16)
val key = new UnsafeRow(1)
key.pointTo(keyBytes, Platform.BYTE_ARRAY_OFFSET, 16)
- val value = new UnsafeRow(2)
- value.pointTo(valueBytes, Platform.BYTE_ARRAY_OFFSET, 16)
var s = 0
while (i < N) {
key.setInt(0, i % 1000)
val h = Murmur3_x86_32.hashUnsafeWords(
- key.getBaseObject, key.getBaseOffset, key.getSizeInBytes, 0)
+ key.getBaseObject, key.getBaseOffset, key.getSizeInBytes, 42)
+ s += h
+ i += 1
+ }
+ }
+
+ benchmark.addCase("fast hash") { iter =>
+ var i = 0
+ val keyBytes = new Array[Byte](16)
+ val key = new UnsafeRow(1)
+ key.pointTo(keyBytes, Platform.BYTE_ARRAY_OFFSET, 16)
+ var s = 0
+ while (i < N) {
+ key.setInt(0, i % 1000)
+ val h = Murmur3_x86_32.hashLong(i % 1000, 42)
s += h
i += 1
}
}
+ benchmark.addCase("arrayEqual") { iter =>
+ var i = 0
+ val keyBytes = new Array[Byte](16)
+ val valueBytes = new Array[Byte](16)
+ val key = new UnsafeRow(1)
+ key.pointTo(keyBytes, Platform.BYTE_ARRAY_OFFSET, 16)
+ val value = new UnsafeRow(1)
+ value.pointTo(valueBytes, Platform.BYTE_ARRAY_OFFSET, 16)
+ value.setInt(0, 555)
+ var s = 0
+ while (i < N) {
+ key.setInt(0, i % 1000)
+ if (key.equals(value)) {
+ s += 1
+ }
+ i += 1
+ }
+ }
+
Seq("off", "on").foreach { heap =>
benchmark.addCase(s"BytesToBytesMap ($heap Heap)") { iter =>
val taskMemoryManager = new TaskMemoryManager(
@@ -195,15 +225,15 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite {
val valueBytes = new Array[Byte](16)
val key = new UnsafeRow(1)
key.pointTo(keyBytes, Platform.BYTE_ARRAY_OFFSET, 16)
- val value = new UnsafeRow(2)
+ val value = new UnsafeRow(1)
value.pointTo(valueBytes, Platform.BYTE_ARRAY_OFFSET, 16)
var i = 0
while (i < N) {
key.setInt(0, i % 65536)
- val loc = map.lookup(key.getBaseObject, key.getBaseOffset, key.getSizeInBytes)
+ val loc = map.lookup(key.getBaseObject, key.getBaseOffset, key.getSizeInBytes,
+ Murmur3_x86_32.hashLong(i % 65536, 42))
if (loc.isDefined) {
- value.pointTo(loc.getValueAddress.getBaseObject, loc.getValueAddress.getBaseOffset,
- loc.getValueLength)
+ value.pointTo(loc.getValueBase, loc.getValueOffset, loc.getValueLength)
value.setInt(0, value.getInt(0) + 1)
i += 1
} else {
@@ -218,9 +248,11 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite {
Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz
BytesToBytesMap: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
-------------------------------------------------------------------------------------------
- hash 628 / 661 83.0 12.0 1.0X
- BytesToBytesMap (off Heap) 3292 / 3408 15.0 66.7 0.2X
- BytesToBytesMap (on Heap) 3349 / 4267 15.0 66.7 0.2X
+ hash 651 / 678 80.0 12.5 1.0X
+ fast hash 336 / 343 155.9 6.4 1.9X
+ arrayEqual 417 / 428 125.0 8.0 1.6X
+ BytesToBytesMap (off Heap) 2594 / 2664 20.2 49.5 0.2X
+ BytesToBytesMap (on Heap) 2693 / 2989 19.5 51.4 0.2X
*/
benchmark.run()
}