diff options
Diffstat (limited to 'core/src/main/scala/org/apache/spark/util/collection')
5 files changed, 710 insertions, 0 deletions
diff --git a/core/src/main/scala/org/apache/spark/util/collection/BitSet.scala b/core/src/main/scala/org/apache/spark/util/collection/BitSet.scala new file mode 100644 index 0000000000..a1a452315d --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/collection/BitSet.scala @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection + + +/** + * A simple, fixed-size bit set implementation. This implementation is fast because it avoids + * safety/bound checking. + */ +class BitSet(numBits: Int) { + + private[this] val words = new Array[Long](bit2words(numBits)) + private[this] val numWords = words.length + + /** + * Sets the bit at the specified index to true. + * @param index the bit index + */ + def set(index: Int) { + val bitmask = 1L << (index & 0x3f) // mod 64 and shift + words(index >> 6) |= bitmask // div by 64 and mask + } + + /** + * Return the value of the bit with the specified index. The value is true if the bit with + * the index is currently set in this BitSet; otherwise, the result is false. + * + * @param index the bit index + * @return the value of the bit with the specified index + */ + def get(index: Int): Boolean = { + val bitmask = 1L << (index & 0x3f) // mod 64 and shift + (words(index >> 6) & bitmask) != 0 // div by 64 and mask + } + + /** Return the number of bits set to true in this BitSet. */ + def cardinality(): Int = { + var sum = 0 + var i = 0 + while (i < numWords) { + sum += java.lang.Long.bitCount(words(i)) + i += 1 + } + sum + } + + /** + * Returns the index of the first bit that is set to true that occurs on or after the + * specified starting index. If no such bit exists then -1 is returned. + * + * To iterate over the true bits in a BitSet, use the following loop: + * + * for (int i = bs.nextSetBit(0); i >= 0; i = bs.nextSetBit(i+1)) { + * // operate on index i here + * } + * + * @param fromIndex the index to start checking from (inclusive) + * @return the index of the next set bit, or -1 if there is no such bit + */ + def nextSetBit(fromIndex: Int): Int = { + var wordIndex = fromIndex >> 6 + if (wordIndex >= numWords) { + return -1 + } + + // Try to find the next set bit in the current word + val subIndex = fromIndex & 0x3f + var word = words(wordIndex) >> subIndex + if (word != 0) { + return (wordIndex << 6) + subIndex + java.lang.Long.numberOfTrailingZeros(word) + } + + // Find the next set bit in the rest of the words + wordIndex += 1 + while (wordIndex < numWords) { + word = words(wordIndex) + if (word != 0) { + return (wordIndex << 6) + java.lang.Long.numberOfTrailingZeros(word) + } + wordIndex += 1 + } + + -1 + } + + /** Return the number of longs it would take to hold numBits. */ + private def bit2words(numBits: Int) = ((numBits - 1) >> 6) + 1 +} diff --git a/core/src/main/scala/org/apache/spark/util/collection/OpenHashMap.scala b/core/src/main/scala/org/apache/spark/util/collection/OpenHashMap.scala new file mode 100644 index 0000000000..45849b3380 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/collection/OpenHashMap.scala @@ -0,0 +1,154 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection + +import scala.reflect.ClassTag + + +/** + * A fast hash map implementation for nullable keys. This hash map supports insertions and updates, + * but not deletions. This map is about 5X faster than java.util.HashMap, while using much less + * space overhead. + * + * Under the hood, it uses our OpenHashSet implementation. + */ +private[spark] +class OpenHashMap[K >: Null : ClassTag, @specialized(Long, Int, Double) V: ClassTag]( + initialCapacity: Int) + extends Iterable[(K, V)] + with Serializable { + + def this() = this(64) + + protected var _keySet = new OpenHashSet[K](initialCapacity) + + // Init in constructor (instead of in declaration) to work around a Scala compiler specialization + // bug that would generate two arrays (one for Object and one for specialized T). + private var _values: Array[V] = _ + _values = new Array[V](_keySet.capacity) + + @transient private var _oldValues: Array[V] = null + + // Treat the null key differently so we can use nulls in "data" to represent empty items. + private var haveNullValue = false + private var nullValue: V = null.asInstanceOf[V] + + override def size: Int = if (haveNullValue) _keySet.size + 1 else _keySet.size + + /** Get the value for a given key */ + def apply(k: K): V = { + if (k == null) { + nullValue + } else { + val pos = _keySet.getPos(k) + if (pos < 0) { + null.asInstanceOf[V] + } else { + _values(pos) + } + } + } + + /** Set the value for a key */ + def update(k: K, v: V) { + if (k == null) { + haveNullValue = true + nullValue = v + } else { + val pos = _keySet.addWithoutResize(k) & OpenHashSet.POSITION_MASK + _values(pos) = v + _keySet.rehashIfNeeded(k, grow, move) + _oldValues = null + } + } + + /** + * If the key doesn't exist yet in the hash map, set its value to defaultValue; otherwise, + * set its value to mergeValue(oldValue). + * + * @return the newly updated value. + */ + def changeValue(k: K, defaultValue: => V, mergeValue: (V) => V): V = { + if (k == null) { + if (haveNullValue) { + nullValue = mergeValue(nullValue) + } else { + haveNullValue = true + nullValue = defaultValue + } + nullValue + } else { + val pos = _keySet.addWithoutResize(k) + if ((pos & OpenHashSet.NONEXISTENCE_MASK) != 0) { + val newValue = defaultValue + _values(pos & OpenHashSet.POSITION_MASK) = newValue + _keySet.rehashIfNeeded(k, grow, move) + newValue + } else { + _values(pos) = mergeValue(_values(pos)) + _values(pos) + } + } + } + + override def iterator = new Iterator[(K, V)] { + var pos = -1 + var nextPair: (K, V) = computeNextPair() + + /** Get the next value we should return from next(), or null if we're finished iterating */ + def computeNextPair(): (K, V) = { + if (pos == -1) { // Treat position -1 as looking at the null value + if (haveNullValue) { + pos += 1 + return (null.asInstanceOf[K], nullValue) + } + pos += 1 + } + pos = _keySet.nextPos(pos) + if (pos >= 0) { + val ret = (_keySet.getValue(pos), _values(pos)) + pos += 1 + ret + } else { + null + } + } + + def hasNext = nextPair != null + + def next() = { + val pair = nextPair + nextPair = computeNextPair() + pair + } + } + + // The following member variables are declared as protected instead of private for the + // specialization to work (specialized class extends the non-specialized one and needs access + // to the "private" variables). + // They also should have been val's. We use var's because there is a Scala compiler bug that + // would throw illegal access error at runtime if they are declared as val's. + protected var grow = (newCapacity: Int) => { + _oldValues = _values + _values = new Array[V](newCapacity) + } + + protected var move = (oldPos: Int, newPos: Int) => { + _values(newPos) = _oldValues(oldPos) + } +} diff --git a/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala b/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala new file mode 100644 index 0000000000..49d95afdb9 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala @@ -0,0 +1,272 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection + +import scala.reflect._ + +/** + * A simple, fast hash set optimized for non-null insertion-only use case, where keys are never + * removed. + * + * The underlying implementation uses Scala compiler's specialization to generate optimized + * storage for two primitive types (Long and Int). It is much faster than Java's standard HashSet + * while incurring much less memory overhead. This can serve as building blocks for higher level + * data structures such as an optimized HashMap. + * + * This OpenHashSet is designed to serve as building blocks for higher level data structures + * such as an optimized hash map. Compared with standard hash set implementations, this class + * provides its various callbacks interfaces (e.g. allocateFunc, moveFunc) and interfaces to + * retrieve the position of a key in the underlying array. + * + * It uses quadratic probing with a power-of-2 hash table size, which is guaranteed + * to explore all spaces for each key (see http://en.wikipedia.org/wiki/Quadratic_probing). + */ +private[spark] +class OpenHashSet[@specialized(Long, Int) T: ClassTag]( + initialCapacity: Int, + loadFactor: Double) + extends Serializable { + + require(initialCapacity <= (1 << 29), "Can't make capacity bigger than 2^29 elements") + require(initialCapacity >= 1, "Invalid initial capacity") + require(loadFactor < 1.0, "Load factor must be less than 1.0") + require(loadFactor > 0.0, "Load factor must be greater than 0.0") + + import OpenHashSet._ + + def this(initialCapacity: Int) = this(initialCapacity, 0.7) + + def this() = this(64) + + // The following member variables are declared as protected instead of private for the + // specialization to work (specialized class extends the non-specialized one and needs access + // to the "private" variables). + + protected val hasher: Hasher[T] = { + // It would've been more natural to write the following using pattern matching. But Scala 2.9.x + // compiler has a bug when specialization is used together with this pattern matching, and + // throws: + // scala.tools.nsc.symtab.Types$TypeError: type mismatch; + // found : scala.reflect.AnyValManifest[Long] + // required: scala.reflect.ClassTag[Int] + // at scala.tools.nsc.typechecker.Contexts$Context.error(Contexts.scala:298) + // at scala.tools.nsc.typechecker.Infer$Inferencer.error(Infer.scala:207) + // ... + val mt = classTag[T] + if (mt == ClassTag.Long) { + (new LongHasher).asInstanceOf[Hasher[T]] + } else if (mt == ClassTag.Int) { + (new IntHasher).asInstanceOf[Hasher[T]] + } else { + new Hasher[T] + } + } + + protected var _capacity = nextPowerOf2(initialCapacity) + protected var _mask = _capacity - 1 + protected var _size = 0 + + protected var _bitset = new BitSet(_capacity) + + // Init of the array in constructor (instead of in declaration) to work around a Scala compiler + // specialization bug that would generate two arrays (one for Object and one for specialized T). + protected var _data: Array[T] = _ + _data = new Array[T](_capacity) + + /** Number of elements in the set. */ + def size: Int = _size + + /** The capacity of the set (i.e. size of the underlying array). */ + def capacity: Int = _capacity + + /** Return true if this set contains the specified element. */ + def contains(k: T): Boolean = getPos(k) != INVALID_POS + + /** + * Add an element to the set. If the set is over capacity after the insertion, grow the set + * and rehash all elements. + */ + def add(k: T) { + addWithoutResize(k) + rehashIfNeeded(k, grow, move) + } + + /** + * Add an element to the set. This one differs from add in that it doesn't trigger rehashing. + * The caller is responsible for calling rehashIfNeeded. + * + * Use (retval & POSITION_MASK) to get the actual position, and + * (retval & EXISTENCE_MASK) != 0 for prior existence. + * + * @return The position where the key is placed, plus the highest order bit is set if the key + * exists previously. + */ + def addWithoutResize(k: T): Int = putInto(_bitset, _data, k) + + /** + * Rehash the set if it is overloaded. + * @param k A parameter unused in the function, but to force the Scala compiler to specialize + * this method. + * @param allocateFunc Callback invoked when we are allocating a new, larger array. + * @param moveFunc Callback invoked when we move the key from one position (in the old data array) + * to a new position (in the new data array). + */ + def rehashIfNeeded(k: T, allocateFunc: (Int) => Unit, moveFunc: (Int, Int) => Unit) { + if (_size > loadFactor * _capacity) { + rehash(k, allocateFunc, moveFunc) + } + } + + /** + * Return the position of the element in the underlying array, or INVALID_POS if it is not found. + */ + def getPos(k: T): Int = { + var pos = hashcode(hasher.hash(k)) & _mask + var i = 1 + while (true) { + if (!_bitset.get(pos)) { + return INVALID_POS + } else if (k == _data(pos)) { + return pos + } else { + val delta = i + pos = (pos + delta) & _mask + i += 1 + } + } + // Never reached here + INVALID_POS + } + + /** Return the value at the specified position. */ + def getValue(pos: Int): T = _data(pos) + + /** + * Return the next position with an element stored, starting from the given position inclusively. + */ + def nextPos(fromPos: Int): Int = _bitset.nextSetBit(fromPos) + + /** + * Put an entry into the set. Return the position where the key is placed. In addition, the + * highest bit in the returned position is set if the key exists prior to this put. + * + * This function assumes the data array has at least one empty slot. + */ + private def putInto(bitset: BitSet, data: Array[T], k: T): Int = { + val mask = data.length - 1 + var pos = hashcode(hasher.hash(k)) & mask + var i = 1 + while (true) { + if (!bitset.get(pos)) { + // This is a new key. + data(pos) = k + bitset.set(pos) + _size += 1 + return pos | NONEXISTENCE_MASK + } else if (data(pos) == k) { + // Found an existing key. + return pos + } else { + val delta = i + pos = (pos + delta) & mask + i += 1 + } + } + // Never reached here + assert(INVALID_POS != INVALID_POS) + INVALID_POS + } + + /** + * Double the table's size and re-hash everything. We are not really using k, but it is declared + * so Scala compiler can specialize this method (which leads to calling the specialized version + * of putInto). + * + * @param k A parameter unused in the function, but to force the Scala compiler to specialize + * this method. + * @param allocateFunc Callback invoked when we are allocating a new, larger array. + * @param moveFunc Callback invoked when we move the key from one position (in the old data array) + * to a new position (in the new data array). + */ + private def rehash(k: T, allocateFunc: (Int) => Unit, moveFunc: (Int, Int) => Unit) { + val newCapacity = _capacity * 2 + require(newCapacity <= (1 << 29), "Can't make capacity bigger than 2^29 elements") + + allocateFunc(newCapacity) + val newData = new Array[T](newCapacity) + val newBitset = new BitSet(newCapacity) + var pos = 0 + _size = 0 + while (pos < _capacity) { + if (_bitset.get(pos)) { + val newPos = putInto(newBitset, newData, _data(pos)) + moveFunc(pos, newPos & POSITION_MASK) + } + pos += 1 + } + _bitset = newBitset + _data = newData + _capacity = newCapacity + _mask = newCapacity - 1 + } + + /** + * Re-hash a value to deal better with hash functions that don't differ + * in the lower bits, similar to java.util.HashMap + */ + private def hashcode(h: Int): Int = { + val r = h ^ (h >>> 20) ^ (h >>> 12) + r ^ (r >>> 7) ^ (r >>> 4) + } + + private def nextPowerOf2(n: Int): Int = { + val highBit = Integer.highestOneBit(n) + if (highBit == n) n else highBit << 1 + } +} + + +private[spark] +object OpenHashSet { + + val INVALID_POS = -1 + val NONEXISTENCE_MASK = 0x80000000 + val POSITION_MASK = 0xEFFFFFF + + /** + * A set of specialized hash function implementation to avoid boxing hash code computation + * in the specialized implementation of OpenHashSet. + */ + sealed class Hasher[@specialized(Long, Int) T] { + def hash(o: T): Int = o.hashCode() + } + + class LongHasher extends Hasher[Long] { + override def hash(o: Long): Int = (o ^ (o >>> 32)).toInt + } + + class IntHasher extends Hasher[Int] { + override def hash(o: Int): Int = o + } + + private def grow1(newSize: Int) {} + private def move1(oldPos: Int, newPos: Int) { } + + private val grow = grow1 _ + private val move = move1 _ +} diff --git a/core/src/main/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMap.scala b/core/src/main/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMap.scala new file mode 100644 index 0000000000..2e1ef06cbc --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMap.scala @@ -0,0 +1,128 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection + +import scala.reflect._ + +/** + * A fast hash map implementation for primitive, non-null keys. This hash map supports + * insertions and updates, but not deletions. This map is about an order of magnitude + * faster than java.util.HashMap, while using much less space overhead. + * + * Under the hood, it uses our OpenHashSet implementation. + */ +private[spark] +class PrimitiveKeyOpenHashMap[@specialized(Long, Int) K: ClassTag, + @specialized(Long, Int, Double) V: ClassTag]( + initialCapacity: Int) + extends Iterable[(K, V)] + with Serializable { + + def this() = this(64) + + require(classTag[K] == classTag[Long] || classTag[K] == classTag[Int]) + + // Init in constructor (instead of in declaration) to work around a Scala compiler specialization + // bug that would generate two arrays (one for Object and one for specialized T). + protected var _keySet: OpenHashSet[K] = _ + private var _values: Array[V] = _ + _keySet = new OpenHashSet[K](initialCapacity) + _values = new Array[V](_keySet.capacity) + + private var _oldValues: Array[V] = null + + override def size = _keySet.size + + /** Get the value for a given key */ + def apply(k: K): V = { + val pos = _keySet.getPos(k) + _values(pos) + } + + /** Get the value for a given key, or returns elseValue if it doesn't exist. */ + def getOrElse(k: K, elseValue: V): V = { + val pos = _keySet.getPos(k) + if (pos >= 0) _values(pos) else elseValue + } + + /** Set the value for a key */ + def update(k: K, v: V) { + val pos = _keySet.addWithoutResize(k) & OpenHashSet.POSITION_MASK + _values(pos) = v + _keySet.rehashIfNeeded(k, grow, move) + _oldValues = null + } + + /** + * If the key doesn't exist yet in the hash map, set its value to defaultValue; otherwise, + * set its value to mergeValue(oldValue). + * + * @return the newly updated value. + */ + def changeValue(k: K, defaultValue: => V, mergeValue: (V) => V): V = { + val pos = _keySet.addWithoutResize(k) + if ((pos & OpenHashSet.NONEXISTENCE_MASK) != 0) { + val newValue = defaultValue + _values(pos & OpenHashSet.POSITION_MASK) = newValue + _keySet.rehashIfNeeded(k, grow, move) + newValue + } else { + _values(pos) = mergeValue(_values(pos)) + _values(pos) + } + } + + override def iterator = new Iterator[(K, V)] { + var pos = 0 + var nextPair: (K, V) = computeNextPair() + + /** Get the next value we should return from next(), or null if we're finished iterating */ + def computeNextPair(): (K, V) = { + pos = _keySet.nextPos(pos) + if (pos >= 0) { + val ret = (_keySet.getValue(pos), _values(pos)) + pos += 1 + ret + } else { + null + } + } + + def hasNext = nextPair != null + + def next() = { + val pair = nextPair + nextPair = computeNextPair() + pair + } + } + + // The following member variables are declared as protected instead of private for the + // specialization to work (specialized class extends the unspecialized one and needs access + // to the "private" variables). + // They also should have been val's. We use var's because there is a Scala compiler bug that + // would throw illegal access error at runtime if they are declared as val's. + protected var grow = (newCapacity: Int) => { + _oldValues = _values + _values = new Array[V](newCapacity) + } + + protected var move = (oldPos: Int, newPos: Int) => { + _values(newPos) = _oldValues(oldPos) + } +} diff --git a/core/src/main/scala/org/apache/spark/util/collection/PrimitiveVector.scala b/core/src/main/scala/org/apache/spark/util/collection/PrimitiveVector.scala new file mode 100644 index 0000000000..465c221d5f --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/collection/PrimitiveVector.scala @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection + +import scala.reflect.ClassTag + +/** Provides a simple, non-threadsafe, array-backed vector that can store primitives. */ +private[spark] +class PrimitiveVector[@specialized(Long, Int, Double) V: ClassTag](initialSize: Int = 64) { + private var numElements = 0 + private var array: Array[V] = _ + + // NB: This must be separate from the declaration, otherwise the specialized parent class + // will get its own array with the same initial size. TODO: Figure out why... + array = new Array[V](initialSize) + + def apply(index: Int): V = { + require(index < numElements) + array(index) + } + + def +=(value: V) { + if (numElements == array.length) { resize(array.length * 2) } + array(numElements) = value + numElements += 1 + } + + def length = numElements + + def getUnderlyingArray = array + + /** Resizes the array, dropping elements if the total length decreases. */ + def resize(newLength: Int) { + val newArray = new Array[V](newLength) + array.copyToArray(newArray) + array = newArray + } +} |