From a3bf3f136caaefa98268607a3529b7554df5fc80 Mon Sep 17 00:00:00 2001 From: Paul Phillips Date: Tue, 1 Dec 2009 18:28:55 +0000 Subject: [This patch submitted by ismael juma - commit m... [This patch submitted by ismael juma - commit message his words, but condensed.] Fix ticket #1600: Serialization and deserialization of hash-based collections should not re-use hashCode. The collection is rebuilt on deserialization - note that this is not compatible with the previous serialization format. All @SerialVersionUIDs have been reset to 1. WeakHashMap is not Serializable and should not be so. TreeHashMap has not been reintegrated yet. OpenHashMap has not been updated. (I think this collection is flawed and should be removed or reimplemented.) --- .../scala/collection/immutable/HashMap.scala | 24 ++++-- .../scala/collection/immutable/HashSet.scala | 22 ++++-- .../scala/collection/mutable/FlatHashTable.scala | 55 ++++++++++++-- src/library/scala/collection/mutable/HashMap.scala | 10 ++- src/library/scala/collection/mutable/HashSet.scala | 10 ++- .../scala/collection/mutable/HashTable.scala | 86 +++++++++++++++++----- .../scala/collection/mutable/LinkedHashMap.scala | 28 +++++-- .../scala/collection/mutable/LinkedHashSet.scala | 13 +++- test/files/jvm/serialization.check | 6 +- test/files/jvm/t1600.scala | 76 +++++++++++++++++++ 10 files changed, 280 insertions(+), 50 deletions(-) create mode 100644 test/files/jvm/t1600.scala diff --git a/src/library/scala/collection/immutable/HashMap.scala b/src/library/scala/collection/immutable/HashMap.scala index a036a9abfb..7825d62527 100644 --- a/src/library/scala/collection/immutable/HashMap.scala +++ b/src/library/scala/collection/immutable/HashMap.scala @@ -33,15 +33,15 @@ import annotation.unchecked.uncheckedVariance * @version 2.0, 19/01/2007 * @since 2.3 */ -@serializable @SerialVersionUID(8886909077084990906L) +@serializable @SerialVersionUID(1L) class HashMap[A, +B] extends Map[A,B] with MapLike[A, B, HashMap[A, B]] with mutable.HashTable[A] { type Entry = scala.collection.mutable.DefaultEntry[A, Any] - protected var later: HashMap[A, B @uncheckedVariance] = null - protected var oldKey: A = _ - protected var oldValue: Option[B @uncheckedVariance] = _ - protected var deltaSize: Int = _ + @transient protected var later: HashMap[A, B @uncheckedVariance] = null + @transient protected var oldKey: A = _ + @transient protected var oldValue: Option[B @uncheckedVariance] = _ + @transient protected var deltaSize: Int = _ override def empty = HashMap.empty[A, B] @@ -125,10 +125,12 @@ class HashMap[A, +B] extends Map[A,B] with MapLike[A, B, HashMap[A, B]] with mut private def logLimit: Int = math.sqrt(table.length).toInt private[this] def markUpdated(key: A, ov: Option[B], delta: Int) { - val lv = loadFactor + val lf = loadFactor later = new HashMap[A, B] { override def initialSize = 0 - override def loadFactor = lv + /* We need to do this to avoid a reference to the outer HashMap */ + _loadFactor = lf + override def loadFactor = _loadFactor table = HashMap.this.table tableSize = HashMap.this.tableSize threshold = HashMap.this.threshold @@ -174,6 +176,14 @@ class HashMap[A, +B] extends Map[A,B] with MapLike[A, B, HashMap[A, B]] with mut while (m.later != null) m = m.later if (m ne this) makeCopy(m) } + + private def writeObject(out: java.io.ObjectOutputStream) { + serializeTo(out, _.value) + } + + private def readObject(in: java.io.ObjectInputStream) { + init[B](in, new Entry(_, _)) + } } /** A factory object for immutable HashMaps. diff --git a/src/library/scala/collection/immutable/HashSet.scala b/src/library/scala/collection/immutable/HashSet.scala index e55469d173..0ad6e156a2 100644 --- a/src/library/scala/collection/immutable/HashSet.scala +++ b/src/library/scala/collection/immutable/HashSet.scala @@ -32,16 +32,16 @@ import generic._ * @version 2.8 * @since 2.3 */ -@serializable @SerialVersionUID(4020728942921483037L) +@serializable @SerialVersionUID(1L) class HashSet[A] extends Set[A] with GenericSetTemplate[A, HashSet] with SetLike[A, HashSet[A]] with mutable.FlatHashTable[A] { override def companion: GenericCompanion[HashSet] = HashSet - protected var later: HashSet[A] = null - protected var changedElem: A = _ - protected var deleted: Boolean = _ + @transient protected var later: HashSet[A] = null + @transient protected var changedElem: A = _ + @transient protected var deleted: Boolean = _ def contains(elem: A): Boolean = synchronized { var m = this @@ -99,10 +99,12 @@ class HashSet[A] extends Set[A] private def logLimit: Int = math.sqrt(table.length).toInt private def markUpdated(elem: A, del: Boolean) { - val lv = loadFactor + val lf = loadFactor later = new HashSet[A] { override def initialSize = 0 - override def loadFactor = lv + /* We need to do this to avoid a reference to the outer HashMap */ + _loadFactor = lf + override def loadFactor = _loadFactor table = HashSet.this.table tableSize = HashSet.this.tableSize threshold = HashSet.this.threshold @@ -132,6 +134,14 @@ class HashSet[A] extends Set[A] while (m.later != null) m = m.later if (m ne this) makeCopy(m) } + + private def writeObject(s: java.io.ObjectOutputStream) { + serializeTo(s) + } + + private def readObject(in: java.io.ObjectInputStream) { + init(in, x => x) + } } /** A factory object for immutable HashSets. diff --git a/src/library/scala/collection/mutable/FlatHashTable.scala b/src/library/scala/collection/mutable/FlatHashTable.scala index 1d55933050..d06ead7888 100644 --- a/src/library/scala/collection/mutable/FlatHashTable.scala +++ b/src/library/scala/collection/mutable/FlatHashTable.scala @@ -27,18 +27,63 @@ trait FlatHashTable[A] { private final val tableDebug = false + @transient private[collection] var _loadFactor = loadFactor + /** The actual hash table. */ - protected var table: Array[AnyRef] = - if (initialSize == 0) null else new Array(initialSize) + @transient protected var table: Array[AnyRef] = new Array(initialCapacity) /** The number of mappings contained in this hash table. */ - protected var tableSize = 0 + @transient protected var tableSize = 0 /** The next size value at which to resize (capacity * load factor). */ - protected var threshold: Int = newThreshold(initialSize) + @transient protected var threshold: Int = newThreshold(initialCapacity) + + import HashTable.powerOfTwo + private def capacity(expectedSize: Int) = if (expectedSize == 0) 1 else powerOfTwo(expectedSize) + private def initialCapacity = capacity(initialSize) + + /** + * Initialises the collection from the input stream. `f` will be called for each element + * read from the input stream in the order determined by the stream. This is useful for + * structures where iteration order is important (e.g. LinkedHashSet). + * + * The serialization format expected is the one produced by `serializeTo`. + */ + private[collection] def init(in: java.io.ObjectInputStream, f: A => Unit) { + in.defaultReadObject + + _loadFactor = in.readInt + assert(_loadFactor > 0) + + val size = in.readInt + assert(size >= 0) + + table = new Array(capacity(size * loadFactorDenum / _loadFactor)) + threshold = newThreshold(table.size) + + var index = 0 + while (index < size) { + val elem = in.readObject.asInstanceOf[A] + f(elem) + addEntry(elem) + index += 1 + } + } + + /** + * Serializes the collection to the output stream by saving the load factor, collection + * size and collection elements. `foreach` determines the order in which the elements are saved + * to the stream. To deserialize, `init` should be used. + */ + private[collection] def serializeTo(out: java.io.ObjectOutputStream) { + out.defaultWriteObject + out.writeInt(_loadFactor) + out.writeInt(tableSize) + iterator.foreach(out.writeObject) + } def findEntry(elem: A): Option[A] = { var h = index(elemHashCode(elem)) @@ -154,7 +199,7 @@ trait FlatHashTable[A] { protected final def index(hcode: Int) = improve(hcode) & (table.length - 1) private def newThreshold(size: Int) = { - val lf = loadFactor + val lf = _loadFactor assert(lf < (loadFactorDenum / 2), "loadFactor too large; must be < 0.5") (size.toLong * lf / loadFactorDenum ).toInt } diff --git a/src/library/scala/collection/mutable/HashMap.scala b/src/library/scala/collection/mutable/HashMap.scala index e4969a3af0..21d524129b 100644 --- a/src/library/scala/collection/mutable/HashMap.scala +++ b/src/library/scala/collection/mutable/HashMap.scala @@ -17,7 +17,7 @@ import generic._ /** * @since 1 */ -@serializable @SerialVersionUID(-8682987922734091219L) +@serializable @SerialVersionUID(1L) class HashMap[A, B] extends Map[A, B] with MapLike[A, B, HashMap[A, B]] with HashTable[A] { @@ -84,6 +84,14 @@ class HashMap[A, B] extends Map[A, B] def hasNext = iter.hasNext def next = iter.next.value } + + private def writeObject(out: java.io.ObjectOutputStream) { + serializeTo(out, _.value) + } + + private def readObject(in: java.io.ObjectInputStream) { + init[B](in, new Entry(_, _)) + } } /** This class implements mutable maps using a hashtable. diff --git a/src/library/scala/collection/mutable/HashSet.scala b/src/library/scala/collection/mutable/HashSet.scala index 9144a4be88..0630f96fa2 100644 --- a/src/library/scala/collection/mutable/HashSet.scala +++ b/src/library/scala/collection/mutable/HashSet.scala @@ -21,7 +21,7 @@ import generic._ * @version 2.0, 31/12/2006 * @since 1 */ -@serializable +@serializable @SerialVersionUID(1L) class HashSet[A] extends Set[A] with GenericSetTemplate[A, HashSet] with SetLike[A, HashSet[A]] @@ -51,6 +51,14 @@ class HashSet[A] extends Set[A] } override def clone(): Set[A] = new HashSet[A] ++= this + + private def writeObject(s: java.io.ObjectOutputStream) { + serializeTo(s) + } + + private def readObject(in: java.io.ObjectInputStream) { + init(in, x => x) + } } /** Factory object for `HashSet` class */ diff --git a/src/library/scala/collection/mutable/HashTable.scala b/src/library/scala/collection/mutable/HashTable.scala index fa3865ada2..cdb6d6904b 100644 --- a/src/library/scala/collection/mutable/HashTable.scala +++ b/src/library/scala/collection/mutable/HashTable.scala @@ -31,6 +31,7 @@ package mutable * @since 1 */ trait HashTable[A] { + import HashTable._ protected type Entry >: Null <: HashEntry[A, Entry] @@ -47,33 +48,65 @@ trait HashTable[A] { */ protected def initialThreshold: Int = newThreshold(initialCapacity) + @transient private[collection] var _loadFactor = loadFactor + /** The actual hash table. */ - protected var table: Array[HashEntry[A, Entry]] = new Array(initialCapacity) + @transient protected var table: Array[HashEntry[A, Entry]] = new Array(initialCapacity) + + /** The number of mappings contained in this hash table. + */ + @transient protected var tableSize: Int = 0 - private def initialCapacity = if (initialSize == 0) 1 else powerOfTwo(initialSize) + /** The next size value at which to resize (capacity * load factor). + */ + @transient protected var threshold: Int = initialThreshold + + private def initialCapacity = capacity(initialSize) /** - * Returns a power of two >= `target`. + * Initialises the collection from the input stream. `f` will be called for each key/value pair + * read from the input stream in the order determined by the stream. This is useful for + * structures where iteration order is important (e.g. LinkedHashMap). */ - private def powerOfTwo(target: Int): Int = { - /* See http://bits.stephan-brumme.com/roundUpToNextPowerOfTwo.html */ - var c = target - 1; - c |= c >>> 1; - c |= c >>> 2; - c |= c >>> 4; - c |= c >>> 8; - c |= c >>> 16; - c + 1; + private[collection] def init[B](in: java.io.ObjectInputStream, f: (A, B) => Entry) { + in.defaultReadObject + + _loadFactor = in.readInt + assert(_loadFactor > 0) + + val size = in.readInt + assert(size >= 0) + + table = new Array(capacity(size * loadFactorDenum / _loadFactor)) + threshold = newThreshold(table.size) + + var index = 0 + while (index < size) { + addEntry(f(in.readObject.asInstanceOf[A], in.readObject.asInstanceOf[B])) + index += 1 + } } - /** The number of mappings contained in this hash table. + /** + * Serializes the collection to the output stream by saving the load factor, collection + * size, collection keys and collection values. `value` is responsible for providing a value + * from an entry. + * + * `foreach` determines the order in which the key/value pairs are saved to the stream. To + * deserialize, `init` should be used. */ - protected var tableSize: Int = 0 + private[collection] def serializeTo[B](out: java.io.ObjectOutputStream, value: Entry => B) { + out.defaultWriteObject + out.writeInt(loadFactor) + out.writeInt(tableSize) + foreachEntry { entry => + out.writeObject(entry.key) + out.writeObject(value(entry)) + } + } - /** The next size value at which to resize (capacity * load factor). - */ - protected var threshold: Int = initialThreshold + private def capacity(expectedSize: Int) = if (expectedSize == 0) 1 else powerOfTwo(expectedSize) /** Find entry with given key in table, null if not found */ @@ -172,9 +205,9 @@ trait HashTable[A] { } private def newThreshold(size: Int) = - ((size.toLong * loadFactor)/loadFactorDenum).toInt + ((size.toLong * _loadFactor)/loadFactorDenum).toInt - private def resize(newSize: Int) = { + private def resize(newSize: Int) { val oldTable = table table = new Array(newSize) var i = oldTable.length - 1 @@ -206,4 +239,19 @@ trait HashTable[A] { protected final def index(hcode: Int) = improve(hcode) & (table.length - 1) } +private[collection] object HashTable { + /** + * Returns a power of two >= `target`. + */ + private[collection] def powerOfTwo(target: Int): Int = { + /* See http://bits.stephan-brumme.com/roundUpToNextPowerOfTwo.html */ + var c = target - 1; + c |= c >>> 1; + c |= c >>> 2; + c |= c >>> 4; + c |= c >>> 8; + c |= c >>> 16; + c + 1; + } +} diff --git a/src/library/scala/collection/mutable/LinkedHashMap.scala b/src/library/scala/collection/mutable/LinkedHashMap.scala index 308db1a4d4..183a3aedb0 100644 --- a/src/library/scala/collection/mutable/LinkedHashMap.scala +++ b/src/library/scala/collection/mutable/LinkedHashMap.scala @@ -29,7 +29,7 @@ object LinkedHashMap extends MutableMapFactory[LinkedHashMap] { /** * @since 2.7 */ -@serializable +@serializable @SerialVersionUID(1L) class LinkedHashMap[A, B] extends Map[A, B] with MapLike[A, B, LinkedHashMap[A, B]] with HashTable[A] { @@ -39,8 +39,8 @@ class LinkedHashMap[A, B] extends Map[A, B] type Entry = LinkedEntry[A, B] - protected var firstEntry: Entry = null - protected var lastEntry: Entry = null + @transient protected var firstEntry: Entry = null + @transient protected var lastEntry: Entry = null def get(key: A): Option[B] = { val e = findEntry(key) @@ -53,9 +53,7 @@ class LinkedHashMap[A, B] extends Map[A, B] if (e == null) { val e = new Entry(key, value) addEntry(e) - if (firstEntry == null) firstEntry = e - else { lastEntry.later = e; e.earlier = lastEntry } - lastEntry = e + updateLinkedEntries(e) None } else { val v = e.value @@ -64,6 +62,12 @@ class LinkedHashMap[A, B] extends Map[A, B] } } + private def updateLinkedEntries(e: Entry) { + if (firstEntry == null) firstEntry = e + else { lastEntry.later = e; e.earlier = lastEntry } + lastEntry = e + } + override def remove(key: A): Option[B] = { val e = removeEntry(key) if (e eq null) None @@ -115,4 +119,16 @@ class LinkedHashMap[A, B] extends Map[A, B] clearTable() firstEntry = null } + + private def writeObject(out: java.io.ObjectOutputStream) { + serializeTo(out, _.value) + } + + private def readObject(in: java.io.ObjectInputStream) { + init[B](in, { (key, value) => + val entry = new Entry(key, value) + updateLinkedEntries(entry) + entry + }) + } } diff --git a/src/library/scala/collection/mutable/LinkedHashSet.scala b/src/library/scala/collection/mutable/LinkedHashSet.scala index 081b068723..b9313bd1d4 100644 --- a/src/library/scala/collection/mutable/LinkedHashSet.scala +++ b/src/library/scala/collection/mutable/LinkedHashSet.scala @@ -17,7 +17,7 @@ import generic._ * Should be rewritten to be more efficient. * @since 2.2 */ -@serializable +@serializable @SerialVersionUID(1L) class LinkedHashSet[A] extends Set[A] with GenericSetTemplate[A, LinkedHashSet] with SetLike[A, LinkedHashSet[A]] @@ -25,7 +25,7 @@ class LinkedHashSet[A] extends Set[A] { override def companion: GenericCompanion[LinkedHashSet] = LinkedHashSet - protected val ordered = new ListBuffer[A] + @transient private var ordered = new ListBuffer[A] override def size = tableSize @@ -52,6 +52,15 @@ class LinkedHashSet[A] extends Set[A] override def iterator = ordered.iterator override def foreach[U](f: A => U) = ordered foreach f + + private def writeObject(s: java.io.ObjectOutputStream) { + serializeTo(s) + } + + private def readObject(in: java.io.ObjectInputStream) { + ordered = new ListBuffer[A] + init(in, ordered += ) + } } /** Factory object for `LinkedHashSet` class */ diff --git a/test/files/jvm/serialization.check b/test/files/jvm/serialization.check index f1b5b10ec6..2b0ad3888b 100644 --- a/test/files/jvm/serialization.check +++ b/test/files/jvm/serialization.check @@ -78,11 +78,11 @@ y = BitSet(2, 3) x equals y: true, y equals x: true x = Map(2 -> B, 1 -> A, 3 -> C) -y = Map(2 -> B, 1 -> A, 3 -> C) +y = Map(1 -> A, 2 -> B, 3 -> C) x equals y: true, y equals x: true x = Set(1, 2) -y = Set(1, 2) +y = Set(2, 1) x equals y: true, y equals x: true x = List((buffers,20), (layers,2), (title,3)) @@ -158,7 +158,7 @@ y = BitSet(0, 8, 9) x equals y: true, y equals x: true x = Map(A -> 1, C -> 3, B -> 2) -y = Map(A -> 1, C -> 3, B -> 2) +y = Map(B -> 2, C -> 3, A -> 1) x equals y: true, y equals x: true x = Set(layers, buffers, title) diff --git a/test/files/jvm/t1600.scala b/test/files/jvm/t1600.scala new file mode 100644 index 0000000000..1cdcee8547 --- /dev/null +++ b/test/files/jvm/t1600.scala @@ -0,0 +1,76 @@ + +/** + * Checks that serialization of hash-based collections works correctly if the hashCode + * changes on deserialization. + */ +object Test { + + import collection._ + def main(args: Array[String]) { + for (i <- Seq(0, 1, 2, 10, 100)) { + def entries = (0 until i).map(i => (new Foo, i)).toList + def elements = entries.map(_._1) + + val maps = Seq[Map[Foo, Int]](new mutable.HashMap, new mutable.LinkedHashMap, + new immutable.HashMap).map(_ ++ entries) + test[Map[Foo, Int]](maps, entries.size, assertMap _) + + val sets = Seq[Set[Foo]](new mutable.HashSet, new mutable.LinkedHashSet, + new immutable.HashSet).map(_ ++ elements) + test[Set[Foo]](sets, entries.size, assertSet _) + } + } + + private def test[A <: AnyRef](collections: Seq[A], expectedSize: Int, assertFunction: (A, Int) => Unit) { + for (collection <- collections) { + assertFunction(collection, expectedSize) + + val bytes = toBytes(collection) + Foo.hashCodeModifier = 1 + val deserializedCollection = toObject[A](bytes) + + assertFunction(deserializedCollection, expectedSize) + assert(deserializedCollection.getClass == collection.getClass, + "collection class should remain the same after deserialization") + Foo.hashCodeModifier = 0 + } + } + + private def toObject[A](bytes: Array[Byte]): A = { + val in = new java.io.ObjectInputStream(new java.io.ByteArrayInputStream(bytes)) + in.readObject.asInstanceOf[A] + } + + private def toBytes(o: AnyRef): Array[Byte] = { + val bos = new java.io.ByteArrayOutputStream + val out = new java.io.ObjectOutputStream(bos) + out.writeObject(o) + out.close + bos.toByteArray + } + + private def assertMap[A, B](map: Map[A, B], expectedSize: Int) { + assert(expectedSize == map.size, "expected map size: " + expectedSize + ", actual size: " + map.size) + map.foreach { case (k, v) => + assert(map.contains(k), "contains should return true for key in the map, key: " + k) + assert(map(k) == v) + } + } + + private def assertSet[A](set: Set[A], expectedSize: Int) { + assert(expectedSize == set.size, "expected set size: " + expectedSize + ", actual size: " + set.size) + set.foreach { e => assert(set.contains(e), "contains should return true for element in the set, element: " + e) } + } + + object Foo { + /* Used to simulate a hashCode change caused by deserializing an instance with an + * identity-based hashCode in another JVM. + */ + var hashCodeModifier = 0 + } + + @serializable + class Foo { + override def hashCode = System.identityHashCode(this) + Foo.hashCodeModifier + } +} -- cgit v1.2.3