From 6e5cbf9e7ebbabc01a974769b1aff27dc34d9af4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=89tienne=20Reuh=20Fildadut?= Date: Sat, 4 Dec 2021 18:13:03 +0100 Subject: [PATCH] Add function references --- anselme.lua | 20 +--- common.lua | 22 +++++ interpreter/common.lua | 2 + interpreter/expression.lua | 27 +++++- interpreter/interpreter.lua | 4 +- parser/common.lua | 14 +-- parser/expression.lua | 49 ++++++---- parser/postparser.lua | 2 +- stdlib/bootscript.lua | 2 + stdlib/types.lua | 2 +- test/tests/function reference call.ans | 27 ++++++ test/tests/function reference call.lua | 123 +++++++++++++++++++++++++ 12 files changed, 245 insertions(+), 49 deletions(-) create mode 100644 common.lua create mode 100644 test/tests/function reference call.ans create mode 100644 test/tests/function reference call.lua diff --git a/anselme.lua b/anselme.lua index 45f4df6..d8f8a2d 100644 --- a/anselme.lua +++ b/anselme.lua @@ -29,6 +29,7 @@ local identifier_pattern = require(anselme_root.."parser.common").identifier_pat local merge_state = require(anselme_root.."interpreter.common").merge_state local stdfuncs = require(anselme_root.."stdlib.functions") local bootscript = require(anselme_root.."stdlib.bootscript") +local copy = require(anselme_root.."common").copy -- wrappers for love.filesystem / luafilesystem local function list_directory(path) @@ -60,25 +61,6 @@ local function is_file(path) end end ---- recursively copy a table, handle cyclic references, no metatable -local function copy(t, cache) - if type(t) == "table" then - cache = cache or {} - if cache[t] then - return cache[t] - else - local c = {} - cache[t] = c - for k, v in pairs(t) do - c[k] = copy(v, cache) - end - return c - end - else - return t - end -end - --- interpreter methods local interpreter_methods = { -- interpreter state diff --git a/common.lua b/common.lua new file mode 100644 index 0000000..a6f0e07 --- /dev/null +++ b/common.lua @@ -0,0 +1,22 @@ +local common +common = { + --- recursively copy a table, handle cyclic references, no metatable + copy = function(t, cache) + if type(t) == "table" then + cache = cache or {} + if cache[t] then + return cache[t] + else + local c = {} + cache[t] = c + for k, v in pairs(t) do + c[k] = common.copy(v, cache) + end + return c + end + else + return t + end + end +} +return common diff --git a/interpreter/common.lua b/interpreter/common.lua index ab93ea9..7e56232 100644 --- a/interpreter/common.lua +++ b/interpreter/common.lua @@ -1,5 +1,6 @@ local atypes, ltypes local eval, run_block +local copy local common --- copy some text & process it to be suited to be sent to Lua in an event @@ -393,5 +394,6 @@ local types = require((...):gsub("interpreter%.common$", "stdlib.types")) atypes, ltypes = types.anselme, types.lua eval = require((...):gsub("common$", "expression")) run_block = require((...):gsub("common$", "interpreter")).run_block +copy = require((...):gsub("interpreter%.common$", "common")).copy return common diff --git a/interpreter/expression.lua b/interpreter/expression.lua index d0e1495..28bf5d5 100644 --- a/interpreter/expression.lua +++ b/interpreter/expression.lua @@ -1,5 +1,6 @@ local expression local to_lua, from_lua, eval_text, is_of_type, truthy, format, pretty_type, get_variable, tags, eval_text_callback, events, flatten_list +local copy local run @@ -76,7 +77,7 @@ local function eval(state, exp) events:pop_buffer(state) if not v then return v, e end return { - type = "eventbuffer", + type = "event buffer", value = l } -- assignment @@ -149,7 +150,12 @@ local function eval(state, exp) elseif exp.type == "variable" then return get_variable(state, exp.name) -- function - elseif exp.type == "function" then + elseif exp.type == "function reference" then + return { + type = "function reference", + value = exp.names + } + elseif exp.type == "function call" then -- eval args: list_brackets local args = {} if exp.argument then @@ -157,6 +163,20 @@ local function eval(state, exp) if not arg then return arg, arge end args = arg.value end + -- function reference: call the referenced function + local variants = exp.variants + if exp.called_name == "()" and args[1].type == "function reference" then + -- remove func ref as first arg + local refv = args[1].value + table.remove(args, 1) + -- get variants of the referenced function + variants = {} + for _, ffqm in ipairs(refv) do + for _, variant in ipairs(state.functions[ffqm]) do + table.insert(variants, variant) + end + end + end -- map named arguments local named_args = {} for i, arg in ipairs(args) do @@ -174,7 +194,7 @@ local function eval(state, exp) -- try to select a function local tried_function_error_messages = {} local selected_variant = { depths = { assignment = nil }, variant = nil } - for _, fn in ipairs(exp.variants) do + for _, fn in ipairs(variants) do -- checkpoint: no args, nothing to select on if fn.type == "checkpoint" then if not selected_variant.variant then @@ -425,5 +445,6 @@ expression = require((...):gsub("interpreter%.expression$", "parser.expression") flatten_list = require((...):gsub("interpreter%.expression$", "parser.common")).flatten_list local common = require((...):gsub("expression$", "common")) to_lua, from_lua, eval_text, is_of_type, truthy, format, pretty_type, get_variable, tags, eval_text_callback, events = common.to_lua, common.from_lua, common.eval_text, common.is_of_type, common.truthy, common.format, common.pretty_type, common.get_variable, common.tags, common.eval_text_callback, common.events +copy = require((...):gsub("interpreter%.expression$", "common")).copy return eval diff --git a/interpreter/interpreter.lua b/interpreter/interpreter.lua index 6e0e598..152065c 100644 --- a/interpreter/interpreter.lua +++ b/interpreter/interpreter.lua @@ -36,7 +36,7 @@ run_line = function(state, line) v, e = eval(state, line.text) if not v then return v, ("%s; at %s"):format(e, line.source) end -- convert text events to choices - if v.type == "eventbuffer" then + if v.type == "event buffer" then local current_tags = tags:current(state) local choice_block_state = { tags = current_tags, block = line.child } local final_buffer = {} @@ -82,7 +82,7 @@ run_line = function(state, line) if not v then return v, ("%s; in automatic event flush at %s"):format(e, line.source) end v, e = eval(state, line.text) if not v then return v, ("%s; at %s"):format(e, line.source) end - if v.type == "eventbuffer" then + if v.type == "event buffer" then v, e = events:write_buffer(state, v.value) if not v then return v, ("%s; at %s"):format(e, line.source) end end diff --git a/parser/common.lua b/parser/common.lua index da0738b..b9d1e7d 100644 --- a/parser/common.lua +++ b/parser/common.lua @@ -174,7 +174,7 @@ common = { if not exp then return nil, rem end if not rem:match("^%s*}") then return nil, ("expected closing } at end of expression before %q"):format(rem) end -- wrap in format() call - local variant, err = common.find_function_variant(state, namespace, "{}", { type = "parentheses", expression = exp }, true) + local variant, err = common.find_function(state, namespace, "{}", { type = "parentheses", expression = exp }, true) if not variant then return variant, err end -- add to text table.insert(l, variant) @@ -211,13 +211,13 @@ common = { return text_exp end end, - -- find compatible function variants from a fully qualified name - -- this functions does not guarantee that functions are fully compatible with the given arguments and only performs a pre-selection without the ones which definitely aren't - -- * list of variants: if success + -- find a list of compatible function variants from a fully qualified name + -- this functions does not guarantee that the returned variants are fully compatible with the given arguments and only performs a pre-selection without the ones which definitely aren't + -- * list of compatible variants: if success -- * nil, err: if error find_function_variant_from_fqm = function(fqm, state, arg) local err = ("compatible function %q variant not found"):format(fqm) - local func = state.functions[fqm] or {} + local func = state.functions[fqm] local args = arg and common.flatten_list(arg) or {} local variants = {} for _, variant in ipairs(func) do @@ -248,7 +248,7 @@ common = { --- same as find_function_variant_from_fqm, but will search every function from the current namespace and up using find -- returns directly a function expression in case of success -- return nil, err otherwise - find_function_variant = function(state, namespace, name, arg, explicit_call) + find_function = function(state, namespace, name, arg, explicit_call) local variants = {} local err = ("compatible function %q variant not found"):format(name) local l = common.find_all(state.aliases, state.functions, namespace, name) @@ -263,7 +263,7 @@ common = { end if #variants > 0 then return { - type = "function", + type = "function call", called_name = name, explicit_call = explicit_call, variants = variants, diff --git a/parser/expression.lua b/parser/expression.lua index ed83ede..0d875f6 100644 --- a/parser/expression.lua +++ b/parser/expression.lua @@ -1,4 +1,4 @@ -local identifier_pattern, format_identifier, find, escape, find_function_variant, parse_text +local identifier_pattern, format_identifier, find, escape, find_function, parse_text, find_all --- binop priority local binops_prio = { @@ -132,7 +132,7 @@ local function expression(s, state, namespace, current_priority, operating_on) right = val } -- find compatible variant - local variant, err = find_function_variant(state, namespace, ":", args, true) + local variant, err = find_function(state, namespace, ":", args, true) if not variant then return variant, err end return expression(r, state, namespace, current_priority, variant) end @@ -171,9 +171,22 @@ local function expression(s, state, namespace, current_priority, operating_on) end end -- find compatible variant - local variant, err = find_function_variant(state, namespace, name, args, explicit_call) + local variant, err = find_function(state, namespace, name, args, explicit_call) if not variant then return variant, err end return expression(r, state, namespace, current_priority, variant) + -- function reference + elseif s:match("^%&"..identifier_pattern) then + local name, r = s:match("^%&("..identifier_pattern..")(.-)$") + name = format_identifier(name) + -- get all functions this name can reference + local lfnqm = find_all(state.aliases, state.functions, namespace, name) + if #lfnqm > 0 then + return expression(r, state, namespace, current_priority, { + type = "function reference", + names = lfnqm + }) + end + return nil, ("can't find function %q to reference"):format(name) end -- unops for prio, oplist in ipairs(unops_prio) do @@ -183,7 +196,7 @@ local function expression(s, state, namespace, current_priority, operating_on) local right, r = expression(s:match("^"..escaped.."(.*)$"), state, namespace, prio) if not right then return nil, ("invalid expression after unop %q: %s"):format(op, r) end -- find variant - local variant, err = find_function_variant(state, namespace, op, right, true) + local variant, err = find_function(state, namespace, op, right, true) if not variant then return variant, err end return expression(r, state, namespace, current_priority, variant) end @@ -227,7 +240,7 @@ local function expression(s, state, namespace, current_priority, operating_on) } end -- find compatible variant - local variant, err = find_function_variant(state, namespace, name, args, explicit_call) + local variant, err = find_function(state, namespace, name, args, explicit_call) if not variant then return variant, err end return expression(r, state, namespace, current_priority, variant) -- other binops @@ -250,12 +263,12 @@ local function expression(s, state, namespace, current_priority, operating_on) left = operating_on, right = right } - local variant, err = find_function_variant(state, namespace, op:match("^(.*)%=$"), args, true) + local variant, err = find_function(state, namespace, op:match("^(.*)%=$"), args, true) if not variant then return variant, err end right = variant end -- assign to a function - if operating_on.type == "function" then + if operating_on.type == "function call" then -- remove non-assignment functions for i=#operating_on.variants, 1, -1 do if not operating_on.variants[i].assignment then @@ -291,7 +304,7 @@ local function expression(s, state, namespace, current_priority, operating_on) left = operating_on, right = right } - local variant, err = find_function_variant(state, namespace, op, args, true) + local variant, err = find_function(state, namespace, op, args, true) if not variant then return variant, err end return expression(r, state, namespace, current_priority, variant) end @@ -300,15 +313,19 @@ local function expression(s, state, namespace, current_priority, operating_on) end end end - -- index + -- index / call if s:match("^%b()") then local content, r = s:match("^(%b())(.*)$") - -- get arguments (parentheses are kept) - local right, r_paren = expression(content, state, namespace) - if not right then return right, r_paren end - if r_paren:match("[^%s]") then return nil, ("unexpected %q at end of index expression"):format(r_paren) end - local args = { type = "list", left = operating_on, right = right } - local variant, err = find_function_variant(state, namespace, "()", args, true) + content = content:gsub("^%(", ""):gsub("%)$", "") + -- get arguments + local args = operating_on + if content:match("[^%s]") then + local right, r_paren = expression(content, state, namespace) + if not right then return right, r_paren end + if r_paren:match("[^%s]") then return nil, ("unexpected %q at end of index/call expression"):format(r_paren) end + args = { type = "list", left = args, right = right } + end + local variant, err = find_function(state, namespace, "()", args, true) if not variant then return variant, err end return expression(r, state, namespace, current_priority, variant) end @@ -319,6 +336,6 @@ end package.loaded[...] = expression local common = require((...):gsub("expression$", "common")) -identifier_pattern, format_identifier, find, escape, find_function_variant, parse_text = common.identifier_pattern, common.format_identifier, common.find, common.escape, common.find_function_variant, common.parse_text +identifier_pattern, format_identifier, find, escape, find_function, parse_text, find_all = common.identifier_pattern, common.format_identifier, common.find, common.escape, common.find_function, common.parse_text, common.find_all return expression diff --git a/parser/postparser.lua b/parser/postparser.lua index f8d85a1..6922f89 100644 --- a/parser/postparser.lua +++ b/parser/postparser.lua @@ -29,7 +29,7 @@ local function parse(state) end param.default = default_exp -- extract type annotation from default value - if default_exp.type == "function" and default_exp.called_name == "::" then + if default_exp.type == "function call" and default_exp.called_name == "::" then param.type_annotation = default_exp.argument.expression.right end end diff --git a/stdlib/bootscript.lua b/stdlib/bootscript.lua index 55028c6..78b3c60 100644 --- a/stdlib/bootscript.lua +++ b/stdlib/bootscript.lua @@ -6,4 +6,6 @@ return [[ :string="string" :list="list" :pair="pair" +:event buffer="event buffer" +:function reference="function reference" ]] diff --git a/stdlib/types.lua b/stdlib/types.lua index 543e449..1f9a485 100644 --- a/stdlib/types.lua +++ b/stdlib/types.lua @@ -131,7 +131,7 @@ types.anselme = { return { [k] = v } end }, - eventbuffer = { + ["event buffer"] = { format = function(val) local v, e = events:write_buffer(anselme.running.state, val) if not v then return v, e end diff --git a/test/tests/function reference call.ans b/test/tests/function reference call.ans new file mode 100644 index 0000000..a985a25 --- /dev/null +++ b/test/tests/function reference call.ans @@ -0,0 +1,27 @@ +$ f(x) + @x+2 + +$ g(x) + @x+3 + +:a = &f + +:b = &g + +3: {f(1)} + +3: {a(1)} + +5: {g(2)} + +4: {a(2)} + +4: {b(1)} + +5: {f(3)} + +13: {b(10)} + +7: {a(5)} + +8: {g(5)} diff --git a/test/tests/function reference call.lua b/test/tests/function reference call.lua new file mode 100644 index 0000000..125e47e --- /dev/null +++ b/test/tests/function reference call.lua @@ -0,0 +1,123 @@ +local _={} +_[55]={} +_[54]={} +_[53]={} +_[52]={} +_[51]={} +_[50]={} +_[49]={} +_[48]={} +_[47]={} +_[46]={} +_[45]={} +_[44]={} +_[43]={} +_[42]={} +_[41]={} +_[40]={} +_[39]={} +_[38]={} +_[37]={text="8",tags=_[55]} +_[36]={text="8: ",tags=_[54]} +_[35]={text="7",tags=_[53]} +_[34]={text="7: ",tags=_[52]} +_[33]={text="13",tags=_[51]} +_[32]={text="13: ",tags=_[50]} +_[31]={text="5",tags=_[49]} +_[30]={text="5: ",tags=_[48]} +_[29]={text="4",tags=_[47]} +_[28]={text="4: ",tags=_[46]} +_[27]={text="4",tags=_[45]} +_[26]={text="4: ",tags=_[44]} +_[25]={text="5",tags=_[43]} +_[24]={text="5: ",tags=_[42]} +_[23]={text="3",tags=_[41]} +_[22]={text="3: ",tags=_[40]} +_[21]={text="3",tags=_[39]} +_[20]={text="3: ",tags=_[38]} +_[19]={_[36],_[37]} +_[18]={_[34],_[35]} +_[17]={_[32],_[33]} +_[16]={_[30],_[31]} +_[15]={_[28],_[29]} +_[14]={_[26],_[27]} +_[13]={_[24],_[25]} +_[12]={_[22],_[23]} +_[11]={_[20],_[21]} +_[10]={"return"} +_[9]={"text",_[19]} +_[8]={"text",_[18]} +_[7]={"text",_[17]} +_[6]={"text",_[16]} +_[5]={"text",_[15]} +_[4]={"text",_[14]} +_[3]={"text",_[13]} +_[2]={"text",_[12]} +_[1]={"text",_[11]} +return {_[1],_[2],_[3],_[4],_[5],_[6],_[7],_[8],_[9],_[10]} +--[[ +{ "text", { { + tags = {}, + text = "3: " + }, { + tags = {}, + text = "3" + } } } +{ "text", { { + tags = {}, + text = "3: " + }, { + tags = {}, + text = "3" + } } } +{ "text", { { + tags = {}, + text = "5: " + }, { + tags = {}, + text = "5" + } } } +{ "text", { { + tags = {}, + text = "4: " + }, { + tags = {}, + text = "4" + } } } +{ "text", { { + tags = {}, + text = "4: " + }, { + tags = {}, + text = "4" + } } } +{ "text", { { + tags = {}, + text = "5: " + }, { + tags = {}, + text = "5" + } } } +{ "text", { { + tags = {}, + text = "13: " + }, { + tags = {}, + text = "13" + } } } +{ "text", { { + tags = {}, + text = "7: " + }, { + tags = {}, + text = "7" + } } } +{ "text", { { + tags = {}, + text = "8: " + }, { + tags = {}, + text = "8" + } } } +{ "return" } +]]-- \ No newline at end of file