diff --git a/README.md b/README.md index 6323759..651dbdf 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,7 @@ # foundobjects -various bits and pieces from around the place \ No newline at end of file +various bits and pieces from around the place + +this repo contains various one-file implementations of various algorithms or simple programs, stuff i've done for learning or experimentation purposes + +nothing here is new or of particular value, might be helpful for educational purposes or something though so here you go diff --git a/hindleymilner - algorithm j.lua b/hindleymilner - algorithm j.lua new file mode 100755 index 0000000..347eb02 --- /dev/null +++ b/hindleymilner - algorithm j.lua @@ -0,0 +1,269 @@ +-- Hindley-Milner type inference using algorithm J. +-- Base on: +-- https://en.wikipedia.org/wiki/Hindley%E2%80%93Milner_type_system#Algorithm_J +-- https://github.com/jfecher/algorithm-j/blob/master/j.ml + +local monotypes = { + base_type = { base = "nil" }, + -- variables + variable = { var = true, level = 0 }, + -- application/function + fn = { fn = true, from = type, to = type } +} +local polytypes = { + -- quantifier + quantifier = { bound = { var }, type = monotype } +} + +---- Monotypes ---- + +-- Produce a new variable monotype. +local current_level = 0 +local function new_var() + return { var = true, level = current_level } -- variable are compared by address: this is a unique by construction, new variable +end +-- Returns a function monotype. +local function make_fn(from, to) + return { fn = true, from = from, to = to } +end +-- Returns a base monotype. +local function make_base(str) + return { base = str } +end + +-- Compare two monotypes for equality. +local function equal(monotype_a, monotype_b) + if monotype_a.base and monotype_b.base and monotype_a.base == monotype_b.base then + return true + elseif monotype_a.var and monotype_a == monotype_b then -- variable are compared by address + return true + elseif monotype_a.fn and monotype_b.fn then + return equal(monotype_a.from, monotype_b.from) and equal(monotype_a.to, monotype_b.to) + end + return false +end + +-- Union-set algorithm. +-- https://en.wikipedia.org/wiki/Disjoint-set_data_structure +local function find(monotype) + while monotype.parent ~= nil do + monotype = monotype.parent + end + return monotype +end +local function union(monotype_a, monotype_b) + monotype_a, monotype_b = find(monotype_a), find(monotype_b) + if equal(monotype_a, monotype_b) then + return -- already in the same set + end + monotype_a.parent = monotype_b +end + +--- Get string representation of a type. +local function type_to_string(t, state) + state = state or { i = 0, map = {} } + t = find(t) + if t.base then + return tostring(t.base) + elseif t.var then + if not state.map[t] then + state.map[t] = string.char(97+state.i) + state.i = state.i + 1 + end + return ("%s"):format(state.map[t]) + elseif t.fn then + local from = find(t.from) + if from.var or from.base then + return ("%s -> %s"):format(type_to_string(from, state), type_to_string(t.to, state)) + else + return ("(%s) -> %s"):format(type_to_string(from, state), type_to_string(t.to, state)) + end + end +end + +-- Check if vartype appear in monotype +local function occurs(vartype, monotype) + if monotype.base then + return false + elseif monotype.var then + return monotype == vartype + elseif monotype.fn then + return occurs(vartype, monotype.from) or occurs(vartype, monotype.to) + end +end + +-- Unification +local function unify(monotype_a, monotype_b) + -- Get monotype representative. + monotype_a = find(monotype_a) + monotype_b = find(monotype_b) + -- Unify this crap. + if monotype_a.base and monotype_b.base and monotype_a.base == monotype_b.base then + return + elseif monotype_a.fn and monotype_b.fn then + unify(monotype_a.from, monotype_b.from) + unify(monotype_a.to, monotype_b.to) + elseif monotype_a.var then + assert(not occurs(monotype_a, monotype_b), "recursive binding") + union(monotype_a, monotype_b) + elseif monotype_b.var then + assert(not occurs(monotype_a, monotype_b), "recursive binding") + union(monotype_b, monotype_a) + else + error(("can't unity type %s and %s")) + end +end + +---- Polytypes ---- + +-- Specializze the polytype by copying the term and replacing the bound type variables consistently by new monotype variables +local function inst(polytype) + local map = {} + for _, var in ipairs(polytype.bound) do + map[var] = new_var() + end + -- copy/replace in the term + local function inst_rec(t) + if t.base then + return t + elseif t.fn then + return make_fn(inst_rec(t.from), inst_rec(t.to)) + elseif t.var then + return map[t] or t + end + end + -- do + return inst_rec(polytype.type) +end + +-- Create a polytype from a monotype, quantifing all variable types that appear in the monotype. +local function generalize(monotype) + local found = {} + local l = {} + local function list_var_rec(t) + if t.fn then + list_var_rec(t.from) + list_var_rec(t.to) + elseif t.var and not found[t] then + if t.level > current_level then + table.insert(l, t) + end + found[t] = true + end + end + list_var_rec(monotype) + return { bound = l, type = monotype } +end + +-- Create a polytype from a monotype, as is. +local function dont_generalize(monotype) + return { bound = {}, type = monotype } +end + +---- Inference ---- + +--- Infer types from expression! +local function infer(expr, env) + env = env or {} + if expr[1] == "base" then + return make_base(expr[2]) + -- Var rule + elseif expr[1] == "id" then + local s = assert(env[expr[2]], ("unbound identifier %s"):format(expr[2])) + local t = inst(s) + return t + -- App rule + elseif expr[1] == "call" then + local t0 = infer(expr[2], env) + local t1 = infer(expr[3], env) + local tt = new_var() + unify(t0, make_fn(t1, tt)) + return tt + -- Abs rule + elseif expr[1] == "lambda" then + local t = new_var() + local envb = setmetatable({ [expr[2]] = dont_generalize(t) }, { __index = env }) + local tt = infer(expr[3], envb) + return make_fn(t, tt) + -- Let rule + elseif expr[1] == "let" then + current_level = current_level + 1 + local t = infer(expr[3], env) + current_level = current_level - 1 + local envb = setmetatable({ [expr[2]] = generalize(t) }, { __index = env }) + local tt = infer(expr[4], envb) + return tt + else + print(require("inspect")(expr)) + error(("unknown expression %s"):format(expr[1])) + end +end + +---- Test ---- + +local function parse(s) + if s:match("^%b()") then + s = s:match("^%b()"):match("^%(%s*(.*)%s*%)$") + local l = {} + local i = 1 + while s:match("[^%s]", i) do + if s:match("^%b()", i) then + local ss + ss, i = s:match("^(%b())%s*()", i) + table.insert(l, parse(ss)) + elseif s:match("^%w+", i) then + local ss + ss, i = s:match("^(%w+)%s*()", i) + table.insert(l, ss) + else + error(("unexpected %q"):format(s:sub(i))) + end + end + return l + elseif s:match("[^%s]") then + error(("expected EOF near %q"):format(s)) + end +end + +local tests = { + { + exp = "(lambda f (lambda x (call (id f) (id x))))", + result = "(a -> b) -> a -> b" + }, + { + exp = "(lambda f (lambda x (call (id f) (call (id f) (id x)))))", + result = "(a -> a) -> a -> a" + }, + { + exp = "(lambda m (lambda n (lambda f (lambda x (call (call (id m) (id f)) (call (call (id n) (id f)) (id x)))))))))", + result = "(a -> b -> c) -> (a -> d -> b) -> a -> d -> c" + }, + { + exp = "(lambda n (lambda f (lambda x (call (id f) (call (call (id n) (id f)) (id x))))))", + result = "((a -> b) -> c -> a) -> (a -> b) -> c -> b" + }, + { + exp = "(lambda m (lambda n (lambda f (lambda x (call (call (id m) (call (id n) (id f))) (id x)))))))))", + result = "(a -> b -> c) -> (d -> a) -> d -> b -> c" + }, + { + exp = "(lambda n (lambda f (lambda x (call (call (call (id n) (lambda g (lambda h (call (id h) (call (id g) (id f)))))) (lambda u (id x))) (lambda u (id u))))))", + result = "(((a -> b) -> (b -> c) -> c) -> (d -> e) -> (f -> f) -> g) -> a -> e -> g" + }, + + { + exp = "(lambda x (let y (id x) (id y)))", + result = "a -> a" + }, + { + exp = "(lambda x (let y (lambda z (id x)) (id y)))", + result = "a -> b -> a" + } +} + +for _, t in ipairs(tests) do + local r = type_to_string(infer(parse(t.exp))) + if r ~= t.result then + print("invalid result for test", t.exp, "expected", t.result, "but received", r) + end +end diff --git a/sexpr.lua b/sexpr.lua new file mode 100644 index 0000000..4e7bd6f --- /dev/null +++ b/sexpr.lua @@ -0,0 +1,70 @@ +--- simple s expressions parser + +local parse, parse_exp, parse_atom + +-- expression: starts with ( and ends with ), contains a whietspace separated list of expressions and tokens +-- s has no leading whitespace, starts with ( +-- returns exp, r (r has no leading whitespace) +-- returns nil, err +parse_exp = function(s) + local r = s:match("^%(%s*(.*)$") + if not r then return nil, "no expression found" end + local exp = {} + repeat + local item, r_item = parse(r) + if item then + table.insert(exp, item) + r = r_item + end + until not item + if not r:match("^%)") then + return nil, "expected closing )" + end + return exp, r:match("^%)%s*(.-)$") +end + +-- atom: litteral delimited by whitespace, ), or (; and with escaping using \ +-- s has no leading whitespace +-- returns exp, r (r has no leading whitespace) +-- returns nil, err +parse_atom = function(s) + local atom = {} + local n, r = s:match("^([^%s%(%)\\]*)(.-)$") + if #n > 0 then table.insert(atom, n) end + while r:match("^\\") do + table.insert(atom, r:match("^\\(.)")) + n, r = r:match("^\\.([^%s%(%)\\]*)(.-)$") + if #n > 0 then table.insert(atom, n) end + end + if #atom == 0 then return nil, "no atom found" end + return table.concat(atom), r:match("^%s*(.-)$") +end + +-- s has no leading whitespace +-- returns exp, r (r has no leading whitespace) +-- returns nil, err +parse = function(s) + local i, r = parse_exp(s) + if i then return i, r end + i, r = parse_atom(s) + if i then return i, r end + return nil, "no expression found" +end + +local function test(s) + local trimmed = s:match("^%s*(.-)$") + local parsed, r = parse(s) + if not parsed then + print(r) + elseif r:match(".") then + print(("unexpected %q at end of expression"):format(r)) + else + print(require("inspect")(parsed)) + end +end + +test("((str (Hel\\)lo world) sa mère) (lol))") + +test("\\lol\\ wut") + +test("()")