Cleanup, bugfixes. Added access control on files and API.

This commit is contained in:
Leo Vasanko
2023-10-23 04:51:39 +03:00
committed by Leo Vasanko
parent bd61d7451e
commit 4852212347
14 changed files with 239 additions and 412 deletions

View File

@@ -1,72 +1,56 @@
import asyncio
from importlib.resources import files
import msgspec
from html5tagger import E
from sanic import Forbidden, Sanic, SanicException, errorpages, raw
from sanic.log import logger
from sanic.response import html, json, redirect
import mimetypes
from cista import config, session, watching
from cista.util import filename
from cista.auth import authbp
from cista.fileio import FileServer
from cista.protocol import ControlBase, ErrorMsg, FileRange, StatusMsg
from importlib.resources import files
from urllib.parse import unquote
app = Sanic("cista")
fileserver = FileServer()
watching.register(app, "/api/watch")
app.blueprint(authbp)
from html5tagger import E
from sanic import Blueprint, Sanic, raw
from sanic.exceptions import Forbidden, NotFound
@app.on_request
async def use_session(request):
request.ctx.session = session.get(request)
# CSRF protection
if request.method == "GET" and request.headers.upgrade != "websocket":
return # Ordinary GET requests are fine
if request.headers.origin.split("//", 1)[1] != request.host:
raise Forbidden("Invalid origin: Cross-Site requests not permitted")
from cista import auth, config, session, watching
from cista.api import bp
from cista.util import filename
from cista.util.apphelpers import handle_sanic_exception
@app.exception(Exception)
async def handle_sanic_exception(request, e):
context, code, message = {}, 500, str(e) or "Internal Server Error"
if isinstance(e, SanicException):
context = e.context or {}
code = e.status_code
message = f"⚠️ {message}"
# Non-browsers get JSON errors
if "text/html" not in request.headers.accept:
return json({"error": {"code": code, "message": message, **context}}, status=code)
# Redirections flash the error message via cookies
if "redirect" in context:
res = redirect(context["redirect"])
res.cookies.add_cookie("message", message, max_age=5)
return res
# Otherwise use Sanic's default error page
return errorpages.HTMLRenderer(request, e, debug=app.debug).full()
@app.on_response
async def resphandler(request, response):
print("In response handler", request.path, response)
def asend(ws, msg):
return ws.send(msg if isinstance(msg, bytes) else msgspec.json.encode(msg).decode())
app = Sanic("cista", strict_slashes=True)
app.blueprint(auth.bp)
app.blueprint(bp)
app.exception(Exception)(handle_sanic_exception)
@app.before_server_start
async def start_fileserver(app, _):
config.load_config() # Main process may have loaded it but we haven't
app.static("/files", config.config.path, use_content_range=True, stream_large_files=True, directory_view=True)
await fileserver.start()
async def main_start(app, loop):
config.load_config()
await watching.start(app, loop)
@app.after_server_stop
async def stop_fileserver(app, _):
await fileserver.stop()
async def main_stop(app, loop):
await watching.stop(app, loop)
@app.get("/<path:path>")
async def wwwroot(request, path=""):
@app.on_request
async def use_session(req):
req.ctx.session = session.get(req)
try:
req.ctx.user = config.config.users[req.ctx.session["user"]] # type: ignore
except (AttributeError, KeyError, TypeError):
req.ctx.user = None
# CSRF protection
if req.method == "GET" and req.headers.upgrade != "websocket":
return # Ordinary GET requests are fine
# Check that origin matches host, for browsers which should all send Origin.
# Curl doesn't send any Origin header, so we allow it anyway.
origin = req.headers.origin
if origin and origin.split("//", 1)[1] != req.host:
raise Forbidden("Invalid origin: Cross-Site requests not permitted")
@app.before_server_start
def http_fileserver(app, _):
bp = Blueprint("fileserver")
bp.on_request(auth.verify)
bp.static("/files/", config.config.path, use_content_range=True, stream_large_files=True, directory_view=True)
app.blueprint(bp)
@app.get("/<path:path>", static=True)
async def wwwroot(req, path=""):
"""Frontend files only"""
name = filename.sanitize(unquote(path)) if path else "index.html"
try:
index = files("cista").joinpath("wwwroot", name).read_bytes()
@@ -74,67 +58,3 @@ async def wwwroot(request, path=""):
raise NotFound(f"File not found: /{path}", extra={"name": name, "exception": repr(e)})
mime = mimetypes.guess_type(name)[0] or "application/octet-stream"
return raw(index, content_type=mime)
@app.websocket('/api/upload')
async def upload(request, ws):
alink = fileserver.alink
url = request.url_for("upload")
while True:
req = None
try:
text = await ws.recv()
if not isinstance(text, str):
raise ValueError(f"Expected JSON control, got binary len(data) = {len(text)}")
req = msgspec.json.decode(text, type=FileRange)
pos = req.start
while pos < req.end and (data := await ws.recv()) and isinstance(data, bytes):
pos += await alink(("upload", req.name, pos, data, req.size))
if pos != req.end:
d = f"{len(data)} bytes" if isinstance(data, bytes) else data
raise ValueError(f"Expected {req.end - pos} more bytes, got {d}")
# Report success
res = StatusMsg(status="ack", req=req)
await asend(ws, res)
await ws.drain()
except Exception as e:
res = ErrorMsg(error=str(e), req=req)
await asend(ws, res)
logger.exception(repr(res), e)
return
@app.websocket('/api/download')
async def download(request, ws):
alink = fileserver.alink
while True:
req = None
try:
print("Waiting for download command")
text = await ws.recv()
if not isinstance(text, str):
raise ValueError(f"Expected JSON control, got binary len(data) = {len(text)}")
req = msgspec.json.decode(text, type=FileRange)
print("download", req)
pos = req.start
while pos < req.end:
end = min(req.end, pos + (1<<20))
data = await alink(("download", req.name, pos, end))
await asend(ws, data)
pos += len(data)
# Report success
res = StatusMsg(status="ack", req=req)
await asend(ws, res)
print(ws, dir(ws))
await ws.drain()
print(res)
except Exception as e:
res = ErrorMsg(error=str(e), req=req)
await asend(ws, res)
logger.exception(repr(res), e)
return
@app.websocket("/api/control")
async def control(request, websocket):
cmd = msgspec.json.decode(await websocket.recv(), type=ControlBase)
await asyncio.to_thread(cmd)
await asend(websocket, StatusMsg(status="ack", req=cmd))