Fix typing and import in the config file module.
This commit is contained in:
		| @@ -7,9 +7,11 @@ from contextlib import suppress | |||||||
| from functools import wraps | from functools import wraps | ||||||
| from hashlib import sha256 | from hashlib import sha256 | ||||||
| from pathlib import Path, PurePath | from pathlib import Path, PurePath | ||||||
| from time import time | from time import sleep, time | ||||||
|  | from typing import Callable, Concatenate, Literal, ParamSpec | ||||||
|  |  | ||||||
| import msgspec | import msgspec | ||||||
|  | import msgspec.toml | ||||||
|  |  | ||||||
|  |  | ||||||
| class Config(msgspec.Struct): | class Config(msgspec.Struct): | ||||||
| @@ -22,6 +24,13 @@ class Config(msgspec.Struct): | |||||||
|     links: dict[str, Link] = {} |     links: dict[str, Link] = {} | ||||||
|  |  | ||||||
|  |  | ||||||
|  | # Typing: arguments for config-modifying functions | ||||||
|  | P = ParamSpec("P") | ||||||
|  | ResultStr = Literal["modified", "created", "read"] | ||||||
|  | RawModifyFunc = Callable[Concatenate[Config, P], Config] | ||||||
|  | ModifyPublic = Callable[P, ResultStr] | ||||||
|  |  | ||||||
|  |  | ||||||
| class User(msgspec.Struct, omit_defaults=True): | class User(msgspec.Struct, omit_defaults=True): | ||||||
|     privileged: bool = False |     privileged: bool = False | ||||||
|     hash: str = "" |     hash: str = "" | ||||||
| @@ -34,11 +43,13 @@ class Link(msgspec.Struct, omit_defaults=True): | |||||||
|     expires: int = 0 |     expires: int = 0 | ||||||
|  |  | ||||||
|  |  | ||||||
| config = None | # Global variables - initialized during application startup | ||||||
| conffile = None | config: Config | ||||||
|  | conffile: Path | ||||||
|  |  | ||||||
|  |  | ||||||
| def init_confdir(): | def init_confdir() -> None: | ||||||
|  |     global conffile | ||||||
|     if p := os.environ.get("CISTA_HOME"): |     if p := os.environ.get("CISTA_HOME"): | ||||||
|         home = Path(p) |         home = Path(p) | ||||||
|     else: |     else: | ||||||
| @@ -49,8 +60,6 @@ def init_confdir(): | |||||||
|     if not home.is_dir(): |     if not home.is_dir(): | ||||||
|         home.mkdir(parents=True, exist_ok=True) |         home.mkdir(parents=True, exist_ok=True) | ||||||
|         home.chmod(0o700) |         home.chmod(0o700) | ||||||
|  |  | ||||||
|     global conffile |  | ||||||
|     conffile = home / "db.toml" |     conffile = home / "db.toml" | ||||||
|  |  | ||||||
|  |  | ||||||
| @@ -77,10 +86,10 @@ def dec_hook(typ, obj): | |||||||
|     raise TypeError |     raise TypeError | ||||||
|  |  | ||||||
|  |  | ||||||
| def config_update(modify): | def config_update( | ||||||
|  |     modify: RawModifyFunc, | ||||||
|  | ) -> ResultStr | Literal["collision"]: | ||||||
|     global config |     global config | ||||||
|     if conffile is None: |  | ||||||
|         init_confdir() |  | ||||||
|     tmpname = conffile.with_suffix(".tmp") |     tmpname = conffile.with_suffix(".tmp") | ||||||
|     try: |     try: | ||||||
|         f = tmpname.open("xb") |         f = tmpname.open("xb") | ||||||
| @@ -95,7 +104,7 @@ def config_update(modify): | |||||||
|             c = msgspec.toml.decode(old, type=Config, dec_hook=dec_hook) |             c = msgspec.toml.decode(old, type=Config, dec_hook=dec_hook) | ||||||
|         except FileNotFoundError: |         except FileNotFoundError: | ||||||
|             old = b"" |             old = b"" | ||||||
|             c = None |             c = Config(path=Path(), listen="", secret=secrets.token_hex(12)) | ||||||
|         c = modify(c) |         c = modify(c) | ||||||
|         new = msgspec.toml.encode(c, enc_hook=enc_hook) |         new = msgspec.toml.encode(c, enc_hook=enc_hook) | ||||||
|         if old == new: |         if old == new: | ||||||
| @@ -118,17 +127,23 @@ def config_update(modify): | |||||||
|     return "modified" if old else "created" |     return "modified" if old else "created" | ||||||
|  |  | ||||||
|  |  | ||||||
| def modifies_config(modify): | def modifies_config( | ||||||
|     """Decorator for functions that modify the config file""" |     modify: Callable[Concatenate[Config, P], Config], | ||||||
|  | ) -> Callable[P, ResultStr]: | ||||||
|  |     """Decorator for functions that modify the config file | ||||||
|  |  | ||||||
|  |     The decorated function takes as first arg Config and returns it modified. | ||||||
|  |     The wrapper handles atomic modification and returns a string indicating the result. | ||||||
|  |     """ | ||||||
|  |  | ||||||
|     @wraps(modify) |     @wraps(modify) | ||||||
|     def wrapper(*args, **kwargs): |     def wrapper(*args: P.args, **kwargs: P.kwargs) -> ResultStr: | ||||||
|         def m(c): |         def m(c: Config) -> Config: | ||||||
|             return modify(c, *args, **kwargs) |             return modify(c, *args, **kwargs) | ||||||
|  |  | ||||||
|         # Retry modification in case of write collision |         # Retry modification in case of write collision | ||||||
|         while (c := config_update(m)) == "collision": |         while (c := config_update(m)) == "collision": | ||||||
|             time.sleep(0.01) |             sleep(0.01) | ||||||
|         return c |         return c | ||||||
|  |  | ||||||
|     return wrapper |     return wrapper | ||||||
| @@ -136,8 +151,7 @@ def modifies_config(modify): | |||||||
|  |  | ||||||
| def load_config(): | def load_config(): | ||||||
|     global config |     global config | ||||||
|     if conffile is None: |     init_confdir() | ||||||
|         init_confdir() |  | ||||||
|     config = msgspec.toml.decode(conffile.read_bytes(), type=Config, dec_hook=dec_hook) |     config = msgspec.toml.decode(conffile.read_bytes(), type=Config, dec_hook=dec_hook) | ||||||
|  |  | ||||||
|  |  | ||||||
| @@ -145,7 +159,7 @@ def load_config(): | |||||||
| def update_config(conf: Config, changes: dict) -> Config: | def update_config(conf: Config, changes: dict) -> Config: | ||||||
|     """Create/update the config with new values, respecting changes done by others.""" |     """Create/update the config with new values, respecting changes done by others.""" | ||||||
|     # Encode into dict, update values with new, convert to Config |     # Encode into dict, update values with new, convert to Config | ||||||
|     settings = {} if conf is None else msgspec.to_builtins(conf, enc_hook=enc_hook) |     settings = msgspec.to_builtins(conf, enc_hook=enc_hook) | ||||||
|     settings.update(changes) |     settings.update(changes) | ||||||
|     return msgspec.convert(settings, Config, dec_hook=dec_hook) |     return msgspec.convert(settings, Config, dec_hook=dec_hook) | ||||||
|  |  | ||||||
| @@ -155,8 +169,13 @@ def update_user(conf: Config, name: str, changes: dict) -> Config: | |||||||
|     """Create/update a user with new values, respecting changes done by others.""" |     """Create/update a user with new values, respecting changes done by others.""" | ||||||
|     # Encode into dict, update values with new, convert to Config |     # Encode into dict, update values with new, convert to Config | ||||||
|     try: |     try: | ||||||
|         u = conf.users[name].__copy__() |         # Copy user by converting to dict and back | ||||||
|     except (KeyError, AttributeError): |         u = msgspec.convert( | ||||||
|  |             msgspec.to_builtins(conf.users[name], enc_hook=enc_hook), | ||||||
|  |             User, | ||||||
|  |             dec_hook=dec_hook, | ||||||
|  |         ) | ||||||
|  |     except KeyError: | ||||||
|         u = User() |         u = User() | ||||||
|     if "password" in changes: |     if "password" in changes: | ||||||
|         from . import auth |         from . import auth | ||||||
| @@ -165,7 +184,7 @@ def update_user(conf: Config, name: str, changes: dict) -> Config: | |||||||
|         del changes["password"] |         del changes["password"] | ||||||
|     udict = msgspec.to_builtins(u, enc_hook=enc_hook) |     udict = msgspec.to_builtins(u, enc_hook=enc_hook) | ||||||
|     udict.update(changes) |     udict.update(changes) | ||||||
|     settings = msgspec.to_builtins(conf, enc_hook=enc_hook) if conf else {"users": {}} |     settings = msgspec.to_builtins(conf, enc_hook=enc_hook) | ||||||
|     settings["users"][name] = msgspec.convert(udict, User, dec_hook=dec_hook) |     settings["users"][name] = msgspec.convert(udict, User, dec_hook=dec_hook) | ||||||
|     return msgspec.convert(settings, Config, dec_hook=dec_hook) |     return msgspec.convert(settings, Config, dec_hook=dec_hook) | ||||||
|  |  | ||||||
| @@ -173,6 +192,7 @@ def update_user(conf: Config, name: str, changes: dict) -> Config: | |||||||
| @modifies_config | @modifies_config | ||||||
| def del_user(conf: Config, name: str) -> Config: | def del_user(conf: Config, name: str) -> Config: | ||||||
|     """Delete named user account.""" |     """Delete named user account.""" | ||||||
|     ret = conf.__copy__() |     # Create a copy by converting to dict and back | ||||||
|     ret.users.pop(name) |     settings = msgspec.to_builtins(conf, enc_hook=enc_hook) | ||||||
|     return ret |     settings["users"].pop(name) | ||||||
|  |     return msgspec.convert(settings, Config, dec_hook=dec_hook) | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Leo Vasanko
					Leo Vasanko