import asyncio from importlib.resources import files import msgspec from html5tagger import E from sanic import Forbidden, Sanic, SanicException, errorpages from sanic.log import logger from sanic.response import html, json, redirect from cista import config, session, watching from cista.auth import authbp from cista.fileio import FileServer from cista.protocol import ControlBase, ErrorMsg, FileRange, StatusMsg app = Sanic("cista") fileserver = FileServer() watching.register(app, "/api/watch") app.blueprint(authbp) @app.on_request async def use_session(request): request.ctx.session = session.get(request) # CSRF protection if request.method == "GET" and request.headers.upgrade != "websocket": return # Ordinary GET requests are fine if request.headers.origin.split("//", 1)[1] != request.host: raise Forbidden("Invalid origin: Cross-Site requests not permitted") @app.exception(Exception) async def handle_sanic_exception(request, e): context, code, message = {}, 500, str(e) or "Internal Server Error" if isinstance(e, SanicException): context = e.context or {} code = e.status_code message = f"⚠️ {message}" # Non-browsers get JSON errors if "text/html" not in request.headers.accept: return json({"error": {"code": code, "message": message, **context}}, status=code) # Redirections flash the error message via cookies if "redirect" in context: res = redirect(context["redirect"]) res.cookies.add_cookie("message", message, max_age=5) return res # Otherwise use Sanic's default error page return errorpages.HTMLRenderer(request, e, debug=app.debug).full() @app.on_response async def resphandler(request, response): print("In response handler", request.path, response) def asend(ws, msg): return ws.send(msg if isinstance(msg, bytes) else msgspec.json.encode(msg).decode()) @app.before_server_start async def start_fileserver(app, _): config.load_config() # Main process may have loaded it but we haven't app.static("/files", config.config.path, use_content_range=True, stream_large_files=True, directory_view=True) await fileserver.start() @app.after_server_stop async def stop_fileserver(app, _): await fileserver.stop() @app.get("/") async def index_page(request): s = config.config.public or session.get(request) if not s: return redirect("/login") index = files("cista").joinpath("static", "index.html").read_text() flash = request.cookies.message if flash: index += str(E.dialog(flash, id="flash", open=True, style="position: fixed; top: 0; left: 0; width: 100%; opacity: .8")) res = html(index) session.flash(res, None) return res return html(index) @app.websocket('/api/upload') async def upload(request, ws): alink = fileserver.alink url = request.url_for("upload") while True: req = None try: text = await ws.recv() if not isinstance(text, str): raise ValueError(f"Expected JSON control, got binary len(data) = {len(text)}") req = msgspec.json.decode(text, type=FileRange) pos = req.start while pos < req.end and (data := await ws.recv()) and isinstance(data, bytes): pos += await alink(("upload", req.name, pos, data, req.size)) if pos != req.end: d = f"{len(data)} bytes" if isinstance(data, bytes) else data raise ValueError(f"Expected {req.end - pos} more bytes, got {d}") # Report success res = StatusMsg(status="ack", req=req) await asend(ws, res) await ws.drain() except Exception as e: res = ErrorMsg(error=str(e), req=req) await asend(ws, res) logger.exception(repr(res), e) return @app.websocket('/api/download') async def download(request, ws): alink = fileserver.alink while True: req = None try: print("Waiting for download command") text = await ws.recv() if not isinstance(text, str): raise ValueError(f"Expected JSON control, got binary len(data) = {len(text)}") req = msgspec.json.decode(text, type=FileRange) print("download", req) pos = req.start while pos < req.end: end = min(req.end, pos + (1<<20)) data = await alink(("download", req.name, pos, end)) await asend(ws, data) pos += len(data) # Report success res = StatusMsg(status="ack", req=req) await asend(ws, res) print(ws, dir(ws)) await ws.drain() print(res) except Exception as e: res = ErrorMsg(error=str(e), req=req) await asend(ws, res) logger.exception(repr(res), e) return @app.websocket("/api/control") async def control(request, websocket): cmd = msgspec.json.decode(await websocket.recv(), type=ControlBase) await asyncio.to_thread(cmd) await asend(websocket, StatusMsg(status="ack", req=cmd))