summaryrefslogtreecommitdiff
path: root/src/compiler/scala/tools/nsc/matching/PatternBindings.scala
blob: 8a36216e779c8116bc79dc331e88aed537cec1d3 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
/* NSC -- new Scala compiler
 * Copyright 2005-2010 LAMP/EPFL
 * Author: Paul Phillips
 */

package scala.tools.nsc
package matching

import transform.ExplicitOuter
import collection.immutable.TreeMap
import PartialFunction._

trait PatternBindings extends ast.TreeDSL
{
  self: ExplicitOuter with ParallelMatching =>

  import global.{ typer => _, _ }
  import definitions.{ EqualsPatternClass }
  import CODE._
  import Debug._

  /** EqualsPattern **/
  def isEquals(tpe: Type)           = cond(tpe) { case TypeRef(_, EqualsPatternClass, _) => true }
  def mkEqualsRef(tpe: Type)        = typeRef(NoPrefix, EqualsPatternClass, List(tpe))
  def decodedEqualsType(tpe: Type)  = condOpt(tpe) { case TypeRef(_, EqualsPatternClass, List(arg)) => arg } getOrElse (tpe)

  // used as argument to `EqualsPatternClass'
  case class PseudoType(o: Tree) extends SimpleTypeProxy {
    override def underlying: Type = o.tpe
    override def safeToString: String = "PseudoType("+o+")"
  }

  // If the given pattern contains alternatives, return it as a list of patterns.
  // Makes typed copies of any bindings found so all alternatives point to final state.
  def extractBindings(p: Pattern): List[Pattern] =
    toPats(_extractBindings(p.boundTree, identity))

  private def _extractBindings(p: Tree, prevBindings: Tree => Tree): List[Tree] = {
    def newPrev(b: Bind) = (x: Tree) => treeCopy.Bind(b, b.name, x) setType x.tpe

    p match {
      case b @ Bind(_, body)  => _extractBindings(body, newPrev(b))
      case Alternative(ps)    => ps map prevBindings
    }
  }

  trait PatternBindingLogic {
    self: Pattern =>

    // This is for traversing the pattern tree - pattern types which might have
    // bound variables beneath them return a list of said patterns for flatMapping.
    def subpatternsForVars: List[Pattern] = Nil

    private def shallowBoundVariables = strip(boundTree)
    private def otherBoundVariables = subpatternsForVars flatMap (_.deepBoundVariables)

    def deepBoundVariables: List[Symbol] = shallowBoundVariables ::: otherBoundVariables
    // An indiscriminate deep search would be:
    //
    // def deepBoundVariables = deepstrip(boundTree)

    lazy val boundVariables = {
      val res = shallowBoundVariables
      val deep = deepBoundVariables

      if (res.size != deep.size)
        TRACE("deep variable list %s is larger than bound %s", deep, res)

      res
    }

    // XXX only a var for short-term experimentation.
    private var _boundTree: Bind = null
    def boundTree = if (_boundTree == null) tree else _boundTree
    def withBoundTree(x: Bind): this.type = {
      _boundTree = x
      tracing[this.type]("Bound")(this)
    }

    // If a tree has bindings, boundTree looks something like
    //   Bind(v3, Bind(v2, Bind(v1, tree)))
    // This takes the given tree and creates a new pattern
    //   using the same bindings.
    def rebindTo(t: Tree): Pattern = {
      if (boundVariables.size < deepBoundVariables.size)
        TRACE("ALERT: rebinding %s is losing %s", this, otherBoundVariables)

      Pattern(wrapBindings(boundVariables, t))
    }

    // Wrap this pattern's bindings around (_: Type)
    def rebindToType(tpe: Type, annotatedType: Type = null): Pattern = {
      val aType = if (annotatedType == null) tpe else annotatedType
      rebindTo(Typed(WILD(tpe), TypeTree(aType)) setType tpe)
    }

    // Wrap them around _
    def rebindToEmpty(tpe: Type): Pattern =
      rebindTo(Typed(EmptyTree, TypeTree(tpe)) setType tpe)

    // Wrap them around a singleton type for an EqualsPattern check.
    def rebindToEqualsCheck(): Pattern =
      rebindToType(equalsCheck)

    // Like rebindToEqualsCheck, but subtly different.  Not trying to be
    // mysterious -- I haven't sorted it all out yet.
    def rebindToObjectCheck(): Pattern = {
      val sType = sufficientType
      rebindToType(mkEqualsRef(sType), sType)
    }

    /** Helpers **/
    private def wrapBindings(vs: List[Symbol], pat: Tree): Tree = vs match {
      case Nil      => pat
      case x :: xs  => Bind(x, wrapBindings(xs, pat)) setType pat.tpe
    }
    private def strip(t: Tree): List[Symbol] = t match {
      case b @ Bind(_, pat) => b.symbol :: strip(pat)
      case _                => Nil
    }
    private def deepstrip(t: Tree): List[Symbol] =
      treeCollect(t, { case x: Bind => x.symbol })
  }

  case class Binding(pvar: Symbol, tvar: Symbol) {
    override def toString() = pp(pvar -> tvar)
  }

  class Bindings(private val vlist: List[Binding]) {
    // if (!vlist.isEmpty)
    //   traceCategory("Bindings", this.toString)

    def get() = vlist

    def add(vs: Iterable[Symbol], tvar: Symbol): Bindings = {
      val newBindings = vs.toList map (v => Binding(v, tvar))
      new Bindings(newBindings ++ vlist)
    }

    override def toString() =
      if (vlist.isEmpty) "No Bindings"
      else "%d Bindings(%s)".format(vlist.size, pp(vlist))
  }

  val NoBinding: Bindings = new Bindings(Nil)
}