Code cleanup and bugfixes:

- Resolve threading deadlock when multiple watch request arrived at the same moment.
- Implement more graceful server exit.
- Reduce excessive logging.
- Fix unix socket clearing; until Sanic starts accepting Path for unix socket name.
This commit is contained in:
Leo Vasanko 2023-11-15 11:02:44 -08:00
parent 669762dfe7
commit b6b387d09b
6 changed files with 63 additions and 43 deletions

View File

@ -111,13 +111,24 @@ async def watch(req, ws):
) )
uuid = token_bytes(16) uuid = token_bytes(16)
try: try:
with watching.state.lock: q, space, root = await asyncio.get_event_loop().run_in_executor(
q = watching.pubsub[uuid] = asyncio.Queue() req.app.ctx.threadexec, subscribe, uuid, ws
# Init with disk usage and full tree )
await ws.send(watching.format_space(watching.state.space)) await ws.send(space)
await ws.send(watching.format_root(watching.state.root)) await ws.send(root)
# Send updates # Send updates
while True: while True:
await ws.send(await q.get()) await ws.send(await q.get())
finally: finally:
del watching.pubsub[uuid] del watching.pubsub[uuid]
def subscribe(uuid, ws):
with watching.state.lock:
q = watching.pubsub[uuid] = asyncio.Queue()
# Init with disk usage and full tree
return (
q,
watching.format_space(watching.state.space),
watching.format_root(watching.state.root),
)

View File

@ -1,6 +1,8 @@
import asyncio import asyncio
import datetime import datetime
import logging
import mimetypes import mimetypes
import threading
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from pathlib import Path, PurePath, PurePosixPath from pathlib import Path, PurePath, PurePosixPath
from stat import S_IFDIR, S_IFREG from stat import S_IFDIR, S_IFREG
@ -12,7 +14,7 @@ import sanic.helpers
from blake3 import blake3 from blake3 import blake3
from sanic import Blueprint, Sanic, empty, raw from sanic import Blueprint, Sanic, empty, raw
from sanic.exceptions import Forbidden, NotFound from sanic.exceptions import Forbidden, NotFound
from sanic.log import logging from sanic.log import logger
from stream_zip import ZIP_AUTO, stream_zip from stream_zip import ZIP_AUTO, stream_zip
from cista import auth, config, session, watching from cista import auth, config, session, watching
@ -31,14 +33,16 @@ app.exception(Exception)(handle_sanic_exception)
@app.before_server_start @app.before_server_start
async def main_start(app, loop): async def main_start(app, loop):
config.load_config() config.load_config()
await watching.start(app, loop) logger.setLevel(logging.INFO)
app.ctx.threadexec = ThreadPoolExecutor( app.ctx.threadexec = ThreadPoolExecutor(
max_workers=8, thread_name_prefix="cista-ioworker" max_workers=8, thread_name_prefix="cista-ioworker"
) )
await watching.start(app, loop)
@app.after_server_stop @app.after_server_stop
async def main_stop(app, loop): async def main_stop(app, loop):
quit.set()
await watching.stop(app, loop) await watching.stop(app, loop)
app.ctx.threadexec.shutdown() app.ctx.threadexec.shutdown()
@ -122,7 +126,7 @@ def _load_wwwroot(www):
if not wwwnew: if not wwwnew:
msg = f"Web frontend missing from {base}\n Did you forget: hatch build\n" msg = f"Web frontend missing from {base}\n Did you forget: hatch build\n"
if not www: if not www:
logging.warning(msg) logger.warning(msg)
if not app.debug: if not app.debug:
msg = "Web frontend missing. Cista installation is broken.\n" msg = "Web frontend missing. Cista installation is broken.\n"
wwwnew[""] = ( wwwnew[""] = (
@ -141,7 +145,7 @@ def _load_wwwroot(www):
async def start(app): async def start(app):
await load_wwwroot(app) await load_wwwroot(app)
if app.debug: if app.debug:
app.add_task(refresh_wwwroot()) app.add_task(refresh_wwwroot(), name="refresh_wwwroot")
async def load_wwwroot(app): async def load_wwwroot(app):
@ -151,27 +155,31 @@ async def load_wwwroot(app):
) )
quit = threading.Event()
async def refresh_wwwroot(): async def refresh_wwwroot():
while True: try:
await asyncio.sleep(0.5) while not quit.is_set():
try: try:
wwwold = www wwwold = www
await load_wwwroot(app) await load_wwwroot(app)
changes = "" changes = ""
for name in sorted(www): for name in sorted(www):
attr = www[name] attr = www[name]
if wwwold.get(name) == attr: if wwwold.get(name) == attr:
continue continue
headers = attr[2] headers = attr[2]
changes += f"{headers['last-modified']} {headers['etag']} /{name}\n" changes += f"{headers['last-modified']} {headers['etag']} /{name}\n"
for name in sorted(set(wwwold) - set(www)): for name in sorted(set(wwwold) - set(www)):
changes += f"Deleted /{name}\n" changes += f"Deleted /{name}\n"
if changes: if changes:
print(f"Updated wwwroot:\n{changes}", end="", flush=True) print(f"Updated wwwroot:\n{changes}", end="", flush=True)
except Exception as e: except Exception as e:
print("Error loading wwwroot", e) print(f"Error loading wwwroot: {e!r}")
if not app.debug: await asyncio.sleep(0.5)
return except asyncio.CancelledError:
pass
@app.route("/<path:path>", methods=["GET", "HEAD"]) @app.route("/<path:path>", methods=["GET", "HEAD"])
@ -251,7 +259,7 @@ async def zip_download(req, keys, zipfile, ext):
for chunk in stream_zip(local_files(files)): for chunk in stream_zip(local_files(files)):
asyncio.run_coroutine_threadsafe(queue.put(chunk), loop).result() asyncio.run_coroutine_threadsafe(queue.put(chunk), loop).result()
except Exception: except Exception:
logging.exception("Error streaming ZIP") logger.exception("Error streaming ZIP")
raise raise
finally: finally:
asyncio.run_coroutine_threadsafe(queue.put(None), loop) asyncio.run_coroutine_threadsafe(queue.put(None), loop)

View File

@ -71,7 +71,7 @@ def verify(request, *, privileged=False):
raise Forbidden("Access Forbidden: Only for privileged users", quiet=True) raise Forbidden("Access Forbidden: Only for privileged users", quiet=True)
elif config.config.public or request.ctx.user: elif config.config.public or request.ctx.user:
return return
raise Unauthorized("Login required", "cookie", quiet=True) raise Unauthorized(f"Login required for {request.path}", "cookie", quiet=True)
bp = Blueprint("auth") bp = Blueprint("auth")

View File

@ -51,7 +51,7 @@ def parse_listen(listen):
raise ValueError( raise ValueError(
f"Directory for unix socket does not exist: {unix.parent}/", f"Directory for unix socket does not exist: {unix.parent}/",
) )
return "http://localhost", {"unix": unix} return "http://localhost", {"unix": unix.as_posix()}
if re.fullmatch(r"(\w+(-\w+)*\.)+\w{2,}", listen, re.UNICODE): if re.fullmatch(r"(\w+(-\w+)*\.)+\w{2,}", listen, re.UNICODE):
return f"https://{listen}", {"host": listen, "port": 443, "ssl": True} return f"https://{listen}", {"host": listen, "port": 443, "ssl": True}
try: try:

View File

@ -21,7 +21,6 @@ def jres(data, **kwargs):
async def handle_sanic_exception(request, e): async def handle_sanic_exception(request, e):
logger.exception(e)
context, code = {}, 500 context, code = {}, 500
message = str(e) message = str(e)
if isinstance(e, SanicException): if isinstance(e, SanicException):

View File

@ -9,7 +9,7 @@ from pathlib import Path, PurePosixPath
import msgspec import msgspec
from natsort import humansorted, natsort_keygen, ns from natsort import humansorted, natsort_keygen, ns
from sanic.log import logging from sanic.log import logger
from cista import config from cista import config
from cista.fileio import fuid from cista.fileio import fuid
@ -113,7 +113,8 @@ class State:
state = State() state = State()
rootpath: Path = None # type: ignore rootpath: Path = None # type: ignore
quit = False quit = threading.Event()
modified_flags = ( modified_flags = (
"IN_CREATE", "IN_CREATE",
"IN_DELETE", "IN_DELETE",
@ -129,7 +130,7 @@ def watcher_thread(loop):
global rootpath global rootpath
import inotify.adapters import inotify.adapters
while not quit: while not quit.is_set():
rootpath = config.config.path rootpath = config.config.path
i = inotify.adapters.InotifyTree(rootpath.as_posix()) i = inotify.adapters.InotifyTree(rootpath.as_posix())
# Initialize the tree from filesystem # Initialize the tree from filesystem
@ -144,7 +145,7 @@ def watcher_thread(loop):
refreshdl = time.monotonic() + 30.0 refreshdl = time.monotonic() + 30.0
for event in i.event_gen(): for event in i.event_gen():
if quit: if quit.is_set():
return return
# Disk usage update # Disk usage update
du = shutil.disk_usage(rootpath) du = shutil.disk_usage(rootpath)
@ -174,7 +175,7 @@ def watcher_thread(loop):
def watcher_thread_poll(loop): def watcher_thread_poll(loop):
global rootpath global rootpath
while not quit: while not quit.is_set():
rootpath = config.config.path rootpath = config.config.path
new = walk() new = walk()
with state.lock: with state.lock:
@ -190,7 +191,7 @@ def watcher_thread_poll(loop):
state.space = space state.space = space
broadcast(format_space(space), loop) broadcast(format_space(space), loop)
time.sleep(2.0) quit.wait(2.0)
def walk(rel=PurePosixPath()) -> list[FileEntry]: # noqa: B008 def walk(rel=PurePosixPath()) -> list[FileEntry]: # noqa: B008
@ -218,14 +219,14 @@ def _walk(rel: PurePosixPath, isfile: int, st: stat_result) -> list[FileEntry]:
try: try:
li = [] li = []
for f in path.iterdir(): for f in path.iterdir():
if quit: if quit.is_set():
raise SystemExit("quit") raise SystemExit("quit")
if f.name.startswith("."): if f.name.startswith("."):
continue # No dotfiles continue # No dotfiles
s = f.stat() s = f.stat()
li.append((int(not stat.S_ISDIR(s.st_mode)), f.name, s)) li.append((int(not stat.S_ISDIR(s.st_mode)), f.name, s))
for [isfile, name, s] in humansorted(li): for [isfile, name, s] in humansorted(li):
if quit: if quit.is_set():
raise SystemExit("quit") raise SystemExit("quit")
subtree = _walk(rel / name, isfile, s) subtree = _walk(rel / name, isfile, s)
child = subtree[0] child = subtree[0]
@ -316,7 +317,7 @@ async def abroadcast(msg):
queue.put_nowait(msg) queue.put_nowait(msg)
except Exception: except Exception:
# Log because asyncio would silently eat the error # Log because asyncio would silently eat the error
logging.exception("Broadcast error") logger.exception("Broadcast error")
async def start(app, loop): async def start(app, loop):
@ -325,11 +326,12 @@ async def start(app, loop):
app.ctx.watcher = threading.Thread( app.ctx.watcher = threading.Thread(
target=watcher_thread if use_inotify else watcher_thread_poll, target=watcher_thread if use_inotify else watcher_thread_poll,
args=[loop], args=[loop],
name="watcher",
) )
app.ctx.watcher.start() app.ctx.watcher.start()
async def stop(app, loop): async def stop(app, loop):
global quit global quit
quit = True quit.set()
app.ctx.watcher.join() app.ctx.watcher.join()