summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/library/scala/math/BigDecimal.scala6
-rw-r--r--src/library/scala/math/BigInt.scala2
-rw-r--r--src/library/scala/math/ScalaNumericConversions.scala21
-rw-r--r--src/partest/scala/tools/partest/category/Runner.scala2
-rw-r--r--test/files/run/numbereq.scala41
5 files changed, 64 insertions, 8 deletions
diff --git a/src/library/scala/math/BigDecimal.scala b/src/library/scala/math/BigDecimal.scala
index bb6965fcdc..c4cccd1f52 100644
--- a/src/library/scala/math/BigDecimal.scala
+++ b/src/library/scala/math/BigDecimal.scala
@@ -30,6 +30,9 @@ object BigDecimal
private val minCached = -512
private val maxCached = 512
+ val MinLong = BigDecimal(Long.MinValue)
+ val MaxLong = BigDecimal(Long.MaxValue)
+
/** Cache ony for defaultMathContext using BigDecimals in a small range. */
private lazy val cache = new Array[BigDecimal](maxCached - minCached + 1)
@@ -173,12 +176,11 @@ extends ScalaNumber with ScalaNumericConversions
else doubleValue.hashCode()
/** Compares this BigDecimal with the specified value for equality.
- * Will only claim equality with scala.BigDecimal and java.math.BigDecimal.
*/
override def equals (that: Any): Boolean = that match {
case that: BigDecimal => this equals that
case that: BigInt => this.toBigIntExact exists (that equals _)
- case x => unifiedPrimitiveEquals(x)
+ case x => (this <= BigDecimal.MaxLong && this >= BigDecimal.MinLong) && unifiedPrimitiveEquals(x)
}
protected[math] def isWhole = (this remainder 1) == BigDecimal(0)
diff --git a/src/library/scala/math/BigInt.scala b/src/library/scala/math/BigInt.scala
index 5267ad8b95..f0988f2934 100644
--- a/src/library/scala/math/BigInt.scala
+++ b/src/library/scala/math/BigInt.scala
@@ -124,7 +124,7 @@ class BigInt(val bigInteger: BigInteger) extends ScalaNumber with ScalaNumericCo
override def equals(that: Any): Boolean = that match {
case that: BigInt => this equals that
case that: BigDecimal => that.toBigIntExact exists (this equals _)
- case x => unifiedPrimitiveEquals(x)
+ case x => (this <= BigInt.MaxLong && this >= BigInt.MinLong) && unifiedPrimitiveEquals(x)
}
protected[math] def isWhole = true
diff --git a/src/library/scala/math/ScalaNumericConversions.scala b/src/library/scala/math/ScalaNumericConversions.scala
index d983595c17..2e754c2584 100644
--- a/src/library/scala/math/ScalaNumericConversions.scala
+++ b/src/library/scala/math/ScalaNumericConversions.scala
@@ -33,15 +33,28 @@ trait ScalaNumericConversions extends ScalaNumber {
else lv.hashCode
}
+ /** Should only be called after all known non-primitive
+ * types have been excluded. This method won't dispatch
+ * anywhere else after checking against the primitives
+ * to avoid infinite recursion between equals and this on
+ * unknown "Number" variants.
+ *
+ * Additionally, this should only be called if the numeric
+ * type is happy to be converted to Long, Float, and Double.
+ * If for instance a BigInt much larger than the Long range is
+ * sent here, it will claim equality with whatever Long is left
+ * in its lower 64 bits. Or a BigDecimal with more precision
+ * than Double can hold: same thing. There's no way given the
+ * interface available here to prevent this error.
+ */
protected def unifiedPrimitiveEquals(x: Any) = x match {
case x: Char => isValidChar && (toInt == x.toInt)
case x: Byte => isValidByte && (toByte == x)
case x: Short => isValidShort && (toShort == x)
case x: Int => isValidInt && (toInt == x)
- case x: Long => toLong == x // XXX
- case x: Float => toFloat == x // XXX
- case x: Double => toDouble == x // XXX
- case x: Number => this equals x
+ case x: Long => toLong == x
+ case x: Float => toFloat == x
+ case x: Double => toDouble == x
case _ => false
}
}
diff --git a/src/partest/scala/tools/partest/category/Runner.scala b/src/partest/scala/tools/partest/category/Runner.scala
index a7713d7dbe..10bf5794a9 100644
--- a/src/partest/scala/tools/partest/category/Runner.scala
+++ b/src/partest/scala/tools/partest/category/Runner.scala
@@ -75,7 +75,7 @@ trait Runner {
def allResults() =
for ((prop, res) <- result) yield {
- ScalacheckTest.this.trace("scalacheck result for %s: %s".format(prop, res))
+ ScalacheckTest.this.trace("%s: %s".format(prop, res))
res.asInstanceOf[ScalacheckResult].passed
}
diff --git a/test/files/run/numbereq.scala b/test/files/run/numbereq.scala
new file mode 100644
index 0000000000..52f32cc52a
--- /dev/null
+++ b/test/files/run/numbereq.scala
@@ -0,0 +1,41 @@
+object Test {
+ def mkNumbers(x: Int): List[AnyRef] = {
+ val base = List(
+ BigDecimal(x),
+ BigInt(x),
+ new java.lang.Double(x.toDouble),
+ new java.lang.Float(x.toFloat),
+ new java.lang.Long(x.toLong),
+ new java.lang.Integer(x)
+ )
+ val extras = List(
+ if (x >= Short.MinValue && x <= Short.MaxValue) List(new java.lang.Short(x.toShort)) else Nil,
+ if (x >= Byte.MinValue && x <= Byte.MaxValue) List(new java.lang.Byte(x.toByte)) else Nil,
+ if (x >= Char.MinValue && x <= Char.MaxValue) List(new java.lang.Character(x.toChar)) else Nil
+ ).flatten
+
+ base ::: extras
+ }
+
+
+ def main(args: Array[String]): Unit = {
+ val ints = (0 to 15).toList map (Short.MinValue >> _)
+ val ints2 = ints map (x => -x)
+ val ints3 = ints map (_ + 1)
+ val ints4 = ints2 map (_ - 1)
+
+ val setneg1 = ints map mkNumbers
+ val setneg2 = ints3 map mkNumbers
+ val setpos1 = ints2 map mkNumbers
+ val setpos2 = ints4 map mkNumbers
+ val zero = mkNumbers(0)
+
+ val sets = setneg1 ++ setneg2 ++ List(zero) ++ setpos1 ++ setpos2
+
+ for (set <- sets ; x <- set ; y <- set) {
+ println("'%s' == '%s' (%s == %s) (%s == %s)".format(x, y, x.hashCode, y.hashCode, x.##, y.##))
+ assert(x == y, "%s/%s != %s/%s".format(x, x.getClass, y, y.getClass))
+ assert(x.## == y.##, "%s != %s".format(x.getClass, y.getClass))
+ }
+ }
+} \ No newline at end of file