-- Hindley-Milner type inference using algorithm J. -- Base on: -- https://en.wikipedia.org/wiki/Hindley%E2%80%93Milner_type_system#Algorithm_J -- https://github.com/jfecher/algorithm-j/blob/master/j.ml local monotypes = { base_type = { base = "nil" }, -- variables variable = { var = true, level = 0 }, -- application/function fn = { fn = true, from = type, to = type } } local polytypes = { -- quantifier quantifier = { bound = { var }, type = monotype } } ---- Monotypes ---- -- Produce a new variable monotype. local current_level = 0 local function new_var() return { var = true, level = current_level } -- variable are compared by address: this is a unique by construction, new variable end -- Returns a function monotype. local function make_fn(from, to) return { fn = true, from = from, to = to } end -- Returns a base monotype. local function make_base(str) return { base = str } end -- Compare two monotypes for equality. local function equal(monotype_a, monotype_b) if monotype_a.base and monotype_b.base and monotype_a.base == monotype_b.base then return true elseif monotype_a.var and monotype_a == monotype_b then -- variable are compared by address return true elseif monotype_a.fn and monotype_b.fn then return equal(monotype_a.from, monotype_b.from) and equal(monotype_a.to, monotype_b.to) end return false end -- Union-set algorithm. -- https://en.wikipedia.org/wiki/Disjoint-set_data_structure local function find(monotype) while monotype.parent ~= nil do monotype = monotype.parent end return monotype end local function union(monotype_a, monotype_b) monotype_a, monotype_b = find(monotype_a), find(monotype_b) if equal(monotype_a, monotype_b) then return -- already in the same set end monotype_a.parent = monotype_b end --- Get string representation of a type. local function type_to_string(t, state) state = state or { i = 0, map = {} } t = find(t) if t.base then return tostring(t.base) elseif t.var then if not state.map[t] then state.map[t] = string.char(97+state.i) state.i = state.i + 1 end return ("%s"):format(state.map[t]) elseif t.fn then local from = find(t.from) if from.var or from.base then return ("%s -> %s"):format(type_to_string(from, state), type_to_string(t.to, state)) else return ("(%s) -> %s"):format(type_to_string(from, state), type_to_string(t.to, state)) end end end -- Check if vartype appear in monotype local function occurs(vartype, monotype) if monotype.base then return false elseif monotype.var then return monotype == vartype elseif monotype.fn then return occurs(vartype, monotype.from) or occurs(vartype, monotype.to) end end -- Unification local function unify(monotype_a, monotype_b) -- Get monotype representative. monotype_a = find(monotype_a) monotype_b = find(monotype_b) -- Unify this crap. if monotype_a.base and monotype_b.base and monotype_a.base == monotype_b.base then return elseif monotype_a.fn and monotype_b.fn then unify(monotype_a.from, monotype_b.from) unify(monotype_a.to, monotype_b.to) elseif monotype_a.var then assert(not occurs(monotype_a, monotype_b), "recursive binding") union(monotype_a, monotype_b) elseif monotype_b.var then assert(not occurs(monotype_a, monotype_b), "recursive binding") union(monotype_b, monotype_a) else error(("can't unity type %s and %s")) end end ---- Polytypes ---- -- Specializze the polytype by copying the term and replacing the bound type variables consistently by new monotype variables local function inst(polytype) local map = {} for _, var in ipairs(polytype.bound) do map[var] = new_var() end -- copy/replace in the term local function inst_rec(t) if t.base then return t elseif t.fn then return make_fn(inst_rec(t.from), inst_rec(t.to)) elseif t.var then return map[t] or t end end -- do return inst_rec(polytype.type) end -- Create a polytype from a monotype, quantifing all variable types that appear in the monotype. local function generalize(monotype) local found = {} local l = {} local function list_var_rec(t) if t.fn then list_var_rec(t.from) list_var_rec(t.to) elseif t.var and not found[t] then if t.level > current_level then table.insert(l, t) end found[t] = true end end list_var_rec(monotype) return { bound = l, type = monotype } end -- Create a polytype from a monotype, as is. local function dont_generalize(monotype) return { bound = {}, type = monotype } end ---- Inference ---- --- Infer types from expression! local function infer(expr, env) env = env or {} if expr[1] == "base" then return make_base(expr[2]) -- Var rule elseif expr[1] == "id" then local s = assert(env[expr[2]], ("unbound identifier %s"):format(expr[2])) local t = inst(s) return t -- App rule elseif expr[1] == "call" then local t0 = infer(expr[2], env) local t1 = infer(expr[3], env) local tt = new_var() unify(t0, make_fn(t1, tt)) return tt -- Abs rule elseif expr[1] == "lambda" then local t = new_var() local envb = setmetatable({ [expr[2]] = dont_generalize(t) }, { __index = env }) local tt = infer(expr[3], envb) return make_fn(t, tt) -- Let rule elseif expr[1] == "let" then current_level = current_level + 1 local t = infer(expr[3], env) current_level = current_level - 1 local envb = setmetatable({ [expr[2]] = generalize(t) }, { __index = env }) local tt = infer(expr[4], envb) return tt else print(require("inspect")(expr)) error(("unknown expression %s"):format(expr[1])) end end ---- Test ---- local function parse(s) if s:match("^%b()") then s = s:match("^%b()"):match("^%(%s*(.*)%s*%)$") local l = {} local i = 1 while s:match("[^%s]", i) do if s:match("^%b()", i) then local ss ss, i = s:match("^(%b())%s*()", i) table.insert(l, parse(ss)) elseif s:match("^%w+", i) then local ss ss, i = s:match("^(%w+)%s*()", i) table.insert(l, ss) else error(("unexpected %q"):format(s:sub(i))) end end return l elseif s:match("[^%s]") then error(("expected EOF near %q"):format(s)) end end local tests = { { exp = "(lambda f (lambda x (call (id f) (id x))))", result = "(a -> b) -> a -> b" }, { exp = "(lambda f (lambda x (call (id f) (call (id f) (id x)))))", result = "(a -> a) -> a -> a" }, { exp = "(lambda m (lambda n (lambda f (lambda x (call (call (id m) (id f)) (call (call (id n) (id f)) (id x)))))))))", result = "(a -> b -> c) -> (a -> d -> b) -> a -> d -> c" }, { exp = "(lambda n (lambda f (lambda x (call (id f) (call (call (id n) (id f)) (id x))))))", result = "((a -> b) -> c -> a) -> (a -> b) -> c -> b" }, { exp = "(lambda m (lambda n (lambda f (lambda x (call (call (id m) (call (id n) (id f))) (id x)))))))))", result = "(a -> b -> c) -> (d -> a) -> d -> b -> c" }, { exp = "(lambda n (lambda f (lambda x (call (call (call (id n) (lambda g (lambda h (call (id h) (call (id g) (id f)))))) (lambda u (id x))) (lambda u (id u))))))", result = "(((a -> b) -> (b -> c) -> c) -> (d -> e) -> (f -> f) -> g) -> a -> e -> g" }, { exp = "(lambda x (let y (id x) (id y)))", result = "a -> a" }, { exp = "(lambda x (let y (lambda z (id x)) (id y)))", result = "a -> b -> a" } } for _, t in ipairs(tests) do local r = type_to_string(infer(parse(t.exp))) if r ~= t.result then print("invalid result for test", t.exp, "expected", t.result, "but received", r) end end