diff --git a/cista/config.py b/cista/config.py index 03fbfdd..093d0c2 100644 --- a/cista/config.py +++ b/cista/config.py @@ -7,9 +7,11 @@ from contextlib import suppress from functools import wraps from hashlib import sha256 from pathlib import Path, PurePath -from time import time +from time import sleep, time +from typing import Callable, Concatenate, Literal, ParamSpec import msgspec +import msgspec.toml class Config(msgspec.Struct): @@ -22,6 +24,13 @@ class Config(msgspec.Struct): links: dict[str, Link] = {} +# Typing: arguments for config-modifying functions +P = ParamSpec("P") +ResultStr = Literal["modified", "created", "read"] +RawModifyFunc = Callable[Concatenate[Config, P], Config] +ModifyPublic = Callable[P, ResultStr] + + class User(msgspec.Struct, omit_defaults=True): privileged: bool = False hash: str = "" @@ -34,11 +43,13 @@ class Link(msgspec.Struct, omit_defaults=True): expires: int = 0 -config = None -conffile = None +# Global variables - initialized during application startup +config: Config +conffile: Path -def init_confdir(): +def init_confdir() -> None: + global conffile if p := os.environ.get("CISTA_HOME"): home = Path(p) else: @@ -49,8 +60,6 @@ def init_confdir(): if not home.is_dir(): home.mkdir(parents=True, exist_ok=True) home.chmod(0o700) - - global conffile conffile = home / "db.toml" @@ -77,10 +86,10 @@ def dec_hook(typ, obj): raise TypeError -def config_update(modify): +def config_update( + modify: RawModifyFunc, +) -> ResultStr | Literal["collision"]: global config - if conffile is None: - init_confdir() tmpname = conffile.with_suffix(".tmp") try: f = tmpname.open("xb") @@ -95,7 +104,7 @@ def config_update(modify): c = msgspec.toml.decode(old, type=Config, dec_hook=dec_hook) except FileNotFoundError: old = b"" - c = None + c = Config(path=Path(), listen="", secret=secrets.token_hex(12)) c = modify(c) new = msgspec.toml.encode(c, enc_hook=enc_hook) if old == new: @@ -118,17 +127,23 @@ def config_update(modify): return "modified" if old else "created" -def modifies_config(modify): - """Decorator for functions that modify the config file""" +def modifies_config( + modify: Callable[Concatenate[Config, P], Config], +) -> Callable[P, ResultStr]: + """Decorator for functions that modify the config file + + The decorated function takes as first arg Config and returns it modified. + The wrapper handles atomic modification and returns a string indicating the result. + """ @wraps(modify) - def wrapper(*args, **kwargs): - def m(c): + def wrapper(*args: P.args, **kwargs: P.kwargs) -> ResultStr: + def m(c: Config) -> Config: return modify(c, *args, **kwargs) # Retry modification in case of write collision while (c := config_update(m)) == "collision": - time.sleep(0.01) + sleep(0.01) return c return wrapper @@ -136,8 +151,7 @@ def modifies_config(modify): def load_config(): global config - if conffile is None: - init_confdir() + init_confdir() config = msgspec.toml.decode(conffile.read_bytes(), type=Config, dec_hook=dec_hook) @@ -145,7 +159,7 @@ def load_config(): def update_config(conf: Config, changes: dict) -> Config: """Create/update the config with new values, respecting changes done by others.""" # Encode into dict, update values with new, convert to Config - settings = {} if conf is None else msgspec.to_builtins(conf, enc_hook=enc_hook) + settings = msgspec.to_builtins(conf, enc_hook=enc_hook) settings.update(changes) return msgspec.convert(settings, Config, dec_hook=dec_hook) @@ -155,8 +169,13 @@ def update_user(conf: Config, name: str, changes: dict) -> Config: """Create/update a user with new values, respecting changes done by others.""" # Encode into dict, update values with new, convert to Config try: - u = conf.users[name].__copy__() - except (KeyError, AttributeError): + # Copy user by converting to dict and back + u = msgspec.convert( + msgspec.to_builtins(conf.users[name], enc_hook=enc_hook), + User, + dec_hook=dec_hook, + ) + except KeyError: u = User() if "password" in changes: from . import auth @@ -165,7 +184,7 @@ def update_user(conf: Config, name: str, changes: dict) -> Config: del changes["password"] udict = msgspec.to_builtins(u, enc_hook=enc_hook) udict.update(changes) - settings = msgspec.to_builtins(conf, enc_hook=enc_hook) if conf else {"users": {}} + settings = msgspec.to_builtins(conf, enc_hook=enc_hook) settings["users"][name] = msgspec.convert(udict, User, dec_hook=dec_hook) return msgspec.convert(settings, Config, dec_hook=dec_hook) @@ -173,6 +192,7 @@ def update_user(conf: Config, name: str, changes: dict) -> Config: @modifies_config def del_user(conf: Config, name: str) -> Config: """Delete named user account.""" - ret = conf.__copy__() - ret.users.pop(name) - return ret + # Create a copy by converting to dict and back + settings = msgspec.to_builtins(conf, enc_hook=enc_hook) + settings["users"].pop(name) + return msgspec.convert(settings, Config, dec_hook=dec_hook)