Implemented control commands and tests. Rewritten error and session/flash handling.
This commit is contained in:
parent
9939cb33fa
commit
e90174a09d
51
cista/app.py
51
cista/app.py
|
@ -1,21 +1,54 @@
|
||||||
|
import asyncio
|
||||||
from importlib.resources import files
|
from importlib.resources import files
|
||||||
|
|
||||||
import msgspec
|
import msgspec
|
||||||
from html5tagger import E
|
from html5tagger import E
|
||||||
from sanic import Sanic, redirect
|
from sanic import Forbidden, Sanic, SanicException, errorpages
|
||||||
from sanic.log import logger
|
from sanic.log import logger
|
||||||
from sanic.response import html
|
from sanic.response import html, json, redirect
|
||||||
|
|
||||||
from . import config, session, watching
|
from . import config, session, watching
|
||||||
from .auth import authbp
|
from .auth import authbp
|
||||||
from .fileio import FileServer
|
from .fileio import FileServer
|
||||||
from .protocol import ErrorMsg, FileRange, StatusMsg
|
from .protocol import ControlBase, ErrorMsg, FileRange, StatusMsg
|
||||||
|
|
||||||
|
|
||||||
app = Sanic("cista")
|
app = Sanic("cista")
|
||||||
fileserver = FileServer()
|
fileserver = FileServer()
|
||||||
watching.register(app, "/api/watch")
|
watching.register(app, "/api/watch")
|
||||||
app.blueprint(authbp)
|
app.blueprint(authbp)
|
||||||
|
|
||||||
|
@app.on_request
|
||||||
|
async def use_session(request):
|
||||||
|
request.ctx.session = session.get(request)
|
||||||
|
# CSRF protection
|
||||||
|
if request.method == "GET" and request.headers.upgrade != "websocket":
|
||||||
|
return # Ordinary GET requests are fine
|
||||||
|
if request.headers.origin.split("//", 1)[1] != request.host:
|
||||||
|
raise Forbidden("Invalid origin: Cross-Site requests not permitted")
|
||||||
|
|
||||||
|
@app.exception(Exception)
|
||||||
|
async def handle_sanic_exception(request, e):
|
||||||
|
context, code, message = {}, 500, str(e) or "Internal Server Error"
|
||||||
|
if isinstance(e, SanicException):
|
||||||
|
context = e.context or {}
|
||||||
|
code = e.status_code
|
||||||
|
message = f"⚠️ {message}"
|
||||||
|
# Non-browsers get JSON errors
|
||||||
|
if "text/html" not in request.headers.accept:
|
||||||
|
return json({"error": {"code": code, "message": message, **context}}, status=code)
|
||||||
|
# Redirections flash the error message via cookies
|
||||||
|
if "redirect" in context:
|
||||||
|
res = redirect(context["redirect"])
|
||||||
|
res.cookies.add_cookie("message", message, max_age=5)
|
||||||
|
return res
|
||||||
|
# Otherwise use Sanic's default error page
|
||||||
|
return errorpages.HTMLRenderer(request, e, debug=app.debug).full()
|
||||||
|
|
||||||
|
@app.on_response
|
||||||
|
async def resphandler(request, response):
|
||||||
|
print("In response handler", request.path, response)
|
||||||
|
|
||||||
def asend(ws, msg):
|
def asend(ws, msg):
|
||||||
return ws.send(msg if isinstance(msg, bytes) else msgspec.json.encode(msg).decode())
|
return ws.send(msg if isinstance(msg, bytes) else msgspec.json.encode(msg).decode())
|
||||||
|
|
||||||
|
@ -36,11 +69,11 @@ async def index_page(request):
|
||||||
if not s:
|
if not s:
|
||||||
return redirect("/login")
|
return redirect("/login")
|
||||||
index = files("cista").joinpath("static", "index.html").read_text()
|
index = files("cista").joinpath("static", "index.html").read_text()
|
||||||
flash = request.cookies.flash
|
flash = request.cookies.message
|
||||||
if flash:
|
if flash:
|
||||||
index += str(E.div(flash, id="flash"))
|
index += str(E.dialog(flash, id="flash", open=True, style="position: fixed; top: 0; left: 0; width: 100%; opacity: .8"))
|
||||||
res = html(index)
|
res = html(index)
|
||||||
res.cookies.delete_cookie("flash")
|
session.flash(res, None)
|
||||||
return res
|
return res
|
||||||
return html(index)
|
return html(index)
|
||||||
|
|
||||||
|
@ -101,3 +134,9 @@ async def download(request, ws):
|
||||||
await asend(ws, res)
|
await asend(ws, res)
|
||||||
logger.exception(repr(res), e)
|
logger.exception(repr(res), e)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@app.websocket("/api/control")
|
||||||
|
async def control(request, websocket):
|
||||||
|
cmd = msgspec.json.decode(await websocket.recv(), type=ControlBase)
|
||||||
|
await asyncio.to_thread(cmd)
|
||||||
|
await asend(websocket, StatusMsg(status="ack", req=cmd))
|
||||||
|
|
|
@ -6,7 +6,7 @@ from unicodedata import normalize
|
||||||
import argon2
|
import argon2
|
||||||
import msgspec
|
import msgspec
|
||||||
from html5tagger import Document
|
from html5tagger import Document
|
||||||
from sanic import BadRequest, Blueprint, html, json, redirect
|
from sanic import BadRequest, Blueprint, Forbidden, html, json, redirect
|
||||||
|
|
||||||
from . import config, session
|
from . import config, session
|
||||||
|
|
||||||
|
@ -64,18 +64,17 @@ async def login_page(request):
|
||||||
with doc.div(id="login"):
|
with doc.div(id="login"):
|
||||||
with doc.form(method="POST", autocomplete="on"):
|
with doc.form(method="POST", autocomplete="on"):
|
||||||
doc.h1("Login")
|
doc.h1("Login")
|
||||||
doc.input(name="username", placeholder="Username", autocomplete="username").br
|
doc.input(name="username", placeholder="Username", autocomplete="username", required=True).br
|
||||||
doc.input(type="password", name="password", placeholder="Password", autocomplete="current-password").br
|
doc.input(type="password", name="password", placeholder="Password", autocomplete="current-password", required=True).br
|
||||||
doc.input(type="submit", value="Login")
|
doc.input(type="submit", value="Login")
|
||||||
s = session.get(request)
|
s = session.get(request)
|
||||||
if s:
|
if s:
|
||||||
name = s["username"]
|
name = s["username"]
|
||||||
with doc.form(method="POST", action="/logout"):
|
with doc.form(method="POST", action="/logout"):
|
||||||
doc.input(type="submit", value=f"Logout {name}")
|
doc.input(type="submit", value=f"Logout {name}")
|
||||||
flash = request.cookies.flash
|
flash = request.cookies.message
|
||||||
if flash:
|
if flash:
|
||||||
print("flash", flash)
|
doc.dialog(flash, id="flash", open=True, style="position: fixed; top: 0; left: 0; width: 100%; opacity: .8")
|
||||||
doc.p(flash)
|
|
||||||
res = html(doc)
|
res = html(doc)
|
||||||
if flash:
|
if flash:
|
||||||
res.cookies.delete_cookie("flash")
|
res.cookies.delete_cookie("flash")
|
||||||
|
@ -85,9 +84,8 @@ async def login_page(request):
|
||||||
|
|
||||||
@authbp.post("/login")
|
@authbp.post("/login")
|
||||||
async def login_post(request):
|
async def login_post(request):
|
||||||
json_format = request.headers.content_type == "application/json"
|
|
||||||
try:
|
try:
|
||||||
if json_format:
|
if request.headers.content_type == "application/json":
|
||||||
username = request.json["username"]
|
username = request.json["username"]
|
||||||
password = request.json["password"]
|
password = request.json["password"]
|
||||||
else:
|
else:
|
||||||
|
@ -96,36 +94,28 @@ async def login_post(request):
|
||||||
if not username or not password:
|
if not username or not password:
|
||||||
raise KeyError
|
raise KeyError
|
||||||
except KeyError:
|
except KeyError:
|
||||||
raise BadRequest("Missing username or password")
|
raise BadRequest("Missing username or password", context={"redirect": "/login"})
|
||||||
try:
|
try:
|
||||||
user = login(username, password)
|
user = login(username, password)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
if json_format:
|
raise Forbidden(str(e), context={"redirect": "/login"})
|
||||||
res = json({
|
|
||||||
"status": "error",
|
|
||||||
"error": str(e),
|
|
||||||
})
|
|
||||||
else:
|
|
||||||
res = redirect("/login")
|
|
||||||
res.cookies.add_cookie("flash", str(e), max_age=5)
|
|
||||||
print("Login error:", res.cookies)
|
|
||||||
return res
|
|
||||||
|
|
||||||
if json_format:
|
if "text/html" in request.headers.accept:
|
||||||
res = json({
|
|
||||||
"status": "authenticated",
|
|
||||||
"user": username,
|
|
||||||
"privileged": user.privileged,
|
|
||||||
})
|
|
||||||
else:
|
|
||||||
res = redirect("/")
|
res = redirect("/")
|
||||||
res.cookies.add_cookie("flash", "Logged in", max_age=5)
|
session.flash(res, "Logged in")
|
||||||
|
else:
|
||||||
|
res = json({"data": {"username": username, "privileged": user.privileged}})
|
||||||
session.create(res, username)
|
session.create(res, username)
|
||||||
return res
|
return res
|
||||||
|
|
||||||
@authbp.post("/logout")
|
@authbp.post("/logout")
|
||||||
async def logout_post(request):
|
async def logout_post(request):
|
||||||
res = redirect("/")
|
s = request.ctx.session
|
||||||
|
msg = "Logged out" if s else "Not logged in"
|
||||||
|
if "text/html" in request.headers.accept:
|
||||||
|
res = redirect("/login")
|
||||||
|
res.cookies.add_cookie("flash", msg, max_age=5)
|
||||||
|
else:
|
||||||
|
res = json({"message": msg})
|
||||||
session.delete(res)
|
session.delete(res)
|
||||||
res.cookies.add_cookie("flash", "Logged out", max_age=5)
|
|
||||||
return res
|
return res
|
||||||
|
|
|
@ -5,7 +5,7 @@ from pathlib import Path, PurePosixPath
|
||||||
|
|
||||||
from pathvalidate import sanitize_filepath
|
from pathvalidate import sanitize_filepath
|
||||||
|
|
||||||
from . import config
|
from . import config, protocol
|
||||||
from .asynclink import AsyncLink
|
from .asynclink import AsyncLink
|
||||||
from .lrucache import LRUCache
|
from .lrucache import LRUCache
|
||||||
|
|
||||||
|
@ -80,6 +80,8 @@ class FileServer:
|
||||||
try:
|
try:
|
||||||
for req in slink:
|
for req in slink:
|
||||||
with req as (command, *args):
|
with req as (command, *args):
|
||||||
|
if command == "control":
|
||||||
|
self.control(*args)
|
||||||
if command == "upload":
|
if command == "upload":
|
||||||
req.set_result(self.upload(*args))
|
req.set_result(self.upload(*args))
|
||||||
elif command == "download":
|
elif command == "download":
|
||||||
|
|
|
@ -1,9 +1,67 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
import shutil
|
||||||
|
|
||||||
from typing import Dict, Tuple, Union
|
from sanic import BadRequest
|
||||||
|
from cista import config
|
||||||
|
from cista.fileio import sanitize_filename
|
||||||
|
|
||||||
import msgspec
|
import msgspec
|
||||||
|
|
||||||
|
## Control commands
|
||||||
|
|
||||||
|
class ControlBase(msgspec.Struct, tag_field="op", tag=str.lower):
|
||||||
|
def __call__(self):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
class MkDir(ControlBase):
|
||||||
|
path: str
|
||||||
|
def __call__(self):
|
||||||
|
path = config.config.path / sanitize_filename(self.path)
|
||||||
|
path.mkdir(parents=False, exist_ok=False)
|
||||||
|
|
||||||
|
class Rename(ControlBase):
|
||||||
|
path: str
|
||||||
|
to: str
|
||||||
|
def __call__(self):
|
||||||
|
to = sanitize_filename(self.to)
|
||||||
|
if "/" in to:
|
||||||
|
raise BadRequest("Rename 'to' name should only contain filename, not path")
|
||||||
|
path = config.config.path / sanitize_filename(self.path)
|
||||||
|
path.rename(path.with_name(to))
|
||||||
|
|
||||||
|
class Rm(ControlBase):
|
||||||
|
sel: list[str]
|
||||||
|
def __call__(self):
|
||||||
|
root = config.config.path
|
||||||
|
sel = [root / sanitize_filename(p) for p in self.sel]
|
||||||
|
for p in sel:
|
||||||
|
shutil.rmtree(p, ignore_errors=True)
|
||||||
|
|
||||||
|
class Mv(ControlBase):
|
||||||
|
sel: list[str]
|
||||||
|
dst: str
|
||||||
|
def __call__(self):
|
||||||
|
root = config.config.path
|
||||||
|
sel = [root / sanitize_filename(p) for p in self.sel]
|
||||||
|
dst = root / sanitize_filename(self.dst)
|
||||||
|
if not dst.is_dir():
|
||||||
|
raise BadRequest("The destination must be a directory")
|
||||||
|
for p in sel:
|
||||||
|
shutil.move(p, dst)
|
||||||
|
|
||||||
|
class Cp(ControlBase):
|
||||||
|
sel: list[str]
|
||||||
|
dst: str
|
||||||
|
def __call__(self):
|
||||||
|
root = config.config.path
|
||||||
|
sel = [root / sanitize_filename(p) for p in self.sel]
|
||||||
|
dst = root / sanitize_filename(self.dst)
|
||||||
|
if not dst.is_dir():
|
||||||
|
raise BadRequest("The destination must be a directory")
|
||||||
|
for p in sel:
|
||||||
|
# Note: copies as dst rather than in dst unless name is appended.
|
||||||
|
shutil.copytree(p, dst / p.name, dirs_exist_ok=True, ignore_dangling_symlinks=True)
|
||||||
|
|
||||||
## File uploads and downloads
|
## File uploads and downloads
|
||||||
|
|
||||||
class FileRange(msgspec.Struct):
|
class FileRange(msgspec.Struct):
|
||||||
|
@ -52,7 +110,7 @@ class DirEntry(msgspec.Struct):
|
||||||
if k != "dir"
|
if k != "dir"
|
||||||
}
|
}
|
||||||
|
|
||||||
DirList = dict[str, Union[FileEntry, DirEntry]]
|
DirList = dict[str, FileEntry | DirEntry]
|
||||||
|
|
||||||
|
|
||||||
class UpdateEntry(msgspec.Struct, omit_defaults=True):
|
class UpdateEntry(msgspec.Struct, omit_defaults=True):
|
||||||
|
|
|
@ -30,3 +30,10 @@ def update(res, s, **kwargs):
|
||||||
|
|
||||||
def delete(res):
|
def delete(res):
|
||||||
res.cookies.delete_cookie("s")
|
res.cookies.delete_cookie("s")
|
||||||
|
|
||||||
|
|
||||||
|
def flash(res, message: str | None):
|
||||||
|
if message is None:
|
||||||
|
res.cookies.delete_cookie("message")
|
||||||
|
else:
|
||||||
|
res.cookies.add_cookie("message", message, max_age=5)
|
||||||
|
|
|
@ -30,6 +30,11 @@ Homepage = ""
|
||||||
[project.scripts]
|
[project.scripts]
|
||||||
cista = "cista.__main__:main"
|
cista = "cista.__main__:main"
|
||||||
|
|
||||||
|
[project.optional-dependencies]
|
||||||
|
dev = [
|
||||||
|
"pytest",
|
||||||
|
]
|
||||||
|
|
||||||
[tool.hatch.version]
|
[tool.hatch.version]
|
||||||
source = "vcs"
|
source = "vcs"
|
||||||
|
|
||||||
|
@ -43,3 +48,13 @@ version_tuple = {version_tuple!r}
|
||||||
targets.sdist.include = [
|
targets.sdist.include = [
|
||||||
"/cista",
|
"/cista",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[tool.pytest.ini_options]
|
||||||
|
addopts = [
|
||||||
|
"--import-mode=importlib",
|
||||||
|
"--verbosity=-1",
|
||||||
|
"-p no:warnings",
|
||||||
|
]
|
||||||
|
testpaths = [
|
||||||
|
"tests",
|
||||||
|
]
|
||||||
|
|
45
tests/test_control.py
Normal file
45
tests/test_control.py
Normal file
|
@ -0,0 +1,45 @@
|
||||||
|
import pytest
|
||||||
|
from cista import config
|
||||||
|
from cista.protocol import MkDir, Rename, Rm, Mv, Cp
|
||||||
|
from pathlib import Path
|
||||||
|
import tempfile
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def setup_temp_dir():
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
|
config.config = config.Config(path=Path(tmpdirname), listen=":0")
|
||||||
|
yield Path(tmpdirname)
|
||||||
|
|
||||||
|
def test_mkdir(setup_temp_dir):
|
||||||
|
cmd = MkDir(path="new_folder")
|
||||||
|
cmd()
|
||||||
|
assert (setup_temp_dir / "new_folder").is_dir()
|
||||||
|
|
||||||
|
def test_rename(setup_temp_dir):
|
||||||
|
(setup_temp_dir / "old_name").mkdir()
|
||||||
|
cmd = Rename(path="old_name", to="new_name")
|
||||||
|
cmd()
|
||||||
|
assert not (setup_temp_dir / "old_name").exists()
|
||||||
|
assert (setup_temp_dir / "new_name").is_dir()
|
||||||
|
|
||||||
|
def test_rm(setup_temp_dir):
|
||||||
|
(setup_temp_dir / "folder_to_remove").mkdir()
|
||||||
|
cmd = Rm(sel=["folder_to_remove"])
|
||||||
|
cmd()
|
||||||
|
assert not (setup_temp_dir / "folder_to_remove").exists()
|
||||||
|
|
||||||
|
def test_mv(setup_temp_dir):
|
||||||
|
(setup_temp_dir / "folder_to_move").mkdir()
|
||||||
|
(setup_temp_dir / "destination").mkdir()
|
||||||
|
cmd = Mv(sel=["folder_to_move"], dst="destination")
|
||||||
|
cmd()
|
||||||
|
assert not (setup_temp_dir / "folder_to_move").exists()
|
||||||
|
assert (setup_temp_dir / "destination" / "folder_to_move").is_dir()
|
||||||
|
|
||||||
|
def test_cp(setup_temp_dir):
|
||||||
|
(setup_temp_dir / "folder_to_copy").mkdir()
|
||||||
|
(setup_temp_dir / "destination").mkdir()
|
||||||
|
cmd = Cp(sel=["folder_to_copy"], dst="destination")
|
||||||
|
cmd()
|
||||||
|
assert (setup_temp_dir / "folder_to_copy").is_dir()
|
||||||
|
assert (setup_temp_dir / "destination" / "folder_to_copy").is_dir()
|
Loading…
Reference in New Issue
Block a user