diff options
author | Adriaan Moors <adriaan.moors@typesafe.com> | 2014-01-09 18:20:23 -0800 |
---|---|---|
committer | Adriaan Moors <adriaan.moors@typesafe.com> | 2014-01-09 18:20:23 -0800 |
commit | 513b9e0a715b4cd515da63cf1a20b195d7a3fee0 (patch) | |
tree | 6aa6bc57571ba645b784e99c1f4cd5196a061383 | |
parent | 8e62486cc9b519e783fc194a63ae61a7b20e2fce (diff) | |
parent | 4b6a0a999e935a94501da272a12956c024141cb2 (diff) | |
download | scala-513b9e0a715b4cd515da63cf1a20b195d7a3fee0.tar.gz scala-513b9e0a715b4cd515da63cf1a20b195d7a3fee0.tar.bz2 scala-513b9e0a715b4cd515da63cf1a20b195d7a3fee0.zip |
Merge pull request #3124 from DarkDimius/fix-7443
Fix SI-7443 Range.sum ignoring Numeric argument and always assuming default 'plus' implementation
-rw-r--r-- | src/library/scala/collection/immutable/NumericRange.scala | 33 | ||||
-rw-r--r-- | src/library/scala/collection/immutable/Range.scala | 21 | ||||
-rw-r--r-- | test/files/scalacheck/range.scala | 41 |
3 files changed, 89 insertions, 6 deletions
diff --git a/src/library/scala/collection/immutable/NumericRange.scala b/src/library/scala/collection/immutable/NumericRange.scala index 486c2b6c8f..249d76584d 100644 --- a/src/library/scala/collection/immutable/NumericRange.scala +++ b/src/library/scala/collection/immutable/NumericRange.scala @@ -175,9 +175,36 @@ extends AbstractSeq[T] with IndexedSeq[T] with Serializable { catch { case _: ClassCastException => false } final override def sum[B >: T](implicit num: Numeric[B]): B = { - if (isEmpty) this.num fromInt 0 - else if (numRangeElements == 1) head - else ((this.num fromInt numRangeElements) * (head + last) / (this.num fromInt 2)) + // arithmetic series formula can be used for regular addition + if ((num eq scala.math.Numeric.IntIsIntegral)|| + (num eq scala.math.Numeric.BigIntIsIntegral)|| + (num eq scala.math.Numeric.ShortIsIntegral)|| + (num eq scala.math.Numeric.ByteIsIntegral)|| + (num eq scala.math.Numeric.CharIsIntegral)|| + (num eq scala.math.Numeric.LongIsIntegral)|| + (num eq scala.math.Numeric.FloatAsIfIntegral)|| + (num eq scala.math.Numeric.BigDecimalIsFractional)|| + (num eq scala.math.Numeric.DoubleAsIfIntegral)) { + val numAsIntegral = num.asInstanceOf[Integral[B]] + import numAsIntegral._ + if (isEmpty) num fromInt 0 + else if (numRangeElements == 1) head + else ((num fromInt numRangeElements) * (head + last) / (num fromInt 2)) + } else { + // user provided custom Numeric, we cannot rely on arithmetic series formula + if (isEmpty) num.zero + else { + var acc = num.zero + var i = head + var idx = 0 + while(idx < length) { + acc = num.plus(acc, i) + i = i + step + idx = idx + 1 + } + acc + } + } } override lazy val hashCode = super.hashCode() diff --git a/src/library/scala/collection/immutable/Range.scala b/src/library/scala/collection/immutable/Range.scala index 00f398a4b0..786b18cd21 100644 --- a/src/library/scala/collection/immutable/Range.scala +++ b/src/library/scala/collection/immutable/Range.scala @@ -259,9 +259,24 @@ extends scala.collection.AbstractSeq[Int] final def contains(x: Int) = isWithinBoundaries(x) && ((x - start) % step == 0) final override def sum[B >: Int](implicit num: Numeric[B]): Int = { - if (isEmpty) 0 - else if (numRangeElements == 1) head - else (numRangeElements.toLong * (head + last) / 2).toInt + if (num eq scala.math.Numeric.IntIsIntegral) { + // this is normal integer range with usual addition. arithmetic series formula can be used + if (isEmpty) 0 + else if (numRangeElements == 1) head + else (numRangeElements.toLong * (head + last) / 2).toInt + } else { + // user provided custom Numeric, we cannot rely on arithmetic series formula + if (isEmpty) num.toInt(num.zero) + else { + var acc = num.zero + var i = head + while(i != terminalElement) { + acc = num.plus(acc, i) + i = i + step + } + num.toInt(acc) + } + } } override def toIterable = this diff --git a/test/files/scalacheck/range.scala b/test/files/scalacheck/range.scala index 6c7c32bfdf..1eb186f303 100644 --- a/test/files/scalacheck/range.scala +++ b/test/files/scalacheck/range.scala @@ -127,6 +127,47 @@ abstract class RangeTest(kind: String) extends Properties("Range "+kind) { (visited == expectedSize(r)) :| str(r) } + property("sum") = forAll(myGen) { r => +// println("----------") +// println("sum "+str(r)) + val rSum = r.sum + val expected = r.length match { + case 0 => 0 + case 1 => r.head + case _ => ((r.head + r.last).toLong * r.length / 2).toInt + } +// println("size: " + r.length) +// println("expected: " + expected) +// println("obtained: " + rSum) + + (rSum == expected) :| str(r) + } + +/* checks that sum respects custom Numeric */ + property("sumCustomNumeric") = forAll(myGen) { r => + val mod = 65536 + object mynum extends Numeric[Int] { + def plus(x: Int, y: Int): Int = (x + y) % mod + override def zero = 0 + + def fromInt(x: Int): Int = ??? + def minus(x: Int, y: Int): Int = ??? + def negate(x: Int): Int = ??? + def times(x: Int, y: Int): Int = ??? + def toDouble(x: Int): Double = ??? + def toFloat(x: Int): Float = ??? + def toInt(x: Int): Int = ((x % mod) + mod * 2) % mod + def toLong(x: Int): Long = ??? + def compare(x: Int, y: Int): Int = ??? + } + + val rSum = r.sum(mynum) + val expected = mynum.toInt(r.sum) + + (rSum == expected) :| str(r) + } + + property("length") = forAll(myGen suchThat (r => expectedSize(r).toInt == expectedSize(r))) { r => // println("length "+str(r)) (r.length == expectedSize(r)) :| str(r) |