Implemented control commands and tests. Rewritten error and session/flash handling.

This commit is contained in:
Leo Vasanko 2023-10-21 04:44:43 +03:00 committed by Leo Vasanko
parent 9939cb33fa
commit e90174a09d
7 changed files with 194 additions and 38 deletions

View File

@ -1,21 +1,54 @@
import asyncio
from importlib.resources import files from importlib.resources import files
import msgspec import msgspec
from html5tagger import E from html5tagger import E
from sanic import Sanic, redirect from sanic import Forbidden, Sanic, SanicException, errorpages
from sanic.log import logger from sanic.log import logger
from sanic.response import html from sanic.response import html, json, redirect
from . import config, session, watching from . import config, session, watching
from .auth import authbp from .auth import authbp
from .fileio import FileServer from .fileio import FileServer
from .protocol import ErrorMsg, FileRange, StatusMsg from .protocol import ControlBase, ErrorMsg, FileRange, StatusMsg
app = Sanic("cista") app = Sanic("cista")
fileserver = FileServer() fileserver = FileServer()
watching.register(app, "/api/watch") watching.register(app, "/api/watch")
app.blueprint(authbp) app.blueprint(authbp)
@app.on_request
async def use_session(request):
request.ctx.session = session.get(request)
# CSRF protection
if request.method == "GET" and request.headers.upgrade != "websocket":
return # Ordinary GET requests are fine
if request.headers.origin.split("//", 1)[1] != request.host:
raise Forbidden("Invalid origin: Cross-Site requests not permitted")
@app.exception(Exception)
async def handle_sanic_exception(request, e):
context, code, message = {}, 500, str(e) or "Internal Server Error"
if isinstance(e, SanicException):
context = e.context or {}
code = e.status_code
message = f"⚠️ {message}"
# Non-browsers get JSON errors
if "text/html" not in request.headers.accept:
return json({"error": {"code": code, "message": message, **context}}, status=code)
# Redirections flash the error message via cookies
if "redirect" in context:
res = redirect(context["redirect"])
res.cookies.add_cookie("message", message, max_age=5)
return res
# Otherwise use Sanic's default error page
return errorpages.HTMLRenderer(request, e, debug=app.debug).full()
@app.on_response
async def resphandler(request, response):
print("In response handler", request.path, response)
def asend(ws, msg): def asend(ws, msg):
return ws.send(msg if isinstance(msg, bytes) else msgspec.json.encode(msg).decode()) return ws.send(msg if isinstance(msg, bytes) else msgspec.json.encode(msg).decode())
@ -36,11 +69,11 @@ async def index_page(request):
if not s: if not s:
return redirect("/login") return redirect("/login")
index = files("cista").joinpath("static", "index.html").read_text() index = files("cista").joinpath("static", "index.html").read_text()
flash = request.cookies.flash flash = request.cookies.message
if flash: if flash:
index += str(E.div(flash, id="flash")) index += str(E.dialog(flash, id="flash", open=True, style="position: fixed; top: 0; left: 0; width: 100%; opacity: .8"))
res = html(index) res = html(index)
res.cookies.delete_cookie("flash") session.flash(res, None)
return res return res
return html(index) return html(index)
@ -101,3 +134,9 @@ async def download(request, ws):
await asend(ws, res) await asend(ws, res)
logger.exception(repr(res), e) logger.exception(repr(res), e)
return return
@app.websocket("/api/control")
async def control(request, websocket):
cmd = msgspec.json.decode(await websocket.recv(), type=ControlBase)
await asyncio.to_thread(cmd)
await asend(websocket, StatusMsg(status="ack", req=cmd))

View File

@ -6,7 +6,7 @@ from unicodedata import normalize
import argon2 import argon2
import msgspec import msgspec
from html5tagger import Document from html5tagger import Document
from sanic import BadRequest, Blueprint, html, json, redirect from sanic import BadRequest, Blueprint, Forbidden, html, json, redirect
from . import config, session from . import config, session
@ -64,18 +64,17 @@ async def login_page(request):
with doc.div(id="login"): with doc.div(id="login"):
with doc.form(method="POST", autocomplete="on"): with doc.form(method="POST", autocomplete="on"):
doc.h1("Login") doc.h1("Login")
doc.input(name="username", placeholder="Username", autocomplete="username").br doc.input(name="username", placeholder="Username", autocomplete="username", required=True).br
doc.input(type="password", name="password", placeholder="Password", autocomplete="current-password").br doc.input(type="password", name="password", placeholder="Password", autocomplete="current-password", required=True).br
doc.input(type="submit", value="Login") doc.input(type="submit", value="Login")
s = session.get(request) s = session.get(request)
if s: if s:
name = s["username"] name = s["username"]
with doc.form(method="POST", action="/logout"): with doc.form(method="POST", action="/logout"):
doc.input(type="submit", value=f"Logout {name}") doc.input(type="submit", value=f"Logout {name}")
flash = request.cookies.flash flash = request.cookies.message
if flash: if flash:
print("flash", flash) doc.dialog(flash, id="flash", open=True, style="position: fixed; top: 0; left: 0; width: 100%; opacity: .8")
doc.p(flash)
res = html(doc) res = html(doc)
if flash: if flash:
res.cookies.delete_cookie("flash") res.cookies.delete_cookie("flash")
@ -85,9 +84,8 @@ async def login_page(request):
@authbp.post("/login") @authbp.post("/login")
async def login_post(request): async def login_post(request):
json_format = request.headers.content_type == "application/json"
try: try:
if json_format: if request.headers.content_type == "application/json":
username = request.json["username"] username = request.json["username"]
password = request.json["password"] password = request.json["password"]
else: else:
@ -96,36 +94,28 @@ async def login_post(request):
if not username or not password: if not username or not password:
raise KeyError raise KeyError
except KeyError: except KeyError:
raise BadRequest("Missing username or password") raise BadRequest("Missing username or password", context={"redirect": "/login"})
try: try:
user = login(username, password) user = login(username, password)
except ValueError as e: except ValueError as e:
if json_format: raise Forbidden(str(e), context={"redirect": "/login"})
res = json({
"status": "error",
"error": str(e),
})
else:
res = redirect("/login")
res.cookies.add_cookie("flash", str(e), max_age=5)
print("Login error:", res.cookies)
return res
if json_format: if "text/html" in request.headers.accept:
res = json({
"status": "authenticated",
"user": username,
"privileged": user.privileged,
})
else:
res = redirect("/") res = redirect("/")
res.cookies.add_cookie("flash", "Logged in", max_age=5) session.flash(res, "Logged in")
else:
res = json({"data": {"username": username, "privileged": user.privileged}})
session.create(res, username) session.create(res, username)
return res return res
@authbp.post("/logout") @authbp.post("/logout")
async def logout_post(request): async def logout_post(request):
res = redirect("/") s = request.ctx.session
msg = "Logged out" if s else "Not logged in"
if "text/html" in request.headers.accept:
res = redirect("/login")
res.cookies.add_cookie("flash", msg, max_age=5)
else:
res = json({"message": msg})
session.delete(res) session.delete(res)
res.cookies.add_cookie("flash", "Logged out", max_age=5)
return res return res

View File

@ -5,7 +5,7 @@ from pathlib import Path, PurePosixPath
from pathvalidate import sanitize_filepath from pathvalidate import sanitize_filepath
from . import config from . import config, protocol
from .asynclink import AsyncLink from .asynclink import AsyncLink
from .lrucache import LRUCache from .lrucache import LRUCache
@ -80,6 +80,8 @@ class FileServer:
try: try:
for req in slink: for req in slink:
with req as (command, *args): with req as (command, *args):
if command == "control":
self.control(*args)
if command == "upload": if command == "upload":
req.set_result(self.upload(*args)) req.set_result(self.upload(*args))
elif command == "download": elif command == "download":

View File

@ -1,9 +1,67 @@
from __future__ import annotations from __future__ import annotations
import shutil
from typing import Dict, Tuple, Union from sanic import BadRequest
from cista import config
from cista.fileio import sanitize_filename
import msgspec import msgspec
## Control commands
class ControlBase(msgspec.Struct, tag_field="op", tag=str.lower):
def __call__(self):
raise NotImplementedError
class MkDir(ControlBase):
path: str
def __call__(self):
path = config.config.path / sanitize_filename(self.path)
path.mkdir(parents=False, exist_ok=False)
class Rename(ControlBase):
path: str
to: str
def __call__(self):
to = sanitize_filename(self.to)
if "/" in to:
raise BadRequest("Rename 'to' name should only contain filename, not path")
path = config.config.path / sanitize_filename(self.path)
path.rename(path.with_name(to))
class Rm(ControlBase):
sel: list[str]
def __call__(self):
root = config.config.path
sel = [root / sanitize_filename(p) for p in self.sel]
for p in sel:
shutil.rmtree(p, ignore_errors=True)
class Mv(ControlBase):
sel: list[str]
dst: str
def __call__(self):
root = config.config.path
sel = [root / sanitize_filename(p) for p in self.sel]
dst = root / sanitize_filename(self.dst)
if not dst.is_dir():
raise BadRequest("The destination must be a directory")
for p in sel:
shutil.move(p, dst)
class Cp(ControlBase):
sel: list[str]
dst: str
def __call__(self):
root = config.config.path
sel = [root / sanitize_filename(p) for p in self.sel]
dst = root / sanitize_filename(self.dst)
if not dst.is_dir():
raise BadRequest("The destination must be a directory")
for p in sel:
# Note: copies as dst rather than in dst unless name is appended.
shutil.copytree(p, dst / p.name, dirs_exist_ok=True, ignore_dangling_symlinks=True)
## File uploads and downloads ## File uploads and downloads
class FileRange(msgspec.Struct): class FileRange(msgspec.Struct):
@ -52,7 +110,7 @@ class DirEntry(msgspec.Struct):
if k != "dir" if k != "dir"
} }
DirList = dict[str, Union[FileEntry, DirEntry]] DirList = dict[str, FileEntry | DirEntry]
class UpdateEntry(msgspec.Struct, omit_defaults=True): class UpdateEntry(msgspec.Struct, omit_defaults=True):

View File

@ -30,3 +30,10 @@ def update(res, s, **kwargs):
def delete(res): def delete(res):
res.cookies.delete_cookie("s") res.cookies.delete_cookie("s")
def flash(res, message: str | None):
if message is None:
res.cookies.delete_cookie("message")
else:
res.cookies.add_cookie("message", message, max_age=5)

View File

@ -30,6 +30,11 @@ Homepage = ""
[project.scripts] [project.scripts]
cista = "cista.__main__:main" cista = "cista.__main__:main"
[project.optional-dependencies]
dev = [
"pytest",
]
[tool.hatch.version] [tool.hatch.version]
source = "vcs" source = "vcs"
@ -43,3 +48,13 @@ version_tuple = {version_tuple!r}
targets.sdist.include = [ targets.sdist.include = [
"/cista", "/cista",
] ]
[tool.pytest.ini_options]
addopts = [
"--import-mode=importlib",
"--verbosity=-1",
"-p no:warnings",
]
testpaths = [
"tests",
]

45
tests/test_control.py Normal file
View File

@ -0,0 +1,45 @@
import pytest
from cista import config
from cista.protocol import MkDir, Rename, Rm, Mv, Cp
from pathlib import Path
import tempfile
@pytest.fixture
def setup_temp_dir():
with tempfile.TemporaryDirectory() as tmpdirname:
config.config = config.Config(path=Path(tmpdirname), listen=":0")
yield Path(tmpdirname)
def test_mkdir(setup_temp_dir):
cmd = MkDir(path="new_folder")
cmd()
assert (setup_temp_dir / "new_folder").is_dir()
def test_rename(setup_temp_dir):
(setup_temp_dir / "old_name").mkdir()
cmd = Rename(path="old_name", to="new_name")
cmd()
assert not (setup_temp_dir / "old_name").exists()
assert (setup_temp_dir / "new_name").is_dir()
def test_rm(setup_temp_dir):
(setup_temp_dir / "folder_to_remove").mkdir()
cmd = Rm(sel=["folder_to_remove"])
cmd()
assert not (setup_temp_dir / "folder_to_remove").exists()
def test_mv(setup_temp_dir):
(setup_temp_dir / "folder_to_move").mkdir()
(setup_temp_dir / "destination").mkdir()
cmd = Mv(sel=["folder_to_move"], dst="destination")
cmd()
assert not (setup_temp_dir / "folder_to_move").exists()
assert (setup_temp_dir / "destination" / "folder_to_move").is_dir()
def test_cp(setup_temp_dir):
(setup_temp_dir / "folder_to_copy").mkdir()
(setup_temp_dir / "destination").mkdir()
cmd = Cp(sel=["folder_to_copy"], dst="destination")
cmd()
assert (setup_temp_dir / "folder_to_copy").is_dir()
assert (setup_temp_dir / "destination" / "folder_to_copy").is_dir()