diff options
-rw-r--r-- | src/library/scala/collection/immutable/HashMap.scala | 24 | ||||
-rw-r--r-- | src/library/scala/collection/immutable/HashSet.scala | 22 | ||||
-rw-r--r-- | src/library/scala/collection/mutable/FlatHashTable.scala | 55 | ||||
-rw-r--r-- | src/library/scala/collection/mutable/HashMap.scala | 10 | ||||
-rw-r--r-- | src/library/scala/collection/mutable/HashSet.scala | 10 | ||||
-rw-r--r-- | src/library/scala/collection/mutable/HashTable.scala | 86 | ||||
-rw-r--r-- | src/library/scala/collection/mutable/LinkedHashMap.scala | 28 | ||||
-rw-r--r-- | src/library/scala/collection/mutable/LinkedHashSet.scala | 13 | ||||
-rw-r--r-- | test/files/jvm/serialization.check | 6 | ||||
-rw-r--r-- | test/files/jvm/t1600.scala | 76 |
10 files changed, 280 insertions, 50 deletions
diff --git a/src/library/scala/collection/immutable/HashMap.scala b/src/library/scala/collection/immutable/HashMap.scala index a036a9abfb..7825d62527 100644 --- a/src/library/scala/collection/immutable/HashMap.scala +++ b/src/library/scala/collection/immutable/HashMap.scala @@ -33,15 +33,15 @@ import annotation.unchecked.uncheckedVariance * @version 2.0, 19/01/2007 * @since 2.3 */ -@serializable @SerialVersionUID(8886909077084990906L) +@serializable @SerialVersionUID(1L) class HashMap[A, +B] extends Map[A,B] with MapLike[A, B, HashMap[A, B]] with mutable.HashTable[A] { type Entry = scala.collection.mutable.DefaultEntry[A, Any] - protected var later: HashMap[A, B @uncheckedVariance] = null - protected var oldKey: A = _ - protected var oldValue: Option[B @uncheckedVariance] = _ - protected var deltaSize: Int = _ + @transient protected var later: HashMap[A, B @uncheckedVariance] = null + @transient protected var oldKey: A = _ + @transient protected var oldValue: Option[B @uncheckedVariance] = _ + @transient protected var deltaSize: Int = _ override def empty = HashMap.empty[A, B] @@ -125,10 +125,12 @@ class HashMap[A, +B] extends Map[A,B] with MapLike[A, B, HashMap[A, B]] with mut private def logLimit: Int = math.sqrt(table.length).toInt private[this] def markUpdated(key: A, ov: Option[B], delta: Int) { - val lv = loadFactor + val lf = loadFactor later = new HashMap[A, B] { override def initialSize = 0 - override def loadFactor = lv + /* We need to do this to avoid a reference to the outer HashMap */ + _loadFactor = lf + override def loadFactor = _loadFactor table = HashMap.this.table tableSize = HashMap.this.tableSize threshold = HashMap.this.threshold @@ -174,6 +176,14 @@ class HashMap[A, +B] extends Map[A,B] with MapLike[A, B, HashMap[A, B]] with mut while (m.later != null) m = m.later if (m ne this) makeCopy(m) } + + private def writeObject(out: java.io.ObjectOutputStream) { + serializeTo(out, _.value) + } + + private def readObject(in: java.io.ObjectInputStream) { + init[B](in, new Entry(_, _)) + } } /** A factory object for immutable HashMaps. diff --git a/src/library/scala/collection/immutable/HashSet.scala b/src/library/scala/collection/immutable/HashSet.scala index e55469d173..0ad6e156a2 100644 --- a/src/library/scala/collection/immutable/HashSet.scala +++ b/src/library/scala/collection/immutable/HashSet.scala @@ -32,16 +32,16 @@ import generic._ * @version 2.8 * @since 2.3 */ -@serializable @SerialVersionUID(4020728942921483037L) +@serializable @SerialVersionUID(1L) class HashSet[A] extends Set[A] with GenericSetTemplate[A, HashSet] with SetLike[A, HashSet[A]] with mutable.FlatHashTable[A] { override def companion: GenericCompanion[HashSet] = HashSet - protected var later: HashSet[A] = null - protected var changedElem: A = _ - protected var deleted: Boolean = _ + @transient protected var later: HashSet[A] = null + @transient protected var changedElem: A = _ + @transient protected var deleted: Boolean = _ def contains(elem: A): Boolean = synchronized { var m = this @@ -99,10 +99,12 @@ class HashSet[A] extends Set[A] private def logLimit: Int = math.sqrt(table.length).toInt private def markUpdated(elem: A, del: Boolean) { - val lv = loadFactor + val lf = loadFactor later = new HashSet[A] { override def initialSize = 0 - override def loadFactor = lv + /* We need to do this to avoid a reference to the outer HashMap */ + _loadFactor = lf + override def loadFactor = _loadFactor table = HashSet.this.table tableSize = HashSet.this.tableSize threshold = HashSet.this.threshold @@ -132,6 +134,14 @@ class HashSet[A] extends Set[A] while (m.later != null) m = m.later if (m ne this) makeCopy(m) } + + private def writeObject(s: java.io.ObjectOutputStream) { + serializeTo(s) + } + + private def readObject(in: java.io.ObjectInputStream) { + init(in, x => x) + } } /** A factory object for immutable HashSets. diff --git a/src/library/scala/collection/mutable/FlatHashTable.scala b/src/library/scala/collection/mutable/FlatHashTable.scala index 1d55933050..d06ead7888 100644 --- a/src/library/scala/collection/mutable/FlatHashTable.scala +++ b/src/library/scala/collection/mutable/FlatHashTable.scala @@ -27,18 +27,63 @@ trait FlatHashTable[A] { private final val tableDebug = false + @transient private[collection] var _loadFactor = loadFactor + /** The actual hash table. */ - protected var table: Array[AnyRef] = - if (initialSize == 0) null else new Array(initialSize) + @transient protected var table: Array[AnyRef] = new Array(initialCapacity) /** The number of mappings contained in this hash table. */ - protected var tableSize = 0 + @transient protected var tableSize = 0 /** The next size value at which to resize (capacity * load factor). */ - protected var threshold: Int = newThreshold(initialSize) + @transient protected var threshold: Int = newThreshold(initialCapacity) + + import HashTable.powerOfTwo + private def capacity(expectedSize: Int) = if (expectedSize == 0) 1 else powerOfTwo(expectedSize) + private def initialCapacity = capacity(initialSize) + + /** + * Initialises the collection from the input stream. `f` will be called for each element + * read from the input stream in the order determined by the stream. This is useful for + * structures where iteration order is important (e.g. LinkedHashSet). + * + * The serialization format expected is the one produced by `serializeTo`. + */ + private[collection] def init(in: java.io.ObjectInputStream, f: A => Unit) { + in.defaultReadObject + + _loadFactor = in.readInt + assert(_loadFactor > 0) + + val size = in.readInt + assert(size >= 0) + + table = new Array(capacity(size * loadFactorDenum / _loadFactor)) + threshold = newThreshold(table.size) + + var index = 0 + while (index < size) { + val elem = in.readObject.asInstanceOf[A] + f(elem) + addEntry(elem) + index += 1 + } + } + + /** + * Serializes the collection to the output stream by saving the load factor, collection + * size and collection elements. `foreach` determines the order in which the elements are saved + * to the stream. To deserialize, `init` should be used. + */ + private[collection] def serializeTo(out: java.io.ObjectOutputStream) { + out.defaultWriteObject + out.writeInt(_loadFactor) + out.writeInt(tableSize) + iterator.foreach(out.writeObject) + } def findEntry(elem: A): Option[A] = { var h = index(elemHashCode(elem)) @@ -154,7 +199,7 @@ trait FlatHashTable[A] { protected final def index(hcode: Int) = improve(hcode) & (table.length - 1) private def newThreshold(size: Int) = { - val lf = loadFactor + val lf = _loadFactor assert(lf < (loadFactorDenum / 2), "loadFactor too large; must be < 0.5") (size.toLong * lf / loadFactorDenum ).toInt } diff --git a/src/library/scala/collection/mutable/HashMap.scala b/src/library/scala/collection/mutable/HashMap.scala index e4969a3af0..21d524129b 100644 --- a/src/library/scala/collection/mutable/HashMap.scala +++ b/src/library/scala/collection/mutable/HashMap.scala @@ -17,7 +17,7 @@ import generic._ /** * @since 1 */ -@serializable @SerialVersionUID(-8682987922734091219L) +@serializable @SerialVersionUID(1L) class HashMap[A, B] extends Map[A, B] with MapLike[A, B, HashMap[A, B]] with HashTable[A] { @@ -84,6 +84,14 @@ class HashMap[A, B] extends Map[A, B] def hasNext = iter.hasNext def next = iter.next.value } + + private def writeObject(out: java.io.ObjectOutputStream) { + serializeTo(out, _.value) + } + + private def readObject(in: java.io.ObjectInputStream) { + init[B](in, new Entry(_, _)) + } } /** This class implements mutable maps using a hashtable. diff --git a/src/library/scala/collection/mutable/HashSet.scala b/src/library/scala/collection/mutable/HashSet.scala index 9144a4be88..0630f96fa2 100644 --- a/src/library/scala/collection/mutable/HashSet.scala +++ b/src/library/scala/collection/mutable/HashSet.scala @@ -21,7 +21,7 @@ import generic._ * @version 2.0, 31/12/2006 * @since 1 */ -@serializable +@serializable @SerialVersionUID(1L) class HashSet[A] extends Set[A] with GenericSetTemplate[A, HashSet] with SetLike[A, HashSet[A]] @@ -51,6 +51,14 @@ class HashSet[A] extends Set[A] } override def clone(): Set[A] = new HashSet[A] ++= this + + private def writeObject(s: java.io.ObjectOutputStream) { + serializeTo(s) + } + + private def readObject(in: java.io.ObjectInputStream) { + init(in, x => x) + } } /** Factory object for `HashSet` class */ diff --git a/src/library/scala/collection/mutable/HashTable.scala b/src/library/scala/collection/mutable/HashTable.scala index fa3865ada2..cdb6d6904b 100644 --- a/src/library/scala/collection/mutable/HashTable.scala +++ b/src/library/scala/collection/mutable/HashTable.scala @@ -31,6 +31,7 @@ package mutable * @since 1 */ trait HashTable[A] { + import HashTable._ protected type Entry >: Null <: HashEntry[A, Entry] @@ -47,33 +48,65 @@ trait HashTable[A] { */ protected def initialThreshold: Int = newThreshold(initialCapacity) + @transient private[collection] var _loadFactor = loadFactor + /** The actual hash table. */ - protected var table: Array[HashEntry[A, Entry]] = new Array(initialCapacity) + @transient protected var table: Array[HashEntry[A, Entry]] = new Array(initialCapacity) + + /** The number of mappings contained in this hash table. + */ + @transient protected var tableSize: Int = 0 - private def initialCapacity = if (initialSize == 0) 1 else powerOfTwo(initialSize) + /** The next size value at which to resize (capacity * load factor). + */ + @transient protected var threshold: Int = initialThreshold + + private def initialCapacity = capacity(initialSize) /** - * Returns a power of two >= `target`. + * Initialises the collection from the input stream. `f` will be called for each key/value pair + * read from the input stream in the order determined by the stream. This is useful for + * structures where iteration order is important (e.g. LinkedHashMap). */ - private def powerOfTwo(target: Int): Int = { - /* See http://bits.stephan-brumme.com/roundUpToNextPowerOfTwo.html */ - var c = target - 1; - c |= c >>> 1; - c |= c >>> 2; - c |= c >>> 4; - c |= c >>> 8; - c |= c >>> 16; - c + 1; + private[collection] def init[B](in: java.io.ObjectInputStream, f: (A, B) => Entry) { + in.defaultReadObject + + _loadFactor = in.readInt + assert(_loadFactor > 0) + + val size = in.readInt + assert(size >= 0) + + table = new Array(capacity(size * loadFactorDenum / _loadFactor)) + threshold = newThreshold(table.size) + + var index = 0 + while (index < size) { + addEntry(f(in.readObject.asInstanceOf[A], in.readObject.asInstanceOf[B])) + index += 1 + } } - /** The number of mappings contained in this hash table. + /** + * Serializes the collection to the output stream by saving the load factor, collection + * size, collection keys and collection values. `value` is responsible for providing a value + * from an entry. + * + * `foreach` determines the order in which the key/value pairs are saved to the stream. To + * deserialize, `init` should be used. */ - protected var tableSize: Int = 0 + private[collection] def serializeTo[B](out: java.io.ObjectOutputStream, value: Entry => B) { + out.defaultWriteObject + out.writeInt(loadFactor) + out.writeInt(tableSize) + foreachEntry { entry => + out.writeObject(entry.key) + out.writeObject(value(entry)) + } + } - /** The next size value at which to resize (capacity * load factor). - */ - protected var threshold: Int = initialThreshold + private def capacity(expectedSize: Int) = if (expectedSize == 0) 1 else powerOfTwo(expectedSize) /** Find entry with given key in table, null if not found */ @@ -172,9 +205,9 @@ trait HashTable[A] { } private def newThreshold(size: Int) = - ((size.toLong * loadFactor)/loadFactorDenum).toInt + ((size.toLong * _loadFactor)/loadFactorDenum).toInt - private def resize(newSize: Int) = { + private def resize(newSize: Int) { val oldTable = table table = new Array(newSize) var i = oldTable.length - 1 @@ -206,4 +239,19 @@ trait HashTable[A] { protected final def index(hcode: Int) = improve(hcode) & (table.length - 1) } +private[collection] object HashTable { + /** + * Returns a power of two >= `target`. + */ + private[collection] def powerOfTwo(target: Int): Int = { + /* See http://bits.stephan-brumme.com/roundUpToNextPowerOfTwo.html */ + var c = target - 1; + c |= c >>> 1; + c |= c >>> 2; + c |= c >>> 4; + c |= c >>> 8; + c |= c >>> 16; + c + 1; + } +} diff --git a/src/library/scala/collection/mutable/LinkedHashMap.scala b/src/library/scala/collection/mutable/LinkedHashMap.scala index 308db1a4d4..183a3aedb0 100644 --- a/src/library/scala/collection/mutable/LinkedHashMap.scala +++ b/src/library/scala/collection/mutable/LinkedHashMap.scala @@ -29,7 +29,7 @@ object LinkedHashMap extends MutableMapFactory[LinkedHashMap] { /** * @since 2.7 */ -@serializable +@serializable @SerialVersionUID(1L) class LinkedHashMap[A, B] extends Map[A, B] with MapLike[A, B, LinkedHashMap[A, B]] with HashTable[A] { @@ -39,8 +39,8 @@ class LinkedHashMap[A, B] extends Map[A, B] type Entry = LinkedEntry[A, B] - protected var firstEntry: Entry = null - protected var lastEntry: Entry = null + @transient protected var firstEntry: Entry = null + @transient protected var lastEntry: Entry = null def get(key: A): Option[B] = { val e = findEntry(key) @@ -53,9 +53,7 @@ class LinkedHashMap[A, B] extends Map[A, B] if (e == null) { val e = new Entry(key, value) addEntry(e) - if (firstEntry == null) firstEntry = e - else { lastEntry.later = e; e.earlier = lastEntry } - lastEntry = e + updateLinkedEntries(e) None } else { val v = e.value @@ -64,6 +62,12 @@ class LinkedHashMap[A, B] extends Map[A, B] } } + private def updateLinkedEntries(e: Entry) { + if (firstEntry == null) firstEntry = e + else { lastEntry.later = e; e.earlier = lastEntry } + lastEntry = e + } + override def remove(key: A): Option[B] = { val e = removeEntry(key) if (e eq null) None @@ -115,4 +119,16 @@ class LinkedHashMap[A, B] extends Map[A, B] clearTable() firstEntry = null } + + private def writeObject(out: java.io.ObjectOutputStream) { + serializeTo(out, _.value) + } + + private def readObject(in: java.io.ObjectInputStream) { + init[B](in, { (key, value) => + val entry = new Entry(key, value) + updateLinkedEntries(entry) + entry + }) + } } diff --git a/src/library/scala/collection/mutable/LinkedHashSet.scala b/src/library/scala/collection/mutable/LinkedHashSet.scala index 081b068723..b9313bd1d4 100644 --- a/src/library/scala/collection/mutable/LinkedHashSet.scala +++ b/src/library/scala/collection/mutable/LinkedHashSet.scala @@ -17,7 +17,7 @@ import generic._ * Should be rewritten to be more efficient. * @since 2.2 */ -@serializable +@serializable @SerialVersionUID(1L) class LinkedHashSet[A] extends Set[A] with GenericSetTemplate[A, LinkedHashSet] with SetLike[A, LinkedHashSet[A]] @@ -25,7 +25,7 @@ class LinkedHashSet[A] extends Set[A] { override def companion: GenericCompanion[LinkedHashSet] = LinkedHashSet - protected val ordered = new ListBuffer[A] + @transient private var ordered = new ListBuffer[A] override def size = tableSize @@ -52,6 +52,15 @@ class LinkedHashSet[A] extends Set[A] override def iterator = ordered.iterator override def foreach[U](f: A => U) = ordered foreach f + + private def writeObject(s: java.io.ObjectOutputStream) { + serializeTo(s) + } + + private def readObject(in: java.io.ObjectInputStream) { + ordered = new ListBuffer[A] + init(in, ordered += ) + } } /** Factory object for `LinkedHashSet` class */ diff --git a/test/files/jvm/serialization.check b/test/files/jvm/serialization.check index f1b5b10ec6..2b0ad3888b 100644 --- a/test/files/jvm/serialization.check +++ b/test/files/jvm/serialization.check @@ -78,11 +78,11 @@ y = BitSet(2, 3) x equals y: true, y equals x: true x = Map(2 -> B, 1 -> A, 3 -> C) -y = Map(2 -> B, 1 -> A, 3 -> C) +y = Map(1 -> A, 2 -> B, 3 -> C) x equals y: true, y equals x: true x = Set(1, 2) -y = Set(1, 2) +y = Set(2, 1) x equals y: true, y equals x: true x = List((buffers,20), (layers,2), (title,3)) @@ -158,7 +158,7 @@ y = BitSet(0, 8, 9) x equals y: true, y equals x: true x = Map(A -> 1, C -> 3, B -> 2) -y = Map(A -> 1, C -> 3, B -> 2) +y = Map(B -> 2, C -> 3, A -> 1) x equals y: true, y equals x: true x = Set(layers, buffers, title) diff --git a/test/files/jvm/t1600.scala b/test/files/jvm/t1600.scala new file mode 100644 index 0000000000..1cdcee8547 --- /dev/null +++ b/test/files/jvm/t1600.scala @@ -0,0 +1,76 @@ + +/** + * Checks that serialization of hash-based collections works correctly if the hashCode + * changes on deserialization. + */ +object Test { + + import collection._ + def main(args: Array[String]) { + for (i <- Seq(0, 1, 2, 10, 100)) { + def entries = (0 until i).map(i => (new Foo, i)).toList + def elements = entries.map(_._1) + + val maps = Seq[Map[Foo, Int]](new mutable.HashMap, new mutable.LinkedHashMap, + new immutable.HashMap).map(_ ++ entries) + test[Map[Foo, Int]](maps, entries.size, assertMap _) + + val sets = Seq[Set[Foo]](new mutable.HashSet, new mutable.LinkedHashSet, + new immutable.HashSet).map(_ ++ elements) + test[Set[Foo]](sets, entries.size, assertSet _) + } + } + + private def test[A <: AnyRef](collections: Seq[A], expectedSize: Int, assertFunction: (A, Int) => Unit) { + for (collection <- collections) { + assertFunction(collection, expectedSize) + + val bytes = toBytes(collection) + Foo.hashCodeModifier = 1 + val deserializedCollection = toObject[A](bytes) + + assertFunction(deserializedCollection, expectedSize) + assert(deserializedCollection.getClass == collection.getClass, + "collection class should remain the same after deserialization") + Foo.hashCodeModifier = 0 + } + } + + private def toObject[A](bytes: Array[Byte]): A = { + val in = new java.io.ObjectInputStream(new java.io.ByteArrayInputStream(bytes)) + in.readObject.asInstanceOf[A] + } + + private def toBytes(o: AnyRef): Array[Byte] = { + val bos = new java.io.ByteArrayOutputStream + val out = new java.io.ObjectOutputStream(bos) + out.writeObject(o) + out.close + bos.toByteArray + } + + private def assertMap[A, B](map: Map[A, B], expectedSize: Int) { + assert(expectedSize == map.size, "expected map size: " + expectedSize + ", actual size: " + map.size) + map.foreach { case (k, v) => + assert(map.contains(k), "contains should return true for key in the map, key: " + k) + assert(map(k) == v) + } + } + + private def assertSet[A](set: Set[A], expectedSize: Int) { + assert(expectedSize == set.size, "expected set size: " + expectedSize + ", actual size: " + set.size) + set.foreach { e => assert(set.contains(e), "contains should return true for element in the set, element: " + e) } + } + + object Foo { + /* Used to simulate a hashCode change caused by deserializing an instance with an + * identity-based hashCode in another JVM. + */ + var hashCodeModifier = 0 + } + + @serializable + class Foo { + override def hashCode = System.identityHashCode(this) + Foo.hashCodeModifier + } +} |