aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorReynold Xin <rxin@databricks.com>2015-07-31 23:55:16 -0700
committerReynold Xin <rxin@databricks.com>2015-07-31 23:55:16 -0700
commitd90f2cf7a2a1d1e69f9ab385f35f62d4091b5302 (patch)
tree94dff8456047924b32f7295dca1e7f47702d5e16 /sql
parent67ad4e21fc68336b0ad6f9a363fb5ebb51f592bf (diff)
downloadspark-d90f2cf7a2a1d1e69f9ab385f35f62d4091b5302.tar.gz
spark-d90f2cf7a2a1d1e69f9ab385f35f62d4091b5302.tar.bz2
spark-d90f2cf7a2a1d1e69f9ab385f35f62d4091b5302.zip
[SPARK-9517][SQL] BytesToBytesMap should encode data the same way as UnsafeExternalSorter
BytesToBytesMap current encodes key/value data in the following format: ``` 8B key length, key data, 8B value length, value data ``` UnsafeExternalSorter, on the other hand, encodes data this way: ``` 4B record length, data ``` As a result, we cannot pass records encoded by BytesToBytesMap directly into UnsafeExternalSorter for sorting. However, if we rearrange data slightly, we can then pass the key/value records directly into UnsafeExternalSorter: ``` 4B key+value length, 4B key length, key data, value data ``` Author: Reynold Xin <rxin@databricks.com> Closes #7845 from rxin/kvsort-rebase and squashes the following commits: 5716b59 [Reynold Xin] Fixed test. 2e62ccb [Reynold Xin] Updated BytesToBytesMap's data encoding to put the key first. a51b641 [Reynold Xin] Added a KV sorter interface.
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeKeyValueSorter.java30
-rw-r--r--sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java73
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala27
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala27
4 files changed, 97 insertions, 60 deletions
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeKeyValueSorter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeKeyValueSorter.java
new file mode 100644
index 0000000000..59c774da74
--- /dev/null
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeKeyValueSorter.java
@@ -0,0 +1,30 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution;
+
+import java.io.IOException;
+
+import org.apache.spark.sql.catalyst.expressions.UnsafeRow;
+import org.apache.spark.unsafe.KVIterator;
+
+public abstract class UnsafeKeyValueSorter {
+
+ public abstract void insert(UnsafeRow key, UnsafeRow value);
+
+ public abstract KVIterator<UnsafeRow, UnsafeRow> sort() throws IOException;
+}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java
index 08a98cdd94..c18b6dea6b 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java
@@ -17,9 +17,6 @@
package org.apache.spark.sql.execution;
-import java.io.IOException;
-import java.util.Iterator;
-
import org.apache.spark.shuffle.ShuffleMemoryManager;
import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.catalyst.expressions.UnsafeProjection;
@@ -28,6 +25,7 @@ import org.apache.spark.sql.types.Decimal;
import org.apache.spark.sql.types.DecimalType;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
+import org.apache.spark.unsafe.KVIterator;
import org.apache.spark.unsafe.PlatformDependent;
import org.apache.spark.unsafe.map.BytesToBytesMap;
import org.apache.spark.unsafe.memory.MemoryLocation;
@@ -157,53 +155,54 @@ public final class UnsafeFixedWidthAggregationMap {
}
/**
- * Mutable pair object returned by {@link UnsafeFixedWidthAggregationMap#iterator()}.
- */
- public static class MapEntry {
- private MapEntry() { };
- public final UnsafeRow key = new UnsafeRow();
- public final UnsafeRow value = new UnsafeRow();
- }
-
- /**
* Returns an iterator over the keys and values in this map.
*
* For efficiency, each call returns the same object.
*/
- public Iterator<MapEntry> iterator() {
- return new Iterator<MapEntry>() {
+ public KVIterator<UnsafeRow, UnsafeRow> iterator() {
+ return new KVIterator<UnsafeRow, UnsafeRow>() {
+
+ private final BytesToBytesMap.BytesToBytesMapIterator mapLocationIterator = map.iterator();
+ private final UnsafeRow key = new UnsafeRow();
+ private final UnsafeRow value = new UnsafeRow();
- private final MapEntry entry = new MapEntry();
- private final Iterator<BytesToBytesMap.Location> mapLocationIterator = map.iterator();
+ @Override
+ public boolean next() {
+ if (mapLocationIterator.hasNext()) {
+ final BytesToBytesMap.Location loc = mapLocationIterator.next();
+ final MemoryLocation keyAddress = loc.getKeyAddress();
+ final MemoryLocation valueAddress = loc.getValueAddress();
+ key.pointTo(
+ keyAddress.getBaseObject(),
+ keyAddress.getBaseOffset(),
+ groupingKeySchema.length(),
+ loc.getKeyLength()
+ );
+ value.pointTo(
+ valueAddress.getBaseObject(),
+ valueAddress.getBaseOffset(),
+ aggregationBufferSchema.length(),
+ loc.getValueLength()
+ );
+ return true;
+ } else {
+ return false;
+ }
+ }
@Override
- public boolean hasNext() {
- return mapLocationIterator.hasNext();
+ public UnsafeRow getKey() {
+ return key;
}
@Override
- public MapEntry next() {
- final BytesToBytesMap.Location loc = mapLocationIterator.next();
- final MemoryLocation keyAddress = loc.getKeyAddress();
- final MemoryLocation valueAddress = loc.getValueAddress();
- entry.key.pointTo(
- keyAddress.getBaseObject(),
- keyAddress.getBaseOffset(),
- groupingKeySchema.length(),
- loc.getKeyLength()
- );
- entry.value.pointTo(
- valueAddress.getBaseObject(),
- valueAddress.getBaseOffset(),
- aggregationBufferSchema.length(),
- loc.getValueLength()
- );
- return entry;
+ public UnsafeRow getValue() {
+ return value;
}
@Override
- public void remove() {
- throw new UnsupportedOperationException();
+ public void close() {
+ // Do nothing.
}
};
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala
index 469de6ca8e..cd87b8deba 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala
@@ -287,21 +287,26 @@ case class GeneratedAggregate(
new Iterator[InternalRow] {
private[this] val mapIterator = aggregationMap.iterator()
private[this] val resultProjection = resultProjectionBuilder()
+ private[this] var _hasNext = mapIterator.next()
- def hasNext: Boolean = mapIterator.hasNext
+ def hasNext: Boolean = _hasNext
def next(): InternalRow = {
- val entry = mapIterator.next()
- val result = resultProjection(joinedRow(entry.key, entry.value))
- if (hasNext) {
- result
+ if (_hasNext) {
+ val result = resultProjection(joinedRow(mapIterator.getKey, mapIterator.getValue))
+ _hasNext = mapIterator.next()
+ if (_hasNext) {
+ result
+ } else {
+ // This is the last element in the iterator, so let's free the buffer. Before we do,
+ // though, we need to make a defensive copy of the result so that we don't return an
+ // object that might contain dangling pointers to the freed memory
+ val resultCopy = result.copy()
+ aggregationMap.free()
+ resultCopy
+ }
} else {
- // This is the last element in the iterator, so let's free the buffer. Before we do,
- // though, we need to make a defensive copy of the result so that we don't return an
- // object that might contain dangling pointers to the freed memory
- val resultCopy = result.copy()
- aggregationMap.free()
- resultCopy
+ throw new java.util.NoSuchElementException
}
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala
index 79fd52dacd..6a2c51ca88 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.execution
import org.scalatest.{BeforeAndAfterEach, Matchers}
import scala.collection.JavaConverters._
+import scala.collection.mutable
import scala.util.Random
import org.apache.spark.SparkFunSuite
@@ -52,7 +53,7 @@ class UnsafeFixedWidthAggregationMapSuite
override def afterEach(): Unit = {
if (taskMemoryManager != null) {
- val leakedShuffleMemory = shuffleMemoryManager.getMemoryConsumptionForThisTask
+ val leakedShuffleMemory = shuffleMemoryManager.getMemoryConsumptionForThisTask()
assert(taskMemoryManager.cleanUpAllAllocatedMemory() === 0)
assert(leakedShuffleMemory === 0)
taskMemoryManager = null
@@ -80,7 +81,7 @@ class UnsafeFixedWidthAggregationMapSuite
PAGE_SIZE_BYTES,
false // disable perf metrics
)
- assert(!map.iterator().hasNext)
+ assert(!map.iterator().next())
map.free()
}
@@ -100,13 +101,13 @@ class UnsafeFixedWidthAggregationMapSuite
// Looking up a key stores a zero-entry in the map (like Python Counters or DefaultDicts)
assert(map.getAggregationBuffer(groupKey) != null)
val iter = map.iterator()
- val entry = iter.next()
- assert(!iter.hasNext)
- entry.key.getString(0) should be ("cats")
- entry.value.getInt(0) should be (0)
+ assert(iter.next())
+ iter.getKey.getString(0) should be ("cats")
+ iter.getValue.getInt(0) should be (0)
+ assert(!iter.next())
// Modifications to rows retrieved from the map should update the values in the map
- entry.value.setInt(0, 42)
+ iter.getValue.setInt(0, 42)
map.getAggregationBuffer(groupKey).getInt(0) should be (42)
map.free()
@@ -128,12 +129,14 @@ class UnsafeFixedWidthAggregationMapSuite
groupKeys.foreach { keyString =>
assert(map.getAggregationBuffer(InternalRow(UTF8String.fromString(keyString))) != null)
}
- val seenKeys: Set[String] = map.iterator().asScala.map { entry =>
- entry.key.getString(0)
- }.toSet
- seenKeys.size should be (groupKeys.size)
- seenKeys should be (groupKeys)
+ val seenKeys = new mutable.HashSet[String]
+ val iter = map.iterator()
+ while (iter.next()) {
+ seenKeys += iter.getKey.getString(0)
+ }
+ assert(seenKeys.size === groupKeys.size)
+ assert(seenKeys === groupKeys)
map.free()
}