diff options
Diffstat (limited to 'core/src/main/scala/org/apache/spark/util/collection/AppendOnlyMap.scala')
-rw-r--r-- | core/src/main/scala/org/apache/spark/util/collection/AppendOnlyMap.scala | 297 |
1 files changed, 297 insertions, 0 deletions
diff --git a/core/src/main/scala/org/apache/spark/util/collection/AppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/AppendOnlyMap.scala new file mode 100644 index 0000000000..d98c7aa3d7 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/collection/AppendOnlyMap.scala @@ -0,0 +1,297 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection + +import java.util.{Arrays, Comparator} + +/** + * A simple open hash table optimized for the append-only use case, where keys + * are never removed, but the value for each key may be changed. + * + * This implementation uses quadratic probing with a power-of-2 hash table + * size, which is guaranteed to explore all spaces for each key (see + * http://en.wikipedia.org/wiki/Quadratic_probing). + * + * TODO: Cache the hash values of each key? java.util.HashMap does that. + */ +private[spark] +class AppendOnlyMap[K, V](initialCapacity: Int = 64) extends Iterable[(K, + V)] with Serializable { + require(initialCapacity <= (1 << 29), "Can't make capacity bigger than 2^29 elements") + require(initialCapacity >= 1, "Invalid initial capacity") + + private var capacity = nextPowerOf2(initialCapacity) + private var mask = capacity - 1 + private var curSize = 0 + private var growThreshold = (LOAD_FACTOR * capacity).toInt + + // Holds keys and values in the same array for memory locality; specifically, the order of + // elements is key0, value0, key1, value1, key2, value2, etc. + private var data = new Array[AnyRef](2 * capacity) + + // Treat the null key differently so we can use nulls in "data" to represent empty items. + private var haveNullValue = false + private var nullValue: V = null.asInstanceOf[V] + + // Triggered by destructiveSortedIterator; the underlying data array may no longer be used + private var destroyed = false + private val destructionMessage = "Map state is invalid from destructive sorting!" + + private val LOAD_FACTOR = 0.7 + + /** Get the value for a given key */ + def apply(key: K): V = { + assert(!destroyed, destructionMessage) + val k = key.asInstanceOf[AnyRef] + if (k.eq(null)) { + return nullValue + } + var pos = rehash(k.hashCode) & mask + var i = 1 + while (true) { + val curKey = data(2 * pos) + if (k.eq(curKey) || k.equals(curKey)) { + return data(2 * pos + 1).asInstanceOf[V] + } else if (curKey.eq(null)) { + return null.asInstanceOf[V] + } else { + val delta = i + pos = (pos + delta) & mask + i += 1 + } + } + return null.asInstanceOf[V] + } + + /** Set the value for a key */ + def update(key: K, value: V): Unit = { + assert(!destroyed, destructionMessage) + val k = key.asInstanceOf[AnyRef] + if (k.eq(null)) { + if (!haveNullValue) { + incrementSize() + } + nullValue = value + haveNullValue = true + return + } + var pos = rehash(key.hashCode) & mask + var i = 1 + while (true) { + val curKey = data(2 * pos) + if (curKey.eq(null)) { + data(2 * pos) = k + data(2 * pos + 1) = value.asInstanceOf[AnyRef] + incrementSize() // Since we added a new key + return + } else if (k.eq(curKey) || k.equals(curKey)) { + data(2 * pos + 1) = value.asInstanceOf[AnyRef] + return + } else { + val delta = i + pos = (pos + delta) & mask + i += 1 + } + } + } + + /** + * Set the value for key to updateFunc(hadValue, oldValue), where oldValue will be the old value + * for key, if any, or null otherwise. Returns the newly updated value. + */ + def changeValue(key: K, updateFunc: (Boolean, V) => V): V = { + assert(!destroyed, destructionMessage) + val k = key.asInstanceOf[AnyRef] + if (k.eq(null)) { + if (!haveNullValue) { + incrementSize() + } + nullValue = updateFunc(haveNullValue, nullValue) + haveNullValue = true + return nullValue + } + var pos = rehash(k.hashCode) & mask + var i = 1 + while (true) { + val curKey = data(2 * pos) + if (k.eq(curKey) || k.equals(curKey)) { + val newValue = updateFunc(true, data(2 * pos + 1).asInstanceOf[V]) + data(2 * pos + 1) = newValue.asInstanceOf[AnyRef] + return newValue + } else if (curKey.eq(null)) { + val newValue = updateFunc(false, null.asInstanceOf[V]) + data(2 * pos) = k + data(2 * pos + 1) = newValue.asInstanceOf[AnyRef] + incrementSize() + return newValue + } else { + val delta = i + pos = (pos + delta) & mask + i += 1 + } + } + null.asInstanceOf[V] // Never reached but needed to keep compiler happy + } + + /** Iterator method from Iterable */ + override def iterator: Iterator[(K, V)] = { + assert(!destroyed, destructionMessage) + new Iterator[(K, V)] { + var pos = -1 + + /** Get the next value we should return from next(), or null if we're finished iterating */ + def nextValue(): (K, V) = { + if (pos == -1) { // Treat position -1 as looking at the null value + if (haveNullValue) { + return (null.asInstanceOf[K], nullValue) + } + pos += 1 + } + while (pos < capacity) { + if (!data(2 * pos).eq(null)) { + return (data(2 * pos).asInstanceOf[K], data(2 * pos + 1).asInstanceOf[V]) + } + pos += 1 + } + null + } + + override def hasNext: Boolean = nextValue() != null + + override def next(): (K, V) = { + val value = nextValue() + if (value == null) { + throw new NoSuchElementException("End of iterator") + } + pos += 1 + value + } + } + } + + override def size: Int = curSize + + /** Increase table size by 1, rehashing if necessary */ + private def incrementSize() { + curSize += 1 + if (curSize > growThreshold) { + growTable() + } + } + + /** + * Re-hash a value to deal better with hash functions that don't differ in the lower bits. + * We use the Murmur Hash 3 finalization step that's also used in fastutil. + */ + private def rehash(h: Int): Int = { + it.unimi.dsi.fastutil.HashCommon.murmurHash3(h) + } + + /** Double the table's size and re-hash everything */ + protected def growTable() { + val newCapacity = capacity * 2 + if (newCapacity >= (1 << 30)) { + // We can't make the table this big because we want an array of 2x + // that size for our data, but array sizes are at most Int.MaxValue + throw new Exception("Can't make capacity bigger than 2^29 elements") + } + val newData = new Array[AnyRef](2 * newCapacity) + val newMask = newCapacity - 1 + // Insert all our old values into the new array. Note that because our old keys are + // unique, there's no need to check for equality here when we insert. + var oldPos = 0 + while (oldPos < capacity) { + if (!data(2 * oldPos).eq(null)) { + val key = data(2 * oldPos) + val value = data(2 * oldPos + 1) + var newPos = rehash(key.hashCode) & newMask + var i = 1 + var keepGoing = true + while (keepGoing) { + val curKey = newData(2 * newPos) + if (curKey.eq(null)) { + newData(2 * newPos) = key + newData(2 * newPos + 1) = value + keepGoing = false + } else { + val delta = i + newPos = (newPos + delta) & newMask + i += 1 + } + } + } + oldPos += 1 + } + data = newData + capacity = newCapacity + mask = newMask + growThreshold = (LOAD_FACTOR * newCapacity).toInt + } + + private def nextPowerOf2(n: Int): Int = { + val highBit = Integer.highestOneBit(n) + if (highBit == n) n else highBit << 1 + } + + /** + * Return an iterator of the map in sorted order. This provides a way to sort the map without + * using additional memory, at the expense of destroying the validity of the map. + */ + def destructiveSortedIterator(cmp: Comparator[(K, V)]): Iterator[(K, V)] = { + destroyed = true + // Pack KV pairs into the front of the underlying array + var keyIndex, newIndex = 0 + while (keyIndex < capacity) { + if (data(2 * keyIndex) != null) { + data(newIndex) = (data(2 * keyIndex), data(2 * keyIndex + 1)) + newIndex += 1 + } + keyIndex += 1 + } + assert(curSize == newIndex + (if (haveNullValue) 1 else 0)) + + // Sort by the given ordering + val rawOrdering = new Comparator[AnyRef] { + def compare(x: AnyRef, y: AnyRef): Int = { + cmp.compare(x.asInstanceOf[(K, V)], y.asInstanceOf[(K, V)]) + } + } + Arrays.sort(data, 0, newIndex, rawOrdering) + + new Iterator[(K, V)] { + var i = 0 + var nullValueReady = haveNullValue + def hasNext: Boolean = (i < newIndex || nullValueReady) + def next(): (K, V) = { + if (nullValueReady) { + nullValueReady = false + (null.asInstanceOf[K], nullValue) + } else { + val item = data(i).asInstanceOf[(K, V)] + i += 1 + item + } + } + } + } + + /** + * Return whether the next insert will cause the map to grow + */ + def atGrowThreshold: Boolean = curSize == growThreshold +} |