Cleanup, bugfixes. Added access control on files and API.
This commit is contained in:
166
cista/app.py
166
cista/app.py
@@ -1,72 +1,56 @@
|
||||
import asyncio
|
||||
from importlib.resources import files
|
||||
|
||||
import msgspec
|
||||
from html5tagger import E
|
||||
from sanic import Forbidden, Sanic, SanicException, errorpages, raw
|
||||
from sanic.log import logger
|
||||
from sanic.response import html, json, redirect
|
||||
import mimetypes
|
||||
|
||||
from cista import config, session, watching
|
||||
from cista.util import filename
|
||||
from cista.auth import authbp
|
||||
from cista.fileio import FileServer
|
||||
from cista.protocol import ControlBase, ErrorMsg, FileRange, StatusMsg
|
||||
from importlib.resources import files
|
||||
from urllib.parse import unquote
|
||||
|
||||
app = Sanic("cista")
|
||||
fileserver = FileServer()
|
||||
watching.register(app, "/api/watch")
|
||||
app.blueprint(authbp)
|
||||
from html5tagger import E
|
||||
from sanic import Blueprint, Sanic, raw
|
||||
from sanic.exceptions import Forbidden, NotFound
|
||||
|
||||
@app.on_request
|
||||
async def use_session(request):
|
||||
request.ctx.session = session.get(request)
|
||||
# CSRF protection
|
||||
if request.method == "GET" and request.headers.upgrade != "websocket":
|
||||
return # Ordinary GET requests are fine
|
||||
if request.headers.origin.split("//", 1)[1] != request.host:
|
||||
raise Forbidden("Invalid origin: Cross-Site requests not permitted")
|
||||
from cista import auth, config, session, watching
|
||||
from cista.api import bp
|
||||
from cista.util import filename
|
||||
from cista.util.apphelpers import handle_sanic_exception
|
||||
|
||||
@app.exception(Exception)
|
||||
async def handle_sanic_exception(request, e):
|
||||
context, code, message = {}, 500, str(e) or "Internal Server Error"
|
||||
if isinstance(e, SanicException):
|
||||
context = e.context or {}
|
||||
code = e.status_code
|
||||
message = f"⚠️ {message}"
|
||||
# Non-browsers get JSON errors
|
||||
if "text/html" not in request.headers.accept:
|
||||
return json({"error": {"code": code, "message": message, **context}}, status=code)
|
||||
# Redirections flash the error message via cookies
|
||||
if "redirect" in context:
|
||||
res = redirect(context["redirect"])
|
||||
res.cookies.add_cookie("message", message, max_age=5)
|
||||
return res
|
||||
# Otherwise use Sanic's default error page
|
||||
return errorpages.HTMLRenderer(request, e, debug=app.debug).full()
|
||||
|
||||
@app.on_response
|
||||
async def resphandler(request, response):
|
||||
print("In response handler", request.path, response)
|
||||
|
||||
def asend(ws, msg):
|
||||
return ws.send(msg if isinstance(msg, bytes) else msgspec.json.encode(msg).decode())
|
||||
app = Sanic("cista", strict_slashes=True)
|
||||
app.blueprint(auth.bp)
|
||||
app.blueprint(bp)
|
||||
app.exception(Exception)(handle_sanic_exception)
|
||||
|
||||
@app.before_server_start
|
||||
async def start_fileserver(app, _):
|
||||
config.load_config() # Main process may have loaded it but we haven't
|
||||
app.static("/files", config.config.path, use_content_range=True, stream_large_files=True, directory_view=True)
|
||||
|
||||
await fileserver.start()
|
||||
async def main_start(app, loop):
|
||||
config.load_config()
|
||||
await watching.start(app, loop)
|
||||
|
||||
@app.after_server_stop
|
||||
async def stop_fileserver(app, _):
|
||||
await fileserver.stop()
|
||||
async def main_stop(app, loop):
|
||||
await watching.stop(app, loop)
|
||||
|
||||
@app.get("/<path:path>")
|
||||
async def wwwroot(request, path=""):
|
||||
@app.on_request
|
||||
async def use_session(req):
|
||||
req.ctx.session = session.get(req)
|
||||
try:
|
||||
req.ctx.user = config.config.users[req.ctx.session["user"]] # type: ignore
|
||||
except (AttributeError, KeyError, TypeError):
|
||||
req.ctx.user = None
|
||||
# CSRF protection
|
||||
if req.method == "GET" and req.headers.upgrade != "websocket":
|
||||
return # Ordinary GET requests are fine
|
||||
# Check that origin matches host, for browsers which should all send Origin.
|
||||
# Curl doesn't send any Origin header, so we allow it anyway.
|
||||
origin = req.headers.origin
|
||||
if origin and origin.split("//", 1)[1] != req.host:
|
||||
raise Forbidden("Invalid origin: Cross-Site requests not permitted")
|
||||
|
||||
@app.before_server_start
|
||||
def http_fileserver(app, _):
|
||||
bp = Blueprint("fileserver")
|
||||
bp.on_request(auth.verify)
|
||||
bp.static("/files/", config.config.path, use_content_range=True, stream_large_files=True, directory_view=True)
|
||||
app.blueprint(bp)
|
||||
|
||||
@app.get("/<path:path>", static=True)
|
||||
async def wwwroot(req, path=""):
|
||||
"""Frontend files only"""
|
||||
name = filename.sanitize(unquote(path)) if path else "index.html"
|
||||
try:
|
||||
index = files("cista").joinpath("wwwroot", name).read_bytes()
|
||||
@@ -74,67 +58,3 @@ async def wwwroot(request, path=""):
|
||||
raise NotFound(f"File not found: /{path}", extra={"name": name, "exception": repr(e)})
|
||||
mime = mimetypes.guess_type(name)[0] or "application/octet-stream"
|
||||
return raw(index, content_type=mime)
|
||||
|
||||
@app.websocket('/api/upload')
|
||||
async def upload(request, ws):
|
||||
alink = fileserver.alink
|
||||
url = request.url_for("upload")
|
||||
while True:
|
||||
req = None
|
||||
try:
|
||||
text = await ws.recv()
|
||||
if not isinstance(text, str):
|
||||
raise ValueError(f"Expected JSON control, got binary len(data) = {len(text)}")
|
||||
req = msgspec.json.decode(text, type=FileRange)
|
||||
pos = req.start
|
||||
while pos < req.end and (data := await ws.recv()) and isinstance(data, bytes):
|
||||
pos += await alink(("upload", req.name, pos, data, req.size))
|
||||
if pos != req.end:
|
||||
d = f"{len(data)} bytes" if isinstance(data, bytes) else data
|
||||
raise ValueError(f"Expected {req.end - pos} more bytes, got {d}")
|
||||
# Report success
|
||||
res = StatusMsg(status="ack", req=req)
|
||||
await asend(ws, res)
|
||||
await ws.drain()
|
||||
except Exception as e:
|
||||
res = ErrorMsg(error=str(e), req=req)
|
||||
await asend(ws, res)
|
||||
logger.exception(repr(res), e)
|
||||
return
|
||||
|
||||
@app.websocket('/api/download')
|
||||
async def download(request, ws):
|
||||
alink = fileserver.alink
|
||||
while True:
|
||||
req = None
|
||||
try:
|
||||
print("Waiting for download command")
|
||||
text = await ws.recv()
|
||||
if not isinstance(text, str):
|
||||
raise ValueError(f"Expected JSON control, got binary len(data) = {len(text)}")
|
||||
req = msgspec.json.decode(text, type=FileRange)
|
||||
print("download", req)
|
||||
pos = req.start
|
||||
while pos < req.end:
|
||||
end = min(req.end, pos + (1<<20))
|
||||
data = await alink(("download", req.name, pos, end))
|
||||
await asend(ws, data)
|
||||
pos += len(data)
|
||||
# Report success
|
||||
res = StatusMsg(status="ack", req=req)
|
||||
await asend(ws, res)
|
||||
print(ws, dir(ws))
|
||||
await ws.drain()
|
||||
print(res)
|
||||
|
||||
except Exception as e:
|
||||
res = ErrorMsg(error=str(e), req=req)
|
||||
await asend(ws, res)
|
||||
logger.exception(repr(res), e)
|
||||
return
|
||||
|
||||
@app.websocket("/api/control")
|
||||
async def control(request, websocket):
|
||||
cmd = msgspec.json.decode(await websocket.recv(), type=ControlBase)
|
||||
await asyncio.to_thread(cmd)
|
||||
await asend(websocket, StatusMsg(status="ack", req=cmd))
|
||||
|
||||
Reference in New Issue
Block a user