Fix typing and import in the config file module.

This commit is contained in:
Leo Vasanko 2025-08-17 10:10:55 -06:00
parent 69a897cfec
commit 091d57dba7

View File

@ -7,9 +7,11 @@ from contextlib import suppress
from functools import wraps
from hashlib import sha256
from pathlib import Path, PurePath
from time import time
from time import sleep, time
from typing import Callable, Concatenate, Literal, ParamSpec
import msgspec
import msgspec.toml
class Config(msgspec.Struct):
@ -22,6 +24,13 @@ class Config(msgspec.Struct):
links: dict[str, Link] = {}
# Typing: arguments for config-modifying functions
P = ParamSpec("P")
ResultStr = Literal["modified", "created", "read"]
RawModifyFunc = Callable[Concatenate[Config, P], Config]
ModifyPublic = Callable[P, ResultStr]
class User(msgspec.Struct, omit_defaults=True):
privileged: bool = False
hash: str = ""
@ -34,11 +43,13 @@ class Link(msgspec.Struct, omit_defaults=True):
expires: int = 0
config = None
conffile = None
# Global variables - initialized during application startup
config: Config
conffile: Path
def init_confdir():
def init_confdir() -> None:
global conffile
if p := os.environ.get("CISTA_HOME"):
home = Path(p)
else:
@ -49,8 +60,6 @@ def init_confdir():
if not home.is_dir():
home.mkdir(parents=True, exist_ok=True)
home.chmod(0o700)
global conffile
conffile = home / "db.toml"
@ -77,10 +86,10 @@ def dec_hook(typ, obj):
raise TypeError
def config_update(modify):
def config_update(
modify: RawModifyFunc,
) -> ResultStr | Literal["collision"]:
global config
if conffile is None:
init_confdir()
tmpname = conffile.with_suffix(".tmp")
try:
f = tmpname.open("xb")
@ -95,7 +104,7 @@ def config_update(modify):
c = msgspec.toml.decode(old, type=Config, dec_hook=dec_hook)
except FileNotFoundError:
old = b""
c = None
c = Config(path=Path(), listen="", secret=secrets.token_hex(12))
c = modify(c)
new = msgspec.toml.encode(c, enc_hook=enc_hook)
if old == new:
@ -118,17 +127,23 @@ def config_update(modify):
return "modified" if old else "created"
def modifies_config(modify):
"""Decorator for functions that modify the config file"""
def modifies_config(
modify: Callable[Concatenate[Config, P], Config],
) -> Callable[P, ResultStr]:
"""Decorator for functions that modify the config file
The decorated function takes as first arg Config and returns it modified.
The wrapper handles atomic modification and returns a string indicating the result.
"""
@wraps(modify)
def wrapper(*args, **kwargs):
def m(c):
def wrapper(*args: P.args, **kwargs: P.kwargs) -> ResultStr:
def m(c: Config) -> Config:
return modify(c, *args, **kwargs)
# Retry modification in case of write collision
while (c := config_update(m)) == "collision":
time.sleep(0.01)
sleep(0.01)
return c
return wrapper
@ -136,7 +151,6 @@ def modifies_config(modify):
def load_config():
global config
if conffile is None:
init_confdir()
config = msgspec.toml.decode(conffile.read_bytes(), type=Config, dec_hook=dec_hook)
@ -145,7 +159,7 @@ def load_config():
def update_config(conf: Config, changes: dict) -> Config:
"""Create/update the config with new values, respecting changes done by others."""
# Encode into dict, update values with new, convert to Config
settings = {} if conf is None else msgspec.to_builtins(conf, enc_hook=enc_hook)
settings = msgspec.to_builtins(conf, enc_hook=enc_hook)
settings.update(changes)
return msgspec.convert(settings, Config, dec_hook=dec_hook)
@ -155,8 +169,13 @@ def update_user(conf: Config, name: str, changes: dict) -> Config:
"""Create/update a user with new values, respecting changes done by others."""
# Encode into dict, update values with new, convert to Config
try:
u = conf.users[name].__copy__()
except (KeyError, AttributeError):
# Copy user by converting to dict and back
u = msgspec.convert(
msgspec.to_builtins(conf.users[name], enc_hook=enc_hook),
User,
dec_hook=dec_hook,
)
except KeyError:
u = User()
if "password" in changes:
from . import auth
@ -165,7 +184,7 @@ def update_user(conf: Config, name: str, changes: dict) -> Config:
del changes["password"]
udict = msgspec.to_builtins(u, enc_hook=enc_hook)
udict.update(changes)
settings = msgspec.to_builtins(conf, enc_hook=enc_hook) if conf else {"users": {}}
settings = msgspec.to_builtins(conf, enc_hook=enc_hook)
settings["users"][name] = msgspec.convert(udict, User, dec_hook=dec_hook)
return msgspec.convert(settings, Config, dec_hook=dec_hook)
@ -173,6 +192,7 @@ def update_user(conf: Config, name: str, changes: dict) -> Config:
@modifies_config
def del_user(conf: Config, name: str) -> Config:
"""Delete named user account."""
ret = conf.__copy__()
ret.users.pop(name)
return ret
# Create a copy by converting to dict and back
settings = msgspec.to_builtins(conf, enc_hook=enc_hook)
settings["users"].pop(name)
return msgspec.convert(settings, Config, dec_hook=dec_hook)