From ea91456310736caed3064c117c5a72f515d59688 Mon Sep 17 00:00:00 2001 From: Tiark Rompf Date: Tue, 13 Apr 2010 11:08:12 +0000 Subject: closes #3241 and improves serialization of hash... closes #3241 and improves serialization of hash tries. review by community. --- .../scala/collection/immutable/HashMap.scala | 98 ++++++++-------------- .../scala/collection/immutable/HashSet.scala | 87 ++++++++----------- test/files/jvm/t1600.scala | 6 +- test/files/run/t3241.check | 1 + test/files/run/t3241.scala | 23 +++++ 5 files changed, 95 insertions(+), 120 deletions(-) create mode 100644 test/files/run/t3241.check create mode 100644 test/files/run/t3241.scala diff --git a/src/library/scala/collection/immutable/HashMap.scala b/src/library/scala/collection/immutable/HashMap.scala index c83a5ffcb8..ae438d0d09 100644 --- a/src/library/scala/collection/immutable/HashMap.scala +++ b/src/library/scala/collection/immutable/HashMap.scala @@ -75,9 +75,10 @@ class HashMap[A, +B] extends Map[A,B] with MapLike[A, B, HashMap[A, B]] { protected def updated0[B1 >: B](key: A, hash: Int, level: Int, value: B1, kv: (A, B1)): HashMap[A, B1] = new HashMap.HashMap1(key, hash, value, kv) + protected def removed0(key: A, hash: Int, level: Int): HashMap[A, B] = this - protected def removed0(key: A, hash: Int, level: Int): HashMap[A, B] = this + protected def writeReplace(): AnyRef = new HashMap.SerializationProxy(this) } @@ -113,14 +114,7 @@ object HashMap extends ImmutableMapFactory[HashMap] { m.updated0(this.key, this.hash, level, this.value, this.kv).updated0(key, hash, level, value, kv) } else { // 32-bit hash collision (rare, but not impossible) - // wrap this in a HashTrieMap if called with level == 0 (otherwise serialization won't work) - if (level == 0) { - val elems = new Array[HashMap[A,B1]](1) - elems(0) = new HashMapCollision1(hash, ListMap.empty.updated(this.key,this.value).updated(key,value)) - new HashTrieMap[A,B1](1 << ((hash >>> level) & 0x1f), elems, 2) - } else { - new HashMapCollision1(hash, ListMap.empty.updated(this.key,this.value).updated(key,value)) - } + new HashMapCollision1(hash, ListMap.empty.updated(this.key,this.value).updated(key,value)) } } @@ -130,18 +124,6 @@ object HashMap extends ImmutableMapFactory[HashMap] { override def iterator: Iterator[(A,B)] = Iterator(ensurePair) override def foreach[U](f: ((A, B)) => U): Unit = f(ensurePair) private[HashMap] def ensurePair: (A,B) = if (kv ne null) kv else { kv = (key, value); kv } - - private def writeObject(out: java.io.ObjectOutputStream) { - out.writeObject(key) - out.writeObject(value) - } - - private def readObject(in: java.io.ObjectInputStream) { - key = in.readObject().asInstanceOf[A] - value = in.readObject().asInstanceOf[B] - hash = computeHash(key) - } - } private class HashMapCollision1[A,+B](private[HashMap] var hash: Int, var kvs: ListMap[A,B @uncheckedVariance]) extends HashMap[A,B] { @@ -171,22 +153,6 @@ object HashMap extends ImmutableMapFactory[HashMap] { override def iterator: Iterator[(A,B)] = kvs.iterator override def foreach[U](f: ((A, B)) => U): Unit = kvs.foreach(f) - - private def writeObject(out: java.io.ObjectOutputStream) { - // this cannot work - reading things in might produce different - // hash codes and remove the collision. however this is never called - // because no references to this class are ever handed out to client code - // and HashTrieMap serialization takes care of the situation - error("cannot serialize an immutable.HashMap where all items have the same 32-bit hash code") - //out.writeObject(kvs) - } - - private def readObject(in: java.io.ObjectInputStream) { - error("cannot deserialize an immutable.HashMap where all items have the same 32-bit hash code") - //kvs = in.readObject().asInstanceOf[ListMap[A,B]] - //hash = computeHash(kvs.) - } - } @@ -251,19 +217,27 @@ object HashMap extends ImmutableMapFactory[HashMap] { val index = (hash >>> level) & 0x1f val mask = (1 << index) val offset = Integer.bitCount(bitmap & (mask-1)) - if (((bitmap >>> index) & 1) == 1) { - val elemsNew = new Array[HashMap[A,B]](elems.length) - Array.copy(elems, 0, elemsNew, 0, elems.length) + if ((bitmap & mask) != 0) { val sub = elems(offset) // TODO: might be worth checking if sub is HashTrieMap (-> monomorphic call site) val subNew = sub.removed0(key, hash, level + 5) - elemsNew(offset) = subNew - // TODO: handle shrinking - val sizeNew = size + (subNew.size - sub.size) - if (sizeNew > 0) - new HashTrieMap(bitmap, elemsNew, size + (subNew.size - sub.size)) - else - HashMap.empty[A,B] + if (subNew.isEmpty) { + val bitmapNew = bitmap ^ mask + if (bitmapNew != 0) { + val elemsNew = new Array[HashMap[A,B]](elems.length - 1) + Array.copy(elems, 0, elemsNew, 0, offset) + Array.copy(elems, offset + 1, elemsNew, offset, elems.length - offset - 1) + val sizeNew = size - sub.size + new HashTrieMap(bitmapNew, elemsNew, sizeNew) + } else + HashMap.empty[A,B] + } else { + val elemsNew = new Array[HashMap[A,B]](elems.length) + Array.copy(elems, 0, elemsNew, 0, elems.length) + elemsNew(offset) = subNew + val sizeNew = size + (subNew.size - sub.size) + new HashTrieMap(bitmap, elemsNew, sizeNew) + } } else { this } @@ -372,31 +346,29 @@ time { mNew.iterator.foreach( p => ()) } } } + } + @serializable @SerialVersionUID(2L) private class SerializationProxy[A,B](@transient private var orig: HashMap[A, B]) { private def writeObject(out: java.io.ObjectOutputStream) { - // no out.defaultWriteObject() - out.writeInt(size) - foreach { p => - out.writeObject(p._1) - out.writeObject(p._2) + val s = orig.size + out.writeInt(s) + for ((k,v) <- orig) { + out.writeObject(k) + out.writeObject(v) } } private def readObject(in: java.io.ObjectInputStream) { - val size = in.readInt - var index = 0 - var m = HashMap.empty[A,B] - while (index < size) { - // TODO: optimize (use unsafe mutable update) - m = m + ((in.readObject.asInstanceOf[A], in.readObject.asInstanceOf[B])) - index += 1 + orig = empty + val s = in.readInt() + for (i <- 0 until s) { + val key = in.readObject().asInstanceOf[A] + val value = in.readObject().asInstanceOf[B] + orig = orig.updated(key, value) } - var tm = m.asInstanceOf[HashTrieMap[A,B]] - bitmap = tm.bitmap - elems = tm.elems - size0 = tm.size0 } + private def readResolve(): AnyRef = orig } } diff --git a/src/library/scala/collection/immutable/HashSet.scala b/src/library/scala/collection/immutable/HashSet.scala index 4779702ea7..942c3f1814 100644 --- a/src/library/scala/collection/immutable/HashSet.scala +++ b/src/library/scala/collection/immutable/HashSet.scala @@ -72,20 +72,11 @@ class HashSet[A] extends Set[A] protected def updated0(key: A, hash: Int, level: Int): HashSet[A] = new HashSet.HashSet1(key, hash) - - protected def removed0(key: A, hash: Int, level: Int): HashSet[A] = this + protected def writeReplace(): AnyRef = new HashSet.SerializationProxy(this) } -/* -object HashSet extends SetFactory[HashSet] { - implicit def canBuildFrom[A]: CanBuildFrom[Coll, A, HashSet[A]] = setCanBuildFrom[A] - override def empty[A]: HashSet[A] = new HashSet -} -*/ - - /** $factoryInfo * @define Coll immutable.HashSet * @define coll immutable hash set @@ -122,14 +113,7 @@ object HashSet extends SetFactory[HashSet] { m.updated0(this.key, this.hash, level).updated0(key, hash, level) } else { // 32-bit hash collision (rare, but not impossible) - // wrap this in a HashTrieSet if called with level == 0 (otherwise serialization won't work) - if (level == 0) { - val elems = new Array[HashSet[A]](1) - elems(0) = new HashSetCollision1(hash, ListSet.empty + this.key + key) - new HashTrieSet[A](1 << ((hash >>> level) & 0x1f), elems, 2) - } else { - new HashSetCollision1(hash, ListSet.empty + this.key + key) - } + new HashSetCollision1(hash, ListSet.empty + this.key + key) } } @@ -138,16 +122,6 @@ object HashSet extends SetFactory[HashSet] { override def iterator: Iterator[A] = Iterator(key) override def foreach[U](f: A => U): Unit = f(key) - - private def writeObject(out: java.io.ObjectOutputStream) { - out.writeObject(key) - } - - private def readObject(in: java.io.ObjectInputStream) { - key = in.readObject().asInstanceOf[A] - hash = computeHash(key) - } - } private class HashSetCollision1[A](private[HashSet] var hash: Int, var ks: ListSet[A]) extends HashSet[A] { @@ -240,24 +214,33 @@ object HashSet extends SetFactory[HashSet] { val index = (hash >>> level) & 0x1f val mask = (1 << index) val offset = Integer.bitCount(bitmap & (mask-1)) - if (((bitmap >>> index) & 1) == 1) { - val elemsNew = new Array[HashSet[A]](elems.length) - Array.copy(elems, 0, elemsNew, 0, elems.length) + if ((bitmap & mask) != 0) { val sub = elems(offset) - // TODO: might be worth checking if sub is HashTrieSet (-> monomorphic call site) + // TODO: might be worth checking if sub is HashTrieMap (-> monomorphic call site) val subNew = sub.removed0(key, hash, level + 5) - elemsNew(offset) = subNew - // TODO: handle shrinking - val sizeNew = size + (subNew.size - sub.size) - if (sizeNew > 0) - new HashTrieSet(bitmap, elemsNew, size + (subNew.size - sub.size)) - else - HashSet.empty[A] + if (subNew.isEmpty) { + val bitmapNew = bitmap ^ mask + if (bitmapNew != 0) { + val elemsNew = new Array[HashSet[A]](elems.length - 1) + Array.copy(elems, 0, elemsNew, 0, offset) + Array.copy(elems, offset + 1, elemsNew, offset, elems.length - offset - 1) + val sizeNew = size - sub.size + new HashTrieSet(bitmapNew, elemsNew, sizeNew) + } else + HashSet.empty[A] + } else { + val elemsNew = new Array[HashSet[A]](elems.length) + Array.copy(elems, 0, elemsNew, 0, elems.length) + elemsNew(offset) = subNew + val sizeNew = size + (subNew.size - sub.size) + new HashTrieSet(bitmap, elemsNew, sizeNew) + } } else { this } } + override def iterator = new Iterator[A] { private[this] var depth = 0 private[this] var arrayStack = new Array[Array[HashSet[A]]](6) @@ -341,31 +324,27 @@ time { mNew.iterator.foreach( p => ()) } i += 1 } } + } - + @serializable @SerialVersionUID(2L) private class SerializationProxy[A,B](@transient private var orig: HashSet[A]) { private def writeObject(out: java.io.ObjectOutputStream) { - // no out.defaultWriteObject() - out.writeInt(size) - foreach { e => + val s = orig.size + out.writeInt(s) + for (e <- orig) { out.writeObject(e) } } private def readObject(in: java.io.ObjectInputStream) { - val size = in.readInt - var index = 0 - var m = HashSet.empty[A] - while (index < size) { - // TODO: optimize (use unsafe mutable update) - m = m + in.readObject.asInstanceOf[A] - index += 1 + orig = empty + val s = in.readInt() + for (i <- 0 until s) { + val e = in.readObject().asInstanceOf[A] + orig = orig + e } - var tm = m.asInstanceOf[HashTrieSet[A]] - bitmap = tm.bitmap - elems = tm.elems - size0 = tm.size0 } + private def readResolve(): AnyRef = orig } } diff --git a/test/files/jvm/t1600.scala b/test/files/jvm/t1600.scala index 1cdcee8547..79391b7e76 100644 --- a/test/files/jvm/t1600.scala +++ b/test/files/jvm/t1600.scala @@ -12,11 +12,11 @@ object Test { def elements = entries.map(_._1) val maps = Seq[Map[Foo, Int]](new mutable.HashMap, new mutable.LinkedHashMap, - new immutable.HashMap).map(_ ++ entries) + immutable.HashMap.empty).map(_ ++ entries) test[Map[Foo, Int]](maps, entries.size, assertMap _) val sets = Seq[Set[Foo]](new mutable.HashSet, new mutable.LinkedHashSet, - new immutable.HashSet).map(_ ++ elements) + immutable.HashSet.empty).map(_ ++ elements) test[Set[Foo]](sets, entries.size, assertSet _) } } @@ -31,7 +31,7 @@ object Test { assertFunction(deserializedCollection, expectedSize) assert(deserializedCollection.getClass == collection.getClass, - "collection class should remain the same after deserialization") + "collection class should remain the same after deserialization ("+deserializedCollection.getClass+" != "+collection.getClass+")") Foo.hashCodeModifier = 0 } } diff --git a/test/files/run/t3241.check b/test/files/run/t3241.check new file mode 100644 index 0000000000..348ebd9491 --- /dev/null +++ b/test/files/run/t3241.check @@ -0,0 +1 @@ +done \ No newline at end of file diff --git a/test/files/run/t3241.scala b/test/files/run/t3241.scala new file mode 100644 index 0000000000..40097a046f --- /dev/null +++ b/test/files/run/t3241.scala @@ -0,0 +1,23 @@ +object Test { + + def main(args : Array[String]) : Unit = { + recurse(Map(1->1, 2->2, 3->3, 4->4, 5->5, 6->6, 7->7)) + recurse(Set(1,2,3,4,5,6,7)) + println("done") + } + + def recurse(map: collection.immutable.Map[Int, Int]): Unit = { + if (!map.isEmpty) { + val x = map.keys.head + recurse(map - x) + } + } + + def recurse(set: collection.immutable.Set[Int]): Unit = { + if (!set.isEmpty) { + val x = set.toStream.head + recurse(set - x) + } + } + +} -- cgit v1.2.3