summaryrefslogblamecommitdiff
path: root/sources/examples/typeinf.scala
blob: af94f9d3c446f9d1a50a7824a5475006042295c7 (plain) (tree)









































                                                             
                                         





















































































































                                                                                   
 
package examples;

trait Term {}
case class Var(x: String)                   extends Term {}
case class Lam(x: String, e: Term)          extends Term {}
case class App(f: Term, e: Term)            extends Term {}
case class Let(x: String, e: Term, f: Term) extends Term {}

module types {
  trait Type {}
  case class Tyvar(a: String) extends Type {}
  case class Arrow(t1: Type, t2: Type) extends Type {}
  case class Tycon(k: String, ts: List[Type]) extends Type {}
  private var n: Int = 0;
  def newTyvar: Type = { n = n + 1 ; Tyvar("a" + n) }
}
import types._;

case class ListSet[a](elems: List[a]) {

  def contains(y: a): Boolean = elems match {
    case List() => False
    case x :: xs => (x == y) || (xs contains y)
  }

  def union(ys: ListSet[a]): ListSet[a] = elems match {
    case List() => ys
    case x :: xs =>
      if (ys contains x) ListSet(xs) union ys
      else ListSet(x :: (ListSet(xs) union ys).elems)
  }

  def diff(ys: ListSet[a]): ListSet[a] = elems match {
    case List() => ListSet(List())
    case x :: xs =>
      if (ys contains x) ListSet(xs) diff ys
      else ListSet(x :: (ListSet(xs) diff ys).elems)
  }
}

module typeInfer {

  trait Subst with Function1[Type,Type] {
    def lookup(x: Tyvar): Type;
    def apply(t: Type): Type = t match {
      case Tyvar(a) => val u = lookup(Tyvar(a)); if (t == u) t else apply(u);
      case Arrow(t1, t2) => Arrow(apply(t1), apply(t2))
      case Tycon(k, ts) => Tycon(k, ts map apply)
    }
    def extend(x: Tyvar, t: Type) = new Subst {
      def lookup(y: Tyvar): Type = if (x == y) t else Subst.this.lookup(y);
    }
  }

  val emptySubst = new Subst { def lookup(t: Tyvar): Type = t }

  def tyvars(t: Type): ListSet[String] = t match {
    case Tyvar(a) => new ListSet(List(a))
    case Arrow(t1, t2) => tyvars(t1) union tyvars(t2)
    case Tycon(k, ts) => tyvars(ts)
  }
  def tyvars(ts: TypeScheme): ListSet[String] = ts match {
    case TypeScheme(vs, t) => tyvars(t) diff new ListSet(vs)
  }
  def tyvars(ts: List[Type]): ListSet[String] = ts match {
    case List() => new ListSet[String](List())
    case t :: ts1 => tyvars(t) union tyvars(ts1)
  }
  def tyvars(env: Env): ListSet[String] = env match {
    case List() => new ListSet[String](List())
    case Pair(x, t) :: env1 => tyvars(t) union tyvars(env1)
  }

  case class TypeScheme(vs: List[String], t: Type) {
    def newInstance: Type =
      (emptySubst foldl_: vs) { (s, a) => s.extend(Tyvar(a), newTyvar) } (t);
  }

  type Env = List[Pair[String, TypeScheme]];

  def lookup(env: Env, x: String): TypeScheme = env match {
    case List() => null
    case Pair(y, t) :: env1 => if (x == y) t else lookup(env1, x)
  }

  def gen(env: Env, t: Type): TypeScheme =
    TypeScheme((tyvars(t) diff tyvars(env)).elems, t);

  def mgu(t: Type, u: Type)(s: Subst): Subst = Pair(s(t), s(u)) match {
    case Pair(Tyvar( a), Tyvar(b)) if (a == b) =>
      s
    case Pair(Tyvar(a), _) =>
      if (tyvars(u) contains a) error("unification failure: occurs check")
      else s.extend(Tyvar(a), u)
    case Pair(_, Tyvar(a)) =>
      mgu(u, t)(s)
    case Pair(Arrow(t1, t2), Arrow(u1, u2)) =>
      mgu(t1, u1)(mgu(t2, u2)(s))
    case Pair(Tycon(k1, ts), Tycon(k2, us)) if (k1 == k2) =>
      (s foldl_: ((ts zip us) map {case Pair(t,u) => mgu(t,u)})) { (s, f) => f(s) }
    case _ => error("unification failure");
  }

  def tp(env: Env, e: Term, t: Type)(s: Subst): Subst = e match {
    case Var(x) => {
      val u = lookup(env, x);
      if (u == null) error("undefined: x");
      else mgu(u.newInstance, t)(s)
    }
    case Lam(x, e1) => {
      val a = newTyvar, b = newTyvar;
      val s1 = mgu(t, Arrow(a, b))(s);
      val env1 = Pair(x, TypeScheme(List(), a)) :: env;
      tp(env1, e1, b)(s1)
    }
    case App(e1, e2) => {
      val a = newTyvar;
      val s1 = tp(env, e1, Arrow(a, t))(s);
      tp(env, e2, a)(s1)
    }
    case Let(x, e1, e2) => {
      val a = newTyvar;
      val s1 = tp(env, e1, a)(s);
      tp(Pair(x, gen(env, s1(a))) :: env, e2, t)(s1)
    }
  }

  def typeOf(env: Env, e: Term): Type = {
    val a = newTyvar;
    tp(env, e, a)(emptySubst)(a)
  }
}

module predefined {
  val booleanType = Tycon("Boolean", List());
  val intType = Tycon("Int", List());
  def listType(t: Type) = Tycon("List", List(t));

  private def gen(t: Type): typeInfer.TypeScheme = typeInfer.gen(List(), t);
  private val a = newTyvar;
  val env = List(
    Pair("true", gen(booleanType)),
    Pair("false", gen(booleanType)),
    Pair("if", gen(Arrow(booleanType, Arrow(a, Arrow(a, a))))),
    Pair("zero", gen(intType)),
    Pair("succ", gen(Arrow(intType, intType))),
    Pair("nil", gen(listType(a))),
    Pair("cons", gen(Arrow(a, Arrow(listType(a), listType(a))))),
    Pair("isEmpty", gen(Arrow(listType(a), booleanType))),
    Pair("head", gen(Arrow(listType(a), a))),
    Pair("tail", gen(Arrow(listType(a), listType(a)))),
    Pair("fix", gen(Arrow(Arrow(a, a), a)))
  )
}

module test {

  def showType(e: Term) = typeInfer.typeOf(predefined.env, e);

  showType(Lam("x", App(App(Var("cons"), Var("x")), Var("nil"))));

}