Fix confdir passing to workers

This commit is contained in:
Leo Vasanko 2023-11-21 13:05:55 +00:00
parent a267a67947
commit 13e1a19c5b

View File

@ -7,7 +7,7 @@ from functools import wraps
from hashlib import sha256 from hashlib import sha256
from pathlib import Path, PurePath from pathlib import Path, PurePath
from time import time from time import time
from contextlib import suppress
import msgspec import msgspec
@ -37,6 +37,22 @@ config = None
conffile = None conffile = None
def init_confdir():
if p := os.environ.get("CISTA_HOME"):
home = Path(p)
else:
xdg = os.environ.get("XDG_CONFIG_HOME")
home = (
Path(xdg).expanduser() / "cista" if xdg else Path.home() / ".config/cista"
)
if not home.exists()
home.mkdir(parents=True, exist_ok=True)
home.chmod(0o700)
global conffile
conffile = home / "db.toml"
def derived_secret(*params, len=8) -> bytes: def derived_secret(*params, len=8) -> bytes:
"""Used to derive secret keys from the main secret""" """Used to derive secret keys from the main secret"""
# Each part is made the same length by hashing first # Each part is made the same length by hashing first
@ -62,8 +78,8 @@ def dec_hook(typ, obj):
def config_update(modify): def config_update(modify):
global config global config
if not conffile.exists(): if conffile is None:
conffile.parent.mkdir(parents=True, exist_ok=True) init_confdir()
tmpname = conffile.with_suffix(".tmp") tmpname = conffile.with_suffix(".tmp")
try: try:
f = tmpname.open("xb") f = tmpname.open("xb")
@ -77,10 +93,6 @@ def config_update(modify):
old = conffile.read_bytes() old = conffile.read_bytes()
c = msgspec.toml.decode(old, type=Config, dec_hook=dec_hook) c = msgspec.toml.decode(old, type=Config, dec_hook=dec_hook)
except FileNotFoundError: except FileNotFoundError:
# No existing config file, make sure we have a folder...
confdir = conffile.parent
confdir.mkdir(parents=True, exist_ok=True)
confdir.chmod(0o700)
old = b"" old = b""
c = None c = None
c = modify(c) c = modify(c)
@ -93,7 +105,9 @@ def config_update(modify):
f.write(new) f.write(new)
f.close() f.close()
if sys.platform == "win32": if sys.platform == "win32":
conffile.unlink() # Windows doesn't support atomic replace # Windows doesn't support atomic replace
with suppress(FileNotFoundError):
conffile.unlink()
tmpname.rename(conffile) # Atomic replace tmpname.rename(conffile) # Atomic replace
except: except:
f.close() f.close()
@ -121,9 +135,8 @@ def modifies_config(modify):
def load_config(): def load_config():
global config, conffile global config, conffile
conffile = ( if conffile is None:
Path(os.environ.get("CISTA_HOME")) or Path.home() / ".local/share/cista/db.toml" init_confdir()
)
config = msgspec.toml.decode(conffile.read_bytes(), type=Config, dec_hook=dec_hook) config = msgspec.toml.decode(conffile.read_bytes(), type=Config, dec_hook=dec_hook)