summaryrefslogtreecommitdiff
path: root/src/library
diff options
context:
space:
mode:
authorRex Kerr <ichoran@gmail.com>2013-12-30 18:32:18 -0800
committerRex Kerr <ichoran@gmail.com>2014-01-15 16:16:50 -0800
commit994de8ffd1341a6cb1d7e9dcd617d170e116fde5 (patch)
treeb04f8af2e9ae029f106b00e0785994413e886b13 /src/library
parent681308a3aa737be1dae0f702fddadce88c70f90e (diff)
downloadscala-994de8ffd1341a6cb1d7e9dcd617d170e116fde5.tar.gz
scala-994de8ffd1341a6cb1d7e9dcd617d170e116fde5.tar.bz2
scala-994de8ffd1341a6cb1d7e9dcd617d170e116fde5.zip
SI-4370 Range bug: Wrong result for Long.MinValue to Long.MaxValue by Int.MaxValue
Fixed by rewriting the entire logic for the count method. This is necessary because the old code was making all kinds of assumptions about what numbers were, but the interface is completely generic. Those assumptions still made have been explicitly specified. Note that you have to make some or you end up doing a binary search, which is not exactly fast. The existing routine is 10-20% slower than the old (broken) one in the worst cases. This seems close enough to me to not bother special-casing Long and BigInt, though I note that this could be done for improved performance. Note that ranges that end up in Int ranges defer to Range for count. We can't assume that one is the smallest increment, so both endpoints and the step need to be Int. A new JUnit test has been added to verify that the test works. It secretly contains an alternate BigInt implementation, but that is a lot slower (>5x) than Long.
Diffstat (limited to 'src/library')
-rw-r--r--src/library/scala/collection/immutable/NumericRange.scala91
1 files changed, 71 insertions, 20 deletions
diff --git a/src/library/scala/collection/immutable/NumericRange.scala b/src/library/scala/collection/immutable/NumericRange.scala
index 249d76584d..f1ac161e9a 100644
--- a/src/library/scala/collection/immutable/NumericRange.scala
+++ b/src/library/scala/collection/immutable/NumericRange.scala
@@ -241,28 +241,79 @@ object NumericRange {
else if (start == end) if (isInclusive) 1 else 0
else if (upward != posStep) 0
else {
- val diff = num.minus(end, start)
- val jumps = num.toLong(num.quot(diff, step))
- val remainder = num.rem(diff, step)
- val longCount = jumps + (
- if (!isInclusive && zero == remainder) 0 else 1
- )
-
- /* The edge cases keep coming. Since e.g.
- * Long.MaxValue + 1 == Long.MinValue
- * we do some more improbable seeming checks lest
- * overflow turn up as an empty range.
+ /* We have to be frightfully paranoid about running out of range.
+ * We also can't assume that the numbers will fit in a Long.
+ * We will assume that if a > 0, -a can be represented, and if
+ * a < 0, -a+1 can be represented. We also assume that if we
+ * can't fit in Int, we can represent 2*Int.MaxValue+3 (at least).
+ * And we assume that numbers wrap rather than cap when they overflow.
*/
- // The second condition contradicts an empty result.
- val isOverflow = longCount == 0 && num.lt(num.plus(start, step), end) == upward
-
- if (longCount > scala.Int.MaxValue || longCount < 0L || isOverflow) {
- val word = if (isInclusive) "to" else "until"
- val descr = List(start, word, end, "by", step) mkString " "
-
- throw new IllegalArgumentException(descr + ": seqs cannot contain more than Int.MaxValue elements.")
+ // Check whether we can short-circuit by deferring to Int range.
+ val startint = num.toInt(start)
+ if (start == num.fromInt(startint)) {
+ val endint = num.toInt(end)
+ if (end == num.fromInt(endint)) {
+ val stepint = num.toInt(step)
+ if (step == num.fromInt(stepint)) {
+ return {
+ if (isInclusive) Range.inclusive(startint, endint, stepint).length
+ else Range (startint, endint, stepint).length
+ }
+ }
+ }
+ }
+ // If we reach this point, deferring to Int failed.
+ // Numbers may be big.
+ val one = num.one
+ val limit = num.fromInt(Int.MaxValue)
+ def check(t: T): T =
+ if (num.gt(t, limit)) throw new IllegalArgumentException("More than Int.MaxValue elements.")
+ else t
+ // If the range crosses zero, it might overflow when subtracted
+ val startside = num.signum(start)
+ val endside = num.signum(end)
+ num.toInt{
+ if (startside*endside >= 0) {
+ // We're sure we can subtract these numbers.
+ // Note that we do not use .rem because of different conventions for Long and BigInt
+ val diff = num.minus(end, start)
+ val quotient = check(num.quot(diff, step))
+ val remainder = num.minus(diff, num.times(quotient, step))
+ if (!isInclusive && zero == remainder) quotient else check(num.plus(quotient, one))
+ }
+ else {
+ // We might not even be able to subtract these numbers.
+ // Jump in three pieces:
+ // * start to -1 or 1, whichever is closer (waypointA)
+ // * one step, which will take us at least to 0 (ends at waypointB)
+ // * there to the end
+ val negone = num.fromInt(-1)
+ val startlim = if (posStep) negone else one
+ val startdiff = num.minus(startlim, start)
+ val startq = check(num.quot(startdiff, step))
+ val waypointA = if (startq == zero) start else num.plus(start, num.times(startq, step))
+ val waypointB = num.plus(waypointA, step)
+ check {
+ if (num.lt(waypointB, end) != upward) {
+ // No last piece
+ if (isInclusive && waypointB == end) num.plus(startq, num.fromInt(2))
+ else num.plus(startq, one)
+ }
+ else {
+ // There is a last piece
+ val enddiff = num.minus(end,waypointB)
+ val endq = check(num.quot(enddiff, step))
+ val last = if (endq == zero) waypointB else num.plus(waypointB, num.times(endq, step))
+ // Now we have to tally up all the pieces
+ // 1 for the initial value
+ // startq steps to waypointA
+ // 1 step to waypointB
+ // endq steps to the end (one less if !isInclusive and last==end)
+ num.plus(startq, num.plus(endq, if (!isInclusive && last==end) one else num.fromInt(2)))
+ }
+ }
+ }
}
- longCount.toInt
}
}