foundobjects/hindleymilner - algorithm j.lua

269 lines
7.3 KiB
Lua
Executable file

-- Hindley-Milner type inference using algorithm J.
-- Base on:
-- https://en.wikipedia.org/wiki/Hindley%E2%80%93Milner_type_system#Algorithm_J
-- https://github.com/jfecher/algorithm-j/blob/master/j.ml
local monotypes = {
base_type = { base = "nil" },
-- variables
variable = { var = true, level = 0 },
-- application/function
fn = { fn = true, from = type, to = type }
}
local polytypes = {
-- quantifier
quantifier = { bound = { var }, type = monotype }
}
---- Monotypes ----
-- Produce a new variable monotype.
local current_level = 0
local function new_var()
return { var = true, level = current_level } -- variable are compared by address: this is a unique by construction, new variable
end
-- Returns a function monotype.
local function make_fn(from, to)
return { fn = true, from = from, to = to }
end
-- Returns a base monotype.
local function make_base(str)
return { base = str }
end
-- Compare two monotypes for equality.
local function equal(monotype_a, monotype_b)
if monotype_a.base and monotype_b.base and monotype_a.base == monotype_b.base then
return true
elseif monotype_a.var and monotype_a == monotype_b then -- variable are compared by address
return true
elseif monotype_a.fn and monotype_b.fn then
return equal(monotype_a.from, monotype_b.from) and equal(monotype_a.to, monotype_b.to)
end
return false
end
-- Union-set algorithm.
-- https://en.wikipedia.org/wiki/Disjoint-set_data_structure
local function find(monotype)
while monotype.parent ~= nil do
monotype = monotype.parent
end
return monotype
end
local function union(monotype_a, monotype_b)
monotype_a, monotype_b = find(monotype_a), find(monotype_b)
if equal(monotype_a, monotype_b) then
return -- already in the same set
end
monotype_a.parent = monotype_b
end
--- Get string representation of a type.
local function type_to_string(t, state)
state = state or { i = 0, map = {} }
t = find(t)
if t.base then
return tostring(t.base)
elseif t.var then
if not state.map[t] then
state.map[t] = string.char(97+state.i)
state.i = state.i + 1
end
return ("%s"):format(state.map[t])
elseif t.fn then
local from = find(t.from)
if from.var or from.base then
return ("%s -> %s"):format(type_to_string(from, state), type_to_string(t.to, state))
else
return ("(%s) -> %s"):format(type_to_string(from, state), type_to_string(t.to, state))
end
end
end
-- Check if vartype appear in monotype
local function occurs(vartype, monotype)
if monotype.base then
return false
elseif monotype.var then
return monotype == vartype
elseif monotype.fn then
return occurs(vartype, monotype.from) or occurs(vartype, monotype.to)
end
end
-- Unification
local function unify(monotype_a, monotype_b)
-- Get monotype representative.
monotype_a = find(monotype_a)
monotype_b = find(monotype_b)
-- Unify this crap.
if monotype_a.base and monotype_b.base and monotype_a.base == monotype_b.base then
return
elseif monotype_a.fn and monotype_b.fn then
unify(monotype_a.from, monotype_b.from)
unify(monotype_a.to, monotype_b.to)
elseif monotype_a.var then
assert(not occurs(monotype_a, monotype_b), "recursive binding")
union(monotype_a, monotype_b)
elseif monotype_b.var then
assert(not occurs(monotype_a, monotype_b), "recursive binding")
union(monotype_b, monotype_a)
else
error(("can't unity type %s and %s"))
end
end
---- Polytypes ----
-- Specializze the polytype by copying the term and replacing the bound type variables consistently by new monotype variables
local function inst(polytype)
local map = {}
for _, var in ipairs(polytype.bound) do
map[var] = new_var()
end
-- copy/replace in the term
local function inst_rec(t)
if t.base then
return t
elseif t.fn then
return make_fn(inst_rec(t.from), inst_rec(t.to))
elseif t.var then
return map[t] or t
end
end
-- do
return inst_rec(polytype.type)
end
-- Create a polytype from a monotype, quantifing all variable types that appear in the monotype.
local function generalize(monotype)
local found = {}
local l = {}
local function list_var_rec(t)
if t.fn then
list_var_rec(t.from)
list_var_rec(t.to)
elseif t.var and not found[t] then
if t.level > current_level then
table.insert(l, t)
end
found[t] = true
end
end
list_var_rec(monotype)
return { bound = l, type = monotype }
end
-- Create a polytype from a monotype, as is.
local function dont_generalize(monotype)
return { bound = {}, type = monotype }
end
---- Inference ----
--- Infer types from expression!
local function infer(expr, env)
env = env or {}
if expr[1] == "base" then
return make_base(expr[2])
-- Var rule
elseif expr[1] == "id" then
local s = assert(env[expr[2]], ("unbound identifier %s"):format(expr[2]))
local t = inst(s)
return t
-- App rule
elseif expr[1] == "call" then
local t0 = infer(expr[2], env)
local t1 = infer(expr[3], env)
local tt = new_var()
unify(t0, make_fn(t1, tt))
return tt
-- Abs rule
elseif expr[1] == "lambda" then
local t = new_var()
local envb = setmetatable({ [expr[2]] = dont_generalize(t) }, { __index = env })
local tt = infer(expr[3], envb)
return make_fn(t, tt)
-- Let rule
elseif expr[1] == "let" then
current_level = current_level + 1
local t = infer(expr[3], env)
current_level = current_level - 1
local envb = setmetatable({ [expr[2]] = generalize(t) }, { __index = env })
local tt = infer(expr[4], envb)
return tt
else
print(require("inspect")(expr))
error(("unknown expression %s"):format(expr[1]))
end
end
---- Test ----
local function parse(s)
if s:match("^%b()") then
s = s:match("^%b()"):match("^%(%s*(.*)%s*%)$")
local l = {}
local i = 1
while s:match("[^%s]", i) do
if s:match("^%b()", i) then
local ss
ss, i = s:match("^(%b())%s*()", i)
table.insert(l, parse(ss))
elseif s:match("^%w+", i) then
local ss
ss, i = s:match("^(%w+)%s*()", i)
table.insert(l, ss)
else
error(("unexpected %q"):format(s:sub(i)))
end
end
return l
elseif s:match("[^%s]") then
error(("expected EOF near %q"):format(s))
end
end
local tests = {
{
exp = "(lambda f (lambda x (call (id f) (id x))))",
result = "(a -> b) -> a -> b"
},
{
exp = "(lambda f (lambda x (call (id f) (call (id f) (id x)))))",
result = "(a -> a) -> a -> a"
},
{
exp = "(lambda m (lambda n (lambda f (lambda x (call (call (id m) (id f)) (call (call (id n) (id f)) (id x)))))))))",
result = "(a -> b -> c) -> (a -> d -> b) -> a -> d -> c"
},
{
exp = "(lambda n (lambda f (lambda x (call (id f) (call (call (id n) (id f)) (id x))))))",
result = "((a -> b) -> c -> a) -> (a -> b) -> c -> b"
},
{
exp = "(lambda m (lambda n (lambda f (lambda x (call (call (id m) (call (id n) (id f))) (id x)))))))))",
result = "(a -> b -> c) -> (d -> a) -> d -> b -> c"
},
{
exp = "(lambda n (lambda f (lambda x (call (call (call (id n) (lambda g (lambda h (call (id h) (call (id g) (id f)))))) (lambda u (id x))) (lambda u (id u))))))",
result = "(((a -> b) -> (b -> c) -> c) -> (d -> e) -> (f -> f) -> g) -> a -> e -> g"
},
{
exp = "(lambda x (let y (id x) (id y)))",
result = "a -> a"
},
{
exp = "(lambda x (let y (lambda z (id x)) (id y)))",
result = "a -> b -> a"
}
}
for _, t in ipairs(tests) do
local r = type_to_string(infer(parse(t.exp)))
if r ~= t.result then
print("invalid result for test", t.exp, "expected", t.result, "but received", r)
end
end