summaryrefslogtreecommitdiff
path: root/core/src/main/scala/mill/util/OSet.scala
blob: a4fcc406cd3c41c7044d7a2c63f82155da845947 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
package mill.util



import scala.collection.mutable
object Strict extends OSetWrapper(true)
object Loose extends OSetWrapper(false)
sealed class OSetWrapper(strictUniqueness: Boolean){
  /**
    * A collection with enforced uniqueness, fast contains and deterministic
    * ordering. Raises an exception if a duplicate is found; call
    * `toSeq.distinct` if you explicitly want to make it swallow duplicates
    */
  trait OSet[V] extends TraversableOnce[V]{
    def contains(v: V): Boolean
    def items: Iterator[V]
    def indexed: IndexedSeq[V]
    def flatMap[T](f: V => TraversableOnce[T]): OSet[T]
    def map[T](f: V => T): OSet[T]
    def filter(f: V => Boolean): OSet[V]
    def withFilter(f: V => Boolean): OSet[V]
    def collect[T](f: PartialFunction[V, T]): OSet[T]
    def zipWithIndex: OSet[(V, Int)]
    def reverse: OSet[V]
    def zip[T](other: OSet[T]): OSet[(V, T)]
    def ++[T >: V](other: TraversableOnce[T]): OSet[T]
  }

  object OSet{
    def empty[V]: OSet[V] = new OSet.Mutable[V]
    implicit def jsonFormat[T: upickle.default.ReadWriter]: upickle.default.ReadWriter[OSet[T]] =
      upickle.default.ReadWriter[OSet[T]] (
        oset => upickle.default.writeJs(oset.toList),
        {case json => OSet.from(upickle.default.readJs[Seq[T]](json))}
      )
    def apply[V](items: V*) = from(items)

    def from[V](items: TraversableOnce[V]): OSet[V] = {
      val set = new OSet.Mutable[V]()
      items.foreach(set.append)
      set
    }


    class Mutable[V]() extends OSet[V]{

      private[this] val set0 = mutable.LinkedHashSet.empty[V]
      def contains(v: V) = set0.contains(v)
      def append(v: V) = if (!contains(v)){
        set0.add(v)

      }else if (strictUniqueness){
        throw new Exception("Duplicated item inserted into OrderedSet: " + v)
      }
      def appendAll(vs: Seq[V]) = vs.foreach(append)
      def items = set0.iterator
      def indexed: IndexedSeq[V] = items.toIndexedSeq
      def set: collection.Set[V] = set0

      def map[T](f: V => T): OSet[T] = {
        val output = new OSet.Mutable[T]
        for(i <- items) output.append(f(i))
        output
      }
      def flatMap[T](f: V => TraversableOnce[T]): OSet[T] = {
        val output = new OSet.Mutable[T]
        for(i <- items) for(i0 <- f(i)) output.append(i0)
        output
      }
      def filter(f: V => Boolean): OSet[V] = {
        val output = new OSet.Mutable[V]
        for(i <- items) if (f(i)) output.append(i)
        output
      }
      def withFilter(f: V => Boolean): OSet[V] = filter(f)

      def collect[T](f: PartialFunction[V, T]) = this.filter(f.isDefinedAt).map(x => f(x))

      def zipWithIndex = {
        var i = 0
        this.map{ x =>
          i += 1
          (x, i-1)
        }
      }

      def reverse = OSet.from(indexed.reverseIterator)

      def zip[T](other: OSet[T]) = OSet.from(items.zip(other.items))
      def ++[T >: V](other: TraversableOnce[T]) = OSet.from(items ++ other)

      // Members declared in scala.collection.GenTraversableOnce
      def isTraversableAgain: Boolean = items.isTraversableAgain
      def toIterator: Iterator[V] = items.toIterator
      def toStream: Stream[V] = items.toStream

      // Members declared in scala.collection.TraversableOnce
      def copyToArray[B >: V](xs: Array[B], start: Int,len: Int): Unit = items.copyToArray(xs, start, len)
      def exists(p: V => Boolean): Boolean = items.exists(p)
      def find(p: V => Boolean): Option[V] = items.find(p)
      def forall(p: V => Boolean): Boolean = items.forall(p)
      def foreach[U](f: V => U): Unit = items.foreach(f)
      def hasDefiniteSize: Boolean = items.hasDefiniteSize
      def isEmpty: Boolean = items.isEmpty
      def seq: scala.collection.TraversableOnce[V] = items
      def toTraversable: Traversable[V] = items.toTraversable

      override def hashCode() = items.map(_.hashCode()).sum
      override def equals(other: Any) = other match{
        case s: OSet[_] => items.sameElements(s.items)
        case _ => super.equals(other)
      }
      override def toString = items.mkString("OSet(", ", ", ")")
    }
  }
}