diff options
-rw-r--r-- | src/compiler/scala/tools/nsc/backend/icode/GenICode.scala | 34 | ||||
-rw-r--r-- | src/library/scala/math/BigDecimal.scala | 12 | ||||
-rw-r--r-- | src/library/scala/math/BigInt.scala | 15 | ||||
-rw-r--r-- | src/library/scala/math/ScalaNumber.java | 1 | ||||
-rw-r--r-- | src/library/scala/math/ScalaNumericConversions.scala | 27 | ||||
-rw-r--r-- | src/library/scala/runtime/BoxesRunTime.java | 25 | ||||
-rw-r--r-- | test/files/run/equality.scala | 36 |
7 files changed, 121 insertions, 29 deletions
diff --git a/src/compiler/scala/tools/nsc/backend/icode/GenICode.scala b/src/compiler/scala/tools/nsc/backend/icode/GenICode.scala index e1ddd260cb..f0c44bb227 100644 --- a/src/compiler/scala/tools/nsc/backend/icode/GenICode.scala +++ b/src/compiler/scala/tools/nsc/backend/icode/GenICode.scala @@ -1410,24 +1410,32 @@ abstract class GenICode extends SubComponent { * When it is statically known that both sides are equal and subtypes of Number of Character, * not using the rich equality is possible (their own equals method will do ok.)*/ def mustUseAnyComparator: Boolean = { + import definitions._ + + /** The various ways a boxed primitive might materialize at runtime. */ + def isJavaBoxed(sym: Symbol) = + (sym == ObjectClass) || + (sym == SerializableClass) || + (sym == ComparableClass) || + (sym isNonBottomSubClass BoxedNumberClass) || + (sym isNonBottomSubClass BoxedCharacterClass) + def isBoxed(sym: Symbol): Boolean = - ((sym isNonBottomSubClass definitions.BoxedNumberClass) || - (!forMSIL && (sym isNonBottomSubClass definitions.BoxedCharacterClass))) - - val lsym = l.tpe.typeSymbol - val rsym = r.tpe.typeSymbol - (lsym == ObjectClass) || - (rsym == ObjectClass) || - (lsym != rsym) && (isBoxed(lsym) || isBoxed(rsym)) + if (forMSIL) (sym isNonBottomSubClass BoxedNumberClass) + else isJavaBoxed(sym) + + isBoxed(l.tpe.typeSymbol) && isBoxed(r.tpe.typeSymbol) } if (mustUseAnyComparator) { - var equalsMethod = BoxesRunTime_equals // when -optimise is on we call the @inline-version of equals, found in ScalaRunTime - if (settings.XO.value) { - equalsMethod = definitions.getMember(definitions.ScalaRunTimeModule, nme.inlinedEquals) - ctx.bb.emit(LOAD_MODULE(definitions.ScalaRunTimeModule)) - } + val equalsMethod = + if (!settings.XO.value) BoxesRunTime_equals + else { + ctx.bb.emit(LOAD_MODULE(definitions.ScalaRunTimeModule)) + definitions.getMember(definitions.ScalaRunTimeModule, nme.inlinedEquals) + } + val ctx1 = genLoad(l, ctx, ANY_REF_CLASS) val ctx2 = genLoad(r, ctx1, ANY_REF_CLASS) ctx2.bb.emit(CALL_METHOD(equalsMethod, if (settings.XO.value) Dynamic else Static(false))) diff --git a/src/library/scala/math/BigDecimal.scala b/src/library/scala/math/BigDecimal.scala index c379b83abf..d5a57b2c62 100644 --- a/src/library/scala/math/BigDecimal.scala +++ b/src/library/scala/math/BigDecimal.scala @@ -165,17 +165,21 @@ extends ScalaNumber with ScalaNumericConversions * which deems 2 == 2.00, whereas in java these are unequal * with unequal hashCodes. */ - override def hashCode(): Int = doubleValue.hashCode() + override def hashCode(): Int = + if (isWhole) unifiedPrimitiveHashcode + else doubleValue.hashCode() /** Compares this BigDecimal with the specified value for equality. * Will only claim equality with scala.BigDecimal and java.math.BigDecimal. */ override def equals (that: Any): Boolean = that match { - case that: BigDecimal => this equals that - case that: BigDec => this equals BigDecimal(that) - case _ => false + case that: BigDecimal => this equals that + case that: BigInt => this.toBigIntExact exists (that equals _) + case x => unifiedPrimitiveEquals(x) } + override protected def isWhole = (this remainder 1) == BigDecimal(0) + /** Compares this BigDecimal with the specified BigDecimal for equality. */ def equals (that: BigDecimal): Boolean = compare(that) == 0 diff --git a/src/library/scala/math/BigInt.scala b/src/library/scala/math/BigInt.scala index f1e89f4f53..5e4bb569b5 100644 --- a/src/library/scala/math/BigInt.scala +++ b/src/library/scala/math/BigInt.scala @@ -24,6 +24,9 @@ object BigInt { private val maxCached = 1024 private val cache = new Array[BigInt](maxCached - minCached + 1) + val MinLong = BigInt(Long.MinValue) + val MaxLong = BigInt(Long.MaxValue) + /** Constructs a <code>BigInt</code> whose value is equal to that of the * specified integer value. * @@ -112,16 +115,20 @@ object BigInt { class BigInt(val bigInteger: BigInteger) extends ScalaNumber with ScalaNumericConversions { /** Returns the hash code for this BigInt. */ - override def hashCode(): Int = this.bigInteger.hashCode() + override def hashCode(): Int = + if (this >= BigInt.MinLong && this <= BigInt.MaxLong) unifiedPrimitiveHashcode + else bigInteger.hashCode /** Compares this BigInt with the specified value for equality. */ override def equals(that: Any): Boolean = that match { - case that: BigInt => this equals that - case that: BigInteger => this equals new BigInt(that) - case _ => false + case that: BigInt => this equals that + case that: BigDecimal => that.toBigIntExact exists (this equals _) + case x => unifiedPrimitiveEquals(x) } + override protected def isWhole = true + /** Compares this BigInt with the specified BigInt for equality. */ def equals (that: BigInt): Boolean = compare(that) == 0 diff --git a/src/library/scala/math/ScalaNumber.java b/src/library/scala/math/ScalaNumber.java index bb54a5d9c0..4ade1dee14 100644 --- a/src/library/scala/math/ScalaNumber.java +++ b/src/library/scala/math/ScalaNumber.java @@ -17,4 +17,5 @@ package scala.math; * @since 2.8 */ public abstract class ScalaNumber extends java.lang.Number { + protected abstract boolean isWhole(); } diff --git a/src/library/scala/math/ScalaNumericConversions.scala b/src/library/scala/math/ScalaNumericConversions.scala index 53465c7438..cc30d5d756 100644 --- a/src/library/scala/math/ScalaNumericConversions.scala +++ b/src/library/scala/math/ScalaNumericConversions.scala @@ -8,10 +8,12 @@ package scala.math +import java.{ lang => jl } + /** Conversions which present a consistent conversion interface * across all the numeric types. */ -trait ScalaNumericConversions extends java.lang.Number { +trait ScalaNumericConversions extends ScalaNumber { def toChar = intValue.toChar def toByte = byteValue def toShort = shortValue @@ -19,4 +21,27 @@ trait ScalaNumericConversions extends java.lang.Number { def toLong = longValue def toFloat = floatValue def toDouble = doubleValue + + def isValidByte = isWhole && (toByte == toInt) + def isValidShort = isWhole && (toShort == toInt) + def isValidInt = isWhole && (toInt == toLong) + def isValidChar = isWhole && (toInt >= Char.MinValue && toInt <= Char.MaxValue) + + protected def unifiedPrimitiveHashcode() = { + val lv = toLong + if (lv >= Int.MinValue && lv <= Int.MaxValue) lv.toInt + else lv.hashCode + } + + protected def unifiedPrimitiveEquals(x: Any) = x match { + case x: Char => isValidChar && (toInt == x.toInt) + case x: Byte => isValidByte && (toByte == x) + case x: Short => isValidShort && (toShort == x) + case x: Int => isValidInt && (toInt == x) + case x: Long => toLong == x // XXX + case x: Float => toFloat == x // XXX + case x: Double => toDouble == x // XXX + case x: Number => this equals x + case _ => false + } } diff --git a/src/library/scala/runtime/BoxesRunTime.java b/src/library/scala/runtime/BoxesRunTime.java index 39ea9abcdd..869eb375ac 100644 --- a/src/library/scala/runtime/BoxesRunTime.java +++ b/src/library/scala/runtime/BoxesRunTime.java @@ -140,29 +140,37 @@ public class BoxesRunTime /* COMPARISON ... COMPARISON ... COMPARISON ... COMPARISON ... COMPARISON ... COMPARISON */ - // That's the method we should use from now on. + /** Since all applicable logic has to be present in the equals method of a ScalaNumber + * in any case, we dispatch to it as soon as we spot one on either side. + */ public static boolean equals(Object x, Object y) { if (x instanceof Number) { + if (x instanceof ScalaNumber) + return x.equals(y); + Number xn = (Number)x; if (y instanceof Number) { - Number yn = (Number)y; - if ((y instanceof ScalaNumber) && !(x instanceof ScalaNumber)) { + if (y instanceof ScalaNumber) return y.equals(x); - } + + Number yn = (Number)y; if ((xn instanceof Double) || (yn instanceof Double)) return xn.doubleValue() == yn.doubleValue(); if ((xn instanceof Float) || (yn instanceof Float)) return xn.floatValue() == yn.floatValue(); if ((xn instanceof Long) || (yn instanceof Long)) return xn.longValue() == yn.longValue(); - return xn.intValue() == yn.intValue(); + if (typeCode(x) <= INT && typeCode(y) <= INT) + return xn.intValue() == yn.intValue(); + + return x.equals(y); } if (y instanceof Character) return equalsNumChar(xn, (Character)y); } else if (x instanceof Character) { Character xc = (Character)x; if (y instanceof Character) - return (xc.charValue() == ((Character)y).charValue()); + return xc.equals(y); if (y instanceof Number) return equalsNumChar((Number)y, xc); } else if (x == null) { @@ -181,7 +189,10 @@ public class BoxesRunTime return x.longValue() == ch; if (x instanceof ScalaNumber) return x.equals(y); - return x.intValue() == ch; + if (typeCode(x) <= INT) + return x.intValue() == ch; + + return x.equals(y); } /** Hashcode algorithm is driven by the requirements imposed diff --git a/test/files/run/equality.scala b/test/files/run/equality.scala new file mode 100644 index 0000000000..5b9ad207da --- /dev/null +++ b/test/files/run/equality.scala @@ -0,0 +1,36 @@ +// a quickly assembled test of equality. Needs work. +object Test +{ + def makeFromInt(x: Int) = List( + x.toByte, x.toShort, x.toInt, x.toLong, x.toFloat, x.toDouble, BigInt(x), BigDecimal(x) + ) ::: ( + if (x < 0) Nil else List(x.toChar) + ) + def makeFromDouble(x: Double) = List( + x.toShort, x.toInt, x.toLong, x.toFloat, x.toDouble, BigInt(x.toInt), BigDecimal(x) + ) + + def main(args: Array[String]): Unit = { + var xs = makeFromInt(5) + for (x <- xs ; y <- xs) { + assert(x == y, x + " == " + y) + assert(hash(x) == hash(y), "hash(%s) == hash(%s)".format(x, y)) + } + + xs = makeFromInt(-5) + for (x <- xs ; y <- xs) { + assert(x == y, x + " == " + y) + assert(hash(x) == hash(y), "hash(%s) == hash(%s)".format(x, y)) + } + + xs = makeFromDouble(500.0) + for (x <- xs ; y <- xs) { + assert(x == y, x + " == " + y) + assert(hash(x) == hash(y), "hash(%s) == hash(%s)".format(x, y)) + } + + // negatives + val bigLong = new java.util.concurrent.atomic.AtomicLong(Long.MaxValue) + assert(-1 != bigLong && bigLong != -1) // bigLong.intValue() == -1 + } +} |