diff --git a/cista/app.py b/cista/app.py index b2e28a9..7e7c098 100755 --- a/cista/app.py +++ b/cista/app.py @@ -1,21 +1,54 @@ +import asyncio from importlib.resources import files import msgspec from html5tagger import E -from sanic import Sanic, redirect +from sanic import Forbidden, Sanic, SanicException, errorpages from sanic.log import logger -from sanic.response import html +from sanic.response import html, json, redirect from . import config, session, watching from .auth import authbp from .fileio import FileServer -from .protocol import ErrorMsg, FileRange, StatusMsg +from .protocol import ControlBase, ErrorMsg, FileRange, StatusMsg + app = Sanic("cista") fileserver = FileServer() watching.register(app, "/api/watch") app.blueprint(authbp) +@app.on_request +async def use_session(request): + request.ctx.session = session.get(request) + # CSRF protection + if request.method == "GET" and request.headers.upgrade != "websocket": + return # Ordinary GET requests are fine + if request.headers.origin.split("//", 1)[1] != request.host: + raise Forbidden("Invalid origin: Cross-Site requests not permitted") + +@app.exception(Exception) +async def handle_sanic_exception(request, e): + context, code, message = {}, 500, str(e) or "Internal Server Error" + if isinstance(e, SanicException): + context = e.context or {} + code = e.status_code + message = f"⚠️ {message}" + # Non-browsers get JSON errors + if "text/html" not in request.headers.accept: + return json({"error": {"code": code, "message": message, **context}}, status=code) + # Redirections flash the error message via cookies + if "redirect" in context: + res = redirect(context["redirect"]) + res.cookies.add_cookie("message", message, max_age=5) + return res + # Otherwise use Sanic's default error page + return errorpages.HTMLRenderer(request, e, debug=app.debug).full() + +@app.on_response +async def resphandler(request, response): + print("In response handler", request.path, response) + def asend(ws, msg): return ws.send(msg if isinstance(msg, bytes) else msgspec.json.encode(msg).decode()) @@ -36,11 +69,11 @@ async def index_page(request): if not s: return redirect("/login") index = files("cista").joinpath("static", "index.html").read_text() - flash = request.cookies.flash + flash = request.cookies.message if flash: - index += str(E.div(flash, id="flash")) + index += str(E.dialog(flash, id="flash", open=True, style="position: fixed; top: 0; left: 0; width: 100%; opacity: .8")) res = html(index) - res.cookies.delete_cookie("flash") + session.flash(res, None) return res return html(index) @@ -101,3 +134,9 @@ async def download(request, ws): await asend(ws, res) logger.exception(repr(res), e) return + +@app.websocket("/api/control") +async def control(request, websocket): + cmd = msgspec.json.decode(await websocket.recv(), type=ControlBase) + await asyncio.to_thread(cmd) + await asend(websocket, StatusMsg(status="ack", req=cmd)) diff --git a/cista/auth.py b/cista/auth.py index 32faca8..64538b8 100755 --- a/cista/auth.py +++ b/cista/auth.py @@ -6,7 +6,7 @@ from unicodedata import normalize import argon2 import msgspec from html5tagger import Document -from sanic import BadRequest, Blueprint, html, json, redirect +from sanic import BadRequest, Blueprint, Forbidden, html, json, redirect from . import config, session @@ -64,18 +64,17 @@ async def login_page(request): with doc.div(id="login"): with doc.form(method="POST", autocomplete="on"): doc.h1("Login") - doc.input(name="username", placeholder="Username", autocomplete="username").br - doc.input(type="password", name="password", placeholder="Password", autocomplete="current-password").br + doc.input(name="username", placeholder="Username", autocomplete="username", required=True).br + doc.input(type="password", name="password", placeholder="Password", autocomplete="current-password", required=True).br doc.input(type="submit", value="Login") s = session.get(request) if s: name = s["username"] with doc.form(method="POST", action="/logout"): doc.input(type="submit", value=f"Logout {name}") - flash = request.cookies.flash + flash = request.cookies.message if flash: - print("flash", flash) - doc.p(flash) + doc.dialog(flash, id="flash", open=True, style="position: fixed; top: 0; left: 0; width: 100%; opacity: .8") res = html(doc) if flash: res.cookies.delete_cookie("flash") @@ -85,9 +84,8 @@ async def login_page(request): @authbp.post("/login") async def login_post(request): - json_format = request.headers.content_type == "application/json" try: - if json_format: + if request.headers.content_type == "application/json": username = request.json["username"] password = request.json["password"] else: @@ -96,36 +94,28 @@ async def login_post(request): if not username or not password: raise KeyError except KeyError: - raise BadRequest("Missing username or password") + raise BadRequest("Missing username or password", context={"redirect": "/login"}) try: user = login(username, password) except ValueError as e: - if json_format: - res = json({ - "status": "error", - "error": str(e), - }) - else: - res = redirect("/login") - res.cookies.add_cookie("flash", str(e), max_age=5) - print("Login error:", res.cookies) - return res + raise Forbidden(str(e), context={"redirect": "/login"}) - if json_format: - res = json({ - "status": "authenticated", - "user": username, - "privileged": user.privileged, - }) - else: + if "text/html" in request.headers.accept: res = redirect("/") - res.cookies.add_cookie("flash", "Logged in", max_age=5) + session.flash(res, "Logged in") + else: + res = json({"data": {"username": username, "privileged": user.privileged}}) session.create(res, username) return res @authbp.post("/logout") async def logout_post(request): - res = redirect("/") + s = request.ctx.session + msg = "Logged out" if s else "Not logged in" + if "text/html" in request.headers.accept: + res = redirect("/login") + res.cookies.add_cookie("flash", msg, max_age=5) + else: + res = json({"message": msg}) session.delete(res) - res.cookies.add_cookie("flash", "Logged out", max_age=5) return res diff --git a/cista/fileio.py b/cista/fileio.py index e225d82..6bedc7e 100755 --- a/cista/fileio.py +++ b/cista/fileio.py @@ -5,7 +5,7 @@ from pathlib import Path, PurePosixPath from pathvalidate import sanitize_filepath -from . import config +from . import config, protocol from .asynclink import AsyncLink from .lrucache import LRUCache @@ -80,6 +80,8 @@ class FileServer: try: for req in slink: with req as (command, *args): + if command == "control": + self.control(*args) if command == "upload": req.set_result(self.upload(*args)) elif command == "download": diff --git a/cista/protocol.py b/cista/protocol.py index f24966d..494fe8e 100755 --- a/cista/protocol.py +++ b/cista/protocol.py @@ -1,9 +1,67 @@ from __future__ import annotations +import shutil -from typing import Dict, Tuple, Union +from sanic import BadRequest +from cista import config +from cista.fileio import sanitize_filename import msgspec +## Control commands + +class ControlBase(msgspec.Struct, tag_field="op", tag=str.lower): + def __call__(self): + raise NotImplementedError + +class MkDir(ControlBase): + path: str + def __call__(self): + path = config.config.path / sanitize_filename(self.path) + path.mkdir(parents=False, exist_ok=False) + +class Rename(ControlBase): + path: str + to: str + def __call__(self): + to = sanitize_filename(self.to) + if "/" in to: + raise BadRequest("Rename 'to' name should only contain filename, not path") + path = config.config.path / sanitize_filename(self.path) + path.rename(path.with_name(to)) + +class Rm(ControlBase): + sel: list[str] + def __call__(self): + root = config.config.path + sel = [root / sanitize_filename(p) for p in self.sel] + for p in sel: + shutil.rmtree(p, ignore_errors=True) + +class Mv(ControlBase): + sel: list[str] + dst: str + def __call__(self): + root = config.config.path + sel = [root / sanitize_filename(p) for p in self.sel] + dst = root / sanitize_filename(self.dst) + if not dst.is_dir(): + raise BadRequest("The destination must be a directory") + for p in sel: + shutil.move(p, dst) + +class Cp(ControlBase): + sel: list[str] + dst: str + def __call__(self): + root = config.config.path + sel = [root / sanitize_filename(p) for p in self.sel] + dst = root / sanitize_filename(self.dst) + if not dst.is_dir(): + raise BadRequest("The destination must be a directory") + for p in sel: + # Note: copies as dst rather than in dst unless name is appended. + shutil.copytree(p, dst / p.name, dirs_exist_ok=True, ignore_dangling_symlinks=True) + ## File uploads and downloads class FileRange(msgspec.Struct): @@ -52,7 +110,7 @@ class DirEntry(msgspec.Struct): if k != "dir" } -DirList = dict[str, Union[FileEntry, DirEntry]] +DirList = dict[str, FileEntry | DirEntry] class UpdateEntry(msgspec.Struct, omit_defaults=True): diff --git a/cista/session.py b/cista/session.py index a310f0e..7d0bc94 100755 --- a/cista/session.py +++ b/cista/session.py @@ -30,3 +30,10 @@ def update(res, s, **kwargs): def delete(res): res.cookies.delete_cookie("s") + + +def flash(res, message: str | None): + if message is None: + res.cookies.delete_cookie("message") + else: + res.cookies.add_cookie("message", message, max_age=5) diff --git a/pyproject.toml b/pyproject.toml index 1b96717..e163766 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,6 +30,11 @@ Homepage = "" [project.scripts] cista = "cista.__main__:main" +[project.optional-dependencies] +dev = [ + "pytest", +] + [tool.hatch.version] source = "vcs" @@ -43,3 +48,13 @@ version_tuple = {version_tuple!r} targets.sdist.include = [ "/cista", ] + +[tool.pytest.ini_options] +addopts = [ + "--import-mode=importlib", + "--verbosity=-1", + "-p no:warnings", +] +testpaths = [ + "tests", +] diff --git a/tests/test_control.py b/tests/test_control.py new file mode 100644 index 0000000..bd767a4 --- /dev/null +++ b/tests/test_control.py @@ -0,0 +1,45 @@ +import pytest +from cista import config +from cista.protocol import MkDir, Rename, Rm, Mv, Cp +from pathlib import Path +import tempfile + +@pytest.fixture +def setup_temp_dir(): + with tempfile.TemporaryDirectory() as tmpdirname: + config.config = config.Config(path=Path(tmpdirname), listen=":0") + yield Path(tmpdirname) + +def test_mkdir(setup_temp_dir): + cmd = MkDir(path="new_folder") + cmd() + assert (setup_temp_dir / "new_folder").is_dir() + +def test_rename(setup_temp_dir): + (setup_temp_dir / "old_name").mkdir() + cmd = Rename(path="old_name", to="new_name") + cmd() + assert not (setup_temp_dir / "old_name").exists() + assert (setup_temp_dir / "new_name").is_dir() + +def test_rm(setup_temp_dir): + (setup_temp_dir / "folder_to_remove").mkdir() + cmd = Rm(sel=["folder_to_remove"]) + cmd() + assert not (setup_temp_dir / "folder_to_remove").exists() + +def test_mv(setup_temp_dir): + (setup_temp_dir / "folder_to_move").mkdir() + (setup_temp_dir / "destination").mkdir() + cmd = Mv(sel=["folder_to_move"], dst="destination") + cmd() + assert not (setup_temp_dir / "folder_to_move").exists() + assert (setup_temp_dir / "destination" / "folder_to_move").is_dir() + +def test_cp(setup_temp_dir): + (setup_temp_dir / "folder_to_copy").mkdir() + (setup_temp_dir / "destination").mkdir() + cmd = Cp(sel=["folder_to_copy"], dst="destination") + cmd() + assert (setup_temp_dir / "folder_to_copy").is_dir() + assert (setup_temp_dir / "destination" / "folder_to_copy").is_dir()