From 0291797fec629c750cefea17ff23c9adc236fec4 Mon Sep 17 00:00:00 2001 From: Aleksandar Prokopec Date: Wed, 6 Jun 2012 18:56:28 +0200 Subject: Fixes SI-5857. Override `min` and `max` in `Range` and `NumericRange` to check if a default `Ordering` for the numeric type in question is used. If so, bypass traversal and compute the minimum or maximum element. --- .../scala/collection/immutable/NumericRange.scala | 30 ++++++++++++++- src/library/scala/collection/immutable/Range.scala | 14 ++++++- test/files/run/t5857.scala | 45 ++++++++++++++++++++++ 3 files changed, 87 insertions(+), 2 deletions(-) create mode 100644 test/files/run/t5857.scala diff --git a/src/library/scala/collection/immutable/NumericRange.scala b/src/library/scala/collection/immutable/NumericRange.scala index 4c82d99c03..5662a11f93 100644 --- a/src/library/scala/collection/immutable/NumericRange.scala +++ b/src/library/scala/collection/immutable/NumericRange.scala @@ -124,7 +124,21 @@ extends AbstractSeq[T] with IndexedSeq[T] with Serializable { if (idx < 0 || idx >= length) throw new IndexOutOfBoundsException(idx.toString) else locationAfterN(idx) } - + + import NumericRange.defaultOrdering + + override def min[T1 >: T](implicit ord: Ordering[T1]): T = + if (ord eq defaultOrdering(num)) { + if (num.signum(step) > 0) start + else last + } else super.min(ord) + + override def max[T1 >: T](implicit ord: Ordering[T1]): T = + if (ord eq defaultOrdering(num)) { + if (num.signum(step) > 0) last + else start + } else super.max(ord) + // Motivated by the desire for Double ranges with BigDecimal precision, // we need some way to map a Range and get another Range. This can't be // done in any fully general way because Ranges are not arbitrary @@ -199,6 +213,7 @@ extends AbstractSeq[T] with IndexedSeq[T] with Serializable { /** A companion object for numeric ranges. */ object NumericRange { + /** Calculates the number of elements in a range given start, end, step, and * whether or not it is inclusive. Throws an exception if step == 0 or * the number of elements exceeds the maximum Int. @@ -257,5 +272,18 @@ object NumericRange { new Exclusive(start, end, step) def inclusive[T](start: T, end: T, step: T)(implicit num: Integral[T]): Inclusive[T] = new Inclusive(start, end, step) + + private[collection] val defaultOrdering = Map[Numeric[_], Ordering[_]]( + Numeric.BigIntIsIntegral -> Ordering.BigInt, + Numeric.IntIsIntegral -> Ordering.Int, + Numeric.ShortIsIntegral -> Ordering.Short, + Numeric.ByteIsIntegral -> Ordering.Byte, + Numeric.CharIsIntegral -> Ordering.Char, + Numeric.LongIsIntegral -> Ordering.Long, + Numeric.FloatAsIfIntegral -> Ordering.Float, + Numeric.DoubleAsIfIntegral -> Ordering.Double, + Numeric.BigDecimalAsIfIntegral -> Ordering.BigDecimal + ) + } diff --git a/src/library/scala/collection/immutable/Range.scala b/src/library/scala/collection/immutable/Range.scala index 033331b58b..7607837491 100644 --- a/src/library/scala/collection/immutable/Range.scala +++ b/src/library/scala/collection/immutable/Range.scala @@ -78,7 +78,19 @@ extends collection.AbstractSeq[Int] final val terminalElement = start + numRangeElements * step override def last = if (isEmpty) Nil.last else lastElement - + + override def min[A1 >: Int](implicit ord: Ordering[A1]): Int = + if (ord eq Ordering.Int) { + if (step > 0) start + else last + } else super.min(ord) + + override def max[A1 >: Int](implicit ord: Ordering[A1]): Int = + if (ord eq Ordering.Int) { + if (step > 0) last + else start + } else super.max(ord) + protected def copy(start: Int, end: Int, step: Int): Range = new Range(start, end, step) /** Create a new range with the `start` and `end` values of this range and diff --git a/test/files/run/t5857.scala b/test/files/run/t5857.scala new file mode 100644 index 0000000000..bf67bedf54 --- /dev/null +++ b/test/files/run/t5857.scala @@ -0,0 +1,45 @@ + + + +object Test { + + def time[U](b: =>U): Long = { + val start = System.currentTimeMillis + b + val end = System.currentTimeMillis + + end - start + } + + def main(args: Array[String]) { + val sz = 1000000000 + + val range = 1 to sz + check { assert(range.min == 1, range.min) } + check { assert(range.max == sz, range.max) } + + val descending = sz to 1 by -1 + check { assert(descending.min == 1) } + check { assert(descending.max == sz) } + + val numeric = 1.0 to sz.toDouble by 1 + check { assert(numeric.min == 1.0) } + check { assert(numeric.max == sz.toDouble) } + + val numdesc = sz.toDouble to 1.0 by -1 + check { assert(numdesc.min == 1.0) } + check { assert(numdesc.max == sz.toDouble) } + } + + def check[U](b: =>U) { + val exectime = time { + b + } + + // whatever it is, it should be less than, say, 250ms + // if `max` involves traversal, it takes over 5 seconds on a 3.2GHz i7 CPU + //println(exectime) + assert(exectime < 250, exectime) + } + +} -- cgit v1.2.3