diff options
author | Rex Kerr <ichoran@gmail.com> | 2013-12-30 18:32:18 -0800 |
---|---|---|
committer | Rex Kerr <ichoran@gmail.com> | 2014-01-15 16:16:50 -0800 |
commit | 994de8ffd1341a6cb1d7e9dcd617d170e116fde5 (patch) | |
tree | b04f8af2e9ae029f106b00e0785994413e886b13 | |
parent | 681308a3aa737be1dae0f702fddadce88c70f90e (diff) | |
download | scala-994de8ffd1341a6cb1d7e9dcd617d170e116fde5.tar.gz scala-994de8ffd1341a6cb1d7e9dcd617d170e116fde5.tar.bz2 scala-994de8ffd1341a6cb1d7e9dcd617d170e116fde5.zip |
SI-4370 Range bug: Wrong result for Long.MinValue to Long.MaxValue by Int.MaxValue
Fixed by rewriting the entire logic for the count method. This is necessary because the old code was making all kinds of assumptions about what numbers were, but the interface is completely generic.
Those assumptions still made have been explicitly specified. Note that you have to make some or you end up doing a binary search, which is not exactly fast.
The existing routine is 10-20% slower than the old (broken) one in the worst cases. This seems close enough to me to not bother special-casing Long and BigInt, though I note that this could be done for improved performance.
Note that ranges that end up in Int ranges defer to Range for count. We can't assume that one is the smallest increment, so both endpoints and the step need to be Int.
A new JUnit test has been added to verify that the test works. It secretly contains an alternate BigInt implementation, but that is a lot slower (>5x) than Long.
-rw-r--r-- | src/library/scala/collection/immutable/NumericRange.scala | 91 | ||||
-rw-r--r-- | test/junit/scala/collection/NumericRangeTest.scala | 123 |
2 files changed, 194 insertions, 20 deletions
diff --git a/src/library/scala/collection/immutable/NumericRange.scala b/src/library/scala/collection/immutable/NumericRange.scala index 249d76584d..f1ac161e9a 100644 --- a/src/library/scala/collection/immutable/NumericRange.scala +++ b/src/library/scala/collection/immutable/NumericRange.scala @@ -241,28 +241,79 @@ object NumericRange { else if (start == end) if (isInclusive) 1 else 0 else if (upward != posStep) 0 else { - val diff = num.minus(end, start) - val jumps = num.toLong(num.quot(diff, step)) - val remainder = num.rem(diff, step) - val longCount = jumps + ( - if (!isInclusive && zero == remainder) 0 else 1 - ) - - /* The edge cases keep coming. Since e.g. - * Long.MaxValue + 1 == Long.MinValue - * we do some more improbable seeming checks lest - * overflow turn up as an empty range. + /* We have to be frightfully paranoid about running out of range. + * We also can't assume that the numbers will fit in a Long. + * We will assume that if a > 0, -a can be represented, and if + * a < 0, -a+1 can be represented. We also assume that if we + * can't fit in Int, we can represent 2*Int.MaxValue+3 (at least). + * And we assume that numbers wrap rather than cap when they overflow. */ - // The second condition contradicts an empty result. - val isOverflow = longCount == 0 && num.lt(num.plus(start, step), end) == upward - - if (longCount > scala.Int.MaxValue || longCount < 0L || isOverflow) { - val word = if (isInclusive) "to" else "until" - val descr = List(start, word, end, "by", step) mkString " " - - throw new IllegalArgumentException(descr + ": seqs cannot contain more than Int.MaxValue elements.") + // Check whether we can short-circuit by deferring to Int range. + val startint = num.toInt(start) + if (start == num.fromInt(startint)) { + val endint = num.toInt(end) + if (end == num.fromInt(endint)) { + val stepint = num.toInt(step) + if (step == num.fromInt(stepint)) { + return { + if (isInclusive) Range.inclusive(startint, endint, stepint).length + else Range (startint, endint, stepint).length + } + } + } + } + // If we reach this point, deferring to Int failed. + // Numbers may be big. + val one = num.one + val limit = num.fromInt(Int.MaxValue) + def check(t: T): T = + if (num.gt(t, limit)) throw new IllegalArgumentException("More than Int.MaxValue elements.") + else t + // If the range crosses zero, it might overflow when subtracted + val startside = num.signum(start) + val endside = num.signum(end) + num.toInt{ + if (startside*endside >= 0) { + // We're sure we can subtract these numbers. + // Note that we do not use .rem because of different conventions for Long and BigInt + val diff = num.minus(end, start) + val quotient = check(num.quot(diff, step)) + val remainder = num.minus(diff, num.times(quotient, step)) + if (!isInclusive && zero == remainder) quotient else check(num.plus(quotient, one)) + } + else { + // We might not even be able to subtract these numbers. + // Jump in three pieces: + // * start to -1 or 1, whichever is closer (waypointA) + // * one step, which will take us at least to 0 (ends at waypointB) + // * there to the end + val negone = num.fromInt(-1) + val startlim = if (posStep) negone else one + val startdiff = num.minus(startlim, start) + val startq = check(num.quot(startdiff, step)) + val waypointA = if (startq == zero) start else num.plus(start, num.times(startq, step)) + val waypointB = num.plus(waypointA, step) + check { + if (num.lt(waypointB, end) != upward) { + // No last piece + if (isInclusive && waypointB == end) num.plus(startq, num.fromInt(2)) + else num.plus(startq, one) + } + else { + // There is a last piece + val enddiff = num.minus(end,waypointB) + val endq = check(num.quot(enddiff, step)) + val last = if (endq == zero) waypointB else num.plus(waypointB, num.times(endq, step)) + // Now we have to tally up all the pieces + // 1 for the initial value + // startq steps to waypointA + // 1 step to waypointB + // endq steps to the end (one less if !isInclusive and last==end) + num.plus(startq, num.plus(endq, if (!isInclusive && last==end) one else num.fromInt(2))) + } + } + } } - longCount.toInt } } diff --git a/test/junit/scala/collection/NumericRangeTest.scala b/test/junit/scala/collection/NumericRangeTest.scala new file mode 100644 index 0000000000..0260723b9d --- /dev/null +++ b/test/junit/scala/collection/NumericRangeTest.scala @@ -0,0 +1,123 @@ +package scala.collection.immutable + +import org.junit.runner.RunWith +import org.junit.runners.JUnit4 +import org.junit.Test +import scala.math._ +import scala.util._ + +/* Tests various maps by making sure they all agree on the same answers. */ +@RunWith(classOf[JUnit4]) +class RangeConsistencyTest { + def r2nr[T: Integral]( + r: Range, puff: T, stride: T, check: (T,T) => Boolean, bi: T => BigInt + ): List[(BigInt,Try[Int])] = { + val num = implicitly[Integral[T]] + import num._ + val one = num.one + + if (!check(puff, fromInt(r.start))) return Nil + val start = puff * fromInt(r.start) + val sp1 = start + one + val sn1 = start - one + + if (!check(puff, fromInt(r.end))) return Nil + val end = puff * fromInt(r.end) + val ep1 = end + one + val en1 = end - one + + if (!check(stride, fromInt(r.step))) return Nil + val step = stride * fromInt(r.step) + + def NR(s: T, e: T, i: T) = { + val delta = (bi(e) - bi(s)).abs - (if (r.isInclusive) 0 else 1) + val n = if (r.length == 0) BigInt(0) else delta / bi(i).abs + 1 + if (r.isInclusive) { + (n, Try(NumericRange.inclusive(s,e,i).length)) + } + else { + (n, Try(NumericRange(s,e,i).length)) + } + } + + List(NR(start, end, step)) ::: + (if (sn1 < start) List(NR(sn1, end, step)) else Nil) ::: + (if (start < sp1) List(NR(sp1, end, step)) else Nil) ::: + (if (en1 < end) List(NR(start, en1, step)) else Nil) ::: + (if (end < ep1) List(NR(start, ep1, step)) else Nil) + } + + // Motivated by SI-4370: Wrong result for Long.MinValue to Long.MaxValue by Int.MaxValue + @Test + def rangeChurnTest() { + val rn = new Random(4370) + for (i <- 0 to 10000) { control.Breaks.breakable { + val start = rn.nextInt + val end = rn.nextInt + val step = rn.nextInt(4) match { + case 0 => 1 + case 1 => -1 + case 2 => (rn.nextInt(11)+2)*(2*rn.nextInt(2)+1) + case 3 => var x = rn.nextInt; while (x==0) x = rn.nextInt; x + } + val r = if (rn.nextBoolean) Range.inclusive(start, end, step) else Range(start, end, step) + + try { r.length } + catch { case iae: IllegalArgumentException => control.Breaks.break } + + val lpuff = rn.nextInt(4) match { + case 0 => 1L + case 1 => rn.nextInt(11)+2L + case 2 => 1L << rn.nextInt(60) + case 3 => math.max(1L, math.abs(rn.nextLong)) + } + val lstride = rn.nextInt(4) match { + case 0 => lpuff + case 1 => 1L + case 2 => 1L << rn.nextInt(60) + case 3 => math.max(1L, math.abs(rn.nextLong)) + } + val lr = r2nr[Long]( + r, lpuff, lstride, + (a,b) => { val x = BigInt(a)*BigInt(b); x.isValidLong }, + x => BigInt(x) + ) + + lr.foreach{ case (n,t) => assert( + t match { + case Failure(_) => n > Int.MaxValue + case Success(m) => n == m + }, + (r.start, r.end, r.step, r.isInclusive, lpuff, lstride, n, t) + )} + + val bipuff = rn.nextInt(3) match { + case 0 => BigInt(1) + case 1 => BigInt(rn.nextLong) + Long.MaxValue + 2 + case 2 => BigInt("1" + "0"*(rn.nextInt(100)+1)) + } + val bistride = rn.nextInt(3) match { + case 0 => bipuff + case 1 => BigInt(1) + case 2 => BigInt("1" + "0"*(rn.nextInt(100)+1)) + } + val bir = r2nr[BigInt](r, bipuff, bistride, (a,b) => true, identity) + + bir.foreach{ case (n,t) => assert( + t match { + case Failure(_) => n > Int.MaxValue + case Success(m) => n == m + }, + (r.start, r.end, r.step, r.isInclusive, bipuff, bistride, n, t) + )} + }} + } + + @Test + def testSI4370() { assert{ + Try((Long.MinValue to Long.MaxValue by Int.MaxValue).length) match { + case Failure(iae: IllegalArgumentException) => true + case _ => false + } + }} +} |