diff --git a/ast/ArgumentTuple.lua b/ast/ArgumentTuple.lua index 0497911..da60e3b 100644 --- a/ast/ArgumentTuple.lua +++ b/ast/ArgumentTuple.lua @@ -98,6 +98,21 @@ ArgumentTuple = ast.abstract.Node { end return r end, + -- recreate new argumenttuple with an assignment argument added + with_assignment = function(self, assignment) + local r = ArgumentTuple:new() + for i=1, self.arity do + if self.positional[i] then + r:add_positional(self.positional[i]) + elseif self.named[i] then + r:add_named(Identifier:new(self.named[i]), self.named[self.named[i]]) + else + r:add_assignment(self.assignment) -- welp it'll error below anyway + end + end + r:add_assignment(assignment) + return r + end, -- return specificity (>=0), secondary specificity (>=0) -- return false, failure message diff --git a/ast/Definition.lua b/ast/Definition.lua index d69e320..7f198e7 100644 --- a/ast/Definition.lua +++ b/ast/Definition.lua @@ -30,12 +30,16 @@ local Definition = ast.abstract.Node { end local symbol = self.symbol:eval(state) - local val = self.expression:eval(state) - - if Overloadable:issub(val) then - state.scope:define_overloadable(symbol, val) + if symbol.alias then + state.scope:define_alias(symbol, self.expression) else - state.scope:define(symbol, val) + local val = self.expression:eval(state) + + if Overloadable:issub(val) then + state.scope:define_overloadable(symbol, val) + else + state.scope:define(symbol, val) + end end return Nil:new() @@ -46,7 +50,9 @@ local Definition = ast.abstract.Node { symbol:prepare(state) val:prepare(state) - if Overloadable:issub(val) then + if self.symbol.alias then + state.scope:define(symbol:with{ alias = false }, val) -- disable alias to avoid call in Identifier:_prepare + elseif Overloadable:issub(val) then state.scope:define_overloadable(symbol, val) else state.scope:define(symbol, val) diff --git a/ast/Environment.lua b/ast/Environment.lua index adf8d63..aec98b6 100644 --- a/ast/Environment.lua +++ b/ast/Environment.lua @@ -16,7 +16,11 @@ local VariableMetadata = ast.abstract.Runtime { self.branched = Branched:new(state, value) end, get = function(self, state) - return self.branched:get(state) + if self.symbol.alias then + return self.branched:get(state):call(state, ArgumentTuple:new()) + else + return self.branched:get(state) + end end, set = function(self, state, value) assert(not self.symbol.constant, ("trying to change the value of constant %s"):format(self.symbol.string)) @@ -24,7 +28,13 @@ local VariableMetadata = ast.abstract.Runtime { local r = self.symbol.type_check:call(state, ArgumentTuple:new(value)) if not r:truthy() then error(("type check failure for %s; %s does not satisfy %s"):format(self.symbol.string, value, self.symbol.type_check)) end end - self.branched:set(state, value) + if self.symbol.alias then + local assign_args = ArgumentTuple:new() + assign_args:add_assignment(value) + self.branched:get(state):call(state, assign_args) + else + self.branched:set(state, value) + end end, _format = function(self, ...) @@ -108,6 +118,19 @@ local Environment = ast.abstract.Runtime { self:define(state, symbol, exp) end end, + define_alias = function(self, state, symbol, call) + assert(symbol.alias, "symbol is not an alias") + assert(call.type == "call", "alias expression must be a call") + + local get = ast.Function:new(ast.ParameterTuple:new(), call):eval(state) + + local set_param = ast.ParameterTuple:new() + set_param:insert_assignment(ast.FunctionParameter:new(ast.Identifier:new("value"))) + local assign_expr = ast.Call:new(call.func, call.arguments:with_assignment(ast.Identifier:new("value"))) + local set = ast.Function:new(set_param, assign_expr):eval(state) + + self:define(state, symbol, ast.Overload:new(get, set)) + end, -- returns bool if variable defined in current or parent environment defined = function(self, state, identifier) @@ -156,17 +179,6 @@ local Environment = ast.abstract.Runtime { return self:_get_variable(state, identifier):set(state, val) end, - -- returns a list {[symbol]=val,...} of all persistent variables in the current strict layer - list_persistent = function(self, state) - assert(self.export, "not an export scope layer") - local r = {} - for _, vm in self.variables:iter(state) do - if vm.symbol.persistent then - r[vm.symbol] = vm:get(state) - end - end - return r - end, -- returns a list {[symbol]=val,...} of all exported variables in the current strict layer list_exported = function(self, state) assert(self.export, "not an export scope layer") diff --git a/ast/Symbol.lua b/ast/Symbol.lua index a64d20a..0b0706f 100644 --- a/ast/Symbol.lua +++ b/ast/Symbol.lua @@ -11,8 +11,8 @@ Symbol = ast.abstract.Node { constant = nil, -- bool type_check = nil, -- exp + alias = nil, -- bool exported = nil, -- bool - persistent = nil, -- bool, imply exported confined_to_branch = nil, -- bool @@ -20,23 +20,29 @@ Symbol = ast.abstract.Node { modifiers = modifiers or {} self.string = str self.constant = modifiers.constant - self.persistent = modifiers.persistent self.type_check = modifiers.type_check + self.alias = modifiers.alias self.confined_to_branch = modifiers.confined_to_branch - self.exported = modifiers.exported or modifiers.persistent + self.exported = modifiers.exported if self.type_check then self.format_priority = operator_priority["_::_"] end end, _eval = function(self, state) - return Symbol:new(self.string, { - constant = self.constant, - persistent = self.persistent, - type_check = self.type_check and self.type_check:eval(state), - confined_to_branch = self.confined_to_branch, - exported = self.exported - }) + return self:with { + type_check = self.type_check and self.type_check:eval(state) + } + end, + + with = function(self, modifiers) + modifiers = modifiers or {} + for _, k in ipairs{"constant", "type_check", "alias", "exported", "confined_to_branch"} do + if modifiers[k] == nil then + modifiers[k] = self[k] + end + end + return Symbol:new(self.string, modifiers) end, _hash = function(self) @@ -48,7 +54,7 @@ Symbol = ast.abstract.Node { if self.constant then s = s .. ":" end - if self.persistent then + if self.alias then s = s .. "&" end if self.exported then diff --git a/ast/abstract/Node.lua b/ast/abstract/Node.lua index bcf2475..e7920da 100644 --- a/ast/abstract/Node.lua +++ b/ast/abstract/Node.lua @@ -85,6 +85,7 @@ Node = class { local s, r = pcall(self._eval, self, state) if s then r._evaluated = true + r:set_source(self.source) return r else error(format_error(state, self, r), 0) diff --git a/parser/expression/primary/function_definition.lua b/parser/expression/primary/function_definition.lua index e9c57c7..2ef37fd 100644 --- a/parser/expression/primary/function_definition.lua +++ b/parser/expression/primary/function_definition.lua @@ -167,11 +167,11 @@ return primary { local mod_const, mod_exported, rem = source:consume(str:match("^(%:(:?)([&@]?)%$)(.-)$")) -- get modifiers - local constant, exported, persistent + local constant, exported, alias if mod_const == ":" then constant = true end if mod_exported == "@" then exported = true - elseif mod_exported == "&" then persistent = true end - local modifiers = { constant = constant, exported = exported, persistent = persistent } + elseif mod_exported == "&" then alias = true end + local modifiers = { constant = constant, exported = exported, alias = alias } -- search for a valid signature local symbol, parameters diff --git a/parser/expression/primary/symbol.lua b/parser/expression/primary/symbol.lua index 5221624..967160c 100644 --- a/parser/expression/primary/symbol.lua +++ b/parser/expression/primary/symbol.lua @@ -16,11 +16,11 @@ return primary { parse = function(self, source, str) local mod_const, mod_export, rem = source:consume(str:match("^(%:(:?)([&@]?))(.-)$")) - local constant, persistent, type_check_exp, exported + local constant, alias, type_check_exp, exported -- get modifier if mod_const == ":" then constant = true end - if mod_export == "&" then persistent = true + if mod_export == "&" then alias = true elseif mod_export == "@" then exported = true end -- name @@ -35,6 +35,6 @@ return primary { type_check_exp = exp.arguments.positional[2] end - return ident:to_symbol{ constant = constant, persistent = persistent, exported = exported, type_check = type_check_exp }:set_source(source), rem + return ident:to_symbol{ constant = constant, alias = alias, exported = exported, type_check = type_check_exp }:set_source(source), rem end } diff --git a/parser/expression/secondary/infix/assignment_call.lua b/parser/expression/secondary/infix/assignment_call.lua index 493a868..1dabbbd 100644 --- a/parser/expression/secondary/infix/assignment_call.lua +++ b/parser/expression/secondary/infix/assignment_call.lua @@ -18,7 +18,7 @@ return infix { end, build_ast = function(self, left, right) - left.arguments:add_assignment(right) - return Call:new(left.func, left.arguments) -- recreate Call since we modified left.arguments + local args = left.arguments:with_assignment(right) + return Call:new(left.func, args) end, } diff --git a/state/ScopeStack.lua b/state/ScopeStack.lua index 378a84c..e7b1480 100644 --- a/state/ScopeStack.lua +++ b/state/ScopeStack.lua @@ -74,6 +74,7 @@ local ScopeStack = class { -- methods that call the associated method from the current scope, see ast.Environment for details define = function(self, symbol, exp) self.current:define(self.state, symbol, exp) end, define_overloadable = function(self, symbol, exp) return self.current:define_overloadable(self.state, symbol, exp) end, + define_alias = function(self, symbol, exp) return self.current:define_alias(self.state, symbol, exp) end, defined = function(self, identifier) return self.current:defined(self.state, identifier) end, defined_in_current = function(self, symbol) return self.current:defined_in_current(self.state, symbol) end, set = function(self, identifier, exp) self.current:set(self.state, identifier, exp) end, @@ -140,12 +141,6 @@ local ScopeStack = class { return self.current end, - -- return a table { [symbol] = value } of persistent variables defined on the root scope on this branch - list_persistent_global = function(self) - local env = self.stack[1] - return env:list_persistent(self.state) - end, - _debug_state = function(self, filter) filter = filter or "" local s = "current branch id: "..self.state.branch_id.."\n" diff --git a/state/State.lua b/state/State.lua index 225e582..cc0f5a8 100644 --- a/state/State.lua +++ b/state/State.lua @@ -6,6 +6,7 @@ local ScopeStack = require("state.ScopeStack") local tag_manager = require("state.tag_manager") local event_manager = require("state.event_manager") local translation_manager = require("state.translation_manager") +local persistent_manager = require("state.persistent_manager") local uuid = require("common").uuid local parser = require("parser") local binser = require("lib.binser") @@ -30,6 +31,7 @@ State = class { event_manager:setup(self) tag_manager:setup(self) + persistent_manager:setup(self) translation_manager:setup(self) end end, @@ -93,28 +95,22 @@ State = class { ---## Saving and loading persistent variables - --- Return a serialized (string) representation of all global persistent variables in this State. + --- Return a serialized (string) representation of all persistent variables in this State. -- -- This can be loaded back later using `:load`. save = function(self) - local list = self.scope:list_persistent_global() - return binser.serialize(anselme.versions.save, list) + local struct = persistent_manager:capture(self) + return binser.serialize(anselme.versions.save, struct) end, --- Load a string generated by `:save`. -- - -- Variables that do not exist currently in the global scope will be defined, those that do will be overwritten with the loaded data. + -- Variables that already exist will be overwritten with the loaded data. load = function(self, save) - local version, list = binser.deserializeN(save, 2) + local version, struct = binser.deserializeN(save, 2) if version ~= anselme.versions.save then print("Loading a save file generated by a different Anselme version, things may break!") end - self.scope:push_global() - for sym, val in pairs(list) do - if self.scope:defined_in_current(sym) then - self.scope:set(sym:to_identifier(), val) - else - self.scope:define(sym, val) - end + for key, val in struct:iter() do + persistent_manager:set(self, key, val) end - self.scope:pop() end, ---## Current script state diff --git a/state/persistent_manager.lua b/state/persistent_manager.lua new file mode 100644 index 0000000..c43182d --- /dev/null +++ b/state/persistent_manager.lua @@ -0,0 +1,47 @@ +local class = require("class") + +local ast = require("ast") +local Table, Identifier + +local persistent_identifier, persistent_symbol + +local persistent_manager = class { + init = false, + + setup = function(self, state) + state.scope:define(persistent_symbol, Table:new(state)) + end, + + -- set the persistant variable `key` to `value` (evaluated) + set = function(self, state, key, value) + local persistent = state.scope:get(persistent_identifier) + persistent:set(state, key, value) + end, + -- get the persistant variable `key`'s value + -- if `default` is given, will set the variable to this if not currently set + get = function(self, state, key, default) + local persistent = state.scope:get(persistent_identifier) + if not persistent:has(state, key) then + if default then + persistent:set(state, key, default) + else + error("persistent key does not exist") + end + end + return persistent:get(state, key) + end, + + -- returns a struct of the current persisted variables + capture = function(self, state) + local persistent = state.scope:get(persistent_identifier) + return persistent:to_struct(state) + end +} + +package.loaded[...] = persistent_manager +Table, Identifier = ast.Table, ast.Identifier + +persistent_identifier = Identifier:new("_persistent") -- Table of { [key] = Call, ... } +persistent_symbol = persistent_identifier:to_symbol() + +return persistent_manager diff --git a/stdlib/init.lua b/stdlib/init.lua index 1f02a12..f423579 100644 --- a/stdlib/init.lua +++ b/stdlib/init.lua @@ -31,6 +31,7 @@ return function(main_state) "text", "structures", "closure", - "checkpoint" + "checkpoint", + "persist", }) end diff --git a/stdlib/persist.lua b/stdlib/persist.lua new file mode 100644 index 0000000..ce2c885 --- /dev/null +++ b/stdlib/persist.lua @@ -0,0 +1,32 @@ +local ast = require("ast") + +local persistent_manager = require("state.persistent_manager") + +return { + { + "persist", "(key, default)", + function(state, key, default) + return persistent_manager:get(state, key, default) + end + }, + { + "persist", "(key, default) = value", + function(state, key, default, value) + persistent_manager:set(state, key, value) + return ast.Nil:new() + end + }, + { + "persist", "(key)", + function(state, key) + return persistent_manager:get(state, key) + end + }, + { + "persist", "(key) = value", + function(state, key, value) + persistent_manager:set(state, key, value) + return ast.Nil:new() + end + }, +}