
141 lines
5.1 KiB
Executable File

import asyncio
from importlib.resources import files
import msgspec
from html5tagger import E
from sanic import Forbidden, Sanic, SanicException, errorpages, raw
from sanic.log import logger
from sanic.response import html, json, redirect
import mimetypes
from cista import config, session, watching
from cista.util import filename
from cista.auth import authbp
from cista.fileio import FileServer
from cista.protocol import ControlBase, ErrorMsg, FileRange, StatusMsg
from urllib.parse import unquote
app = Sanic("cista")
fileserver = FileServer()
watching.register(app, "/api/watch")
async def use_session(request):
request.ctx.session = session.get(request)
# CSRF protection
if request.method == "GET" and request.headers.upgrade != "websocket":
return # Ordinary GET requests are fine
if request.headers.origin.split("//", 1)[1] !=
raise Forbidden("Invalid origin: Cross-Site requests not permitted")
async def handle_sanic_exception(request, e):
context, code, message = {}, 500, str(e) or "Internal Server Error"
if isinstance(e, SanicException):
context = e.context or {}
code = e.status_code
message = f"⚠️ {message}"
# Non-browsers get JSON errors
if "text/html" not in request.headers.accept:
return json({"error": {"code": code, "message": message, **context}}, status=code)
# Redirections flash the error message via cookies
if "redirect" in context:
res = redirect(context["redirect"])
res.cookies.add_cookie("message", message, max_age=5)
return res
# Otherwise use Sanic's default error page
return errorpages.HTMLRenderer(request, e, debug=app.debug).full()
async def resphandler(request, response):
print("In response handler", request.path, response)
def asend(ws, msg):
return ws.send(msg if isinstance(msg, bytes) else msgspec.json.encode(msg).decode())
async def start_fileserver(app, _):
config.load_config() # Main process may have loaded it but we haven't
app.static("/files", config.config.path, use_content_range=True, stream_large_files=True, directory_view=True)
await fileserver.start()
async def stop_fileserver(app, _):
await fileserver.stop()
async def wwwroot(request, path=""):
name = filename.sanitize(unquote(path)) if path else "index.html"
index = files("cista").joinpath("wwwroot", name).read_bytes()
except OSError as e:
raise NotFound(f"File not found: /{path}", extra={"name": name, "exception": repr(e)})
mime = mimetypes.guess_type(name)[0] or "application/octet-stream"
return raw(index, content_type=mime)
async def upload(request, ws):
alink = fileserver.alink
url = request.url_for("upload")
while True:
req = None
text = await ws.recv()
if not isinstance(text, str):
raise ValueError(f"Expected JSON control, got binary len(data) = {len(text)}")
req = msgspec.json.decode(text, type=FileRange)
pos = req.start
while pos < req.end and (data := await ws.recv()) and isinstance(data, bytes):
pos += await alink(("upload",, pos, data, req.size))
if pos != req.end:
d = f"{len(data)} bytes" if isinstance(data, bytes) else data
raise ValueError(f"Expected {req.end - pos} more bytes, got {d}")
# Report success
res = StatusMsg(status="ack", req=req)
await asend(ws, res)
await ws.drain()
except Exception as e:
res = ErrorMsg(error=str(e), req=req)
await asend(ws, res)
logger.exception(repr(res), e)
async def download(request, ws):
alink = fileserver.alink
while True:
req = None
print("Waiting for download command")
text = await ws.recv()
if not isinstance(text, str):
raise ValueError(f"Expected JSON control, got binary len(data) = {len(text)}")
req = msgspec.json.decode(text, type=FileRange)
print("download", req)
pos = req.start
while pos < req.end:
end = min(req.end, pos + (1<<20))
data = await alink(("download",, pos, end))
await asend(ws, data)
pos += len(data)
# Report success
res = StatusMsg(status="ack", req=req)
await asend(ws, res)
print(ws, dir(ws))
await ws.drain()
except Exception as e:
res = ErrorMsg(error=str(e), req=req)
await asend(ws, res)
logger.exception(repr(res), e)
async def control(request, websocket):
cmd = msgspec.json.decode(await websocket.recv(), type=ControlBase)
await asyncio.to_thread(cmd)
await asend(websocket, StatusMsg(status="ack", req=cmd))