summaryrefslogtreecommitdiff
path: root/src/library
diff options
context:
space:
mode:
authorRex Kerr <ichoran@gmail.com>2015-08-29 16:02:53 -0700
committerRex Kerr <ichoran@gmail.com>2015-09-19 17:58:05 -0700
commit00d3f103b3db5530bfbf6b565843d0938a3cef48 (patch)
tree4f545dcd6486933766d42a4fcfd53fd70c92c08d /src/library
parentc287df96cf42084828d9528353d5c7ad5c0e4b3a (diff)
downloadscala-00d3f103b3db5530bfbf6b565843d0938a3cef48.tar.gz
scala-00d3f103b3db5530bfbf6b565843d0938a3cef48.tar.bz2
scala-00d3f103b3db5530bfbf6b565843d0938a3cef48.zip
SI-9388 Fix Range behavior around Int.MaxValue
terminalElement (the element _after_ the last one!) was used to terminate foreach loops and sums of non-standard instances of Numeric. Unfortunately, this could result in the end wrapping around and hitting the beginning again, making the first element bad. This patch fixes the behavior by altering the loop to end after the last element is encountered. The particular flavor was chosen out of a few possibilities because it gave the best microbenchmarks on both large and small ranges. Test written. While testing, a bug was also uncovered in NumericRange, and was also fixed. In brief, the logic around sum is rather complex since division is not unique when you have overflow. Floating point has its own complexities, too. Also updated incorrect test t4658 that insisted on incorrect answers (?!) and added logic to make sure it at least stays self-consistent, and fixed the range.scala test which used the same wrong (overflow-prone) formula that the Range collection did.
Diffstat (limited to 'src/library')
-rw-r--r--src/library/scala/collection/immutable/NumericRange.scala86
-rw-r--r--src/library/scala/collection/immutable/Range.scala29
2 files changed, 73 insertions, 42 deletions
diff --git a/src/library/scala/collection/immutable/NumericRange.scala b/src/library/scala/collection/immutable/NumericRange.scala
index 28e56a6d87..11603a118b 100644
--- a/src/library/scala/collection/immutable/NumericRange.scala
+++ b/src/library/scala/collection/immutable/NumericRange.scala
@@ -175,34 +175,68 @@ extends AbstractSeq[T] with IndexedSeq[T] with Serializable {
catch { case _: ClassCastException => false }
final override def sum[B >: T](implicit num: Numeric[B]): B = {
- // arithmetic series formula can be used for regular addition
- if ((num eq scala.math.Numeric.IntIsIntegral)||
- (num eq scala.math.Numeric.BigIntIsIntegral)||
- (num eq scala.math.Numeric.ShortIsIntegral)||
- (num eq scala.math.Numeric.ByteIsIntegral)||
- (num eq scala.math.Numeric.CharIsIntegral)||
- (num eq scala.math.Numeric.LongIsIntegral)||
- (num eq scala.math.Numeric.FloatAsIfIntegral)||
- (num eq scala.math.Numeric.BigDecimalIsFractional)||
- (num eq scala.math.Numeric.DoubleAsIfIntegral)) {
- val numAsIntegral = num.asInstanceOf[Integral[B]]
- import numAsIntegral._
- if (isEmpty) num fromInt 0
- else if (numRangeElements == 1) head
- else ((num fromInt numRangeElements) * (head + last) / (num fromInt 2))
- } else {
- // user provided custom Numeric, we cannot rely on arithmetic series formula
- if (isEmpty) num.zero
+ if (isEmpty) num.zero
+ else if (numRangeElements == 1) head
+ else {
+ // If there is no overflow, use arithmetic series formula
+ // a + ... (n terms total) ... + b = n*(a+b)/2
+ if ((num eq scala.math.Numeric.IntIsIntegral)||
+ (num eq scala.math.Numeric.ShortIsIntegral)||
+ (num eq scala.math.Numeric.ByteIsIntegral)||
+ (num eq scala.math.Numeric.CharIsIntegral)) {
+ // We can do math with no overflow in a Long--easy
+ val exact = (numRangeElements * ((num toLong head) + (num toInt last))) / 2
+ num fromInt exact.toInt
+ }
+ else if (num eq scala.math.Numeric.LongIsIntegral) {
+ // Uh-oh, might be overflow, so we have to divide before we overflow.
+ // Either numRangeElements or (head + last) must be even, so divide the even one before multiplying
+ val a = head.toLong
+ val b = last.toLong
+ val ans =
+ if ((numRangeElements & 1) == 0) (numRangeElements / 2) * (a + b)
+ else numRangeElements * {
+ // Sum is even, but we might overflow it, so divide in pieces and add back remainder
+ val ha = a/2
+ val hb = b/2
+ ha + hb + ((a - 2*ha) + (b - 2*hb)) / 2
+ }
+ ans.asInstanceOf[B]
+ }
+ else if ((num eq scala.math.Numeric.FloatAsIfIntegral) ||
+ (num eq scala.math.Numeric.DoubleAsIfIntegral)) {
+ // Try to compute sum with reasonable accuracy, avoiding over/underflow
+ val numAsIntegral = num.asInstanceOf[Integral[B]]
+ import numAsIntegral._
+ val a = math.abs(head.toDouble)
+ val b = math.abs(last.toDouble)
+ val two = num fromInt 2
+ val nre = num fromInt numRangeElements
+ if (a > 1e38 || b > 1e38) nre * ((head / two) + (last / two)) // Compute in parts to avoid Infinity if possible
+ else (nre / two) * (head + last) // Don't need to worry about infinity; this will be more accurate and avoid underflow
+ }
+ else if ((num eq scala.math.Numeric.BigIntIsIntegral) ||
+ (num eq scala.math.Numeric.BigDecimalIsFractional)) {
+ // No overflow, so we can use arithmetic series formula directly
+ // (not going to worry about running out of memory)
+ val numAsIntegral = num.asInstanceOf[Integral[B]]
+ import numAsIntegral._
+ ((num fromInt numRangeElements) * (head + last)) / (num fromInt 2)
+ }
else {
- var acc = num.zero
- var i = head
- var idx = 0
- while(idx < length) {
- acc = num.plus(acc, i)
- i = i + step
- idx = idx + 1
+ // User provided custom Numeric, so we cannot rely on arithmetic series formula (e.g. won't work on something like Z_6)
+ if (isEmpty) num.zero
+ else {
+ var acc = num.zero
+ var i = head
+ var idx = 0
+ while(idx < length) {
+ acc = num.plus(acc, i)
+ i = i + step
+ idx = idx + 1
+ }
+ acc
}
- acc
}
}
}
diff --git a/src/library/scala/collection/immutable/Range.scala b/src/library/scala/collection/immutable/Range.scala
index 79cd673932..ca6720da19 100644
--- a/src/library/scala/collection/immutable/Range.scala
+++ b/src/library/scala/collection/immutable/Range.scala
@@ -153,19 +153,15 @@ extends scala.collection.AbstractSeq[Int]
}
@inline final override def foreach[@specialized(Unit) U](f: Int => U) {
- validateMaxLength()
- val isCommonCase = (start != Int.MinValue || end != Int.MinValue)
- var i = start
- var count = 0
- val terminal = terminalElement
- val step = this.step
- while(
- if(isCommonCase) { i != terminal }
- else { count < numRangeElements }
- ) {
- f(i)
- count += 1
- i += step
+ // Implementation chosen on the basis of favorable microbenchmarks
+ // Note--initialization catches step == 0 so we don't need to here
+ if (!isEmpty) {
+ var i = start
+ while (true) {
+ f(i)
+ if (i == lastElement) return
+ i += step
+ }
}
}
@@ -347,18 +343,19 @@ extends scala.collection.AbstractSeq[Int]
// this is normal integer range with usual addition. arithmetic series formula can be used
if (isEmpty) 0
else if (numRangeElements == 1) head
- else (numRangeElements.toLong * (head + last) / 2).toInt
+ else ((numRangeElements * (head.toLong + last)) / 2).toInt
} else {
// user provided custom Numeric, we cannot rely on arithmetic series formula
if (isEmpty) num.toInt(num.zero)
else {
var acc = num.zero
var i = head
- while(i != terminalElement) {
+ while (true) {
acc = num.plus(acc, i)
+ if (i == lastElement) return num.toInt(acc)
i = i + step
}
- num.toInt(acc)
+ 0 // Never hit this--just to satisfy compiler since it doesn't know while(true) has type Nothing
}
}
}