summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/library/scala/collection/immutable/HashMap.scala24
-rw-r--r--src/library/scala/collection/immutable/HashSet.scala22
-rw-r--r--src/library/scala/collection/mutable/FlatHashTable.scala55
-rw-r--r--src/library/scala/collection/mutable/HashMap.scala10
-rw-r--r--src/library/scala/collection/mutable/HashSet.scala10
-rw-r--r--src/library/scala/collection/mutable/HashTable.scala86
-rw-r--r--src/library/scala/collection/mutable/LinkedHashMap.scala28
-rw-r--r--src/library/scala/collection/mutable/LinkedHashSet.scala13
-rw-r--r--test/files/jvm/serialization.check6
-rw-r--r--test/files/jvm/t1600.scala76
10 files changed, 280 insertions, 50 deletions
diff --git a/src/library/scala/collection/immutable/HashMap.scala b/src/library/scala/collection/immutable/HashMap.scala
index a036a9abfb..7825d62527 100644
--- a/src/library/scala/collection/immutable/HashMap.scala
+++ b/src/library/scala/collection/immutable/HashMap.scala
@@ -33,15 +33,15 @@ import annotation.unchecked.uncheckedVariance
* @version 2.0, 19/01/2007
* @since 2.3
*/
-@serializable @SerialVersionUID(8886909077084990906L)
+@serializable @SerialVersionUID(1L)
class HashMap[A, +B] extends Map[A,B] with MapLike[A, B, HashMap[A, B]] with mutable.HashTable[A] {
type Entry = scala.collection.mutable.DefaultEntry[A, Any]
- protected var later: HashMap[A, B @uncheckedVariance] = null
- protected var oldKey: A = _
- protected var oldValue: Option[B @uncheckedVariance] = _
- protected var deltaSize: Int = _
+ @transient protected var later: HashMap[A, B @uncheckedVariance] = null
+ @transient protected var oldKey: A = _
+ @transient protected var oldValue: Option[B @uncheckedVariance] = _
+ @transient protected var deltaSize: Int = _
override def empty = HashMap.empty[A, B]
@@ -125,10 +125,12 @@ class HashMap[A, +B] extends Map[A,B] with MapLike[A, B, HashMap[A, B]] with mut
private def logLimit: Int = math.sqrt(table.length).toInt
private[this] def markUpdated(key: A, ov: Option[B], delta: Int) {
- val lv = loadFactor
+ val lf = loadFactor
later = new HashMap[A, B] {
override def initialSize = 0
- override def loadFactor = lv
+ /* We need to do this to avoid a reference to the outer HashMap */
+ _loadFactor = lf
+ override def loadFactor = _loadFactor
table = HashMap.this.table
tableSize = HashMap.this.tableSize
threshold = HashMap.this.threshold
@@ -174,6 +176,14 @@ class HashMap[A, +B] extends Map[A,B] with MapLike[A, B, HashMap[A, B]] with mut
while (m.later != null) m = m.later
if (m ne this) makeCopy(m)
}
+
+ private def writeObject(out: java.io.ObjectOutputStream) {
+ serializeTo(out, _.value)
+ }
+
+ private def readObject(in: java.io.ObjectInputStream) {
+ init[B](in, new Entry(_, _))
+ }
}
/** A factory object for immutable HashMaps.
diff --git a/src/library/scala/collection/immutable/HashSet.scala b/src/library/scala/collection/immutable/HashSet.scala
index e55469d173..0ad6e156a2 100644
--- a/src/library/scala/collection/immutable/HashSet.scala
+++ b/src/library/scala/collection/immutable/HashSet.scala
@@ -32,16 +32,16 @@ import generic._
* @version 2.8
* @since 2.3
*/
-@serializable @SerialVersionUID(4020728942921483037L)
+@serializable @SerialVersionUID(1L)
class HashSet[A] extends Set[A]
with GenericSetTemplate[A, HashSet]
with SetLike[A, HashSet[A]]
with mutable.FlatHashTable[A] {
override def companion: GenericCompanion[HashSet] = HashSet
- protected var later: HashSet[A] = null
- protected var changedElem: A = _
- protected var deleted: Boolean = _
+ @transient protected var later: HashSet[A] = null
+ @transient protected var changedElem: A = _
+ @transient protected var deleted: Boolean = _
def contains(elem: A): Boolean = synchronized {
var m = this
@@ -99,10 +99,12 @@ class HashSet[A] extends Set[A]
private def logLimit: Int = math.sqrt(table.length).toInt
private def markUpdated(elem: A, del: Boolean) {
- val lv = loadFactor
+ val lf = loadFactor
later = new HashSet[A] {
override def initialSize = 0
- override def loadFactor = lv
+ /* We need to do this to avoid a reference to the outer HashMap */
+ _loadFactor = lf
+ override def loadFactor = _loadFactor
table = HashSet.this.table
tableSize = HashSet.this.tableSize
threshold = HashSet.this.threshold
@@ -132,6 +134,14 @@ class HashSet[A] extends Set[A]
while (m.later != null) m = m.later
if (m ne this) makeCopy(m)
}
+
+ private def writeObject(s: java.io.ObjectOutputStream) {
+ serializeTo(s)
+ }
+
+ private def readObject(in: java.io.ObjectInputStream) {
+ init(in, x => x)
+ }
}
/** A factory object for immutable HashSets.
diff --git a/src/library/scala/collection/mutable/FlatHashTable.scala b/src/library/scala/collection/mutable/FlatHashTable.scala
index 1d55933050..d06ead7888 100644
--- a/src/library/scala/collection/mutable/FlatHashTable.scala
+++ b/src/library/scala/collection/mutable/FlatHashTable.scala
@@ -27,18 +27,63 @@ trait FlatHashTable[A] {
private final val tableDebug = false
+ @transient private[collection] var _loadFactor = loadFactor
+
/** The actual hash table.
*/
- protected var table: Array[AnyRef] =
- if (initialSize == 0) null else new Array(initialSize)
+ @transient protected var table: Array[AnyRef] = new Array(initialCapacity)
/** The number of mappings contained in this hash table.
*/
- protected var tableSize = 0
+ @transient protected var tableSize = 0
/** The next size value at which to resize (capacity * load factor).
*/
- protected var threshold: Int = newThreshold(initialSize)
+ @transient protected var threshold: Int = newThreshold(initialCapacity)
+
+ import HashTable.powerOfTwo
+ private def capacity(expectedSize: Int) = if (expectedSize == 0) 1 else powerOfTwo(expectedSize)
+ private def initialCapacity = capacity(initialSize)
+
+ /**
+ * Initialises the collection from the input stream. `f` will be called for each element
+ * read from the input stream in the order determined by the stream. This is useful for
+ * structures where iteration order is important (e.g. LinkedHashSet).
+ *
+ * The serialization format expected is the one produced by `serializeTo`.
+ */
+ private[collection] def init(in: java.io.ObjectInputStream, f: A => Unit) {
+ in.defaultReadObject
+
+ _loadFactor = in.readInt
+ assert(_loadFactor > 0)
+
+ val size = in.readInt
+ assert(size >= 0)
+
+ table = new Array(capacity(size * loadFactorDenum / _loadFactor))
+ threshold = newThreshold(table.size)
+
+ var index = 0
+ while (index < size) {
+ val elem = in.readObject.asInstanceOf[A]
+ f(elem)
+ addEntry(elem)
+ index += 1
+ }
+ }
+
+ /**
+ * Serializes the collection to the output stream by saving the load factor, collection
+ * size and collection elements. `foreach` determines the order in which the elements are saved
+ * to the stream. To deserialize, `init` should be used.
+ */
+ private[collection] def serializeTo(out: java.io.ObjectOutputStream) {
+ out.defaultWriteObject
+ out.writeInt(_loadFactor)
+ out.writeInt(tableSize)
+ iterator.foreach(out.writeObject)
+ }
def findEntry(elem: A): Option[A] = {
var h = index(elemHashCode(elem))
@@ -154,7 +199,7 @@ trait FlatHashTable[A] {
protected final def index(hcode: Int) = improve(hcode) & (table.length - 1)
private def newThreshold(size: Int) = {
- val lf = loadFactor
+ val lf = _loadFactor
assert(lf < (loadFactorDenum / 2), "loadFactor too large; must be < 0.5")
(size.toLong * lf / loadFactorDenum ).toInt
}
diff --git a/src/library/scala/collection/mutable/HashMap.scala b/src/library/scala/collection/mutable/HashMap.scala
index e4969a3af0..21d524129b 100644
--- a/src/library/scala/collection/mutable/HashMap.scala
+++ b/src/library/scala/collection/mutable/HashMap.scala
@@ -17,7 +17,7 @@ import generic._
/**
* @since 1
*/
-@serializable @SerialVersionUID(-8682987922734091219L)
+@serializable @SerialVersionUID(1L)
class HashMap[A, B] extends Map[A, B]
with MapLike[A, B, HashMap[A, B]]
with HashTable[A] {
@@ -84,6 +84,14 @@ class HashMap[A, B] extends Map[A, B]
def hasNext = iter.hasNext
def next = iter.next.value
}
+
+ private def writeObject(out: java.io.ObjectOutputStream) {
+ serializeTo(out, _.value)
+ }
+
+ private def readObject(in: java.io.ObjectInputStream) {
+ init[B](in, new Entry(_, _))
+ }
}
/** This class implements mutable maps using a hashtable.
diff --git a/src/library/scala/collection/mutable/HashSet.scala b/src/library/scala/collection/mutable/HashSet.scala
index 9144a4be88..0630f96fa2 100644
--- a/src/library/scala/collection/mutable/HashSet.scala
+++ b/src/library/scala/collection/mutable/HashSet.scala
@@ -21,7 +21,7 @@ import generic._
* @version 2.0, 31/12/2006
* @since 1
*/
-@serializable
+@serializable @SerialVersionUID(1L)
class HashSet[A] extends Set[A]
with GenericSetTemplate[A, HashSet]
with SetLike[A, HashSet[A]]
@@ -51,6 +51,14 @@ class HashSet[A] extends Set[A]
}
override def clone(): Set[A] = new HashSet[A] ++= this
+
+ private def writeObject(s: java.io.ObjectOutputStream) {
+ serializeTo(s)
+ }
+
+ private def readObject(in: java.io.ObjectInputStream) {
+ init(in, x => x)
+ }
}
/** Factory object for `HashSet` class */
diff --git a/src/library/scala/collection/mutable/HashTable.scala b/src/library/scala/collection/mutable/HashTable.scala
index fa3865ada2..cdb6d6904b 100644
--- a/src/library/scala/collection/mutable/HashTable.scala
+++ b/src/library/scala/collection/mutable/HashTable.scala
@@ -31,6 +31,7 @@ package mutable
* @since 1
*/
trait HashTable[A] {
+ import HashTable._
protected type Entry >: Null <: HashEntry[A, Entry]
@@ -47,33 +48,65 @@ trait HashTable[A] {
*/
protected def initialThreshold: Int = newThreshold(initialCapacity)
+ @transient private[collection] var _loadFactor = loadFactor
+
/** The actual hash table.
*/
- protected var table: Array[HashEntry[A, Entry]] = new Array(initialCapacity)
+ @transient protected var table: Array[HashEntry[A, Entry]] = new Array(initialCapacity)
+
+ /** The number of mappings contained in this hash table.
+ */
+ @transient protected var tableSize: Int = 0
- private def initialCapacity = if (initialSize == 0) 1 else powerOfTwo(initialSize)
+ /** The next size value at which to resize (capacity * load factor).
+ */
+ @transient protected var threshold: Int = initialThreshold
+
+ private def initialCapacity = capacity(initialSize)
/**
- * Returns a power of two >= `target`.
+ * Initialises the collection from the input stream. `f` will be called for each key/value pair
+ * read from the input stream in the order determined by the stream. This is useful for
+ * structures where iteration order is important (e.g. LinkedHashMap).
*/
- private def powerOfTwo(target: Int): Int = {
- /* See http://bits.stephan-brumme.com/roundUpToNextPowerOfTwo.html */
- var c = target - 1;
- c |= c >>> 1;
- c |= c >>> 2;
- c |= c >>> 4;
- c |= c >>> 8;
- c |= c >>> 16;
- c + 1;
+ private[collection] def init[B](in: java.io.ObjectInputStream, f: (A, B) => Entry) {
+ in.defaultReadObject
+
+ _loadFactor = in.readInt
+ assert(_loadFactor > 0)
+
+ val size = in.readInt
+ assert(size >= 0)
+
+ table = new Array(capacity(size * loadFactorDenum / _loadFactor))
+ threshold = newThreshold(table.size)
+
+ var index = 0
+ while (index < size) {
+ addEntry(f(in.readObject.asInstanceOf[A], in.readObject.asInstanceOf[B]))
+ index += 1
+ }
}
- /** The number of mappings contained in this hash table.
+ /**
+ * Serializes the collection to the output stream by saving the load factor, collection
+ * size, collection keys and collection values. `value` is responsible for providing a value
+ * from an entry.
+ *
+ * `foreach` determines the order in which the key/value pairs are saved to the stream. To
+ * deserialize, `init` should be used.
*/
- protected var tableSize: Int = 0
+ private[collection] def serializeTo[B](out: java.io.ObjectOutputStream, value: Entry => B) {
+ out.defaultWriteObject
+ out.writeInt(loadFactor)
+ out.writeInt(tableSize)
+ foreachEntry { entry =>
+ out.writeObject(entry.key)
+ out.writeObject(value(entry))
+ }
+ }
- /** The next size value at which to resize (capacity * load factor).
- */
- protected var threshold: Int = initialThreshold
+ private def capacity(expectedSize: Int) = if (expectedSize == 0) 1 else powerOfTwo(expectedSize)
/** Find entry with given key in table, null if not found
*/
@@ -172,9 +205,9 @@ trait HashTable[A] {
}
private def newThreshold(size: Int) =
- ((size.toLong * loadFactor)/loadFactorDenum).toInt
+ ((size.toLong * _loadFactor)/loadFactorDenum).toInt
- private def resize(newSize: Int) = {
+ private def resize(newSize: Int) {
val oldTable = table
table = new Array(newSize)
var i = oldTable.length - 1
@@ -206,4 +239,19 @@ trait HashTable[A] {
protected final def index(hcode: Int) = improve(hcode) & (table.length - 1)
}
+private[collection] object HashTable {
+ /**
+ * Returns a power of two >= `target`.
+ */
+ private[collection] def powerOfTwo(target: Int): Int = {
+ /* See http://bits.stephan-brumme.com/roundUpToNextPowerOfTwo.html */
+ var c = target - 1;
+ c |= c >>> 1;
+ c |= c >>> 2;
+ c |= c >>> 4;
+ c |= c >>> 8;
+ c |= c >>> 16;
+ c + 1;
+ }
+}
diff --git a/src/library/scala/collection/mutable/LinkedHashMap.scala b/src/library/scala/collection/mutable/LinkedHashMap.scala
index 308db1a4d4..183a3aedb0 100644
--- a/src/library/scala/collection/mutable/LinkedHashMap.scala
+++ b/src/library/scala/collection/mutable/LinkedHashMap.scala
@@ -29,7 +29,7 @@ object LinkedHashMap extends MutableMapFactory[LinkedHashMap] {
/**
* @since 2.7
*/
-@serializable
+@serializable @SerialVersionUID(1L)
class LinkedHashMap[A, B] extends Map[A, B]
with MapLike[A, B, LinkedHashMap[A, B]]
with HashTable[A] {
@@ -39,8 +39,8 @@ class LinkedHashMap[A, B] extends Map[A, B]
type Entry = LinkedEntry[A, B]
- protected var firstEntry: Entry = null
- protected var lastEntry: Entry = null
+ @transient protected var firstEntry: Entry = null
+ @transient protected var lastEntry: Entry = null
def get(key: A): Option[B] = {
val e = findEntry(key)
@@ -53,9 +53,7 @@ class LinkedHashMap[A, B] extends Map[A, B]
if (e == null) {
val e = new Entry(key, value)
addEntry(e)
- if (firstEntry == null) firstEntry = e
- else { lastEntry.later = e; e.earlier = lastEntry }
- lastEntry = e
+ updateLinkedEntries(e)
None
} else {
val v = e.value
@@ -64,6 +62,12 @@ class LinkedHashMap[A, B] extends Map[A, B]
}
}
+ private def updateLinkedEntries(e: Entry) {
+ if (firstEntry == null) firstEntry = e
+ else { lastEntry.later = e; e.earlier = lastEntry }
+ lastEntry = e
+ }
+
override def remove(key: A): Option[B] = {
val e = removeEntry(key)
if (e eq null) None
@@ -115,4 +119,16 @@ class LinkedHashMap[A, B] extends Map[A, B]
clearTable()
firstEntry = null
}
+
+ private def writeObject(out: java.io.ObjectOutputStream) {
+ serializeTo(out, _.value)
+ }
+
+ private def readObject(in: java.io.ObjectInputStream) {
+ init[B](in, { (key, value) =>
+ val entry = new Entry(key, value)
+ updateLinkedEntries(entry)
+ entry
+ })
+ }
}
diff --git a/src/library/scala/collection/mutable/LinkedHashSet.scala b/src/library/scala/collection/mutable/LinkedHashSet.scala
index 081b068723..b9313bd1d4 100644
--- a/src/library/scala/collection/mutable/LinkedHashSet.scala
+++ b/src/library/scala/collection/mutable/LinkedHashSet.scala
@@ -17,7 +17,7 @@ import generic._
* Should be rewritten to be more efficient.
* @since 2.2
*/
-@serializable
+@serializable @SerialVersionUID(1L)
class LinkedHashSet[A] extends Set[A]
with GenericSetTemplate[A, LinkedHashSet]
with SetLike[A, LinkedHashSet[A]]
@@ -25,7 +25,7 @@ class LinkedHashSet[A] extends Set[A]
{
override def companion: GenericCompanion[LinkedHashSet] = LinkedHashSet
- protected val ordered = new ListBuffer[A]
+ @transient private var ordered = new ListBuffer[A]
override def size = tableSize
@@ -52,6 +52,15 @@ class LinkedHashSet[A] extends Set[A]
override def iterator = ordered.iterator
override def foreach[U](f: A => U) = ordered foreach f
+
+ private def writeObject(s: java.io.ObjectOutputStream) {
+ serializeTo(s)
+ }
+
+ private def readObject(in: java.io.ObjectInputStream) {
+ ordered = new ListBuffer[A]
+ init(in, ordered += )
+ }
}
/** Factory object for `LinkedHashSet` class */
diff --git a/test/files/jvm/serialization.check b/test/files/jvm/serialization.check
index f1b5b10ec6..2b0ad3888b 100644
--- a/test/files/jvm/serialization.check
+++ b/test/files/jvm/serialization.check
@@ -78,11 +78,11 @@ y = BitSet(2, 3)
x equals y: true, y equals x: true
x = Map(2 -> B, 1 -> A, 3 -> C)
-y = Map(2 -> B, 1 -> A, 3 -> C)
+y = Map(1 -> A, 2 -> B, 3 -> C)
x equals y: true, y equals x: true
x = Set(1, 2)
-y = Set(1, 2)
+y = Set(2, 1)
x equals y: true, y equals x: true
x = List((buffers,20), (layers,2), (title,3))
@@ -158,7 +158,7 @@ y = BitSet(0, 8, 9)
x equals y: true, y equals x: true
x = Map(A -> 1, C -> 3, B -> 2)
-y = Map(A -> 1, C -> 3, B -> 2)
+y = Map(B -> 2, C -> 3, A -> 1)
x equals y: true, y equals x: true
x = Set(layers, buffers, title)
diff --git a/test/files/jvm/t1600.scala b/test/files/jvm/t1600.scala
new file mode 100644
index 0000000000..1cdcee8547
--- /dev/null
+++ b/test/files/jvm/t1600.scala
@@ -0,0 +1,76 @@
+
+/**
+ * Checks that serialization of hash-based collections works correctly if the hashCode
+ * changes on deserialization.
+ */
+object Test {
+
+ import collection._
+ def main(args: Array[String]) {
+ for (i <- Seq(0, 1, 2, 10, 100)) {
+ def entries = (0 until i).map(i => (new Foo, i)).toList
+ def elements = entries.map(_._1)
+
+ val maps = Seq[Map[Foo, Int]](new mutable.HashMap, new mutable.LinkedHashMap,
+ new immutable.HashMap).map(_ ++ entries)
+ test[Map[Foo, Int]](maps, entries.size, assertMap _)
+
+ val sets = Seq[Set[Foo]](new mutable.HashSet, new mutable.LinkedHashSet,
+ new immutable.HashSet).map(_ ++ elements)
+ test[Set[Foo]](sets, entries.size, assertSet _)
+ }
+ }
+
+ private def test[A <: AnyRef](collections: Seq[A], expectedSize: Int, assertFunction: (A, Int) => Unit) {
+ for (collection <- collections) {
+ assertFunction(collection, expectedSize)
+
+ val bytes = toBytes(collection)
+ Foo.hashCodeModifier = 1
+ val deserializedCollection = toObject[A](bytes)
+
+ assertFunction(deserializedCollection, expectedSize)
+ assert(deserializedCollection.getClass == collection.getClass,
+ "collection class should remain the same after deserialization")
+ Foo.hashCodeModifier = 0
+ }
+ }
+
+ private def toObject[A](bytes: Array[Byte]): A = {
+ val in = new java.io.ObjectInputStream(new java.io.ByteArrayInputStream(bytes))
+ in.readObject.asInstanceOf[A]
+ }
+
+ private def toBytes(o: AnyRef): Array[Byte] = {
+ val bos = new java.io.ByteArrayOutputStream
+ val out = new java.io.ObjectOutputStream(bos)
+ out.writeObject(o)
+ out.close
+ bos.toByteArray
+ }
+
+ private def assertMap[A, B](map: Map[A, B], expectedSize: Int) {
+ assert(expectedSize == map.size, "expected map size: " + expectedSize + ", actual size: " + map.size)
+ map.foreach { case (k, v) =>
+ assert(map.contains(k), "contains should return true for key in the map, key: " + k)
+ assert(map(k) == v)
+ }
+ }
+
+ private def assertSet[A](set: Set[A], expectedSize: Int) {
+ assert(expectedSize == set.size, "expected set size: " + expectedSize + ", actual size: " + set.size)
+ set.foreach { e => assert(set.contains(e), "contains should return true for element in the set, element: " + e) }
+ }
+
+ object Foo {
+ /* Used to simulate a hashCode change caused by deserializing an instance with an
+ * identity-based hashCode in another JVM.
+ */
+ var hashCodeModifier = 0
+ }
+
+ @serializable
+ class Foo {
+ override def hashCode = System.identityHashCode(this) + Foo.hashCodeModifier
+ }
+}