diff options
Diffstat (limited to 'src/library')
-rw-r--r-- | src/library/scala/collection/immutable/NumericRange.scala | 91 |
1 files changed, 71 insertions, 20 deletions
diff --git a/src/library/scala/collection/immutable/NumericRange.scala b/src/library/scala/collection/immutable/NumericRange.scala index 249d76584d..f1ac161e9a 100644 --- a/src/library/scala/collection/immutable/NumericRange.scala +++ b/src/library/scala/collection/immutable/NumericRange.scala @@ -241,28 +241,79 @@ object NumericRange { else if (start == end) if (isInclusive) 1 else 0 else if (upward != posStep) 0 else { - val diff = num.minus(end, start) - val jumps = num.toLong(num.quot(diff, step)) - val remainder = num.rem(diff, step) - val longCount = jumps + ( - if (!isInclusive && zero == remainder) 0 else 1 - ) - - /* The edge cases keep coming. Since e.g. - * Long.MaxValue + 1 == Long.MinValue - * we do some more improbable seeming checks lest - * overflow turn up as an empty range. + /* We have to be frightfully paranoid about running out of range. + * We also can't assume that the numbers will fit in a Long. + * We will assume that if a > 0, -a can be represented, and if + * a < 0, -a+1 can be represented. We also assume that if we + * can't fit in Int, we can represent 2*Int.MaxValue+3 (at least). + * And we assume that numbers wrap rather than cap when they overflow. */ - // The second condition contradicts an empty result. - val isOverflow = longCount == 0 && num.lt(num.plus(start, step), end) == upward - - if (longCount > scala.Int.MaxValue || longCount < 0L || isOverflow) { - val word = if (isInclusive) "to" else "until" - val descr = List(start, word, end, "by", step) mkString " " - - throw new IllegalArgumentException(descr + ": seqs cannot contain more than Int.MaxValue elements.") + // Check whether we can short-circuit by deferring to Int range. + val startint = num.toInt(start) + if (start == num.fromInt(startint)) { + val endint = num.toInt(end) + if (end == num.fromInt(endint)) { + val stepint = num.toInt(step) + if (step == num.fromInt(stepint)) { + return { + if (isInclusive) Range.inclusive(startint, endint, stepint).length + else Range (startint, endint, stepint).length + } + } + } + } + // If we reach this point, deferring to Int failed. + // Numbers may be big. + val one = num.one + val limit = num.fromInt(Int.MaxValue) + def check(t: T): T = + if (num.gt(t, limit)) throw new IllegalArgumentException("More than Int.MaxValue elements.") + else t + // If the range crosses zero, it might overflow when subtracted + val startside = num.signum(start) + val endside = num.signum(end) + num.toInt{ + if (startside*endside >= 0) { + // We're sure we can subtract these numbers. + // Note that we do not use .rem because of different conventions for Long and BigInt + val diff = num.minus(end, start) + val quotient = check(num.quot(diff, step)) + val remainder = num.minus(diff, num.times(quotient, step)) + if (!isInclusive && zero == remainder) quotient else check(num.plus(quotient, one)) + } + else { + // We might not even be able to subtract these numbers. + // Jump in three pieces: + // * start to -1 or 1, whichever is closer (waypointA) + // * one step, which will take us at least to 0 (ends at waypointB) + // * there to the end + val negone = num.fromInt(-1) + val startlim = if (posStep) negone else one + val startdiff = num.minus(startlim, start) + val startq = check(num.quot(startdiff, step)) + val waypointA = if (startq == zero) start else num.plus(start, num.times(startq, step)) + val waypointB = num.plus(waypointA, step) + check { + if (num.lt(waypointB, end) != upward) { + // No last piece + if (isInclusive && waypointB == end) num.plus(startq, num.fromInt(2)) + else num.plus(startq, one) + } + else { + // There is a last piece + val enddiff = num.minus(end,waypointB) + val endq = check(num.quot(enddiff, step)) + val last = if (endq == zero) waypointB else num.plus(waypointB, num.times(endq, step)) + // Now we have to tally up all the pieces + // 1 for the initial value + // startq steps to waypointA + // 1 step to waypointB + // endq steps to the end (one less if !isInclusive and last==end) + num.plus(startq, num.plus(endq, if (!isInclusive && last==end) one else num.fromInt(2))) + } + } + } } - longCount.toInt } } |