diff --git a/cista/__main__.py b/cista/__main__.py index d0de830..bfff876 100755 --- a/cista/__main__.py +++ b/cista/__main__.py @@ -7,7 +7,7 @@ import cista from cista import app, config, droppy, serve, server80 from cista.util import pwgen -app, server80.app # Needed for Sanic multiprocessing +del app, server80.app # Only import needed, for Sanic multiprocessing doc = f"""Cista {cista.__version__} - A file storage for the web. diff --git a/cista/api.py b/cista/api.py new file mode 100644 index 0000000..9d74738 --- /dev/null +++ b/cista/api.py @@ -0,0 +1,86 @@ +import asyncio +import typing + +import msgspec +from sanic import Blueprint, SanicException + +from cista import watching +from cista.fileio import FileServer +from cista.protocol import ControlBase, ErrorMsg, FileRange, StatusMsg +from cista.util.apphelpers import asend, websocket_wrapper + +bp = Blueprint("api", url_prefix="/api") +fileserver = FileServer() + +@bp.before_server_start +async def start_fileserver(app, _): + await fileserver.start() + +@bp.after_server_stop +async def stop_fileserver(app, _): + await fileserver.stop() + +@bp.websocket('upload') +@websocket_wrapper +async def upload(req, ws): + alink = fileserver.alink + while True: + req = None + text = await ws.recv() + if not isinstance(text, str): + raise ValueError(f"Expected JSON control, got binary len(data) = {len(text)}") + req = msgspec.json.decode(text, type=FileRange) + pos = req.start + data = None + while pos < req.end and (data := await ws.recv()) and isinstance(data, bytes): + sentsize = await alink(("upload", req.name, pos, data, req.size)) + pos += typing.cast(int, sentsize) + if pos != req.end: + d = f"{len(data)} bytes" if isinstance(data, bytes) else data + raise ValueError(f"Expected {req.end - pos} more bytes, got {d}") + # Report success + res = StatusMsg(status="ack", req=req) + await asend(ws, res) + #await ws.drain() + +@bp.websocket('download') +@websocket_wrapper +async def download(req, ws): + alink = fileserver.alink + while True: + req = None + text = await ws.recv() + if not isinstance(text, str): + raise ValueError(f"Expected JSON control, got binary len(data) = {len(text)}") + req = msgspec.json.decode(text, type=FileRange) + pos = req.start + while pos < req.end: + end = min(req.end, pos + (1<<20)) + data = typing.cast(bytes, await alink(("download", req.name, pos, end))) + await asend(ws, data) + pos += len(data) + # Report success + res = StatusMsg(status="ack", req=req) + await asend(ws, res) + #await ws.drain() + +@bp.websocket("control") +@websocket_wrapper +async def control(req, ws): + cmd = msgspec.json.decode(await ws.recv(), type=ControlBase) + await asyncio.to_thread(cmd) + await asend(ws, StatusMsg(status="ack", req=cmd)) + +@bp.websocket("watch") +@websocket_wrapper +async def watch(req, ws): + try: + with watching.tree_lock: + q = watching.pubsub[ws] = asyncio.Queue() + # Init with full tree + await ws.send(watching.refresh()) + # Send updates + while True: + await ws.send(await q.get()) + finally: + del watching.pubsub[ws] diff --git a/cista/app.py b/cista/app.py index def31cc..5f0e010 100755 --- a/cista/app.py +++ b/cista/app.py @@ -1,72 +1,56 @@ -import asyncio -from importlib.resources import files - -import msgspec -from html5tagger import E -from sanic import Forbidden, Sanic, SanicException, errorpages, raw -from sanic.log import logger -from sanic.response import html, json, redirect import mimetypes - -from cista import config, session, watching -from cista.util import filename -from cista.auth import authbp -from cista.fileio import FileServer -from cista.protocol import ControlBase, ErrorMsg, FileRange, StatusMsg +from importlib.resources import files from urllib.parse import unquote -app = Sanic("cista") -fileserver = FileServer() -watching.register(app, "/api/watch") -app.blueprint(authbp) +from html5tagger import E +from sanic import Blueprint, Sanic, raw +from sanic.exceptions import Forbidden, NotFound -@app.on_request -async def use_session(request): - request.ctx.session = session.get(request) - # CSRF protection - if request.method == "GET" and request.headers.upgrade != "websocket": - return # Ordinary GET requests are fine - if request.headers.origin.split("//", 1)[1] != request.host: - raise Forbidden("Invalid origin: Cross-Site requests not permitted") +from cista import auth, config, session, watching +from cista.api import bp +from cista.util import filename +from cista.util.apphelpers import handle_sanic_exception -@app.exception(Exception) -async def handle_sanic_exception(request, e): - context, code, message = {}, 500, str(e) or "Internal Server Error" - if isinstance(e, SanicException): - context = e.context or {} - code = e.status_code - message = f"⚠️ {message}" - # Non-browsers get JSON errors - if "text/html" not in request.headers.accept: - return json({"error": {"code": code, "message": message, **context}}, status=code) - # Redirections flash the error message via cookies - if "redirect" in context: - res = redirect(context["redirect"]) - res.cookies.add_cookie("message", message, max_age=5) - return res - # Otherwise use Sanic's default error page - return errorpages.HTMLRenderer(request, e, debug=app.debug).full() - -@app.on_response -async def resphandler(request, response): - print("In response handler", request.path, response) - -def asend(ws, msg): - return ws.send(msg if isinstance(msg, bytes) else msgspec.json.encode(msg).decode()) +app = Sanic("cista", strict_slashes=True) +app.blueprint(auth.bp) +app.blueprint(bp) +app.exception(Exception)(handle_sanic_exception) @app.before_server_start -async def start_fileserver(app, _): - config.load_config() # Main process may have loaded it but we haven't - app.static("/files", config.config.path, use_content_range=True, stream_large_files=True, directory_view=True) - - await fileserver.start() +async def main_start(app, loop): + config.load_config() + await watching.start(app, loop) @app.after_server_stop -async def stop_fileserver(app, _): - await fileserver.stop() +async def main_stop(app, loop): + await watching.stop(app, loop) -@app.get("/") -async def wwwroot(request, path=""): +@app.on_request +async def use_session(req): + req.ctx.session = session.get(req) + try: + req.ctx.user = config.config.users[req.ctx.session["user"]] # type: ignore + except (AttributeError, KeyError, TypeError): + req.ctx.user = None + # CSRF protection + if req.method == "GET" and req.headers.upgrade != "websocket": + return # Ordinary GET requests are fine + # Check that origin matches host, for browsers which should all send Origin. + # Curl doesn't send any Origin header, so we allow it anyway. + origin = req.headers.origin + if origin and origin.split("//", 1)[1] != req.host: + raise Forbidden("Invalid origin: Cross-Site requests not permitted") + +@app.before_server_start +def http_fileserver(app, _): + bp = Blueprint("fileserver") + bp.on_request(auth.verify) + bp.static("/files/", config.config.path, use_content_range=True, stream_large_files=True, directory_view=True) + app.blueprint(bp) + +@app.get("/", static=True) +async def wwwroot(req, path=""): + """Frontend files only""" name = filename.sanitize(unquote(path)) if path else "index.html" try: index = files("cista").joinpath("wwwroot", name).read_bytes() @@ -74,67 +58,3 @@ async def wwwroot(request, path=""): raise NotFound(f"File not found: /{path}", extra={"name": name, "exception": repr(e)}) mime = mimetypes.guess_type(name)[0] or "application/octet-stream" return raw(index, content_type=mime) - -@app.websocket('/api/upload') -async def upload(request, ws): - alink = fileserver.alink - url = request.url_for("upload") - while True: - req = None - try: - text = await ws.recv() - if not isinstance(text, str): - raise ValueError(f"Expected JSON control, got binary len(data) = {len(text)}") - req = msgspec.json.decode(text, type=FileRange) - pos = req.start - while pos < req.end and (data := await ws.recv()) and isinstance(data, bytes): - pos += await alink(("upload", req.name, pos, data, req.size)) - if pos != req.end: - d = f"{len(data)} bytes" if isinstance(data, bytes) else data - raise ValueError(f"Expected {req.end - pos} more bytes, got {d}") - # Report success - res = StatusMsg(status="ack", req=req) - await asend(ws, res) - await ws.drain() - except Exception as e: - res = ErrorMsg(error=str(e), req=req) - await asend(ws, res) - logger.exception(repr(res), e) - return - -@app.websocket('/api/download') -async def download(request, ws): - alink = fileserver.alink - while True: - req = None - try: - print("Waiting for download command") - text = await ws.recv() - if not isinstance(text, str): - raise ValueError(f"Expected JSON control, got binary len(data) = {len(text)}") - req = msgspec.json.decode(text, type=FileRange) - print("download", req) - pos = req.start - while pos < req.end: - end = min(req.end, pos + (1<<20)) - data = await alink(("download", req.name, pos, end)) - await asend(ws, data) - pos += len(data) - # Report success - res = StatusMsg(status="ack", req=req) - await asend(ws, res) - print(ws, dir(ws)) - await ws.drain() - print(res) - - except Exception as e: - res = ErrorMsg(error=str(e), req=req) - await asend(ws, res) - logger.exception(repr(res), e) - return - -@app.websocket("/api/control") -async def control(request, websocket): - cmd = msgspec.json.decode(await websocket.recv(), type=ControlBase) - await asyncio.to_thread(cmd) - await asend(websocket, StatusMsg(status="ack", req=cmd)) diff --git a/cista/auth.py b/cista/auth.py index d17de74..570f373 100755 --- a/cista/auth.py +++ b/cista/auth.py @@ -6,7 +6,8 @@ from unicodedata import normalize import argon2 import msgspec from html5tagger import Document -from sanic import BadRequest, Blueprint, Forbidden, html, json, redirect +from sanic import Blueprint, html, json, redirect +from sanic.exceptions import BadRequest, Forbidden, Unauthorized from cista import config, session @@ -56,9 +57,18 @@ class LoginResponse(msgspec.Struct): privileged: bool = False error: str = "" -authbp = Blueprint("auth") +def verify(request, privileged=False): + """Raise Unauthorized or Forbidden if the request is not authorized""" + if privileged: + if request.ctx.user: + if request.ctx.user.privileged: return + raise Forbidden("Access Forbidden: Only for privileged users") + elif config.config.public or request.ctx.user: return + raise Unauthorized("Login required", "cookie", context={"redirect": "/login"}) -@authbp.get("/login") +bp = Blueprint("auth") + +@bp.get("/login") async def login_page(request): doc = Document("Cista Login") with doc.div(id="login"): @@ -82,7 +92,7 @@ async def login_page(request): session.delete(res) return res -@authbp.post("/login") +@bp.post("/login") async def login_post(request): try: if request.headers.content_type == "application/json": @@ -108,7 +118,7 @@ async def login_post(request): session.create(res, username) return res -@authbp.post("/logout") +@bp.post("/logout") async def logout_post(request): s = request.ctx.session msg = "Logged out" if s else "Not logged in" diff --git a/cista/fileio.py b/cista/fileio.py index b525e00..9c10c82 100755 --- a/cista/fileio.py +++ b/cista/fileio.py @@ -30,6 +30,7 @@ class File: self.writable = True def write(self, pos, buffer, *, file_size=None): + assert self.fd is not None if not self.writable: # Create/open file self.open_rw() @@ -39,6 +40,7 @@ class File: os.write(self.fd, buffer) def __getitem__(self, slice): + assert self.fd is not None if self.fd is None: self.open_ro() os.lseek(self.fd, slice.start, os.SEEK_SET) @@ -71,8 +73,6 @@ class FileServer: try: for req in slink: with req as (command, *args): - if command == "control": - self.control(*args) if command == "upload": req.set_result(self.upload(*args)) elif command == "download": diff --git a/cista/index.html b/cista/index.html deleted file mode 100755 index db336cb..0000000 --- a/cista/index.html +++ /dev/null @@ -1,241 +0,0 @@ - -Storage - -
-

Quick file upload

-

Uses parallel WebSocket connections for increased bandwidth /api/upload

- - -
- -
-

Files

-
    -
    - - diff --git a/cista/protocol.py b/cista/protocol.py index 6fbf0e3..1964e06 100755 --- a/cista/protocol.py +++ b/cista/protocol.py @@ -1,6 +1,7 @@ from __future__ import annotations import shutil +from typing import Any import msgspec from sanic import BadRequest @@ -71,14 +72,12 @@ class FileRange(msgspec.Struct): start: int end: int -class ErrorMsg(msgspec.Struct): - error: str - req: FileRange - class StatusMsg(msgspec.Struct): status: str req: FileRange +class ErrorMsg(msgspec.Struct): + error: dict[str, Any] ## Directory listings diff --git a/cista/serve.py b/cista/serve.py index 35f978a..7b8ecc9 100755 --- a/cista/serve.py +++ b/cista/serve.py @@ -4,7 +4,7 @@ from pathlib import Path from sanic import Sanic -from cista import config +from cista import config, server80 def run(dev=False): @@ -15,11 +15,10 @@ def run(dev=False): os.environ["SANIC_IGNORE_PRODUCTION_WARNING"] = "1" if opts.get("ssl"): # Run plain HTTP redirect/acme server on port 80 - from . import server80 server80.app.prepare(port=80, motd=False) domain = opts["host"] - opts["ssl"] = str(config.conffile.parent / domain) - app.prepare(**opts, motd=False, dev=dev, auto_reload=dev, access_log=True) + opts["ssl"] = str(config.conffile.parent / domain) # type: ignore + app.prepare(**opts, motd=False, dev=dev, auto_reload=dev, access_log=True) # type: ignore Sanic.serve() def parse_listen(listen): diff --git a/cista/session.py b/cista/session.py index ea030d1..b718290 100755 --- a/cista/session.py +++ b/cista/session.py @@ -26,7 +26,8 @@ def create(res, username, **kwargs): def update(res, s, **kwargs): s.update(kwargs) s = jwt.encode(s, session_secret()) - res.cookies.add_cookie("s", s, httponly=True, max_age=max(1, s["exp"] - int(time()))) + max_age = max(1, s["exp"] - int(time())) # type: ignore + res.cookies.add_cookie("s", s, httponly=True, max_age=max_age) def delete(res): res.cookies.delete_cookie("s") diff --git a/cista/util/apphelpers.py b/cista/util/apphelpers.py new file mode 100644 index 0000000..f5d790d --- /dev/null +++ b/cista/util/apphelpers.py @@ -0,0 +1,58 @@ +from functools import wraps + +import msgspec +from sanic import errorpages +from sanic.exceptions import SanicException +from sanic.log import logger +from sanic.response import raw, redirect + +from cista import auth +from cista.protocol import ErrorMsg + + +def asend(ws, msg): + """Send JSON message or bytes to a websocket""" + return ws.send(msg if isinstance(msg, bytes) else msgspec.json.encode(msg).decode()) + +def jres(data, **kwargs): + """JSON Sanic response, using msgspec encoding""" + return raw(msgspec.json.encode(data), content_type="application/json", **kwargs) + +async def handle_sanic_exception(request, e): + logger.exception(e) + context, code = {}, 500 + message = str(e) + if isinstance(e, SanicException): + context = e.context or {} + code = e.status_code + if not message or not request.app.debug and code == 500: + message = "Internal Server Error" + message = f"⚠️ {message}" if code < 500 else f"🛑 {message}" + # Non-browsers get JSON errors + if "text/html" not in request.headers.accept: + return jres(ErrorMsg({"code": code, "message": message, **context}), status=code) + # Redirections flash the error message via cookies + if "redirect" in context: + res = redirect(context["redirect"]) + res.cookies.add_cookie("message", message, max_age=5) + return res + # Otherwise use Sanic's default error page + return errorpages.HTMLRenderer(request, e, debug=request.app.debug).full() + +def websocket_wrapper(handler): + """Decorator for websocket handlers that catches exceptions and sends them back to the client""" + @wraps(handler) + async def wrapper(request, ws, *args, **kwargs): + try: + auth.verify(request) + await handler(request, ws, *args, **kwargs) + except Exception as e: + logger.exception(e) + context, code, message = {}, 500, str(e) or "Internal Server Error" + if isinstance(e, SanicException): + context = e.context or {} + code = e.status_code + message = f"⚠️ {message}" if code < 500 else f"🛑 {message}" + await asend(ws, ErrorMsg({"code": code, "message": message, **context})) + raise + return wrapper diff --git a/cista/util/filename.py b/cista/util/filename.py index 9f863a4..c464b2d 100644 --- a/cista/util/filename.py +++ b/cista/util/filename.py @@ -1,7 +1,9 @@ -from pathvalidate import sanitize_filepath import unicodedata from pathlib import PurePosixPath +from pathvalidate import sanitize_filepath + + def sanitize(filename: str) -> str: filename = unicodedata.normalize("NFC", filename) # UNIX filenames can contain backslashes but for compatibility we replace them with dashes diff --git a/cista/watching.py b/cista/watching.py index 02ad054..6b76b09 100755 --- a/cista/watching.py +++ b/cista/watching.py @@ -8,6 +8,7 @@ import msgspec from cista import config from cista.protocol import DirEntry, FileEntry, UpdateEntry +from cista.util.apphelpers import websocket_wrapper pubsub = {} tree = {"": None} @@ -132,30 +133,14 @@ async def broadcast(msg): for queue in pubsub.values(): await queue.put_nowait(msg) -def register(app, url): - @app.before_server_start - async def start_watcher(app, loop): - global rootpath - config.load_config() - rootpath = config.config.path - app.ctx.watcher = threading.Thread(target=watcher_thread, args=[loop]) - app.ctx.watcher.start() +async def start(app, loop): + global rootpath + config.load_config() + rootpath = config.config.path + app.ctx.watcher = threading.Thread(target=watcher_thread, args=[loop]) + app.ctx.watcher.start() - @app.after_server_stop - async def stop_watcher(app, _): - global quit - quit = True - app.ctx.watcher.join() - - @app.websocket(url) - async def watch(request, ws): - try: - with tree_lock: - q = pubsub[ws] = asyncio.Queue() - # Init with full tree - await ws.send(refresh()) - # Send updates - while True: - await ws.send(await q.get()) - finally: - del pubsub[ws] +async def stop(app, loop): + global quit + quit = True + app.ctx.watcher.join() diff --git a/pyproject.toml b/pyproject.toml index 3858e56..bcfd2e9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -62,3 +62,8 @@ addopts = [ testpaths = [ "tests", ] + +[tool.isort] +#src_paths = ["cista", "tests"] +line_length = 120 +multi_line_output = 5 diff --git a/tests/test_control.py b/tests/test_control.py index bd767a4..858f52b 100644 --- a/tests/test_control.py +++ b/tests/test_control.py @@ -1,8 +1,11 @@ -import pytest -from cista import config -from cista.protocol import MkDir, Rename, Rm, Mv, Cp -from pathlib import Path import tempfile +from pathlib import Path + +import pytest + +from cista import config +from cista.protocol import Cp, MkDir, Mv, Rename, Rm + @pytest.fixture def setup_temp_dir():