Found a couple things

This commit is contained in:
Étienne Fildadut 2025-06-27 22:08:10 +02:00
parent daef111089
commit 5214fed9d3
3 changed files with 344 additions and 1 deletions

View file

@ -1,3 +1,7 @@
# foundobjects
various bits and pieces from around the place
various bits and pieces from around the place
this repo contains various one-file implementations of various algorithms or simple programs, stuff i've done for learning or experimentation purposes
nothing here is new or of particular value, might be helpful for educational purposes or something though so here you go

269
hindleymilner - algorithm j.lua Executable file
View file

@ -0,0 +1,269 @@
-- Hindley-Milner type inference using algorithm J.
-- Base on:
-- https://en.wikipedia.org/wiki/Hindley%E2%80%93Milner_type_system#Algorithm_J
-- https://github.com/jfecher/algorithm-j/blob/master/j.ml
local monotypes = {
base_type = { base = "nil" },
-- variables
variable = { var = true, level = 0 },
-- application/function
fn = { fn = true, from = type, to = type }
}
local polytypes = {
-- quantifier
quantifier = { bound = { var }, type = monotype }
}
---- Monotypes ----
-- Produce a new variable monotype.
local current_level = 0
local function new_var()
return { var = true, level = current_level } -- variable are compared by address: this is a unique by construction, new variable
end
-- Returns a function monotype.
local function make_fn(from, to)
return { fn = true, from = from, to = to }
end
-- Returns a base monotype.
local function make_base(str)
return { base = str }
end
-- Compare two monotypes for equality.
local function equal(monotype_a, monotype_b)
if monotype_a.base and monotype_b.base and monotype_a.base == monotype_b.base then
return true
elseif monotype_a.var and monotype_a == monotype_b then -- variable are compared by address
return true
elseif monotype_a.fn and monotype_b.fn then
return equal(monotype_a.from, monotype_b.from) and equal(monotype_a.to, monotype_b.to)
end
return false
end
-- Union-set algorithm.
-- https://en.wikipedia.org/wiki/Disjoint-set_data_structure
local function find(monotype)
while monotype.parent ~= nil do
monotype = monotype.parent
end
return monotype
end
local function union(monotype_a, monotype_b)
monotype_a, monotype_b = find(monotype_a), find(monotype_b)
if equal(monotype_a, monotype_b) then
return -- already in the same set
end
monotype_a.parent = monotype_b
end
--- Get string representation of a type.
local function type_to_string(t, state)
state = state or { i = 0, map = {} }
t = find(t)
if t.base then
return tostring(t.base)
elseif t.var then
if not state.map[t] then
state.map[t] = string.char(97+state.i)
state.i = state.i + 1
end
return ("%s"):format(state.map[t])
elseif t.fn then
local from = find(t.from)
if from.var or from.base then
return ("%s -> %s"):format(type_to_string(from, state), type_to_string(t.to, state))
else
return ("(%s) -> %s"):format(type_to_string(from, state), type_to_string(t.to, state))
end
end
end
-- Check if vartype appear in monotype
local function occurs(vartype, monotype)
if monotype.base then
return false
elseif monotype.var then
return monotype == vartype
elseif monotype.fn then
return occurs(vartype, monotype.from) or occurs(vartype, monotype.to)
end
end
-- Unification
local function unify(monotype_a, monotype_b)
-- Get monotype representative.
monotype_a = find(monotype_a)
monotype_b = find(monotype_b)
-- Unify this crap.
if monotype_a.base and monotype_b.base and monotype_a.base == monotype_b.base then
return
elseif monotype_a.fn and monotype_b.fn then
unify(monotype_a.from, monotype_b.from)
unify(monotype_a.to, monotype_b.to)
elseif monotype_a.var then
assert(not occurs(monotype_a, monotype_b), "recursive binding")
union(monotype_a, monotype_b)
elseif monotype_b.var then
assert(not occurs(monotype_a, monotype_b), "recursive binding")
union(monotype_b, monotype_a)
else
error(("can't unity type %s and %s"))
end
end
---- Polytypes ----
-- Specializze the polytype by copying the term and replacing the bound type variables consistently by new monotype variables
local function inst(polytype)
local map = {}
for _, var in ipairs(polytype.bound) do
map[var] = new_var()
end
-- copy/replace in the term
local function inst_rec(t)
if t.base then
return t
elseif t.fn then
return make_fn(inst_rec(t.from), inst_rec(t.to))
elseif t.var then
return map[t] or t
end
end
-- do
return inst_rec(polytype.type)
end
-- Create a polytype from a monotype, quantifing all variable types that appear in the monotype.
local function generalize(monotype)
local found = {}
local l = {}
local function list_var_rec(t)
if t.fn then
list_var_rec(t.from)
list_var_rec(t.to)
elseif t.var and not found[t] then
if t.level > current_level then
table.insert(l, t)
end
found[t] = true
end
end
list_var_rec(monotype)
return { bound = l, type = monotype }
end
-- Create a polytype from a monotype, as is.
local function dont_generalize(monotype)
return { bound = {}, type = monotype }
end
---- Inference ----
--- Infer types from expression!
local function infer(expr, env)
env = env or {}
if expr[1] == "base" then
return make_base(expr[2])
-- Var rule
elseif expr[1] == "id" then
local s = assert(env[expr[2]], ("unbound identifier %s"):format(expr[2]))
local t = inst(s)
return t
-- App rule
elseif expr[1] == "call" then
local t0 = infer(expr[2], env)
local t1 = infer(expr[3], env)
local tt = new_var()
unify(t0, make_fn(t1, tt))
return tt
-- Abs rule
elseif expr[1] == "lambda" then
local t = new_var()
local envb = setmetatable({ [expr[2]] = dont_generalize(t) }, { __index = env })
local tt = infer(expr[3], envb)
return make_fn(t, tt)
-- Let rule
elseif expr[1] == "let" then
current_level = current_level + 1
local t = infer(expr[3], env)
current_level = current_level - 1
local envb = setmetatable({ [expr[2]] = generalize(t) }, { __index = env })
local tt = infer(expr[4], envb)
return tt
else
print(require("inspect")(expr))
error(("unknown expression %s"):format(expr[1]))
end
end
---- Test ----
local function parse(s)
if s:match("^%b()") then
s = s:match("^%b()"):match("^%(%s*(.*)%s*%)$")
local l = {}
local i = 1
while s:match("[^%s]", i) do
if s:match("^%b()", i) then
local ss
ss, i = s:match("^(%b())%s*()", i)
table.insert(l, parse(ss))
elseif s:match("^%w+", i) then
local ss
ss, i = s:match("^(%w+)%s*()", i)
table.insert(l, ss)
else
error(("unexpected %q"):format(s:sub(i)))
end
end
return l
elseif s:match("[^%s]") then
error(("expected EOF near %q"):format(s))
end
end
local tests = {
{
exp = "(lambda f (lambda x (call (id f) (id x))))",
result = "(a -> b) -> a -> b"
},
{
exp = "(lambda f (lambda x (call (id f) (call (id f) (id x)))))",
result = "(a -> a) -> a -> a"
},
{
exp = "(lambda m (lambda n (lambda f (lambda x (call (call (id m) (id f)) (call (call (id n) (id f)) (id x)))))))))",
result = "(a -> b -> c) -> (a -> d -> b) -> a -> d -> c"
},
{
exp = "(lambda n (lambda f (lambda x (call (id f) (call (call (id n) (id f)) (id x))))))",
result = "((a -> b) -> c -> a) -> (a -> b) -> c -> b"
},
{
exp = "(lambda m (lambda n (lambda f (lambda x (call (call (id m) (call (id n) (id f))) (id x)))))))))",
result = "(a -> b -> c) -> (d -> a) -> d -> b -> c"
},
{
exp = "(lambda n (lambda f (lambda x (call (call (call (id n) (lambda g (lambda h (call (id h) (call (id g) (id f)))))) (lambda u (id x))) (lambda u (id u))))))",
result = "(((a -> b) -> (b -> c) -> c) -> (d -> e) -> (f -> f) -> g) -> a -> e -> g"
},
{
exp = "(lambda x (let y (id x) (id y)))",
result = "a -> a"
},
{
exp = "(lambda x (let y (lambda z (id x)) (id y)))",
result = "a -> b -> a"
}
}
for _, t in ipairs(tests) do
local r = type_to_string(infer(parse(t.exp)))
if r ~= t.result then
print("invalid result for test", t.exp, "expected", t.result, "but received", r)
end
end

70
sexpr.lua Normal file
View file

@ -0,0 +1,70 @@
--- simple s expressions parser
local parse, parse_exp, parse_atom
-- expression: starts with ( and ends with ), contains a whietspace separated list of expressions and tokens
-- s has no leading whitespace, starts with (
-- returns exp, r (r has no leading whitespace)
-- returns nil, err
parse_exp = function(s)
local r = s:match("^%(%s*(.*)$")
if not r then return nil, "no expression found" end
local exp = {}
repeat
local item, r_item = parse(r)
if item then
table.insert(exp, item)
r = r_item
end
until not item
if not r:match("^%)") then
return nil, "expected closing )"
end
return exp, r:match("^%)%s*(.-)$")
end
-- atom: litteral delimited by whitespace, ), or (; and with escaping using \
-- s has no leading whitespace
-- returns exp, r (r has no leading whitespace)
-- returns nil, err
parse_atom = function(s)
local atom = {}
local n, r = s:match("^([^%s%(%)\\]*)(.-)$")
if #n > 0 then table.insert(atom, n) end
while r:match("^\\") do
table.insert(atom, r:match("^\\(.)"))
n, r = r:match("^\\.([^%s%(%)\\]*)(.-)$")
if #n > 0 then table.insert(atom, n) end
end
if #atom == 0 then return nil, "no atom found" end
return table.concat(atom), r:match("^%s*(.-)$")
end
-- s has no leading whitespace
-- returns exp, r (r has no leading whitespace)
-- returns nil, err
parse = function(s)
local i, r = parse_exp(s)
if i then return i, r end
i, r = parse_atom(s)
if i then return i, r end
return nil, "no expression found"
end
local function test(s)
local trimmed = s:match("^%s*(.-)$")
local parsed, r = parse(s)
if not parsed then
print(r)
elseif r:match(".") then
print(("unexpected %q at end of expression"):format(r))
else
print(require("inspect")(parsed))
end
end
test("((str (Hel\\)lo world) sa mère) (lol))")
test("\\lol\\ wut")
test("()")