summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorPaul Phillips <paulp@improving.org>2009-12-01 18:28:55 +0000
committerPaul Phillips <paulp@improving.org>2009-12-01 18:28:55 +0000
commita3bf3f136caaefa98268607a3529b7554df5fc80 (patch)
tree220322df1130f864af06661baca1b7f4434ee32d
parentc2359ccec521ed24641fb010774e7b39b4ae62b2 (diff)
downloadscala-a3bf3f136caaefa98268607a3529b7554df5fc80.tar.gz
scala-a3bf3f136caaefa98268607a3529b7554df5fc80.tar.bz2
scala-a3bf3f136caaefa98268607a3529b7554df5fc80.zip
[This patch submitted by ismael juma - commit m...
[This patch submitted by ismael juma - commit message his words, but condensed.] Fix ticket #1600: Serialization and deserialization of hash-based collections should not re-use hashCode. The collection is rebuilt on deserialization - note that this is not compatible with the previous serialization format. All @SerialVersionUIDs have been reset to 1. WeakHashMap is not Serializable and should not be so. TreeHashMap has not been reintegrated yet. OpenHashMap has not been updated. (I think this collection is flawed and should be removed or reimplemented.)
-rw-r--r--src/library/scala/collection/immutable/HashMap.scala24
-rw-r--r--src/library/scala/collection/immutable/HashSet.scala22
-rw-r--r--src/library/scala/collection/mutable/FlatHashTable.scala55
-rw-r--r--src/library/scala/collection/mutable/HashMap.scala10
-rw-r--r--src/library/scala/collection/mutable/HashSet.scala10
-rw-r--r--src/library/scala/collection/mutable/HashTable.scala86
-rw-r--r--src/library/scala/collection/mutable/LinkedHashMap.scala28
-rw-r--r--src/library/scala/collection/mutable/LinkedHashSet.scala13
-rw-r--r--test/files/jvm/serialization.check6
-rw-r--r--test/files/jvm/t1600.scala76
10 files changed, 280 insertions, 50 deletions
diff --git a/src/library/scala/collection/immutable/HashMap.scala b/src/library/scala/collection/immutable/HashMap.scala
index a036a9abfb..7825d62527 100644
--- a/src/library/scala/collection/immutable/HashMap.scala
+++ b/src/library/scala/collection/immutable/HashMap.scala
@@ -33,15 +33,15 @@ import annotation.unchecked.uncheckedVariance
* @version 2.0, 19/01/2007
* @since 2.3
*/
-@serializable @SerialVersionUID(8886909077084990906L)
+@serializable @SerialVersionUID(1L)
class HashMap[A, +B] extends Map[A,B] with MapLike[A, B, HashMap[A, B]] with mutable.HashTable[A] {
type Entry = scala.collection.mutable.DefaultEntry[A, Any]
- protected var later: HashMap[A, B @uncheckedVariance] = null
- protected var oldKey: A = _
- protected var oldValue: Option[B @uncheckedVariance] = _
- protected var deltaSize: Int = _
+ @transient protected var later: HashMap[A, B @uncheckedVariance] = null
+ @transient protected var oldKey: A = _
+ @transient protected var oldValue: Option[B @uncheckedVariance] = _
+ @transient protected var deltaSize: Int = _
override def empty = HashMap.empty[A, B]
@@ -125,10 +125,12 @@ class HashMap[A, +B] extends Map[A,B] with MapLike[A, B, HashMap[A, B]] with mut
private def logLimit: Int = math.sqrt(table.length).toInt
private[this] def markUpdated(key: A, ov: Option[B], delta: Int) {
- val lv = loadFactor
+ val lf = loadFactor
later = new HashMap[A, B] {
override def initialSize = 0
- override def loadFactor = lv
+ /* We need to do this to avoid a reference to the outer HashMap */
+ _loadFactor = lf
+ override def loadFactor = _loadFactor
table = HashMap.this.table
tableSize = HashMap.this.tableSize
threshold = HashMap.this.threshold
@@ -174,6 +176,14 @@ class HashMap[A, +B] extends Map[A,B] with MapLike[A, B, HashMap[A, B]] with mut
while (m.later != null) m = m.later
if (m ne this) makeCopy(m)
}
+
+ private def writeObject(out: java.io.ObjectOutputStream) {
+ serializeTo(out, _.value)
+ }
+
+ private def readObject(in: java.io.ObjectInputStream) {
+ init[B](in, new Entry(_, _))
+ }
}
/** A factory object for immutable HashMaps.
diff --git a/src/library/scala/collection/immutable/HashSet.scala b/src/library/scala/collection/immutable/HashSet.scala
index e55469d173..0ad6e156a2 100644
--- a/src/library/scala/collection/immutable/HashSet.scala
+++ b/src/library/scala/collection/immutable/HashSet.scala
@@ -32,16 +32,16 @@ import generic._
* @version 2.8
* @since 2.3
*/
-@serializable @SerialVersionUID(4020728942921483037L)
+@serializable @SerialVersionUID(1L)
class HashSet[A] extends Set[A]
with GenericSetTemplate[A, HashSet]
with SetLike[A, HashSet[A]]
with mutable.FlatHashTable[A] {
override def companion: GenericCompanion[HashSet] = HashSet
- protected var later: HashSet[A] = null
- protected var changedElem: A = _
- protected var deleted: Boolean = _
+ @transient protected var later: HashSet[A] = null
+ @transient protected var changedElem: A = _
+ @transient protected var deleted: Boolean = _
def contains(elem: A): Boolean = synchronized {
var m = this
@@ -99,10 +99,12 @@ class HashSet[A] extends Set[A]
private def logLimit: Int = math.sqrt(table.length).toInt
private def markUpdated(elem: A, del: Boolean) {
- val lv = loadFactor
+ val lf = loadFactor
later = new HashSet[A] {
override def initialSize = 0
- override def loadFactor = lv
+ /* We need to do this to avoid a reference to the outer HashMap */
+ _loadFactor = lf
+ override def loadFactor = _loadFactor
table = HashSet.this.table
tableSize = HashSet.this.tableSize
threshold = HashSet.this.threshold
@@ -132,6 +134,14 @@ class HashSet[A] extends Set[A]
while (m.later != null) m = m.later
if (m ne this) makeCopy(m)
}
+
+ private def writeObject(s: java.io.ObjectOutputStream) {
+ serializeTo(s)
+ }
+
+ private def readObject(in: java.io.ObjectInputStream) {
+ init(in, x => x)
+ }
}
/** A factory object for immutable HashSets.
diff --git a/src/library/scala/collection/mutable/FlatHashTable.scala b/src/library/scala/collection/mutable/FlatHashTable.scala
index 1d55933050..d06ead7888 100644
--- a/src/library/scala/collection/mutable/FlatHashTable.scala
+++ b/src/library/scala/collection/mutable/FlatHashTable.scala
@@ -27,18 +27,63 @@ trait FlatHashTable[A] {
private final val tableDebug = false
+ @transient private[collection] var _loadFactor = loadFactor
+
/** The actual hash table.
*/
- protected var table: Array[AnyRef] =
- if (initialSize == 0) null else new Array(initialSize)
+ @transient protected var table: Array[AnyRef] = new Array(initialCapacity)
/** The number of mappings contained in this hash table.
*/
- protected var tableSize = 0
+ @transient protected var tableSize = 0
/** The next size value at which to resize (capacity * load factor).
*/
- protected var threshold: Int = newThreshold(initialSize)
+ @transient protected var threshold: Int = newThreshold(initialCapacity)
+
+ import HashTable.powerOfTwo
+ private def capacity(expectedSize: Int) = if (expectedSize == 0) 1 else powerOfTwo(expectedSize)
+ private def initialCapacity = capacity(initialSize)
+
+ /**
+ * Initialises the collection from the input stream. `f` will be called for each element
+ * read from the input stream in the order determined by the stream. This is useful for
+ * structures where iteration order is important (e.g. LinkedHashSet).
+ *
+ * The serialization format expected is the one produced by `serializeTo`.
+ */
+ private[collection] def init(in: java.io.ObjectInputStream, f: A => Unit) {
+ in.defaultReadObject
+
+ _loadFactor = in.readInt
+ assert(_loadFactor > 0)
+
+ val size = in.readInt
+ assert(size >= 0)
+
+ table = new Array(capacity(size * loadFactorDenum / _loadFactor))
+ threshold = newThreshold(table.size)
+
+ var index = 0
+ while (index < size) {
+ val elem = in.readObject.asInstanceOf[A]
+ f(elem)
+ addEntry(elem)
+ index += 1
+ }
+ }
+
+ /**
+ * Serializes the collection to the output stream by saving the load factor, collection
+ * size and collection elements. `foreach` determines the order in which the elements are saved
+ * to the stream. To deserialize, `init` should be used.
+ */
+ private[collection] def serializeTo(out: java.io.ObjectOutputStream) {
+ out.defaultWriteObject
+ out.writeInt(_loadFactor)
+ out.writeInt(tableSize)
+ iterator.foreach(out.writeObject)
+ }
def findEntry(elem: A): Option[A] = {
var h = index(elemHashCode(elem))
@@ -154,7 +199,7 @@ trait FlatHashTable[A] {
protected final def index(hcode: Int) = improve(hcode) & (table.length - 1)
private def newThreshold(size: Int) = {
- val lf = loadFactor
+ val lf = _loadFactor
assert(lf < (loadFactorDenum / 2), "loadFactor too large; must be < 0.5")
(size.toLong * lf / loadFactorDenum ).toInt
}
diff --git a/src/library/scala/collection/mutable/HashMap.scala b/src/library/scala/collection/mutable/HashMap.scala
index e4969a3af0..21d524129b 100644
--- a/src/library/scala/collection/mutable/HashMap.scala
+++ b/src/library/scala/collection/mutable/HashMap.scala
@@ -17,7 +17,7 @@ import generic._
/**
* @since 1
*/
-@serializable @SerialVersionUID(-8682987922734091219L)
+@serializable @SerialVersionUID(1L)
class HashMap[A, B] extends Map[A, B]
with MapLike[A, B, HashMap[A, B]]
with HashTable[A] {
@@ -84,6 +84,14 @@ class HashMap[A, B] extends Map[A, B]
def hasNext = iter.hasNext
def next = iter.next.value
}
+
+ private def writeObject(out: java.io.ObjectOutputStream) {
+ serializeTo(out, _.value)
+ }
+
+ private def readObject(in: java.io.ObjectInputStream) {
+ init[B](in, new Entry(_, _))
+ }
}
/** This class implements mutable maps using a hashtable.
diff --git a/src/library/scala/collection/mutable/HashSet.scala b/src/library/scala/collection/mutable/HashSet.scala
index 9144a4be88..0630f96fa2 100644
--- a/src/library/scala/collection/mutable/HashSet.scala
+++ b/src/library/scala/collection/mutable/HashSet.scala
@@ -21,7 +21,7 @@ import generic._
* @version 2.0, 31/12/2006
* @since 1
*/
-@serializable
+@serializable @SerialVersionUID(1L)
class HashSet[A] extends Set[A]
with GenericSetTemplate[A, HashSet]
with SetLike[A, HashSet[A]]
@@ -51,6 +51,14 @@ class HashSet[A] extends Set[A]
}
override def clone(): Set[A] = new HashSet[A] ++= this
+
+ private def writeObject(s: java.io.ObjectOutputStream) {
+ serializeTo(s)
+ }
+
+ private def readObject(in: java.io.ObjectInputStream) {
+ init(in, x => x)
+ }
}
/** Factory object for `HashSet` class */
diff --git a/src/library/scala/collection/mutable/HashTable.scala b/src/library/scala/collection/mutable/HashTable.scala
index fa3865ada2..cdb6d6904b 100644
--- a/src/library/scala/collection/mutable/HashTable.scala
+++ b/src/library/scala/collection/mutable/HashTable.scala
@@ -31,6 +31,7 @@ package mutable
* @since 1
*/
trait HashTable[A] {
+ import HashTable._
protected type Entry >: Null <: HashEntry[A, Entry]
@@ -47,33 +48,65 @@ trait HashTable[A] {
*/
protected def initialThreshold: Int = newThreshold(initialCapacity)
+ @transient private[collection] var _loadFactor = loadFactor
+
/** The actual hash table.
*/
- protected var table: Array[HashEntry[A, Entry]] = new Array(initialCapacity)
+ @transient protected var table: Array[HashEntry[A, Entry]] = new Array(initialCapacity)
+
+ /** The number of mappings contained in this hash table.
+ */
+ @transient protected var tableSize: Int = 0
- private def initialCapacity = if (initialSize == 0) 1 else powerOfTwo(initialSize)
+ /** The next size value at which to resize (capacity * load factor).
+ */
+ @transient protected var threshold: Int = initialThreshold
+
+ private def initialCapacity = capacity(initialSize)
/**
- * Returns a power of two >= `target`.
+ * Initialises the collection from the input stream. `f` will be called for each key/value pair
+ * read from the input stream in the order determined by the stream. This is useful for
+ * structures where iteration order is important (e.g. LinkedHashMap).
*/
- private def powerOfTwo(target: Int): Int = {
- /* See http://bits.stephan-brumme.com/roundUpToNextPowerOfTwo.html */
- var c = target - 1;
- c |= c >>> 1;
- c |= c >>> 2;
- c |= c >>> 4;
- c |= c >>> 8;
- c |= c >>> 16;
- c + 1;
+ private[collection] def init[B](in: java.io.ObjectInputStream, f: (A, B) => Entry) {
+ in.defaultReadObject
+
+ _loadFactor = in.readInt
+ assert(_loadFactor > 0)
+
+ val size = in.readInt
+ assert(size >= 0)
+
+ table = new Array(capacity(size * loadFactorDenum / _loadFactor))
+ threshold = newThreshold(table.size)
+
+ var index = 0
+ while (index < size) {
+ addEntry(f(in.readObject.asInstanceOf[A], in.readObject.asInstanceOf[B]))
+ index += 1
+ }
}
- /** The number of mappings contained in this hash table.
+ /**
+ * Serializes the collection to the output stream by saving the load factor, collection
+ * size, collection keys and collection values. `value` is responsible for providing a value
+ * from an entry.
+ *
+ * `foreach` determines the order in which the key/value pairs are saved to the stream. To
+ * deserialize, `init` should be used.
*/
- protected var tableSize: Int = 0
+ private[collection] def serializeTo[B](out: java.io.ObjectOutputStream, value: Entry => B) {
+ out.defaultWriteObject
+ out.writeInt(loadFactor)
+ out.writeInt(tableSize)
+ foreachEntry { entry =>
+ out.writeObject(entry.key)
+ out.writeObject(value(entry))
+ }
+ }
- /** The next size value at which to resize (capacity * load factor).
- */
- protected var threshold: Int = initialThreshold
+ private def capacity(expectedSize: Int) = if (expectedSize == 0) 1 else powerOfTwo(expectedSize)
/** Find entry with given key in table, null if not found
*/
@@ -172,9 +205,9 @@ trait HashTable[A] {
}
private def newThreshold(size: Int) =
- ((size.toLong * loadFactor)/loadFactorDenum).toInt
+ ((size.toLong * _loadFactor)/loadFactorDenum).toInt
- private def resize(newSize: Int) = {
+ private def resize(newSize: Int) {
val oldTable = table
table = new Array(newSize)
var i = oldTable.length - 1
@@ -206,4 +239,19 @@ trait HashTable[A] {
protected final def index(hcode: Int) = improve(hcode) & (table.length - 1)
}
+private[collection] object HashTable {
+ /**
+ * Returns a power of two >= `target`.
+ */
+ private[collection] def powerOfTwo(target: Int): Int = {
+ /* See http://bits.stephan-brumme.com/roundUpToNextPowerOfTwo.html */
+ var c = target - 1;
+ c |= c >>> 1;
+ c |= c >>> 2;
+ c |= c >>> 4;
+ c |= c >>> 8;
+ c |= c >>> 16;
+ c + 1;
+ }
+}
diff --git a/src/library/scala/collection/mutable/LinkedHashMap.scala b/src/library/scala/collection/mutable/LinkedHashMap.scala
index 308db1a4d4..183a3aedb0 100644
--- a/src/library/scala/collection/mutable/LinkedHashMap.scala
+++ b/src/library/scala/collection/mutable/LinkedHashMap.scala
@@ -29,7 +29,7 @@ object LinkedHashMap extends MutableMapFactory[LinkedHashMap] {
/**
* @since 2.7
*/
-@serializable
+@serializable @SerialVersionUID(1L)
class LinkedHashMap[A, B] extends Map[A, B]
with MapLike[A, B, LinkedHashMap[A, B]]
with HashTable[A] {
@@ -39,8 +39,8 @@ class LinkedHashMap[A, B] extends Map[A, B]
type Entry = LinkedEntry[A, B]
- protected var firstEntry: Entry = null
- protected var lastEntry: Entry = null
+ @transient protected var firstEntry: Entry = null
+ @transient protected var lastEntry: Entry = null
def get(key: A): Option[B] = {
val e = findEntry(key)
@@ -53,9 +53,7 @@ class LinkedHashMap[A, B] extends Map[A, B]
if (e == null) {
val e = new Entry(key, value)
addEntry(e)
- if (firstEntry == null) firstEntry = e
- else { lastEntry.later = e; e.earlier = lastEntry }
- lastEntry = e
+ updateLinkedEntries(e)
None
} else {
val v = e.value
@@ -64,6 +62,12 @@ class LinkedHashMap[A, B] extends Map[A, B]
}
}
+ private def updateLinkedEntries(e: Entry) {
+ if (firstEntry == null) firstEntry = e
+ else { lastEntry.later = e; e.earlier = lastEntry }
+ lastEntry = e
+ }
+
override def remove(key: A): Option[B] = {
val e = removeEntry(key)
if (e eq null) None
@@ -115,4 +119,16 @@ class LinkedHashMap[A, B] extends Map[A, B]
clearTable()
firstEntry = null
}
+
+ private def writeObject(out: java.io.ObjectOutputStream) {
+ serializeTo(out, _.value)
+ }
+
+ private def readObject(in: java.io.ObjectInputStream) {
+ init[B](in, { (key, value) =>
+ val entry = new Entry(key, value)
+ updateLinkedEntries(entry)
+ entry
+ })
+ }
}
diff --git a/src/library/scala/collection/mutable/LinkedHashSet.scala b/src/library/scala/collection/mutable/LinkedHashSet.scala
index 081b068723..b9313bd1d4 100644
--- a/src/library/scala/collection/mutable/LinkedHashSet.scala
+++ b/src/library/scala/collection/mutable/LinkedHashSet.scala
@@ -17,7 +17,7 @@ import generic._
* Should be rewritten to be more efficient.
* @since 2.2
*/
-@serializable
+@serializable @SerialVersionUID(1L)
class LinkedHashSet[A] extends Set[A]
with GenericSetTemplate[A, LinkedHashSet]
with SetLike[A, LinkedHashSet[A]]
@@ -25,7 +25,7 @@ class LinkedHashSet[A] extends Set[A]
{
override def companion: GenericCompanion[LinkedHashSet] = LinkedHashSet
- protected val ordered = new ListBuffer[A]
+ @transient private var ordered = new ListBuffer[A]
override def size = tableSize
@@ -52,6 +52,15 @@ class LinkedHashSet[A] extends Set[A]
override def iterator = ordered.iterator
override def foreach[U](f: A => U) = ordered foreach f
+
+ private def writeObject(s: java.io.ObjectOutputStream) {
+ serializeTo(s)
+ }
+
+ private def readObject(in: java.io.ObjectInputStream) {
+ ordered = new ListBuffer[A]
+ init(in, ordered += )
+ }
}
/** Factory object for `LinkedHashSet` class */
diff --git a/test/files/jvm/serialization.check b/test/files/jvm/serialization.check
index f1b5b10ec6..2b0ad3888b 100644
--- a/test/files/jvm/serialization.check
+++ b/test/files/jvm/serialization.check
@@ -78,11 +78,11 @@ y = BitSet(2, 3)
x equals y: true, y equals x: true
x = Map(2 -> B, 1 -> A, 3 -> C)
-y = Map(2 -> B, 1 -> A, 3 -> C)
+y = Map(1 -> A, 2 -> B, 3 -> C)
x equals y: true, y equals x: true
x = Set(1, 2)
-y = Set(1, 2)
+y = Set(2, 1)
x equals y: true, y equals x: true
x = List((buffers,20), (layers,2), (title,3))
@@ -158,7 +158,7 @@ y = BitSet(0, 8, 9)
x equals y: true, y equals x: true
x = Map(A -> 1, C -> 3, B -> 2)
-y = Map(A -> 1, C -> 3, B -> 2)
+y = Map(B -> 2, C -> 3, A -> 1)
x equals y: true, y equals x: true
x = Set(layers, buffers, title)
diff --git a/test/files/jvm/t1600.scala b/test/files/jvm/t1600.scala
new file mode 100644
index 0000000000..1cdcee8547
--- /dev/null
+++ b/test/files/jvm/t1600.scala
@@ -0,0 +1,76 @@
+
+/**
+ * Checks that serialization of hash-based collections works correctly if the hashCode
+ * changes on deserialization.
+ */
+object Test {
+
+ import collection._
+ def main(args: Array[String]) {
+ for (i <- Seq(0, 1, 2, 10, 100)) {
+ def entries = (0 until i).map(i => (new Foo, i)).toList
+ def elements = entries.map(_._1)
+
+ val maps = Seq[Map[Foo, Int]](new mutable.HashMap, new mutable.LinkedHashMap,
+ new immutable.HashMap).map(_ ++ entries)
+ test[Map[Foo, Int]](maps, entries.size, assertMap _)
+
+ val sets = Seq[Set[Foo]](new mutable.HashSet, new mutable.LinkedHashSet,
+ new immutable.HashSet).map(_ ++ elements)
+ test[Set[Foo]](sets, entries.size, assertSet _)
+ }
+ }
+
+ private def test[A <: AnyRef](collections: Seq[A], expectedSize: Int, assertFunction: (A, Int) => Unit) {
+ for (collection <- collections) {
+ assertFunction(collection, expectedSize)
+
+ val bytes = toBytes(collection)
+ Foo.hashCodeModifier = 1
+ val deserializedCollection = toObject[A](bytes)
+
+ assertFunction(deserializedCollection, expectedSize)
+ assert(deserializedCollection.getClass == collection.getClass,
+ "collection class should remain the same after deserialization")
+ Foo.hashCodeModifier = 0
+ }
+ }
+
+ private def toObject[A](bytes: Array[Byte]): A = {
+ val in = new java.io.ObjectInputStream(new java.io.ByteArrayInputStream(bytes))
+ in.readObject.asInstanceOf[A]
+ }
+
+ private def toBytes(o: AnyRef): Array[Byte] = {
+ val bos = new java.io.ByteArrayOutputStream
+ val out = new java.io.ObjectOutputStream(bos)
+ out.writeObject(o)
+ out.close
+ bos.toByteArray
+ }
+
+ private def assertMap[A, B](map: Map[A, B], expectedSize: Int) {
+ assert(expectedSize == map.size, "expected map size: " + expectedSize + ", actual size: " + map.size)
+ map.foreach { case (k, v) =>
+ assert(map.contains(k), "contains should return true for key in the map, key: " + k)
+ assert(map(k) == v)
+ }
+ }
+
+ private def assertSet[A](set: Set[A], expectedSize: Int) {
+ assert(expectedSize == set.size, "expected set size: " + expectedSize + ", actual size: " + set.size)
+ set.foreach { e => assert(set.contains(e), "contains should return true for element in the set, element: " + e) }
+ }
+
+ object Foo {
+ /* Used to simulate a hashCode change caused by deserializing an instance with an
+ * identity-based hashCode in another JVM.
+ */
+ var hashCodeModifier = 0
+ }
+
+ @serializable
+ class Foo {
+ override def hashCode = System.identityHashCode(this) + Foo.hashCodeModifier
+ }
+}