summaryrefslogtreecommitdiff
path: root/sources/examples/typeinf.scala
diff options
context:
space:
mode:
Diffstat (limited to 'sources/examples/typeinf.scala')
-rw-r--r--sources/examples/typeinf.scala162
1 files changed, 162 insertions, 0 deletions
diff --git a/sources/examples/typeinf.scala b/sources/examples/typeinf.scala
new file mode 100644
index 0000000000..5ff6f03f50
--- /dev/null
+++ b/sources/examples/typeinf.scala
@@ -0,0 +1,162 @@
+package examples;
+
+trait Term {}
+case class Var(x: String) extends Term {}
+case class Lam(x: String, e: Term) extends Term {}
+case class App(f: Term, e: Term) extends Term {}
+case class Let(x: String, e: Term, f: Term) extends Term {}
+
+module types {
+ trait Type {}
+ case class Tyvar(a: String) extends Type {}
+ case class Arrow(t1: Type, t2: Type) extends Type {}
+ case class Tycon(k: String, ts: List[Type]) extends Type {}
+ private var n: Int = 0;
+ def newTyvar: Type = { n = n + 1 ; Tyvar("a" + n) }
+}
+import types._;
+
+case class ListSet[a](elems: List[a]) {
+
+ def contains(y: a): Boolean = elems match {
+ case List() => False
+ case x :: xs => (x == y) || (xs contains y)
+ }
+
+ def union(ys: ListSet[a]): ListSet[a] = elems match {
+ case List() => ys
+ case x :: xs =>
+ if (ys contains x) ListSet(xs) union ys
+ else ListSet(x :: (ListSet(xs) union ys).elems)
+ }
+
+ def diff(ys: ListSet[a]): ListSet[a] = elems match {
+ case List() => ListSet(List())
+ case x :: xs =>
+ if (ys contains x) ListSet(xs) diff ys
+ else ListSet(x :: (ListSet(xs) diff ys).elems)
+ }
+}
+
+module typeInfer {
+
+ trait Subst extends Function1[Type,Type] {
+ def lookup(x: Tyvar): Type;
+ def apply(t: Type): Type = t match {
+ case Tyvar(a) => val u = lookup(Tyvar(a)); if (t == u) t else apply(u);
+ case Arrow(t1, t2) => Arrow(apply(t1), apply(t2))
+ case Tycon(k, ts) => Tycon(k, ts map apply)
+ }
+ def extend(x: Tyvar, t: Type) = new Subst {
+ def lookup(y: Tyvar): Type = if (x == y) t else Subst.this.lookup(y);
+ }
+ }
+
+ val emptySubst = new Subst { def lookup(t: Tyvar): Type = t }
+
+ def tyvars(t: Type): ListSet[String] = t match {
+ case Tyvar(a) => new ListSet(List(a))
+ case Arrow(t1, t2) => tyvars(t1) union tyvars(t2)
+ case Tycon(k, ts) => tyvars(ts)
+ }
+ def tyvars(ts: TypeScheme): ListSet[String] = ts match {
+ case TypeScheme(vs, t) => tyvars(t) diff new ListSet(vs)
+ }
+ def tyvars(ts: List[Type]): ListSet[String] = ts match {
+ case List() => new ListSet[String](List())
+ case t :: ts1 => tyvars(t) union tyvars(ts1)
+ }
+ def tyvars(env: Env): ListSet[String] = env match {
+ case List() => new ListSet[String](List())
+ case Pair(x, t) :: env1 => tyvars(t) union tyvars(env1)
+ }
+
+ case class TypeScheme(vs: List[String], t: Type) {
+ def newInstance: Type =
+ (emptySubst foldl_: vs) { (s, a) => s.extend(Tyvar(a), newTyvar) } (t);
+ }
+
+ type Env = List[Pair[String, TypeScheme]];
+
+ def lookup(env: Env, x: String): TypeScheme = env match {
+ case List() => null
+ case Pair(y, t) :: env1 => if (x == y) t else lookup(env1, x)
+ }
+
+ def gen(env: Env, t: Type): TypeScheme =
+ TypeScheme((tyvars(t) diff tyvars(env)).elems, t);
+
+ def mgu(t: Type, u: Type)(s: Subst): Subst = Pair(s(t), s(u)) match {
+ case Pair(Tyvar( a), Tyvar(b)) if (a == b) =>
+ s
+ case Pair(Tyvar(a), _) =>
+ if (tyvars(u) contains a) error("unification failure: occurs check")
+ else s.extend(Tyvar(a), u)
+ case Pair(_, Tyvar(a)) =>
+ mgu(u, t)(s)
+ case Pair(Arrow(t1, t2), Arrow(u1, u2)) =>
+ mgu(t1, u1)(mgu(t2, u2)(s))
+ case Pair(Tycon(k1, ts), Tycon(k2, us)) if (k1 == k2) =>
+ (s foldl_: ((ts zip us) map {case Pair(t,u) => mgu(t,u)})) { (s, f) => f(s) }
+ case _ => error("unification failure");
+ }
+
+ def tp(env: Env, e: Term, t: Type)(s: Subst): Subst = e match {
+ case Var(x) => {
+ val u = lookup(env, x);
+ if (u == null) error("undefined: x");
+ else mgu(u.newInstance, t)(s)
+ }
+ case Lam(x, e1) => {
+ val a = newTyvar, b = newTyvar;
+ val s1 = mgu(t, Arrow(a, b))(s);
+ val env1 = Pair(x, TypeScheme(List(), a)) :: env;
+ tp(env1, e1, b)(s1)
+ }
+ case App(e1, e2) => {
+ val a = newTyvar;
+ val s1 = tp(env, e1, Arrow(a, t))(s);
+ tp(env, e2, a)(s1)
+ }
+ case Let(x, e1, e2) => {
+ val a = newTyvar;
+ val s1 = tp(env, e1, a)(s);
+ tp(Pair(x, gen(env, s1(a))) :: env, e2, t)(s1)
+ }
+ }
+
+ def typeOf(env: Env, e: Term): Type = {
+ val a = newTyvar;
+ tp(env, e, a)(emptySubst)(a)
+ }
+}
+
+module predefined {
+ val booleanType = Tycon("Boolean", List());
+ val intType = Tycon("Int", List());
+ def listType(t: Type) = Tycon("List", List(t));
+
+ private def gen(t: Type): typeInfer.TypeScheme = typeInfer.gen(List(), t);
+ private val a = newTyvar;
+ val env = List(
+ Pair("true", gen(booleanType)),
+ Pair("false", gen(booleanType)),
+ Pair("if", gen(Arrow(booleanType, Arrow(a, Arrow(a, a))))),
+ Pair("zero", gen(intType)),
+ Pair("succ", gen(Arrow(intType, intType))),
+ Pair("nil", gen(listType(a))),
+ Pair("cons", gen(Arrow(a, Arrow(listType(a), listType(a))))),
+ Pair("isEmpty", gen(Arrow(listType(a), booleanType))),
+ Pair("head", gen(Arrow(listType(a), a))),
+ Pair("tail", gen(Arrow(listType(a), listType(a)))),
+ Pair("fix", gen(Arrow(Arrow(a, a), a)))
+ )
+}
+
+module test {
+
+ def showType(e: Term) = typeInfer.typeOf(predefined.env, e);
+
+ showType(Lam("x", App(App(Var("cons"), Var("x")), Var("nil"))));
+
+} \ No newline at end of file