summaryrefslogtreecommitdiff
path: root/sources/examples/typeinf.scala
blob: af94f9d3c446f9d1a50a7824a5475006042295c7 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
package examples;

trait Term {}
case class Var(x: String)                   extends Term {}
case class Lam(x: String, e: Term)          extends Term {}
case class App(f: Term, e: Term)            extends Term {}
case class Let(x: String, e: Term, f: Term) extends Term {}

module types {
  trait Type {}
  case class Tyvar(a: String) extends Type {}
  case class Arrow(t1: Type, t2: Type) extends Type {}
  case class Tycon(k: String, ts: List[Type]) extends Type {}
  private var n: Int = 0;
  def newTyvar: Type = { n = n + 1 ; Tyvar("a" + n) }
}
import types._;

case class ListSet[a](elems: List[a]) {

  def contains(y: a): Boolean = elems match {
    case List() => False
    case x :: xs => (x == y) || (xs contains y)
  }

  def union(ys: ListSet[a]): ListSet[a] = elems match {
    case List() => ys
    case x :: xs =>
      if (ys contains x) ListSet(xs) union ys
      else ListSet(x :: (ListSet(xs) union ys).elems)
  }

  def diff(ys: ListSet[a]): ListSet[a] = elems match {
    case List() => ListSet(List())
    case x :: xs =>
      if (ys contains x) ListSet(xs) diff ys
      else ListSet(x :: (ListSet(xs) diff ys).elems)
  }
}

module typeInfer {

  trait Subst with Function1[Type,Type] {
    def lookup(x: Tyvar): Type;
    def apply(t: Type): Type = t match {
      case Tyvar(a) => val u = lookup(Tyvar(a)); if (t == u) t else apply(u);
      case Arrow(t1, t2) => Arrow(apply(t1), apply(t2))
      case Tycon(k, ts) => Tycon(k, ts map apply)
    }
    def extend(x: Tyvar, t: Type) = new Subst {
      def lookup(y: Tyvar): Type = if (x == y) t else Subst.this.lookup(y);
    }
  }

  val emptySubst = new Subst { def lookup(t: Tyvar): Type = t }

  def tyvars(t: Type): ListSet[String] = t match {
    case Tyvar(a) => new ListSet(List(a))
    case Arrow(t1, t2) => tyvars(t1) union tyvars(t2)
    case Tycon(k, ts) => tyvars(ts)
  }
  def tyvars(ts: TypeScheme): ListSet[String] = ts match {
    case TypeScheme(vs, t) => tyvars(t) diff new ListSet(vs)
  }
  def tyvars(ts: List[Type]): ListSet[String] = ts match {
    case List() => new ListSet[String](List())
    case t :: ts1 => tyvars(t) union tyvars(ts1)
  }
  def tyvars(env: Env): ListSet[String] = env match {
    case List() => new ListSet[String](List())
    case Pair(x, t) :: env1 => tyvars(t) union tyvars(env1)
  }

  case class TypeScheme(vs: List[String], t: Type) {
    def newInstance: Type =
      (emptySubst foldl_: vs) { (s, a) => s.extend(Tyvar(a), newTyvar) } (t);
  }

  type Env = List[Pair[String, TypeScheme]];

  def lookup(env: Env, x: String): TypeScheme = env match {
    case List() => null
    case Pair(y, t) :: env1 => if (x == y) t else lookup(env1, x)
  }

  def gen(env: Env, t: Type): TypeScheme =
    TypeScheme((tyvars(t) diff tyvars(env)).elems, t);

  def mgu(t: Type, u: Type)(s: Subst): Subst = Pair(s(t), s(u)) match {
    case Pair(Tyvar( a), Tyvar(b)) if (a == b) =>
      s
    case Pair(Tyvar(a), _) =>
      if (tyvars(u) contains a) error("unification failure: occurs check")
      else s.extend(Tyvar(a), u)
    case Pair(_, Tyvar(a)) =>
      mgu(u, t)(s)
    case Pair(Arrow(t1, t2), Arrow(u1, u2)) =>
      mgu(t1, u1)(mgu(t2, u2)(s))
    case Pair(Tycon(k1, ts), Tycon(k2, us)) if (k1 == k2) =>
      (s foldl_: ((ts zip us) map {case Pair(t,u) => mgu(t,u)})) { (s, f) => f(s) }
    case _ => error("unification failure");
  }

  def tp(env: Env, e: Term, t: Type)(s: Subst): Subst = e match {
    case Var(x) => {
      val u = lookup(env, x);
      if (u == null) error("undefined: x");
      else mgu(u.newInstance, t)(s)
    }
    case Lam(x, e1) => {
      val a = newTyvar, b = newTyvar;
      val s1 = mgu(t, Arrow(a, b))(s);
      val env1 = Pair(x, TypeScheme(List(), a)) :: env;
      tp(env1, e1, b)(s1)
    }
    case App(e1, e2) => {
      val a = newTyvar;
      val s1 = tp(env, e1, Arrow(a, t))(s);
      tp(env, e2, a)(s1)
    }
    case Let(x, e1, e2) => {
      val a = newTyvar;
      val s1 = tp(env, e1, a)(s);
      tp(Pair(x, gen(env, s1(a))) :: env, e2, t)(s1)
    }
  }

  def typeOf(env: Env, e: Term): Type = {
    val a = newTyvar;
    tp(env, e, a)(emptySubst)(a)
  }
}

module predefined {
  val booleanType = Tycon("Boolean", List());
  val intType = Tycon("Int", List());
  def listType(t: Type) = Tycon("List", List(t));

  private def gen(t: Type): typeInfer.TypeScheme = typeInfer.gen(List(), t);
  private val a = newTyvar;
  val env = List(
    Pair("true", gen(booleanType)),
    Pair("false", gen(booleanType)),
    Pair("if", gen(Arrow(booleanType, Arrow(a, Arrow(a, a))))),
    Pair("zero", gen(intType)),
    Pair("succ", gen(Arrow(intType, intType))),
    Pair("nil", gen(listType(a))),
    Pair("cons", gen(Arrow(a, Arrow(listType(a), listType(a))))),
    Pair("isEmpty", gen(Arrow(listType(a), booleanType))),
    Pair("head", gen(Arrow(listType(a), a))),
    Pair("tail", gen(Arrow(listType(a), listType(a)))),
    Pair("fix", gen(Arrow(Arrow(a, a), a)))
  )
}

module test {

  def showType(e: Term) = typeInfer.typeOf(predefined.env, e);

  showType(Lam("x", App(App(Var("cons"), Var("x")), Var("nil"))));

}