diff options
Diffstat (limited to 'sources/examples/typeinf.scala')
-rw-r--r-- | sources/examples/typeinf.scala | 162 |
1 files changed, 162 insertions, 0 deletions
diff --git a/sources/examples/typeinf.scala b/sources/examples/typeinf.scala new file mode 100644 index 0000000000..5ff6f03f50 --- /dev/null +++ b/sources/examples/typeinf.scala @@ -0,0 +1,162 @@ +package examples; + +trait Term {} +case class Var(x: String) extends Term {} +case class Lam(x: String, e: Term) extends Term {} +case class App(f: Term, e: Term) extends Term {} +case class Let(x: String, e: Term, f: Term) extends Term {} + +module types { + trait Type {} + case class Tyvar(a: String) extends Type {} + case class Arrow(t1: Type, t2: Type) extends Type {} + case class Tycon(k: String, ts: List[Type]) extends Type {} + private var n: Int = 0; + def newTyvar: Type = { n = n + 1 ; Tyvar("a" + n) } +} +import types._; + +case class ListSet[a](elems: List[a]) { + + def contains(y: a): Boolean = elems match { + case List() => False + case x :: xs => (x == y) || (xs contains y) + } + + def union(ys: ListSet[a]): ListSet[a] = elems match { + case List() => ys + case x :: xs => + if (ys contains x) ListSet(xs) union ys + else ListSet(x :: (ListSet(xs) union ys).elems) + } + + def diff(ys: ListSet[a]): ListSet[a] = elems match { + case List() => ListSet(List()) + case x :: xs => + if (ys contains x) ListSet(xs) diff ys + else ListSet(x :: (ListSet(xs) diff ys).elems) + } +} + +module typeInfer { + + trait Subst extends Function1[Type,Type] { + def lookup(x: Tyvar): Type; + def apply(t: Type): Type = t match { + case Tyvar(a) => val u = lookup(Tyvar(a)); if (t == u) t else apply(u); + case Arrow(t1, t2) => Arrow(apply(t1), apply(t2)) + case Tycon(k, ts) => Tycon(k, ts map apply) + } + def extend(x: Tyvar, t: Type) = new Subst { + def lookup(y: Tyvar): Type = if (x == y) t else Subst.this.lookup(y); + } + } + + val emptySubst = new Subst { def lookup(t: Tyvar): Type = t } + + def tyvars(t: Type): ListSet[String] = t match { + case Tyvar(a) => new ListSet(List(a)) + case Arrow(t1, t2) => tyvars(t1) union tyvars(t2) + case Tycon(k, ts) => tyvars(ts) + } + def tyvars(ts: TypeScheme): ListSet[String] = ts match { + case TypeScheme(vs, t) => tyvars(t) diff new ListSet(vs) + } + def tyvars(ts: List[Type]): ListSet[String] = ts match { + case List() => new ListSet[String](List()) + case t :: ts1 => tyvars(t) union tyvars(ts1) + } + def tyvars(env: Env): ListSet[String] = env match { + case List() => new ListSet[String](List()) + case Pair(x, t) :: env1 => tyvars(t) union tyvars(env1) + } + + case class TypeScheme(vs: List[String], t: Type) { + def newInstance: Type = + (emptySubst foldl_: vs) { (s, a) => s.extend(Tyvar(a), newTyvar) } (t); + } + + type Env = List[Pair[String, TypeScheme]]; + + def lookup(env: Env, x: String): TypeScheme = env match { + case List() => null + case Pair(y, t) :: env1 => if (x == y) t else lookup(env1, x) + } + + def gen(env: Env, t: Type): TypeScheme = + TypeScheme((tyvars(t) diff tyvars(env)).elems, t); + + def mgu(t: Type, u: Type)(s: Subst): Subst = Pair(s(t), s(u)) match { + case Pair(Tyvar( a), Tyvar(b)) if (a == b) => + s + case Pair(Tyvar(a), _) => + if (tyvars(u) contains a) error("unification failure: occurs check") + else s.extend(Tyvar(a), u) + case Pair(_, Tyvar(a)) => + mgu(u, t)(s) + case Pair(Arrow(t1, t2), Arrow(u1, u2)) => + mgu(t1, u1)(mgu(t2, u2)(s)) + case Pair(Tycon(k1, ts), Tycon(k2, us)) if (k1 == k2) => + (s foldl_: ((ts zip us) map {case Pair(t,u) => mgu(t,u)})) { (s, f) => f(s) } + case _ => error("unification failure"); + } + + def tp(env: Env, e: Term, t: Type)(s: Subst): Subst = e match { + case Var(x) => { + val u = lookup(env, x); + if (u == null) error("undefined: x"); + else mgu(u.newInstance, t)(s) + } + case Lam(x, e1) => { + val a = newTyvar, b = newTyvar; + val s1 = mgu(t, Arrow(a, b))(s); + val env1 = Pair(x, TypeScheme(List(), a)) :: env; + tp(env1, e1, b)(s1) + } + case App(e1, e2) => { + val a = newTyvar; + val s1 = tp(env, e1, Arrow(a, t))(s); + tp(env, e2, a)(s1) + } + case Let(x, e1, e2) => { + val a = newTyvar; + val s1 = tp(env, e1, a)(s); + tp(Pair(x, gen(env, s1(a))) :: env, e2, t)(s1) + } + } + + def typeOf(env: Env, e: Term): Type = { + val a = newTyvar; + tp(env, e, a)(emptySubst)(a) + } +} + +module predefined { + val booleanType = Tycon("Boolean", List()); + val intType = Tycon("Int", List()); + def listType(t: Type) = Tycon("List", List(t)); + + private def gen(t: Type): typeInfer.TypeScheme = typeInfer.gen(List(), t); + private val a = newTyvar; + val env = List( + Pair("true", gen(booleanType)), + Pair("false", gen(booleanType)), + Pair("if", gen(Arrow(booleanType, Arrow(a, Arrow(a, a))))), + Pair("zero", gen(intType)), + Pair("succ", gen(Arrow(intType, intType))), + Pair("nil", gen(listType(a))), + Pair("cons", gen(Arrow(a, Arrow(listType(a), listType(a))))), + Pair("isEmpty", gen(Arrow(listType(a), booleanType))), + Pair("head", gen(Arrow(listType(a), a))), + Pair("tail", gen(Arrow(listType(a), listType(a)))), + Pair("fix", gen(Arrow(Arrow(a, a), a))) + ) +} + +module test { + + def showType(e: Term) = typeInfer.typeOf(predefined.env, e); + + showType(Lam("x", App(App(Var("cons"), Var("x")), Var("nil")))); + +}
\ No newline at end of file |