From 00d3f103b3db5530bfbf6b565843d0938a3cef48 Mon Sep 17 00:00:00 2001 From: Rex Kerr Date: Sat, 29 Aug 2015 16:02:53 -0700 Subject: SI-9388 Fix Range behavior around Int.MaxValue terminalElement (the element _after_ the last one!) was used to terminate foreach loops and sums of non-standard instances of Numeric. Unfortunately, this could result in the end wrapping around and hitting the beginning again, making the first element bad. This patch fixes the behavior by altering the loop to end after the last element is encountered. The particular flavor was chosen out of a few possibilities because it gave the best microbenchmarks on both large and small ranges. Test written. While testing, a bug was also uncovered in NumericRange, and was also fixed. In brief, the logic around sum is rather complex since division is not unique when you have overflow. Floating point has its own complexities, too. Also updated incorrect test t4658 that insisted on incorrect answers (?!) and added logic to make sure it at least stays self-consistent, and fixed the range.scala test which used the same wrong (overflow-prone) formula that the Range collection did. --- .../scala/collection/immutable/NumericRange.scala | 86 +++++++++++++++------- src/library/scala/collection/immutable/Range.scala | 29 ++++---- test/files/run/t4658.check | 7 +- test/files/run/t4658.scala | 11 ++- test/files/scalacheck/range.scala | 17 ++++- .../immutable/RangeConsistencyTest.scala | 24 ++++++ 6 files changed, 128 insertions(+), 46 deletions(-) diff --git a/src/library/scala/collection/immutable/NumericRange.scala b/src/library/scala/collection/immutable/NumericRange.scala index 28e56a6d87..11603a118b 100644 --- a/src/library/scala/collection/immutable/NumericRange.scala +++ b/src/library/scala/collection/immutable/NumericRange.scala @@ -175,34 +175,68 @@ extends AbstractSeq[T] with IndexedSeq[T] with Serializable { catch { case _: ClassCastException => false } final override def sum[B >: T](implicit num: Numeric[B]): B = { - // arithmetic series formula can be used for regular addition - if ((num eq scala.math.Numeric.IntIsIntegral)|| - (num eq scala.math.Numeric.BigIntIsIntegral)|| - (num eq scala.math.Numeric.ShortIsIntegral)|| - (num eq scala.math.Numeric.ByteIsIntegral)|| - (num eq scala.math.Numeric.CharIsIntegral)|| - (num eq scala.math.Numeric.LongIsIntegral)|| - (num eq scala.math.Numeric.FloatAsIfIntegral)|| - (num eq scala.math.Numeric.BigDecimalIsFractional)|| - (num eq scala.math.Numeric.DoubleAsIfIntegral)) { - val numAsIntegral = num.asInstanceOf[Integral[B]] - import numAsIntegral._ - if (isEmpty) num fromInt 0 - else if (numRangeElements == 1) head - else ((num fromInt numRangeElements) * (head + last) / (num fromInt 2)) - } else { - // user provided custom Numeric, we cannot rely on arithmetic series formula - if (isEmpty) num.zero + if (isEmpty) num.zero + else if (numRangeElements == 1) head + else { + // If there is no overflow, use arithmetic series formula + // a + ... (n terms total) ... + b = n*(a+b)/2 + if ((num eq scala.math.Numeric.IntIsIntegral)|| + (num eq scala.math.Numeric.ShortIsIntegral)|| + (num eq scala.math.Numeric.ByteIsIntegral)|| + (num eq scala.math.Numeric.CharIsIntegral)) { + // We can do math with no overflow in a Long--easy + val exact = (numRangeElements * ((num toLong head) + (num toInt last))) / 2 + num fromInt exact.toInt + } + else if (num eq scala.math.Numeric.LongIsIntegral) { + // Uh-oh, might be overflow, so we have to divide before we overflow. + // Either numRangeElements or (head + last) must be even, so divide the even one before multiplying + val a = head.toLong + val b = last.toLong + val ans = + if ((numRangeElements & 1) == 0) (numRangeElements / 2) * (a + b) + else numRangeElements * { + // Sum is even, but we might overflow it, so divide in pieces and add back remainder + val ha = a/2 + val hb = b/2 + ha + hb + ((a - 2*ha) + (b - 2*hb)) / 2 + } + ans.asInstanceOf[B] + } + else if ((num eq scala.math.Numeric.FloatAsIfIntegral) || + (num eq scala.math.Numeric.DoubleAsIfIntegral)) { + // Try to compute sum with reasonable accuracy, avoiding over/underflow + val numAsIntegral = num.asInstanceOf[Integral[B]] + import numAsIntegral._ + val a = math.abs(head.toDouble) + val b = math.abs(last.toDouble) + val two = num fromInt 2 + val nre = num fromInt numRangeElements + if (a > 1e38 || b > 1e38) nre * ((head / two) + (last / two)) // Compute in parts to avoid Infinity if possible + else (nre / two) * (head + last) // Don't need to worry about infinity; this will be more accurate and avoid underflow + } + else if ((num eq scala.math.Numeric.BigIntIsIntegral) || + (num eq scala.math.Numeric.BigDecimalIsFractional)) { + // No overflow, so we can use arithmetic series formula directly + // (not going to worry about running out of memory) + val numAsIntegral = num.asInstanceOf[Integral[B]] + import numAsIntegral._ + ((num fromInt numRangeElements) * (head + last)) / (num fromInt 2) + } else { - var acc = num.zero - var i = head - var idx = 0 - while(idx < length) { - acc = num.plus(acc, i) - i = i + step - idx = idx + 1 + // User provided custom Numeric, so we cannot rely on arithmetic series formula (e.g. won't work on something like Z_6) + if (isEmpty) num.zero + else { + var acc = num.zero + var i = head + var idx = 0 + while(idx < length) { + acc = num.plus(acc, i) + i = i + step + idx = idx + 1 + } + acc } - acc } } } diff --git a/src/library/scala/collection/immutable/Range.scala b/src/library/scala/collection/immutable/Range.scala index 79cd673932..ca6720da19 100644 --- a/src/library/scala/collection/immutable/Range.scala +++ b/src/library/scala/collection/immutable/Range.scala @@ -153,19 +153,15 @@ extends scala.collection.AbstractSeq[Int] } @inline final override def foreach[@specialized(Unit) U](f: Int => U) { - validateMaxLength() - val isCommonCase = (start != Int.MinValue || end != Int.MinValue) - var i = start - var count = 0 - val terminal = terminalElement - val step = this.step - while( - if(isCommonCase) { i != terminal } - else { count < numRangeElements } - ) { - f(i) - count += 1 - i += step + // Implementation chosen on the basis of favorable microbenchmarks + // Note--initialization catches step == 0 so we don't need to here + if (!isEmpty) { + var i = start + while (true) { + f(i) + if (i == lastElement) return + i += step + } } } @@ -347,18 +343,19 @@ extends scala.collection.AbstractSeq[Int] // this is normal integer range with usual addition. arithmetic series formula can be used if (isEmpty) 0 else if (numRangeElements == 1) head - else (numRangeElements.toLong * (head + last) / 2).toInt + else ((numRangeElements * (head.toLong + last)) / 2).toInt } else { // user provided custom Numeric, we cannot rely on arithmetic series formula if (isEmpty) num.toInt(num.zero) else { var acc = num.zero var i = head - while(i != terminalElement) { + while (true) { acc = num.plus(acc, i) + if (i == lastElement) return num.toInt(acc) i = i + step } - num.toInt(acc) + 0 // Never hit this--just to satisfy compiler since it doesn't know while(true) has type Nothing } } } diff --git a/test/files/run/t4658.check b/test/files/run/t4658.check index bb6405175e..3bc52daef3 100644 --- a/test/files/run/t4658.check +++ b/test/files/run/t4658.check @@ -1,5 +1,5 @@ Ranges: -1073741824 +-1073741824 1073741824 0 0 @@ -20,7 +20,7 @@ Ranges: -10 IntRanges: -1073741824 --1073741824 +1073741824 0 0 55 @@ -78,3 +78,6 @@ BigIntRanges: -24 -30 -10 +BigInt agrees with Long: true +Long agrees with Int when rounded: true +Numeric Int agrees with Range: true diff --git a/test/files/run/t4658.scala b/test/files/run/t4658.scala index 8c07c50694..7fc6d4584c 100644 --- a/test/files/run/t4658.scala +++ b/test/files/run/t4658.scala @@ -2,6 +2,7 @@ import scala.collection.immutable.NumericRange //#4658 object Test { + // Only works for Int values! Need to rethink explicit otherwise. case class R(start: Int, end: Int, step: Int = 1, inclusive: Boolean = true) val rangeData = Array( @@ -28,6 +29,14 @@ object Test { numericLongRanges.foreach{range => println(range.sum)} println("BigIntRanges:") numericBigIntRanges.foreach{range => println(range.sum)} + println("BigInt agrees with Long: " + + (numericLongRanges zip numericBigIntRanges).forall{ case (lr, bir) => lr.sum == bir.sum } + ) + println("Long agrees with Int when rounded: " + + (numericLongRanges zip numericIntRanges).forall{ case (lr, ir) => lr.sum.toInt == ir.sum } + ) + println("Numeric Int agrees with Range: " + + (numericIntRanges zip ranges).forall{ case (ir, r) => ir.sum == r.sum } + ) } - } \ No newline at end of file diff --git a/test/files/scalacheck/range.scala b/test/files/scalacheck/range.scala index 493083a51f..ac24b52f8d 100644 --- a/test/files/scalacheck/range.scala +++ b/test/files/scalacheck/range.scala @@ -134,7 +134,22 @@ abstract class RangeTest(kind: String) extends Properties("Range "+kind) { val expected = r.length match { case 0 => 0 case 1 => r.head - case _ => ((r.head + r.last).toLong * r.length / 2).toInt + case x if x < 1000 => + // Explicit sum, to guard against having the same mistake in both the + // range implementation and test implementation of sum formula. + // (Yes, this happened before.) + var i = r.head + var s = 0L + var n = x + while (n > 0) { + s += i + i += r.step + n -= 1 + } + s.toInt + case _ => + // Make sure head + last doesn't overflow! + ((r.head.toLong + r.last) * r.length / 2).toInt } // println("size: " + r.length) // println("expected: " + expected) diff --git a/test/junit/scala/collection/immutable/RangeConsistencyTest.scala b/test/junit/scala/collection/immutable/RangeConsistencyTest.scala index 135796979d..760498c162 100644 --- a/test/junit/scala/collection/immutable/RangeConsistencyTest.scala +++ b/test/junit/scala/collection/immutable/RangeConsistencyTest.scala @@ -148,4 +148,28 @@ class RangeConsistencyTest { val bdRange = bd(-10.0) until bd(0.0) by bd(4.5) assert( bdRange sameElements List(bd(-10.0), bd(-5.5), bd(-1.0)) ) } + + @Test + def test_SI9388() { + val possiblyNotDefaultNumeric = new scala.math.Numeric[Int] { + def fromInt(x: Int) = x + def minus(x: Int, y: Int): Int = x - y + def negate(x: Int): Int = -x + def plus(x: Int, y: Int): Int = x + y + def times(x: Int, y: Int): Int = x*y + def toDouble(x: Int): Double = x.toDouble + def toFloat(x: Int): Float = x.toFloat + def toInt(x: Int): Int = x + def toLong(x: Int): Long = x.toLong + def compare(x: Int, y: Int) = x compare y + } + val r = (Int.MinValue to Int.MaxValue by (1<<23)) + val nr = NumericRange(Int.MinValue, Int.MaxValue, 1 << 23) + assert({ var i = 0; r.foreach(_ => i += 1); i } == 512) + assert({ var i = 0; nr.foreach(_ => i += 1); i } == 512) + assert(r.sum == Int.MinValue) + assert(nr.sum == Int.MinValue) + assert(r.sum(possiblyNotDefaultNumeric) == Int.MinValue) + assert(nr.sum(possiblyNotDefaultNumeric) == Int.MinValue) + } } -- cgit v1.2.3