diff options
Diffstat (limited to 'sources/examples/typeinf.scala')
-rw-r--r-- | sources/examples/typeinf.scala | 250 |
1 files changed, 158 insertions, 92 deletions
diff --git a/sources/examples/typeinf.scala b/sources/examples/typeinf.scala index 0a3a96819e..77976be656 100644 --- a/sources/examples/typeinf.scala +++ b/sources/examples/typeinf.scala @@ -1,49 +1,41 @@ package examples; trait Term {} -case class Var(x: String) extends Term {} -case class Lam(x: String, e: Term) extends Term {} -case class App(f: Term, e: Term) extends Term {} -case class Let(x: String, e: Term, f: Term) extends Term {} - -object types { - trait Type {} - case class Tyvar(a: String) extends Type {} - case class Arrow(t1: Type, t2: Type) extends Type {} - case class Tycon(k: String, ts: List[Type]) extends Type {} - private var n: Int = 0; - def newTyvar: Type = { n = n + 1 ; Tyvar("a" + n) } -} -import types._; - -case class ListSet[a](elems: List[a]) { - - def contains(y: a): Boolean = elems match { - case List() => false - case x :: xs => (x == y) || (xs contains y) - } - def union(ys: ListSet[a]): ListSet[a] = elems match { - case List() => ys - case x :: xs => - if (ys contains x) ListSet(xs) union ys - else ListSet(x :: (ListSet(xs) union ys).elems) - } +case class Var(x: String) extends Term { + override def toString() = x +} +case class Lam(x: String, e: Term) extends Term { + override def toString() = "(\\" + x + "." + e + ")" +} +case class App(f: Term, e: Term) extends Term { + override def toString() = "(" + f + " " + e + ")" +} +case class Let(x: String, e: Term, f: Term) extends Term { + override def toString() = "let " + x + " = " + e + " in " + f; +} - def diff(ys: ListSet[a]): ListSet[a] = elems match { - case List() => ListSet(List()) - case x :: xs => - if (ys contains x) ListSet(xs) diff ys - else ListSet(x :: (ListSet(xs) diff ys).elems) - } +sealed trait Type {} +case class Tyvar(a: String) extends Type { + override def toString() = a +} +case class Arrow(t1: Type, t2: Type) extends Type { + override def toString() = "(" + t1 + "->" + t2 + ")" +} +case class Tycon(k: String, ts: List[Type]) extends Type { + override def toString() = + k + (if (ts.isEmpty) "" else ts.mkString("[", ",", "]")) } object typeInfer { + private var n: Int = 0; + def newTyvar(): Type = { n = n + 1 ; Tyvar("a" + n) } + trait Subst with Function1[Type,Type] { def lookup(x: Tyvar): Type; def apply(t: Type): Type = t match { - case Tyvar(a) => val u = lookup(Tyvar(a)); if (t == u) t else apply(u); + case tv @ Tyvar(a) => val u = lookup(tv); if (t == u) t else apply(u); case Arrow(t1, t2) => Arrow(apply(t1), apply(t2)) case Tycon(k, ts) => Tycon(k, ts map apply) } @@ -54,26 +46,9 @@ object typeInfer { val emptySubst = new Subst { def lookup(t: Tyvar): Type = t } - def tyvars(t: Type): ListSet[String] = t match { - case Tyvar(a) => new ListSet(List(a)) - case Arrow(t1, t2) => tyvars(t1) union tyvars(t2) - case Tycon(k, ts) => tyvars(ts) - } - def tyvars(ts: TypeScheme): ListSet[String] = ts match { - case TypeScheme(vs, t) => tyvars(t) diff new ListSet(vs) - } - def tyvars(ts: List[Type]): ListSet[String] = ts match { - case List() => new ListSet[String](List()) - case t :: ts1 => tyvars(t) union tyvars(ts1) - } - def tyvars(env: Env): ListSet[String] = env match { - case List() => new ListSet[String](List()) - case Pair(x, t) :: env1 => tyvars(t) union tyvars(env1) - } - - case class TypeScheme(vs: List[String], t: Type) { + case class TypeScheme(tyvars: List[Tyvar], tpe: Type) { def newInstance: Type = - vs.foldLeft(emptySubst) { (s, a) => s.extend(Tyvar(a), newTyvar) } (t); + (emptySubst /: tyvars) ((s, tv) => s.extend(tv, newTyvar())) (tpe); } type Env = List[Pair[String, TypeScheme]]; @@ -84,50 +59,66 @@ object typeInfer { } def gen(env: Env, t: Type): TypeScheme = - TypeScheme((tyvars(t) diff tyvars(env)).elems, t); + TypeScheme(tyvars(t) diff tyvars(env), t); - def mgu(t: Type, u: Type)(s: Subst): Subst = Pair(s(t), s(u)) match { - case Pair(Tyvar( a), Tyvar(b)) if (a == b) => + def tyvars(t: Type): List[Tyvar] = t match { + case tv @ Tyvar(a) => List(tv) + case Arrow(t1, t2) => tyvars(t1) union tyvars(t2) + case Tycon(k, ts) => (List[Tyvar]() /: ts) ((tvs, t) => tvs union tyvars(t)); + } + + def tyvars(ts: TypeScheme): List[Tyvar] = + tyvars(ts.tpe) diff ts.tyvars; + + def tyvars(env: Env): List[Tyvar] = + (List[Tyvar]() /: env) ((tvs, nt) => tvs union tyvars(nt._2)); + + def mgu(t: Type, u: Type, s: Subst): Subst = Pair(s(t), s(u)) match { + case Pair(Tyvar(a), Tyvar(b)) if (a == b) => s - case Pair(Tyvar(a), _) => - if (tyvars(u) contains a) error("unification failure: occurs check") - else s.extend(Tyvar(a), u) + case Pair(Tyvar(a), _) if !(tyvars(u) contains a) => + s.extend(Tyvar(a), u) case Pair(_, Tyvar(a)) => - mgu(u, t)(s) + mgu(u, t, s) case Pair(Arrow(t1, t2), Arrow(u1, u2)) => - mgu(t1, u1)(mgu(t2, u2)(s)) + mgu(t1, u1, mgu(t2, u2, s)) case Pair(Tycon(k1, ts), Tycon(k2, us)) if (k1 == k2) => - ((ts zip us) map {case Pair(t,u) => mgu(t,u)}).foldLeft(s) { (s, f) => f(s) } - case _ => error("unification failure"); + (s /: (ts zip us)) ((s, tu) => mgu(tu._1, tu._2, s)) + case _ => throw new TypeError("cannot unify " + s(t) + " with " + s(u)) } - def tp(env: Env, e: Term, t: Type)(s: Subst): Subst = e match { - case Var(x) => { - val u = lookup(env, x); - if (u == null) error("undefined: x"); - else mgu(u.newInstance, t)(s) - } - case Lam(x, e1) => { - val a = newTyvar, b = newTyvar; - val s1 = mgu(t, Arrow(a, b))(s); - val env1 = Pair(x, TypeScheme(List(), a)) :: env; - tp(env1, e1, b)(s1) - } - case App(e1, e2) => { - val a = newTyvar; - val s1 = tp(env, e1, Arrow(a, t))(s); - tp(env, e2, a)(s1) - } - case Let(x, e1, e2) => { - val a = newTyvar; - val s1 = tp(env, e1, a)(s); - tp(Pair(x, gen(env, s1(a))) :: env, e2, t)(s1) + case class TypeError(s: String) extends Exception(s) {} + + def tp(env: Env, e: Term, t: Type, s: Subst): Subst = { + current = e; + e match { + case Var(x) => + val u = lookup(env, x); + if (u == null) throw new TypeError("undefined: " + x); + else mgu(u.newInstance, t, s) + + case Lam(x, e1) => + val a = newTyvar(), b = newTyvar(); + val s1 = mgu(t, Arrow(a, b), s); + val env1 = Pair(x, TypeScheme(List(), a)) :: env; + tp(env1, e1, b, s1) + + case App(e1, e2) => + val a = newTyvar(); + val s1 = tp(env, e1, Arrow(a, t), s); + tp(env, e2, a, s1) + + case Let(x, e1, e2) => + val a = newTyvar(); + val s1 = tp(env, e1, a, s); + tp(Pair(x, gen(env, s1(a))) :: env, e2, t, s1) } } + var current: Term = null; def typeOf(env: Env, e: Term): Type = { - val a = newTyvar; - tp(env, e, a)(emptySubst)(a) + val a = newTyvar(); + tp(env, e, a, emptySubst)(a) } } @@ -137,7 +128,7 @@ object predefined { def listType(t: Type) = Tycon("List", List(t)); private def gen(t: Type): typeInfer.TypeScheme = typeInfer.gen(List(), t); - private val a = newTyvar; + private val a = typeInfer.newTyvar(); val env = List( /* Pair("true", gen(booleanType)), @@ -155,11 +146,86 @@ object predefined { ) } -object test with Executable { +abstract class MiniMLParsers[intype] extends CharParsers[intype] { + + /** whitespace */ + def whitespace = rep{chr(' ') ||| chr('\t') ||| chr('\n')}; + + /** A given character, possible preceded by whitespace */ + def wschr(ch: char) = whitespace &&& chr(ch); + + /** identifiers or keywords */ + def id: Parser[String] = + for ( + val c: char <- rep(chr(' ')) &&& chr(Character.isLetter); + val cs: List[char] <- rep(chr(Character.isLetterOrDigit)) + ) yield (c :: cs).mkString("", "", ""); + + /** Non-keyword identifiers */ + def ident: Parser[String] = + for (val s <- id; s != "let" && s != "in") yield s; + + /** term = '\' ident '.' term | term1 {term1} | let ident "=" term in term */ + def term: Parser[Term] = + ( for ( + val _ <- wschr('\\'); + val x <- ident; + val _ <- wschr('.'); + val t <- term) + yield Lam(x, t): Term ) + ||| + ( for ( + val letid <- id; letid == "let"; + val x <- ident; + val _ <- wschr('='); + val t <- term; + val inid <- id; inid == "in"; + val c <- term) + yield Let(x, t, c) ) + ||| + ( for ( + val t <- term1; + val ts <- rep(term1)) + yield (t /: ts)((f, arg) => App(f, arg)) ); + + /** term1 = ident | '(' term ')' */ + def term1: Parser[Term] = + ( for (val s <- ident) + yield Var(s): Term ) + ||| + ( for ( + val _ <- wschr('('); + val t <- term; + val _ <- wschr(')')) + yield t ); + + /** all = term ';' */ + def all: Parser[Term] = + for ( + val t <- term; + val _ <- wschr(';')) + yield t; +} + +object testInfer { -// def showType(e: Term) = typeInfer.typeOf(predefined.env, e); - def showType(e: Term) = typeInfer.typeOf(Nil, e); + def showType(e: Term): String = + try { + typeInfer.typeOf(predefined.env, e).toString(); + } catch { + case typeInfer.TypeError(msg) => + "\n cannot type: " + typeInfer.current + + "\n reason: " + msg; + } - Console.println( - showType(Lam("x", App(App(Var("cons"), Var("x")), Var("nil"))))); + def main(args: Array[String]): unit = { + val ps = new MiniMLParsers[int] with ParseString(args(0)); + ps.all(ps.input) match { + case Some(Pair(term, _)) => + System.out.println("" + term + ": " + showType(term)); + case None => + System.out.println("syntax error"); + } + } } + |