From e71bff9562ef5f69e8b321a3c86c66a0fc679b26 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=89tienne=20Reuh=20Fildadut?= Date: Wed, 27 Dec 2023 21:25:14 +0100 Subject: [PATCH] Replace persistent variable system Previous system linked the variable name with the saved value, meaning the variable could not be renamed or moved outside the global scope. Instead we propose to store all persistent values in a global table, identifying each by a key. To still allow nice manipulation with identifiers, the alias syntax replace the persistent syntax for symbols - an aliases symbol will act as if a function call was used in place of the identifier when it appear. --- ast/ArgumentTuple.lua | 15 ++++++ ast/Definition.lua | 18 ++++--- ast/Environment.lua | 38 ++++++++++----- ast/Symbol.lua | 28 ++++++----- ast/abstract/Node.lua | 1 + .../primary/function_definition.lua | 6 +-- parser/expression/primary/symbol.lua | 6 +-- .../secondary/infix/assignment_call.lua | 4 +- state/ScopeStack.lua | 7 +-- state/State.lua | 22 ++++----- state/persistent_manager.lua | 47 +++++++++++++++++++ stdlib/init.lua | 3 +- stdlib/persist.lua | 32 +++++++++++++ 13 files changed, 169 insertions(+), 58 deletions(-) create mode 100644 state/persistent_manager.lua create mode 100644 stdlib/persist.lua diff --git a/ast/ArgumentTuple.lua b/ast/ArgumentTuple.lua index 0497911..da60e3b 100644 --- a/ast/ArgumentTuple.lua +++ b/ast/ArgumentTuple.lua @@ -98,6 +98,21 @@ ArgumentTuple = ast.abstract.Node { end return r end, + -- recreate new argumenttuple with an assignment argument added + with_assignment = function(self, assignment) + local r = ArgumentTuple:new() + for i=1, self.arity do + if self.positional[i] then + r:add_positional(self.positional[i]) + elseif self.named[i] then + r:add_named(Identifier:new(self.named[i]), self.named[self.named[i]]) + else + r:add_assignment(self.assignment) -- welp it'll error below anyway + end + end + r:add_assignment(assignment) + return r + end, -- return specificity (>=0), secondary specificity (>=0) -- return false, failure message diff --git a/ast/Definition.lua b/ast/Definition.lua index d69e320..7f198e7 100644 --- a/ast/Definition.lua +++ b/ast/Definition.lua @@ -30,12 +30,16 @@ local Definition = ast.abstract.Node { end local symbol = self.symbol:eval(state) - local val = self.expression:eval(state) - - if Overloadable:issub(val) then - state.scope:define_overloadable(symbol, val) + if symbol.alias then + state.scope:define_alias(symbol, self.expression) else - state.scope:define(symbol, val) + local val = self.expression:eval(state) + + if Overloadable:issub(val) then + state.scope:define_overloadable(symbol, val) + else + state.scope:define(symbol, val) + end end return Nil:new() @@ -46,7 +50,9 @@ local Definition = ast.abstract.Node { symbol:prepare(state) val:prepare(state) - if Overloadable:issub(val) then + if self.symbol.alias then + state.scope:define(symbol:with{ alias = false }, val) -- disable alias to avoid call in Identifier:_prepare + elseif Overloadable:issub(val) then state.scope:define_overloadable(symbol, val) else state.scope:define(symbol, val) diff --git a/ast/Environment.lua b/ast/Environment.lua index adf8d63..aec98b6 100644 --- a/ast/Environment.lua +++ b/ast/Environment.lua @@ -16,7 +16,11 @@ local VariableMetadata = ast.abstract.Runtime { self.branched = Branched:new(state, value) end, get = function(self, state) - return self.branched:get(state) + if self.symbol.alias then + return self.branched:get(state):call(state, ArgumentTuple:new()) + else + return self.branched:get(state) + end end, set = function(self, state, value) assert(not self.symbol.constant, ("trying to change the value of constant %s"):format(self.symbol.string)) @@ -24,7 +28,13 @@ local VariableMetadata = ast.abstract.Runtime { local r = self.symbol.type_check:call(state, ArgumentTuple:new(value)) if not r:truthy() then error(("type check failure for %s; %s does not satisfy %s"):format(self.symbol.string, value, self.symbol.type_check)) end end - self.branched:set(state, value) + if self.symbol.alias then + local assign_args = ArgumentTuple:new() + assign_args:add_assignment(value) + self.branched:get(state):call(state, assign_args) + else + self.branched:set(state, value) + end end, _format = function(self, ...) @@ -108,6 +118,19 @@ local Environment = ast.abstract.Runtime { self:define(state, symbol, exp) end end, + define_alias = function(self, state, symbol, call) + assert(symbol.alias, "symbol is not an alias") + assert(call.type == "call", "alias expression must be a call") + + local get = ast.Function:new(ast.ParameterTuple:new(), call):eval(state) + + local set_param = ast.ParameterTuple:new() + set_param:insert_assignment(ast.FunctionParameter:new(ast.Identifier:new("value"))) + local assign_expr = ast.Call:new(call.func, call.arguments:with_assignment(ast.Identifier:new("value"))) + local set = ast.Function:new(set_param, assign_expr):eval(state) + + self:define(state, symbol, ast.Overload:new(get, set)) + end, -- returns bool if variable defined in current or parent environment defined = function(self, state, identifier) @@ -156,17 +179,6 @@ local Environment = ast.abstract.Runtime { return self:_get_variable(state, identifier):set(state, val) end, - -- returns a list {[symbol]=val,...} of all persistent variables in the current strict layer - list_persistent = function(self, state) - assert(self.export, "not an export scope layer") - local r = {} - for _, vm in self.variables:iter(state) do - if vm.symbol.persistent then - r[vm.symbol] = vm:get(state) - end - end - return r - end, -- returns a list {[symbol]=val,...} of all exported variables in the current strict layer list_exported = function(self, state) assert(self.export, "not an export scope layer") diff --git a/ast/Symbol.lua b/ast/Symbol.lua index a64d20a..0b0706f 100644 --- a/ast/Symbol.lua +++ b/ast/Symbol.lua @@ -11,8 +11,8 @@ Symbol = ast.abstract.Node { constant = nil, -- bool type_check = nil, -- exp + alias = nil, -- bool exported = nil, -- bool - persistent = nil, -- bool, imply exported confined_to_branch = nil, -- bool @@ -20,23 +20,29 @@ Symbol = ast.abstract.Node { modifiers = modifiers or {} self.string = str self.constant = modifiers.constant - self.persistent = modifiers.persistent self.type_check = modifiers.type_check + self.alias = modifiers.alias self.confined_to_branch = modifiers.confined_to_branch - self.exported = modifiers.exported or modifiers.persistent + self.exported = modifiers.exported if self.type_check then self.format_priority = operator_priority["_::_"] end end, _eval = function(self, state) - return Symbol:new(self.string, { - constant = self.constant, - persistent = self.persistent, - type_check = self.type_check and self.type_check:eval(state), - confined_to_branch = self.confined_to_branch, - exported = self.exported - }) + return self:with { + type_check = self.type_check and self.type_check:eval(state) + } + end, + + with = function(self, modifiers) + modifiers = modifiers or {} + for _, k in ipairs{"constant", "type_check", "alias", "exported", "confined_to_branch"} do + if modifiers[k] == nil then + modifiers[k] = self[k] + end + end + return Symbol:new(self.string, modifiers) end, _hash = function(self) @@ -48,7 +54,7 @@ Symbol = ast.abstract.Node { if self.constant then s = s .. ":" end - if self.persistent then + if self.alias then s = s .. "&" end if self.exported then diff --git a/ast/abstract/Node.lua b/ast/abstract/Node.lua index bcf2475..e7920da 100644 --- a/ast/abstract/Node.lua +++ b/ast/abstract/Node.lua @@ -85,6 +85,7 @@ Node = class { local s, r = pcall(self._eval, self, state) if s then r._evaluated = true + r:set_source(self.source) return r else error(format_error(state, self, r), 0) diff --git a/parser/expression/primary/function_definition.lua b/parser/expression/primary/function_definition.lua index e9c57c7..2ef37fd 100644 --- a/parser/expression/primary/function_definition.lua +++ b/parser/expression/primary/function_definition.lua @@ -167,11 +167,11 @@ return primary { local mod_const, mod_exported, rem = source:consume(str:match("^(%:(:?)([&@]?)%$)(.-)$")) -- get modifiers - local constant, exported, persistent + local constant, exported, alias if mod_const == ":" then constant = true end if mod_exported == "@" then exported = true - elseif mod_exported == "&" then persistent = true end - local modifiers = { constant = constant, exported = exported, persistent = persistent } + elseif mod_exported == "&" then alias = true end + local modifiers = { constant = constant, exported = exported, alias = alias } -- search for a valid signature local symbol, parameters diff --git a/parser/expression/primary/symbol.lua b/parser/expression/primary/symbol.lua index 5221624..967160c 100644 --- a/parser/expression/primary/symbol.lua +++ b/parser/expression/primary/symbol.lua @@ -16,11 +16,11 @@ return primary { parse = function(self, source, str) local mod_const, mod_export, rem = source:consume(str:match("^(%:(:?)([&@]?))(.-)$")) - local constant, persistent, type_check_exp, exported + local constant, alias, type_check_exp, exported -- get modifier if mod_const == ":" then constant = true end - if mod_export == "&" then persistent = true + if mod_export == "&" then alias = true elseif mod_export == "@" then exported = true end -- name @@ -35,6 +35,6 @@ return primary { type_check_exp = exp.arguments.positional[2] end - return ident:to_symbol{ constant = constant, persistent = persistent, exported = exported, type_check = type_check_exp }:set_source(source), rem + return ident:to_symbol{ constant = constant, alias = alias, exported = exported, type_check = type_check_exp }:set_source(source), rem end } diff --git a/parser/expression/secondary/infix/assignment_call.lua b/parser/expression/secondary/infix/assignment_call.lua index 493a868..1dabbbd 100644 --- a/parser/expression/secondary/infix/assignment_call.lua +++ b/parser/expression/secondary/infix/assignment_call.lua @@ -18,7 +18,7 @@ return infix { end, build_ast = function(self, left, right) - left.arguments:add_assignment(right) - return Call:new(left.func, left.arguments) -- recreate Call since we modified left.arguments + local args = left.arguments:with_assignment(right) + return Call:new(left.func, args) end, } diff --git a/state/ScopeStack.lua b/state/ScopeStack.lua index 378a84c..e7b1480 100644 --- a/state/ScopeStack.lua +++ b/state/ScopeStack.lua @@ -74,6 +74,7 @@ local ScopeStack = class { -- methods that call the associated method from the current scope, see ast.Environment for details define = function(self, symbol, exp) self.current:define(self.state, symbol, exp) end, define_overloadable = function(self, symbol, exp) return self.current:define_overloadable(self.state, symbol, exp) end, + define_alias = function(self, symbol, exp) return self.current:define_alias(self.state, symbol, exp) end, defined = function(self, identifier) return self.current:defined(self.state, identifier) end, defined_in_current = function(self, symbol) return self.current:defined_in_current(self.state, symbol) end, set = function(self, identifier, exp) self.current:set(self.state, identifier, exp) end, @@ -140,12 +141,6 @@ local ScopeStack = class { return self.current end, - -- return a table { [symbol] = value } of persistent variables defined on the root scope on this branch - list_persistent_global = function(self) - local env = self.stack[1] - return env:list_persistent(self.state) - end, - _debug_state = function(self, filter) filter = filter or "" local s = "current branch id: "..self.state.branch_id.."\n" diff --git a/state/State.lua b/state/State.lua index 225e582..cc0f5a8 100644 --- a/state/State.lua +++ b/state/State.lua @@ -6,6 +6,7 @@ local ScopeStack = require("state.ScopeStack") local tag_manager = require("state.tag_manager") local event_manager = require("state.event_manager") local translation_manager = require("state.translation_manager") +local persistent_manager = require("state.persistent_manager") local uuid = require("common").uuid local parser = require("parser") local binser = require("lib.binser") @@ -30,6 +31,7 @@ State = class { event_manager:setup(self) tag_manager:setup(self) + persistent_manager:setup(self) translation_manager:setup(self) end end, @@ -93,28 +95,22 @@ State = class { ---## Saving and loading persistent variables - --- Return a serialized (string) representation of all global persistent variables in this State. + --- Return a serialized (string) representation of all persistent variables in this State. -- -- This can be loaded back later using `:load`. save = function(self) - local list = self.scope:list_persistent_global() - return binser.serialize(anselme.versions.save, list) + local struct = persistent_manager:capture(self) + return binser.serialize(anselme.versions.save, struct) end, --- Load a string generated by `:save`. -- - -- Variables that do not exist currently in the global scope will be defined, those that do will be overwritten with the loaded data. + -- Variables that already exist will be overwritten with the loaded data. load = function(self, save) - local version, list = binser.deserializeN(save, 2) + local version, struct = binser.deserializeN(save, 2) if version ~= anselme.versions.save then print("Loading a save file generated by a different Anselme version, things may break!") end - self.scope:push_global() - for sym, val in pairs(list) do - if self.scope:defined_in_current(sym) then - self.scope:set(sym:to_identifier(), val) - else - self.scope:define(sym, val) - end + for key, val in struct:iter() do + persistent_manager:set(self, key, val) end - self.scope:pop() end, ---## Current script state diff --git a/state/persistent_manager.lua b/state/persistent_manager.lua new file mode 100644 index 0000000..c43182d --- /dev/null +++ b/state/persistent_manager.lua @@ -0,0 +1,47 @@ +local class = require("class") + +local ast = require("ast") +local Table, Identifier + +local persistent_identifier, persistent_symbol + +local persistent_manager = class { + init = false, + + setup = function(self, state) + state.scope:define(persistent_symbol, Table:new(state)) + end, + + -- set the persistant variable `key` to `value` (evaluated) + set = function(self, state, key, value) + local persistent = state.scope:get(persistent_identifier) + persistent:set(state, key, value) + end, + -- get the persistant variable `key`'s value + -- if `default` is given, will set the variable to this if not currently set + get = function(self, state, key, default) + local persistent = state.scope:get(persistent_identifier) + if not persistent:has(state, key) then + if default then + persistent:set(state, key, default) + else + error("persistent key does not exist") + end + end + return persistent:get(state, key) + end, + + -- returns a struct of the current persisted variables + capture = function(self, state) + local persistent = state.scope:get(persistent_identifier) + return persistent:to_struct(state) + end +} + +package.loaded[...] = persistent_manager +Table, Identifier = ast.Table, ast.Identifier + +persistent_identifier = Identifier:new("_persistent") -- Table of { [key] = Call, ... } +persistent_symbol = persistent_identifier:to_symbol() + +return persistent_manager diff --git a/stdlib/init.lua b/stdlib/init.lua index 1f02a12..f423579 100644 --- a/stdlib/init.lua +++ b/stdlib/init.lua @@ -31,6 +31,7 @@ return function(main_state) "text", "structures", "closure", - "checkpoint" + "checkpoint", + "persist", }) end diff --git a/stdlib/persist.lua b/stdlib/persist.lua new file mode 100644 index 0000000..ce2c885 --- /dev/null +++ b/stdlib/persist.lua @@ -0,0 +1,32 @@ +local ast = require("ast") + +local persistent_manager = require("state.persistent_manager") + +return { + { + "persist", "(key, default)", + function(state, key, default) + return persistent_manager:get(state, key, default) + end + }, + { + "persist", "(key, default) = value", + function(state, key, default, value) + persistent_manager:set(state, key, value) + return ast.Nil:new() + end + }, + { + "persist", "(key)", + function(state, key) + return persistent_manager:get(state, key) + end + }, + { + "persist", "(key) = value", + function(state, key, value) + persistent_manager:set(state, key, value) + return ast.Nil:new() + end + }, +}