From 2e1aefd1188602c870229d4b6e00d1acfc83036d Mon Sep 17 00:00:00 2001 From: Philipp Haller Date: Tue, 23 Mar 2010 15:11:05 +0000 Subject: Fixes #3186. Closes #2214. --- src/library/scala/Enumeration.scala | 41 +++++++++++++++++++++++++++++++------ 1 file changed, 35 insertions(+), 6 deletions(-) (limited to 'src') diff --git a/src/library/scala/Enumeration.scala b/src/library/scala/Enumeration.scala index 8a708a4615..205cf1cf84 100644 --- a/src/library/scala/Enumeration.scala +++ b/src/library/scala/Enumeration.scala @@ -16,6 +16,15 @@ import scala.collection.mutable.{Builder, AddingBuilder, Map, HashMap} import scala.collection.immutable.{Set, BitSet} import scala.collection.generic.CanBuildFrom +private object Enumeration { + + /* This map is used to cache enumeration instances for + resolving enumeration _values_ to equal objects (by-reference) + when values are deserialized. */ + private val emap: Map[Class[_], Enumeration] = new HashMap + +} + /**

* Defines a finite set of values specific to the enumeration. Typically * these values enumerate all possible forms something can take and provide a @@ -61,6 +70,21 @@ abstract class Enumeration(initial: Int, names: String*) { def this() = this(0, null) def this(names: String*) = this(0, names: _*) + Enumeration.emap.get(getClass) match { + case None => + Enumeration.emap += (getClass -> this) + case Some(_) => + /* do nothing */ + } + + private def readResolve(): AnyRef = Enumeration.emap.get(getClass) match { + case None => + Enumeration.emap += (getClass -> this) + this + case Some(existing) => + existing + } + /** The name of this enumeration. */ override def toString = { @@ -192,13 +216,13 @@ abstract class Enumeration(initial: Int, names: String*) { /** The type of the enumerated values. */ @serializable @SerialVersionUID(7091335633555234129L) - abstract class Value extends Ordered[Enumeration#Value] { + abstract class Value extends Ordered[Value] { /** the id and bit location of this enumeration value */ def id: Int - override def compare(that: Enumeration#Value): Int = this.id - that.id + override def compare(that: Value): Int = this.id - that.id override def equals(other: Any): Boolean = other match { - case that: Enumeration#Value => compare(that) == 0 + case that: Value => compare(that) == 0 case _ => false } override def hashCode: Int = id.hashCode @@ -211,7 +235,7 @@ abstract class Enumeration(initial: Int, names: String*) { if (id >= 32) throw new IllegalArgumentException 1 << id } - /** this enumeration value as an Long bit mask. + /** this enumeration value as a Long bit mask. * @throws IllegalArgumentException if id is greater than 63 */ @deprecated("mask64 will be removed") @@ -243,9 +267,14 @@ abstract class Enumeration(initial: Int, names: String*) { override def toString() = if (name eq null) Enumeration.this.nameOf(i) else name - private def readResolve(): AnyRef = - if (vmap ne null) vmap(i) + private def readResolve(): AnyRef = { + val enum = Enumeration.emap.get(Enumeration.this.getClass) match { + case None => Enumeration.this + case Some(existing) => existing + } + if (enum.vmap ne null) enum.vmap(i) else this + } } /** A class for sets of values -- cgit v1.2.3