diff --git a/ast/Closure.lua b/ast/Closure.lua index 435ac72..05acc68 100644 --- a/ast/Closure.lua +++ b/ast/Closure.lua @@ -44,9 +44,9 @@ Closure = Runtime(Overloadable) { format_parameters = function(self, state) return self.func.parameters:format(state) end, - call_compatible = function(self, state, args) + call_dispatched = function(self, state, args) state.scope:push(self.exported_scope) - local exp = self.func:call_compatible(state, args) + local exp = self.func:call_dispatched(state, args) state.scope:pop() return exp end, diff --git a/ast/Function.lua b/ast/Function.lua index 5a95b66..f5116f4 100644 --- a/ast/Function.lua +++ b/ast/Function.lua @@ -41,7 +41,7 @@ Function = Overloadable { format_parameters = function(self, state) return self.parameters:format(state) end, - call_compatible = function(self, state, args) + call_dispatched = function(self, state, args) state.scope:push() args:bind_parameter_tuple(state, self.parameters) diff --git a/ast/LuaFunction.lua b/ast/LuaFunction.lua index 4392289..719e26b 100644 --- a/ast/LuaFunction.lua +++ b/ast/LuaFunction.lua @@ -34,7 +34,7 @@ LuaFunction = ast.abstract.Runtime(Overloadable) { format_parameters = function(self, state) return self.parameters:format(state) end, - call_compatible = function(self, state, args) + call_dispatched = function(self, state, args) local lua_args = { state } state.scope:push() diff --git a/ast/Overload.lua b/ast/Overload.lua index 1a9b2bf..547f6db 100644 --- a/ast/Overload.lua +++ b/ast/Overload.lua @@ -29,7 +29,7 @@ Overload = ast.abstract.Node { end end, - call = function(self, state, args) + dispatch = function(self, state, args) local failure = {} -- list of failure messages (kept until we find the first success) local success, success_specificity, success_secondary_specificity = nil, -1, -1 -- some might think that iterating a list for every function call is a terrible idea, but that list has a fixed number of elements, so big O notation says suck it up @@ -42,7 +42,7 @@ Overload = ast.abstract.Node { if secondary_specificity > success_secondary_specificity then success, success_specificity, success_secondary_specificity = fn, specificity, secondary_specificity elseif secondary_specificity == success_secondary_specificity then - error(("more than one function match %s, matching functions were at least (specificity %s.%s):\n\t• %s\n\t• %s"):format(args:format(state), specificity, secondary_specificity, fn:format_parameters(state), success:format_parameters(state)), 0) + return nil, ("more than one function match %s, matching functions were at least (specificity %s.%s):\n\t• %s\n\t• %s"):format(args:format(state), specificity, secondary_specificity, fn:format_parameters(state), success:format_parameters(state)) end end -- no need to add error message for less specific function since we already should have at least one success @@ -51,10 +51,9 @@ Overload = ast.abstract.Node { end end if success then - return success:call_compatible(state, args) + return success, args else - -- error - error(("no function match %s, possible functions were:\n\t• %s"):format(args:format(state), table.concat(failure, "\n\t• ")), 0) + return nil, ("no function match %s, possible functions were:\n\t• %s"):format(args:format(state), table.concat(failure, "\n\t• ")) end end } diff --git a/ast/Quote.lua b/ast/Quote.lua index 700bb8e..7646e39 100644 --- a/ast/Quote.lua +++ b/ast/Quote.lua @@ -26,8 +26,14 @@ Quote = ast.abstract.Node { fn(self.expression, ...) end, - call = function(self, state, args) - assert(args.arity == 0, "Quote! does not accept arguments") + dispatch = function(self, state, args) + if args.arity == 0 then + return self, args + else + return nil, "Quote! does not accept arguments" + end + end, + call_dispatched = function(self, state, args) return self.expression:eval(state) end } diff --git a/ast/Typed.lua b/ast/Typed.lua index d6af880..78db976 100644 --- a/ast/Typed.lua +++ b/ast/Typed.lua @@ -2,7 +2,10 @@ local ast = require("ast") local operator_priority = require("common").operator_priority -return ast.abstract.Runtime { +local format_identifier + +local Typed +Typed = ast.abstract.Runtime { type = "typed", expression = nil, @@ -14,6 +17,15 @@ return ast.abstract.Runtime { end, _format = function(self, state, prio, ...) + -- try custom format + if state and state.scope:defined(format_identifier) then + local custom_format = format_identifier:eval(state) + local args = ast.ArgumentTuple:new(self) + local fn, d_args = custom_format:dispatch(state, args) + if fn then + return custom_format:call(state, d_args):format(state, prio, ...) + end + end return ("type(%s, %s)"):format(self.type_expression:format(state, operator_priority["_,_"], ...), self.expression:format_right(state, operator_priority["_,_"], ...)) end, @@ -22,3 +34,8 @@ return ast.abstract.Runtime { fn(self.expression, ...) end } + +package.loaded[...] = Typed +format_identifier = ast.Identifier:new("format") + +return Typed diff --git a/ast/abstract/Node.lua b/ast/abstract/Node.lua index ff65e0a..17cb758 100644 --- a/ast/abstract/Node.lua +++ b/ast/abstract/Node.lua @@ -64,7 +64,7 @@ Node = class { -- to be preferably used during construction only set_source = function(self, source) local str_source = tostring(source) - if self.source == "?" then + if self.source == "?" and str_source ~= "?" then self.source = str_source self:traverse(traverse.set_source, str_source) end @@ -156,16 +156,39 @@ Node = class { return t end, + -- call the node with the given arguments -- return result AST -- arg is a ArgumentTuple node (already evaluated) - -- redefine if relevant + -- do not redefine; instead redefine :dispatch and :call_dispatched call = function(self, state, arg) + local dispatched, dispatched_arg = self:dispatch(state, arg) + if dispatched then + return dispatched:call_dispatched(state, dispatched_arg) + else + error(("can't call %s %s: %s"):format(self.type, self:format(state), dispatched_arg), 0) + end + end, + -- find a function that can be called with the given arguments + -- return function, arg if a function is found that can be called with arg. The returned arg may be different than the input arg. + -- return nil, message if no matching function found + dispatch = function(self, state, arg) + -- by default, look for custom call operator if state.scope:defined(custom_call_identifier) then local custom_call = custom_call_identifier:eval(state) - return custom_call:call(state, arg:with_first_argument(self)) - else - error("trying to call a "..self.type..": "..self:format(state)) + local dispatched, dispatched_arg = custom_call:dispatch(state, arg:with_first_argument(self)) + if dispatched then + return dispatched, dispatched_arg + else + return nil, dispatched_arg + end end + return nil, "not callable" + end, + -- call the node with the given arguments + -- this assumes that this node was correctly dispatched to (was returned by a previous call to :dispatch) + -- you can therefore assume that the arguments are valid and compatible with this node + call_dispatched = function(self, state, arg) + error(("%s is not callable"):format(self.type)) end, -- merge any changes back into the main branch diff --git a/ast/abstract/Overloadable.lua b/ast/abstract/Overloadable.lua index 7b19d08..d8891c7 100644 --- a/ast/abstract/Overloadable.lua +++ b/ast/abstract/Overloadable.lua @@ -1,3 +1,5 @@ +-- for nodes that can be put in an Overload + local ast = require("ast") return ast.abstract.Node { @@ -9,19 +11,21 @@ return ast.abstract.Node { compatible_with_arguments = function(self, state, args) error("not implemented for "..self.type) end, - -- same as :call, but assumes :compatible_with_arguments was checked before the call - call_compatible = function(self, state, args) - error("not implemented for "..self.type) - end, -- return string format_parameters = function(self, state) return self:format(state) end, - -- default for :call - call = function(self, state, args) - assert(self:compatible_with_arguments(state, args)) - return self:call_compatible(state, args) - end + -- can be called either after a successful :dispatch or :compatible_with_arguments + call_dispatched = function(self, state, args) + error("not implemented for "..self.type) + end, + + -- default for :dispatch + dispatch = function(self, state, args) + local s, err = self:compatible_with_arguments(state, args) + if s then return self, args + else return nil, err end + end, }