diff options
-rw-r--r-- | src/library/scala/collection/MapLike.scala | 28 | ||||
-rw-r--r-- | src/library/scala/collection/mutable/LinkedHashMap.scala | 21 | ||||
-rw-r--r-- | test/files/run/t4954.scala | 45 |
3 files changed, 81 insertions, 13 deletions
diff --git a/src/library/scala/collection/MapLike.scala b/src/library/scala/collection/MapLike.scala index b9b8f62574..75f9ff93db 100644 --- a/src/library/scala/collection/MapLike.scala +++ b/src/library/scala/collection/MapLike.scala @@ -227,30 +227,34 @@ self => def default(key: A): B = throw new NoSuchElementException("key not found: " + key) - /** Filters this map by retaining only keys satisfying a predicate. - * @param p the predicate used to test keys - * @return an immutable map consisting only of those key value pairs of this map where the key satisfies - * the predicate `p`. The resulting map wraps the original map without copying any elements. - */ - def filterKeys(p: A => Boolean): Map[A, B] = new AbstractMap[A, B] with DefaultMap[A, B] { + protected class FilteredKeys(p: A => Boolean) extends AbstractMap[A, B] with DefaultMap[A, B] { override def foreach[C](f: ((A, B)) => C): Unit = for (kv <- self) if (p(kv._1)) f(kv) def iterator = self.iterator.filter(kv => p(kv._1)) override def contains(key: A) = self.contains(key) && p(key) def get(key: A) = if (!p(key)) None else self.get(key) } - - /** Transforms this map by applying a function to every retrieved value. - * @param f the function used to transform values of this map. - * @return a map view which maps every key of this map - * to `f(this(key))`. The resulting map wraps the original map without copying any elements. + + /** Filters this map by retaining only keys satisfying a predicate. + * @param p the predicate used to test keys + * @return an immutable map consisting only of those key value pairs of this map where the key satisfies + * the predicate `p`. The resulting map wraps the original map without copying any elements. */ - def mapValues[C](f: B => C): Map[A, C] = new AbstractMap[A, C] with DefaultMap[A, C] { + def filterKeys(p: A => Boolean): Map[A, B] = new FilteredKeys(p) + + protected class MappedValues[C](f: B => C) extends AbstractMap[A, C] with DefaultMap[A, C] { override def foreach[D](g: ((A, C)) => D): Unit = for ((k, v) <- self) g((k, f(v))) def iterator = for ((k, v) <- self.iterator) yield (k, f(v)) override def size = self.size override def contains(key: A) = self.contains(key) def get(key: A) = self.get(key).map(f) } + + /** Transforms this map by applying a function to every retrieved value. + * @param f the function used to transform values of this map. + * @return a map view which maps every key of this map + * to `f(this(key))`. The resulting map wraps the original map without copying any elements. + */ + def mapValues[C](f: B => C): Map[A, C] = new MappedValues(f) // The following 5 operations (updated, two times +, two times ++) should really be // generic, returning This[B]. We need better covariance support to express that though. diff --git a/src/library/scala/collection/mutable/LinkedHashMap.scala b/src/library/scala/collection/mutable/LinkedHashMap.scala index 4150cf9eba..5643e070f8 100644 --- a/src/library/scala/collection/mutable/LinkedHashMap.scala +++ b/src/library/scala/collection/mutable/LinkedHashMap.scala @@ -49,7 +49,8 @@ class LinkedHashMap[A, B] extends AbstractMap[A, B] with Map[A, B] with MapLike[A, B, LinkedHashMap[A, B]] with HashTable[A, LinkedEntry[A, B]] - with Serializable { + with Serializable +{ override def empty = LinkedHashMap.empty[A, B] override def size = tableSize @@ -107,7 +108,25 @@ class LinkedHashMap[A, B] extends AbstractMap[A, B] if (hasNext) { val res = (cur.key, cur.value); cur = cur.later; res } else Iterator.empty.next } + + protected class FilteredKeys(p: A => Boolean) extends super.FilteredKeys(p) { + override def empty = LinkedHashMap.empty + } + + override def filterKeys(p: A => Boolean): scala.collection.Map[A, B] = new FilteredKeys(p) + protected class MappedValues[C](f: B => C) extends super.MappedValues[C](f) { + override def empty = LinkedHashMap.empty + } + + override def mapValues[C](f: B => C): scala.collection.Map[A, C] = new MappedValues(f) + + protected class DefaultKeySet extends super.DefaultKeySet { + override def empty = LinkedHashSet.empty + } + + override def keySet: scala.collection.Set[A] = new DefaultKeySet + override def keysIterator: Iterator[A] = new AbstractIterator[A] { private var cur = firstEntry def hasNext = cur ne null diff --git a/test/files/run/t4954.scala b/test/files/run/t4954.scala new file mode 100644 index 0000000000..b4916e651d --- /dev/null +++ b/test/files/run/t4954.scala @@ -0,0 +1,45 @@ + + +import collection._ + + +object Test { + + def main(args: Array[String]) { + val m = scala.collection.mutable.LinkedHashMap("one" -> 1, "two" -> 2, "three" -> 3, "four" -> 4, "five" -> 5) + val expected = List("one", "two", "three", "four", "five") + assert(m.keys.iterator.toList == expected) + assert(m.keys.drop(0).iterator.toList == expected) + assert(m.keys.drop(1).iterator.toList == expected.drop(1)) + assert(m.keys.drop(2).iterator.toList == expected.drop(2)) + assert(m.keys.drop(3).iterator.toList == expected.drop(3)) + assert(m.keys.drop(4).iterator.toList == expected.drop(4)) + assert(m.keys.drop(5).iterator.toList == expected.drop(5)) + + val expvals = List(1, 2, 3, 4, 5) + assert(m.values.iterator.toList == expvals) + assert(m.values.drop(0).iterator.toList == expvals) + assert(m.values.drop(1).iterator.toList == expvals.drop(1)) + assert(m.values.drop(2).iterator.toList == expvals.drop(2)) + assert(m.values.drop(3).iterator.toList == expvals.drop(3)) + assert(m.values.drop(4).iterator.toList == expvals.drop(4)) + assert(m.values.drop(5).iterator.toList == expvals.drop(5)) + + val pred = (x: String) => x.length < 6 + val filtered = m.filterKeys(pred) + assert(filtered.drop(0).keys.toList == expected.filter(pred)) + assert(filtered.drop(1).keys.toList == expected.filter(pred).drop(1)) + assert(filtered.drop(2).keys.toList == expected.filter(pred).drop(2)) + assert(filtered.drop(3).keys.toList == expected.filter(pred).drop(3)) + assert(filtered.drop(4).keys.toList == expected.filter(pred).drop(4)) + + val mapped = m.mapValues(-_) + assert(mapped.drop(0).keys.toList == expected) + assert(mapped.drop(1).keys.toList == expected.drop(1)) + assert(mapped.drop(2).keys.toList == expected.drop(2)) + assert(mapped.drop(3).keys.toList == expected.drop(3)) + assert(mapped.drop(4).keys.toList == expected.drop(4)) + assert(mapped.drop(5).keys.toList == expected.drop(5)) + } + +} |