Fix typing and import in the config file module.
This commit is contained in:
parent
69a897cfec
commit
091d57dba7
@ -7,9 +7,11 @@ from contextlib import suppress
|
||||
from functools import wraps
|
||||
from hashlib import sha256
|
||||
from pathlib import Path, PurePath
|
||||
from time import time
|
||||
from time import sleep, time
|
||||
from typing import Callable, Concatenate, Literal, ParamSpec
|
||||
|
||||
import msgspec
|
||||
import msgspec.toml
|
||||
|
||||
|
||||
class Config(msgspec.Struct):
|
||||
@ -22,6 +24,13 @@ class Config(msgspec.Struct):
|
||||
links: dict[str, Link] = {}
|
||||
|
||||
|
||||
# Typing: arguments for config-modifying functions
|
||||
P = ParamSpec("P")
|
||||
ResultStr = Literal["modified", "created", "read"]
|
||||
RawModifyFunc = Callable[Concatenate[Config, P], Config]
|
||||
ModifyPublic = Callable[P, ResultStr]
|
||||
|
||||
|
||||
class User(msgspec.Struct, omit_defaults=True):
|
||||
privileged: bool = False
|
||||
hash: str = ""
|
||||
@ -34,11 +43,13 @@ class Link(msgspec.Struct, omit_defaults=True):
|
||||
expires: int = 0
|
||||
|
||||
|
||||
config = None
|
||||
conffile = None
|
||||
# Global variables - initialized during application startup
|
||||
config: Config
|
||||
conffile: Path
|
||||
|
||||
|
||||
def init_confdir():
|
||||
def init_confdir() -> None:
|
||||
global conffile
|
||||
if p := os.environ.get("CISTA_HOME"):
|
||||
home = Path(p)
|
||||
else:
|
||||
@ -49,8 +60,6 @@ def init_confdir():
|
||||
if not home.is_dir():
|
||||
home.mkdir(parents=True, exist_ok=True)
|
||||
home.chmod(0o700)
|
||||
|
||||
global conffile
|
||||
conffile = home / "db.toml"
|
||||
|
||||
|
||||
@ -77,10 +86,10 @@ def dec_hook(typ, obj):
|
||||
raise TypeError
|
||||
|
||||
|
||||
def config_update(modify):
|
||||
def config_update(
|
||||
modify: RawModifyFunc,
|
||||
) -> ResultStr | Literal["collision"]:
|
||||
global config
|
||||
if conffile is None:
|
||||
init_confdir()
|
||||
tmpname = conffile.with_suffix(".tmp")
|
||||
try:
|
||||
f = tmpname.open("xb")
|
||||
@ -95,7 +104,7 @@ def config_update(modify):
|
||||
c = msgspec.toml.decode(old, type=Config, dec_hook=dec_hook)
|
||||
except FileNotFoundError:
|
||||
old = b""
|
||||
c = None
|
||||
c = Config(path=Path(), listen="", secret=secrets.token_hex(12))
|
||||
c = modify(c)
|
||||
new = msgspec.toml.encode(c, enc_hook=enc_hook)
|
||||
if old == new:
|
||||
@ -118,17 +127,23 @@ def config_update(modify):
|
||||
return "modified" if old else "created"
|
||||
|
||||
|
||||
def modifies_config(modify):
|
||||
"""Decorator for functions that modify the config file"""
|
||||
def modifies_config(
|
||||
modify: Callable[Concatenate[Config, P], Config],
|
||||
) -> Callable[P, ResultStr]:
|
||||
"""Decorator for functions that modify the config file
|
||||
|
||||
The decorated function takes as first arg Config and returns it modified.
|
||||
The wrapper handles atomic modification and returns a string indicating the result.
|
||||
"""
|
||||
|
||||
@wraps(modify)
|
||||
def wrapper(*args, **kwargs):
|
||||
def m(c):
|
||||
def wrapper(*args: P.args, **kwargs: P.kwargs) -> ResultStr:
|
||||
def m(c: Config) -> Config:
|
||||
return modify(c, *args, **kwargs)
|
||||
|
||||
# Retry modification in case of write collision
|
||||
while (c := config_update(m)) == "collision":
|
||||
time.sleep(0.01)
|
||||
sleep(0.01)
|
||||
return c
|
||||
|
||||
return wrapper
|
||||
@ -136,7 +151,6 @@ def modifies_config(modify):
|
||||
|
||||
def load_config():
|
||||
global config
|
||||
if conffile is None:
|
||||
init_confdir()
|
||||
config = msgspec.toml.decode(conffile.read_bytes(), type=Config, dec_hook=dec_hook)
|
||||
|
||||
@ -145,7 +159,7 @@ def load_config():
|
||||
def update_config(conf: Config, changes: dict) -> Config:
|
||||
"""Create/update the config with new values, respecting changes done by others."""
|
||||
# Encode into dict, update values with new, convert to Config
|
||||
settings = {} if conf is None else msgspec.to_builtins(conf, enc_hook=enc_hook)
|
||||
settings = msgspec.to_builtins(conf, enc_hook=enc_hook)
|
||||
settings.update(changes)
|
||||
return msgspec.convert(settings, Config, dec_hook=dec_hook)
|
||||
|
||||
@ -155,8 +169,13 @@ def update_user(conf: Config, name: str, changes: dict) -> Config:
|
||||
"""Create/update a user with new values, respecting changes done by others."""
|
||||
# Encode into dict, update values with new, convert to Config
|
||||
try:
|
||||
u = conf.users[name].__copy__()
|
||||
except (KeyError, AttributeError):
|
||||
# Copy user by converting to dict and back
|
||||
u = msgspec.convert(
|
||||
msgspec.to_builtins(conf.users[name], enc_hook=enc_hook),
|
||||
User,
|
||||
dec_hook=dec_hook,
|
||||
)
|
||||
except KeyError:
|
||||
u = User()
|
||||
if "password" in changes:
|
||||
from . import auth
|
||||
@ -165,7 +184,7 @@ def update_user(conf: Config, name: str, changes: dict) -> Config:
|
||||
del changes["password"]
|
||||
udict = msgspec.to_builtins(u, enc_hook=enc_hook)
|
||||
udict.update(changes)
|
||||
settings = msgspec.to_builtins(conf, enc_hook=enc_hook) if conf else {"users": {}}
|
||||
settings = msgspec.to_builtins(conf, enc_hook=enc_hook)
|
||||
settings["users"][name] = msgspec.convert(udict, User, dec_hook=dec_hook)
|
||||
return msgspec.convert(settings, Config, dec_hook=dec_hook)
|
||||
|
||||
@ -173,6 +192,7 @@ def update_user(conf: Config, name: str, changes: dict) -> Config:
|
||||
@modifies_config
|
||||
def del_user(conf: Config, name: str) -> Config:
|
||||
"""Delete named user account."""
|
||||
ret = conf.__copy__()
|
||||
ret.users.pop(name)
|
||||
return ret
|
||||
# Create a copy by converting to dict and back
|
||||
settings = msgspec.to_builtins(conf, enc_hook=enc_hook)
|
||||
settings["users"].pop(name)
|
||||
return msgspec.convert(settings, Config, dec_hook=dec_hook)
|
||||
|
Loading…
x
Reference in New Issue
Block a user