From b6b387d09bc867701173b370e52686c0ecd7f388 Mon Sep 17 00:00:00 2001 From: Leo Vasanko Date: Wed, 15 Nov 2023 11:02:44 -0800 Subject: [PATCH] Code cleanup and bugfixes: - Resolve threading deadlock when multiple watch request arrived at the same moment. - Implement more graceful server exit. - Reduce excessive logging. - Fix unix socket clearing; until Sanic starts accepting Path for unix socket name. --- cista/api.py | 21 +++++++++++---- cista/app.py | 58 +++++++++++++++++++++++----------------- cista/auth.py | 2 +- cista/serve.py | 2 +- cista/util/apphelpers.py | 1 - cista/watching.py | 22 ++++++++------- 6 files changed, 63 insertions(+), 43 deletions(-) diff --git a/cista/api.py b/cista/api.py index 99da767..ca097ea 100644 --- a/cista/api.py +++ b/cista/api.py @@ -111,13 +111,24 @@ async def watch(req, ws): ) uuid = token_bytes(16) try: - with watching.state.lock: - q = watching.pubsub[uuid] = asyncio.Queue() - # Init with disk usage and full tree - await ws.send(watching.format_space(watching.state.space)) - await ws.send(watching.format_root(watching.state.root)) + q, space, root = await asyncio.get_event_loop().run_in_executor( + req.app.ctx.threadexec, subscribe, uuid, ws + ) + await ws.send(space) + await ws.send(root) # Send updates while True: await ws.send(await q.get()) finally: del watching.pubsub[uuid] + + +def subscribe(uuid, ws): + with watching.state.lock: + q = watching.pubsub[uuid] = asyncio.Queue() + # Init with disk usage and full tree + return ( + q, + watching.format_space(watching.state.space), + watching.format_root(watching.state.root), + ) diff --git a/cista/app.py b/cista/app.py index ef0534a..4d4e1bb 100644 --- a/cista/app.py +++ b/cista/app.py @@ -1,6 +1,8 @@ import asyncio import datetime +import logging import mimetypes +import threading from concurrent.futures import ThreadPoolExecutor from pathlib import Path, PurePath, PurePosixPath from stat import S_IFDIR, S_IFREG @@ -12,7 +14,7 @@ import sanic.helpers from blake3 import blake3 from sanic import Blueprint, Sanic, empty, raw from sanic.exceptions import Forbidden, NotFound -from sanic.log import logging +from sanic.log import logger from stream_zip import ZIP_AUTO, stream_zip from cista import auth, config, session, watching @@ -31,14 +33,16 @@ app.exception(Exception)(handle_sanic_exception) @app.before_server_start async def main_start(app, loop): config.load_config() - await watching.start(app, loop) + logger.setLevel(logging.INFO) app.ctx.threadexec = ThreadPoolExecutor( max_workers=8, thread_name_prefix="cista-ioworker" ) + await watching.start(app, loop) @app.after_server_stop async def main_stop(app, loop): + quit.set() await watching.stop(app, loop) app.ctx.threadexec.shutdown() @@ -122,7 +126,7 @@ def _load_wwwroot(www): if not wwwnew: msg = f"Web frontend missing from {base}\n Did you forget: hatch build\n" if not www: - logging.warning(msg) + logger.warning(msg) if not app.debug: msg = "Web frontend missing. Cista installation is broken.\n" wwwnew[""] = ( @@ -141,7 +145,7 @@ def _load_wwwroot(www): async def start(app): await load_wwwroot(app) if app.debug: - app.add_task(refresh_wwwroot()) + app.add_task(refresh_wwwroot(), name="refresh_wwwroot") async def load_wwwroot(app): @@ -151,27 +155,31 @@ async def load_wwwroot(app): ) +quit = threading.Event() + + async def refresh_wwwroot(): - while True: - await asyncio.sleep(0.5) - try: - wwwold = www - await load_wwwroot(app) - changes = "" - for name in sorted(www): - attr = www[name] - if wwwold.get(name) == attr: - continue - headers = attr[2] - changes += f"{headers['last-modified']} {headers['etag']} /{name}\n" - for name in sorted(set(wwwold) - set(www)): - changes += f"Deleted /{name}\n" - if changes: - print(f"Updated wwwroot:\n{changes}", end="", flush=True) - except Exception as e: - print("Error loading wwwroot", e) - if not app.debug: - return + try: + while not quit.is_set(): + try: + wwwold = www + await load_wwwroot(app) + changes = "" + for name in sorted(www): + attr = www[name] + if wwwold.get(name) == attr: + continue + headers = attr[2] + changes += f"{headers['last-modified']} {headers['etag']} /{name}\n" + for name in sorted(set(wwwold) - set(www)): + changes += f"Deleted /{name}\n" + if changes: + print(f"Updated wwwroot:\n{changes}", end="", flush=True) + except Exception as e: + print(f"Error loading wwwroot: {e!r}") + await asyncio.sleep(0.5) + except asyncio.CancelledError: + pass @app.route("/", methods=["GET", "HEAD"]) @@ -251,7 +259,7 @@ async def zip_download(req, keys, zipfile, ext): for chunk in stream_zip(local_files(files)): asyncio.run_coroutine_threadsafe(queue.put(chunk), loop).result() except Exception: - logging.exception("Error streaming ZIP") + logger.exception("Error streaming ZIP") raise finally: asyncio.run_coroutine_threadsafe(queue.put(None), loop) diff --git a/cista/auth.py b/cista/auth.py index 7efeb8c..09b249f 100644 --- a/cista/auth.py +++ b/cista/auth.py @@ -71,7 +71,7 @@ def verify(request, *, privileged=False): raise Forbidden("Access Forbidden: Only for privileged users", quiet=True) elif config.config.public or request.ctx.user: return - raise Unauthorized("Login required", "cookie", quiet=True) + raise Unauthorized(f"Login required for {request.path}", "cookie", quiet=True) bp = Blueprint("auth") diff --git a/cista/serve.py b/cista/serve.py index 6c1e310..24c9638 100644 --- a/cista/serve.py +++ b/cista/serve.py @@ -51,7 +51,7 @@ def parse_listen(listen): raise ValueError( f"Directory for unix socket does not exist: {unix.parent}/", ) - return "http://localhost", {"unix": unix} + return "http://localhost", {"unix": unix.as_posix()} if re.fullmatch(r"(\w+(-\w+)*\.)+\w{2,}", listen, re.UNICODE): return f"https://{listen}", {"host": listen, "port": 443, "ssl": True} try: diff --git a/cista/util/apphelpers.py b/cista/util/apphelpers.py index ee2562d..c0433b1 100644 --- a/cista/util/apphelpers.py +++ b/cista/util/apphelpers.py @@ -21,7 +21,6 @@ def jres(data, **kwargs): async def handle_sanic_exception(request, e): - logger.exception(e) context, code = {}, 500 message = str(e) if isinstance(e, SanicException): diff --git a/cista/watching.py b/cista/watching.py index bd9e647..5ead7b3 100644 --- a/cista/watching.py +++ b/cista/watching.py @@ -9,7 +9,7 @@ from pathlib import Path, PurePosixPath import msgspec from natsort import humansorted, natsort_keygen, ns -from sanic.log import logging +from sanic.log import logger from cista import config from cista.fileio import fuid @@ -113,7 +113,8 @@ class State: state = State() rootpath: Path = None # type: ignore -quit = False +quit = threading.Event() + modified_flags = ( "IN_CREATE", "IN_DELETE", @@ -129,7 +130,7 @@ def watcher_thread(loop): global rootpath import inotify.adapters - while not quit: + while not quit.is_set(): rootpath = config.config.path i = inotify.adapters.InotifyTree(rootpath.as_posix()) # Initialize the tree from filesystem @@ -144,7 +145,7 @@ def watcher_thread(loop): refreshdl = time.monotonic() + 30.0 for event in i.event_gen(): - if quit: + if quit.is_set(): return # Disk usage update du = shutil.disk_usage(rootpath) @@ -174,7 +175,7 @@ def watcher_thread(loop): def watcher_thread_poll(loop): global rootpath - while not quit: + while not quit.is_set(): rootpath = config.config.path new = walk() with state.lock: @@ -190,7 +191,7 @@ def watcher_thread_poll(loop): state.space = space broadcast(format_space(space), loop) - time.sleep(2.0) + quit.wait(2.0) def walk(rel=PurePosixPath()) -> list[FileEntry]: # noqa: B008 @@ -218,14 +219,14 @@ def _walk(rel: PurePosixPath, isfile: int, st: stat_result) -> list[FileEntry]: try: li = [] for f in path.iterdir(): - if quit: + if quit.is_set(): raise SystemExit("quit") if f.name.startswith("."): continue # No dotfiles s = f.stat() li.append((int(not stat.S_ISDIR(s.st_mode)), f.name, s)) for [isfile, name, s] in humansorted(li): - if quit: + if quit.is_set(): raise SystemExit("quit") subtree = _walk(rel / name, isfile, s) child = subtree[0] @@ -316,7 +317,7 @@ async def abroadcast(msg): queue.put_nowait(msg) except Exception: # Log because asyncio would silently eat the error - logging.exception("Broadcast error") + logger.exception("Broadcast error") async def start(app, loop): @@ -325,11 +326,12 @@ async def start(app, loop): app.ctx.watcher = threading.Thread( target=watcher_thread if use_inotify else watcher_thread_poll, args=[loop], + name="watcher", ) app.ctx.watcher.start() async def stop(app, loop): global quit - quit = True + quit.set() app.ctx.watcher.join()