summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLukas Rytz <lukas.rytz@typesafe.com>2015-09-21 15:53:52 +0200
committerLukas Rytz <lukas.rytz@typesafe.com>2015-09-21 15:53:52 +0200
commit7eadd684678611dc0f6710f9fdd14ff52bf8fb78 (patch)
tree4e625c025ecf8ce5a9a4043dc13c77e54ba884f1
parent7f1defca22b719b16814626b0218d12e24d5f149 (diff)
parent00d3f103b3db5530bfbf6b565843d0938a3cef48 (diff)
downloadscala-7eadd684678611dc0f6710f9fdd14ff52bf8fb78.tar.gz
scala-7eadd684678611dc0f6710f9fdd14ff52bf8fb78.tar.bz2
scala-7eadd684678611dc0f6710f9fdd14ff52bf8fb78.zip
Merge pull request #4716 from Ichoran/issue/9388
SI-9388 Fix Range behavior around Int.MaxValue
-rw-r--r--src/library/scala/collection/immutable/NumericRange.scala86
-rw-r--r--src/library/scala/collection/immutable/Range.scala29
-rw-r--r--test/files/run/t4658.check7
-rw-r--r--test/files/run/t4658.scala11
-rw-r--r--test/files/scalacheck/range.scala17
-rw-r--r--test/junit/scala/collection/immutable/RangeConsistencyTest.scala24
6 files changed, 128 insertions, 46 deletions
diff --git a/src/library/scala/collection/immutable/NumericRange.scala b/src/library/scala/collection/immutable/NumericRange.scala
index 28e56a6d87..11603a118b 100644
--- a/src/library/scala/collection/immutable/NumericRange.scala
+++ b/src/library/scala/collection/immutable/NumericRange.scala
@@ -175,34 +175,68 @@ extends AbstractSeq[T] with IndexedSeq[T] with Serializable {
catch { case _: ClassCastException => false }
final override def sum[B >: T](implicit num: Numeric[B]): B = {
- // arithmetic series formula can be used for regular addition
- if ((num eq scala.math.Numeric.IntIsIntegral)||
- (num eq scala.math.Numeric.BigIntIsIntegral)||
- (num eq scala.math.Numeric.ShortIsIntegral)||
- (num eq scala.math.Numeric.ByteIsIntegral)||
- (num eq scala.math.Numeric.CharIsIntegral)||
- (num eq scala.math.Numeric.LongIsIntegral)||
- (num eq scala.math.Numeric.FloatAsIfIntegral)||
- (num eq scala.math.Numeric.BigDecimalIsFractional)||
- (num eq scala.math.Numeric.DoubleAsIfIntegral)) {
- val numAsIntegral = num.asInstanceOf[Integral[B]]
- import numAsIntegral._
- if (isEmpty) num fromInt 0
- else if (numRangeElements == 1) head
- else ((num fromInt numRangeElements) * (head + last) / (num fromInt 2))
- } else {
- // user provided custom Numeric, we cannot rely on arithmetic series formula
- if (isEmpty) num.zero
+ if (isEmpty) num.zero
+ else if (numRangeElements == 1) head
+ else {
+ // If there is no overflow, use arithmetic series formula
+ // a + ... (n terms total) ... + b = n*(a+b)/2
+ if ((num eq scala.math.Numeric.IntIsIntegral)||
+ (num eq scala.math.Numeric.ShortIsIntegral)||
+ (num eq scala.math.Numeric.ByteIsIntegral)||
+ (num eq scala.math.Numeric.CharIsIntegral)) {
+ // We can do math with no overflow in a Long--easy
+ val exact = (numRangeElements * ((num toLong head) + (num toInt last))) / 2
+ num fromInt exact.toInt
+ }
+ else if (num eq scala.math.Numeric.LongIsIntegral) {
+ // Uh-oh, might be overflow, so we have to divide before we overflow.
+ // Either numRangeElements or (head + last) must be even, so divide the even one before multiplying
+ val a = head.toLong
+ val b = last.toLong
+ val ans =
+ if ((numRangeElements & 1) == 0) (numRangeElements / 2) * (a + b)
+ else numRangeElements * {
+ // Sum is even, but we might overflow it, so divide in pieces and add back remainder
+ val ha = a/2
+ val hb = b/2
+ ha + hb + ((a - 2*ha) + (b - 2*hb)) / 2
+ }
+ ans.asInstanceOf[B]
+ }
+ else if ((num eq scala.math.Numeric.FloatAsIfIntegral) ||
+ (num eq scala.math.Numeric.DoubleAsIfIntegral)) {
+ // Try to compute sum with reasonable accuracy, avoiding over/underflow
+ val numAsIntegral = num.asInstanceOf[Integral[B]]
+ import numAsIntegral._
+ val a = math.abs(head.toDouble)
+ val b = math.abs(last.toDouble)
+ val two = num fromInt 2
+ val nre = num fromInt numRangeElements
+ if (a > 1e38 || b > 1e38) nre * ((head / two) + (last / two)) // Compute in parts to avoid Infinity if possible
+ else (nre / two) * (head + last) // Don't need to worry about infinity; this will be more accurate and avoid underflow
+ }
+ else if ((num eq scala.math.Numeric.BigIntIsIntegral) ||
+ (num eq scala.math.Numeric.BigDecimalIsFractional)) {
+ // No overflow, so we can use arithmetic series formula directly
+ // (not going to worry about running out of memory)
+ val numAsIntegral = num.asInstanceOf[Integral[B]]
+ import numAsIntegral._
+ ((num fromInt numRangeElements) * (head + last)) / (num fromInt 2)
+ }
else {
- var acc = num.zero
- var i = head
- var idx = 0
- while(idx < length) {
- acc = num.plus(acc, i)
- i = i + step
- idx = idx + 1
+ // User provided custom Numeric, so we cannot rely on arithmetic series formula (e.g. won't work on something like Z_6)
+ if (isEmpty) num.zero
+ else {
+ var acc = num.zero
+ var i = head
+ var idx = 0
+ while(idx < length) {
+ acc = num.plus(acc, i)
+ i = i + step
+ idx = idx + 1
+ }
+ acc
}
- acc
}
}
}
diff --git a/src/library/scala/collection/immutable/Range.scala b/src/library/scala/collection/immutable/Range.scala
index 79cd673932..ca6720da19 100644
--- a/src/library/scala/collection/immutable/Range.scala
+++ b/src/library/scala/collection/immutable/Range.scala
@@ -153,19 +153,15 @@ extends scala.collection.AbstractSeq[Int]
}
@inline final override def foreach[@specialized(Unit) U](f: Int => U) {
- validateMaxLength()
- val isCommonCase = (start != Int.MinValue || end != Int.MinValue)
- var i = start
- var count = 0
- val terminal = terminalElement
- val step = this.step
- while(
- if(isCommonCase) { i != terminal }
- else { count < numRangeElements }
- ) {
- f(i)
- count += 1
- i += step
+ // Implementation chosen on the basis of favorable microbenchmarks
+ // Note--initialization catches step == 0 so we don't need to here
+ if (!isEmpty) {
+ var i = start
+ while (true) {
+ f(i)
+ if (i == lastElement) return
+ i += step
+ }
}
}
@@ -347,18 +343,19 @@ extends scala.collection.AbstractSeq[Int]
// this is normal integer range with usual addition. arithmetic series formula can be used
if (isEmpty) 0
else if (numRangeElements == 1) head
- else (numRangeElements.toLong * (head + last) / 2).toInt
+ else ((numRangeElements * (head.toLong + last)) / 2).toInt
} else {
// user provided custom Numeric, we cannot rely on arithmetic series formula
if (isEmpty) num.toInt(num.zero)
else {
var acc = num.zero
var i = head
- while(i != terminalElement) {
+ while (true) {
acc = num.plus(acc, i)
+ if (i == lastElement) return num.toInt(acc)
i = i + step
}
- num.toInt(acc)
+ 0 // Never hit this--just to satisfy compiler since it doesn't know while(true) has type Nothing
}
}
}
diff --git a/test/files/run/t4658.check b/test/files/run/t4658.check
index bb6405175e..3bc52daef3 100644
--- a/test/files/run/t4658.check
+++ b/test/files/run/t4658.check
@@ -1,5 +1,5 @@
Ranges:
-1073741824
+-1073741824
1073741824
0
0
@@ -20,7 +20,7 @@ Ranges:
-10
IntRanges:
-1073741824
--1073741824
+1073741824
0
0
55
@@ -78,3 +78,6 @@ BigIntRanges:
-24
-30
-10
+BigInt agrees with Long: true
+Long agrees with Int when rounded: true
+Numeric Int agrees with Range: true
diff --git a/test/files/run/t4658.scala b/test/files/run/t4658.scala
index 8c07c50694..7fc6d4584c 100644
--- a/test/files/run/t4658.scala
+++ b/test/files/run/t4658.scala
@@ -2,6 +2,7 @@ import scala.collection.immutable.NumericRange
//#4658
object Test {
+ // Only works for Int values! Need to rethink explicit otherwise.
case class R(start: Int, end: Int, step: Int = 1, inclusive: Boolean = true)
val rangeData = Array(
@@ -28,6 +29,14 @@ object Test {
numericLongRanges.foreach{range => println(range.sum)}
println("BigIntRanges:")
numericBigIntRanges.foreach{range => println(range.sum)}
+ println("BigInt agrees with Long: " +
+ (numericLongRanges zip numericBigIntRanges).forall{ case (lr, bir) => lr.sum == bir.sum }
+ )
+ println("Long agrees with Int when rounded: " +
+ (numericLongRanges zip numericIntRanges).forall{ case (lr, ir) => lr.sum.toInt == ir.sum }
+ )
+ println("Numeric Int agrees with Range: " +
+ (numericIntRanges zip ranges).forall{ case (ir, r) => ir.sum == r.sum }
+ )
}
-
} \ No newline at end of file
diff --git a/test/files/scalacheck/range.scala b/test/files/scalacheck/range.scala
index 493083a51f..ac24b52f8d 100644
--- a/test/files/scalacheck/range.scala
+++ b/test/files/scalacheck/range.scala
@@ -134,7 +134,22 @@ abstract class RangeTest(kind: String) extends Properties("Range "+kind) {
val expected = r.length match {
case 0 => 0
case 1 => r.head
- case _ => ((r.head + r.last).toLong * r.length / 2).toInt
+ case x if x < 1000 =>
+ // Explicit sum, to guard against having the same mistake in both the
+ // range implementation and test implementation of sum formula.
+ // (Yes, this happened before.)
+ var i = r.head
+ var s = 0L
+ var n = x
+ while (n > 0) {
+ s += i
+ i += r.step
+ n -= 1
+ }
+ s.toInt
+ case _ =>
+ // Make sure head + last doesn't overflow!
+ ((r.head.toLong + r.last) * r.length / 2).toInt
}
// println("size: " + r.length)
// println("expected: " + expected)
diff --git a/test/junit/scala/collection/immutable/RangeConsistencyTest.scala b/test/junit/scala/collection/immutable/RangeConsistencyTest.scala
index 135796979d..760498c162 100644
--- a/test/junit/scala/collection/immutable/RangeConsistencyTest.scala
+++ b/test/junit/scala/collection/immutable/RangeConsistencyTest.scala
@@ -148,4 +148,28 @@ class RangeConsistencyTest {
val bdRange = bd(-10.0) until bd(0.0) by bd(4.5)
assert( bdRange sameElements List(bd(-10.0), bd(-5.5), bd(-1.0)) )
}
+
+ @Test
+ def test_SI9388() {
+ val possiblyNotDefaultNumeric = new scala.math.Numeric[Int] {
+ def fromInt(x: Int) = x
+ def minus(x: Int, y: Int): Int = x - y
+ def negate(x: Int): Int = -x
+ def plus(x: Int, y: Int): Int = x + y
+ def times(x: Int, y: Int): Int = x*y
+ def toDouble(x: Int): Double = x.toDouble
+ def toFloat(x: Int): Float = x.toFloat
+ def toInt(x: Int): Int = x
+ def toLong(x: Int): Long = x.toLong
+ def compare(x: Int, y: Int) = x compare y
+ }
+ val r = (Int.MinValue to Int.MaxValue by (1<<23))
+ val nr = NumericRange(Int.MinValue, Int.MaxValue, 1 << 23)
+ assert({ var i = 0; r.foreach(_ => i += 1); i } == 512)
+ assert({ var i = 0; nr.foreach(_ => i += 1); i } == 512)
+ assert(r.sum == Int.MinValue)
+ assert(nr.sum == Int.MinValue)
+ assert(r.sum(possiblyNotDefaultNumeric) == Int.MinValue)
+ assert(nr.sum(possiblyNotDefaultNumeric) == Int.MinValue)
+ }
}