Formatting and fix Internal Server Error on upload

This commit is contained in:
Leo Vasanko 2023-10-28 20:20:34 +00:00
parent b3fd9637eb
commit acdd776b92
16 changed files with 176 additions and 52 deletions

View File

@ -1 +1,3 @@
from cista._version import __version__
__version__ # Public API

View File

@ -34,6 +34,7 @@ User management:
--password Reset password
"""
def main():
# Dev mode doesn't catch exceptions
if "--dev" in sys.argv:
@ -45,6 +46,7 @@ def main():
print("Error:", e)
return 1
def _main():
args = docopt(doc)
if args["--user"]:
@ -64,19 +66,25 @@ def _main():
if not necessary_opts:
# Maybe run without arguments
print(doc)
print("No config file found! Get started with:\n cista -l :8000 /path/to/files, or\n cista -l example.com --import-droppy # Uses Droppy files\n")
print(
"No config file found! Get started with:\n cista -l :8000 /path/to/files, or\n cista -l example.com --import-droppy # Uses Droppy files\n"
)
return 1
settings = {}
if import_droppy:
if exists:
raise ValueError(f"Importing Droppy: First remove the existing configuration:\n rm {config.conffile}")
raise ValueError(
f"Importing Droppy: First remove the existing configuration:\n rm {config.conffile}"
)
settings = droppy.readconf()
if path: settings["path"] = path
if listen: settings["listen"] = listen
if path:
settings["path"] = path
if listen:
settings["listen"] = listen
operation = config.update_config(settings)
print(f"Config {operation}: {config.conffile}")
# Prepare to serve
domain = unix = port = None
unix = None
url, _ = serve.parse_listen(config.config.listen)
if not config.config.path.is_dir():
raise ValueError(f"No such directory: {config.config.path}")
@ -88,6 +96,7 @@ def _main():
# Run the server
serve.run(dev=dev)
def _confdir(args):
if args["-c"]:
# Custom config directory
@ -99,6 +108,7 @@ def _confdir(args):
confdir = confdir.parent
config.conffile = config.conffile.with_parent(confdir)
def _user(args):
_confdir(args)
config.load_config()
@ -123,5 +133,6 @@ def _user(args):
if res == "read":
print(" No changes")
if __name__ == "__main__":
sys.exit(main())

View File

@ -12,15 +12,18 @@ from cista.util.apphelpers import asend, websocket_wrapper
bp = Blueprint("api", url_prefix="/api")
fileserver = FileServer()
@bp.before_server_start
async def start_fileserver(app, _):
await fileserver.start()
@bp.after_server_stop
async def stop_fileserver(app, _):
await fileserver.stop()
@bp.websocket('upload')
@bp.websocket("upload")
@websocket_wrapper
async def upload(req, ws):
alink = fileserver.alink
@ -28,7 +31,9 @@ async def upload(req, ws):
req = None
text = await ws.recv()
if not isinstance(text, str):
raise ValueError(f"Expected JSON control, got binary len(data) = {len(text)}")
raise ValueError(
f"Expected JSON control, got binary len(data) = {len(text)}"
)
req = msgspec.json.decode(text, type=FileRange)
pos = req.start
data = None
@ -43,7 +48,8 @@ async def upload(req, ws):
await asend(ws, res)
# await ws.drain()
@bp.websocket('download')
@bp.websocket("download")
@websocket_wrapper
async def download(req, ws):
alink = fileserver.alink
@ -51,7 +57,9 @@ async def download(req, ws):
req = None
text = await ws.recv()
if not isinstance(text, str):
raise ValueError(f"Expected JSON control, got binary len(data) = {len(text)}")
raise ValueError(
f"Expected JSON control, got binary len(data) = {len(text)}"
)
req = msgspec.json.decode(text, type=FileRange)
pos = req.start
while pos < req.end:
@ -64,6 +72,7 @@ async def download(req, ws):
await asend(ws, res)
# await ws.drain()
@bp.websocket("control")
@websocket_wrapper
async def control(req, ws):
@ -71,6 +80,7 @@ async def control(req, ws):
await asyncio.to_thread(cmd)
await asend(ws, StatusMsg(status="ack", req=cmd))
@bp.websocket("watch")
@websocket_wrapper
async def watch(req, ws):

View File

@ -2,7 +2,6 @@ import mimetypes
from importlib.resources import files
from urllib.parse import unquote
from html5tagger import E
from sanic import Blueprint, Sanic, raw
from sanic.exceptions import Forbidden, NotFound
@ -16,15 +15,18 @@ app.blueprint(auth.bp)
app.blueprint(bp)
app.exception(Exception)(handle_sanic_exception)
@app.before_server_start
async def main_start(app, loop):
config.load_config()
await watching.start(app, loop)
@app.after_server_stop
async def main_stop(app, loop):
await watching.stop(app, loop)
@app.on_request
async def use_session(req):
req.ctx.session = session.get(req)
@ -41,13 +43,21 @@ async def use_session(req):
if origin and origin.split("//", 1)[1] != req.host:
raise Forbidden("Invalid origin: Cross-Site requests not permitted")
@app.before_server_start
def http_fileserver(app, _):
bp = Blueprint("fileserver")
bp.on_request(auth.verify)
bp.static("/files/", config.config.path, use_content_range=True, stream_large_files=True, directory_view=True)
bp.static(
"/files/",
config.config.path,
use_content_range=True,
stream_large_files=True,
directory_view=True,
)
app.blueprint(bp)
@app.get("/<path:path>", static=True)
async def wwwroot(req, path=""):
"""Frontend files only"""
@ -55,6 +65,8 @@ async def wwwroot(req, path=""):
try:
index = files("cista").joinpath("wwwroot", name).read_bytes()
except OSError as e:
raise NotFound(f"File not found: /{path}", extra={"name": name, "exception": repr(e)})
raise NotFound(
f"File not found: /{path}", extra={"name": name, "exception": repr(e)}
)
mime = mimetypes.guess_type(name)[0] or "application/octet-stream"
return raw(index, content_type=mime)

View File

@ -12,10 +12,12 @@ from sanic.exceptions import BadRequest, Forbidden, Unauthorized
from cista import config, session
_argon = argon2.PasswordHasher()
_droppyhash = re.compile(r'^([a-f0-9]{64})\$([a-f0-9]{8})$')
_droppyhash = re.compile(r"^([a-f0-9]{64})\$([a-f0-9]{8})$")
def _pwnorm(password):
return normalize('NFC', password).strip().encode()
return normalize("NFC", password).strip().encode()
def login(username: str, password: str):
un = _pwnorm(username)
@ -49,33 +51,51 @@ def login(username: str, password: str):
u.lastSeen = now
return u
def set_password(user: config.User, password: str):
user.hash = _argon.hash(_pwnorm(password))
class LoginResponse(msgspec.Struct):
user: str = ""
privileged: bool = False
error: str = ""
def verify(request, privileged=False):
"""Raise Unauthorized or Forbidden if the request is not authorized"""
if privileged:
if request.ctx.user:
if request.ctx.user.privileged: return
if request.ctx.user.privileged:
return
raise Forbidden("Access Forbidden: Only for privileged users")
elif config.config.public or request.ctx.user: return
elif config.config.public or request.ctx.user:
return
raise Unauthorized("Login required", "cookie", context={"redirect": "/login"})
bp = Blueprint("auth")
@bp.get("/login")
async def login_page(request):
doc = Document("Cista Login")
with doc.div(id="login"):
with doc.form(method="POST", autocomplete="on"):
doc.h1("Login")
doc.input(name="username", placeholder="Username", autocomplete="username", required=True).br
doc.input(type="password", name="password", placeholder="Password", autocomplete="current-password", required=True).br
doc.input(
name="username",
placeholder="Username",
autocomplete="username",
required=True,
).br
doc.input(
type="password",
name="password",
placeholder="Password",
autocomplete="current-password",
required=True,
).br
doc.input(type="submit", value="Login")
s = session.get(request)
if s:
@ -84,7 +104,12 @@ async def login_page(request):
doc.input(type="submit", value=f"Logout {name}")
flash = request.cookies.message
if flash:
doc.dialog(flash, id="flash", open=True, style="position: fixed; top: 0; left: 0; width: 100%; opacity: .8")
doc.dialog(
flash,
id="flash",
open=True,
style="position: fixed; top: 0; left: 0; width: 100%; opacity: .8",
)
res = html(doc)
if flash:
res.cookies.delete_cookie("flash")
@ -92,6 +117,7 @@ async def login_page(request):
session.delete(res)
return res
@bp.post("/login")
async def login_post(request):
try:
@ -118,6 +144,7 @@ async def login_post(request):
session.create(res, username)
return res
@bp.post("/logout")
async def logout_post(request):
s = request.ctx.session

View File

@ -17,26 +17,28 @@ class Config(msgspec.Struct):
users: dict[str, User] = {}
links: dict[str, Link] = {}
class User(msgspec.Struct, omit_defaults=True):
privileged: bool = False
hash: str = ""
lastSeen: int = 0
class Link(msgspec.Struct, omit_defaults=True):
location: str
creator: str = ""
expires: int = 0
config = None
conffile = Path.home() / ".local/share/cista/db.toml"
def derived_secret(*params, len=8) -> bytes:
"""Used to derive secret keys from the main secret"""
# Each part is made the same length by hashing first
combined = b"".join(
sha256(
p if isinstance(p, bytes) else f"{p}".encode()
).digest()
sha256(p if isinstance(p, bytes) else f"{p}".encode()).digest()
for p in [config.secret, *params]
)
# Output a bytes of the desired length
@ -48,11 +50,13 @@ def enc_hook(obj):
return obj.as_posix()
raise TypeError
def dec_hook(typ, obj):
if typ is Path:
return Path(obj)
raise TypeError
def config_update(modify):
global config
if not conffile.exists():
@ -93,21 +97,28 @@ def config_update(modify):
config = c
return "modified" if old else "created"
def modifies_config(modify):
"""Decorator for functions that modify the config file"""
@wraps(modify)
def wrapper(*args, **kwargs):
m = lambda c: modify(c, *args, **kwargs)
def m(c):
return modify(c, *args, **kwargs)
# Retry modification in case of write collision
while (c := config_update(m)) == "collision":
time.sleep(0.01)
return c
return wrapper
def load_config():
global config
config = msgspec.toml.decode(conffile.read_bytes(), type=Config, dec_hook=dec_hook)
@modifies_config
def update_config(conf: Config, changes: dict) -> Config:
"""Create/update the config with new values, respecting changes done by others."""
@ -116,6 +127,7 @@ def update_config(conf: Config, changes: dict) -> Config:
settings.update(changes)
return msgspec.convert(settings, Config, dec_hook=dec_hook)
@modifies_config
def update_user(conf: Config, name: str, changes: dict) -> Config:
"""Create/update a user with new values, respecting changes done by others."""
@ -126,6 +138,7 @@ def update_user(conf: Config, name: str, changes: dict) -> Config:
u = User()
if "password" in changes:
from . import auth
auth.set_password(u, changes["password"])
del changes["password"]
udict = msgspec.to_builtins(u, enc_hook=enc_hook)
@ -134,6 +147,7 @@ def update_user(conf: Config, name: str, changes: dict) -> Config:
settings["users"][name] = msgspec.convert(udict, User, dec_hook=dec_hook)
return msgspec.convert(settings, Config, dec_hook=dec_hook)
@modifies_config
def del_user(conf: Config, name: str) -> Config:
"""Delete named user account."""

View File

@ -11,6 +11,7 @@ def readconf() -> dict:
cf["listen"] = _droppy_listeners(cf)
return cf | db
def _droppy_listeners(cf):
"""Convert Droppy listeners to our format, for typical cases but not in full."""
for listener in cf["listeners"]:
@ -20,15 +21,19 @@ def _droppy_listeners(cf):
continue
socket = listener.get("socket")
if socket:
if isinstance(socket, list): socket = socket[0]
if isinstance(socket, list):
socket = socket[0]
return f"{socket}"
port = listener["port"]
if isinstance(port, list): port = port[0]
if isinstance(port, list):
port = port[0]
host = listener["host"]
if isinstance(host, list): host = host[0]
if host in ("127.0.0.1", "::", "localhost"): return f":{port}"
if isinstance(host, list):
host = host[0]
if host in ("127.0.0.1", "::", "localhost"):
return f":{port}"
return f"{host}:{port}"
except (KeyError, IndexError):
continue
# If none matched, fallback to Droppy default
return f"0.0.0.0:8989"
return "0.0.0.0:8989"

View File

@ -1,7 +1,5 @@
import asyncio
import os
import unicodedata
from pathlib import PurePosixPath
from cista import config
from cista.util import filename
@ -13,6 +11,7 @@ def fuid(stat) -> str:
"""Unique file ID. Stays the same on renames and modification."""
return config.derived_secret("filekey-inode", stat.st_dev, stat.st_ino).hex()
class File:
def __init__(self, filename):
self.path = config.config.path / filename
@ -30,23 +29,24 @@ class File:
self.writable = True
def write(self, pos, buffer, *, file_size=None):
assert self.fd is not None
if not self.writable:
# Create/open file
self.open_rw()
assert self.fd is not None
if file_size is not None:
os.ftruncate(self.fd, file_size)
os.lseek(self.fd, pos, os.SEEK_SET)
os.write(self.fd, buffer)
def __getitem__(self, slice):
assert self.fd is not None
if self.fd is None:
self.open_ro()
assert self.fd is not None
os.lseek(self.fd, slice.start, os.SEEK_SET)
l = slice.stop - slice.start
data = os.read(self.fd, l)
if len(data) < l: raise EOFError("Error reading requested range")
size = slice.stop - slice.start
data = os.read(self.fd, size)
if len(data) < size:
raise EOFError("Error reading requested range")
return data
def close(self):
@ -59,10 +59,11 @@ class File:
class FileServer:
async def start(self):
self.alink = AsyncLink()
self.worker = asyncio.get_event_loop().run_in_executor(None, self.worker_thread, self.alink.to_sync)
self.worker = asyncio.get_event_loop().run_in_executor(
None, self.worker_thread, self.alink.to_sync
)
self.cache = LRUCache(File, capacity=10, maxage=5.0)
async def stop(self):

View File

@ -10,6 +10,7 @@ from cista import config, server80
def run(dev=False):
"""Run Sanic main process that spawns worker processes to serve HTTP requests."""
from .app import app
url, opts = parse_listen(config.config.listen)
# Silence Sanic's warning about running in production rather than debug
os.environ["SANIC_IGNORE_PRODUCTION_WARNING"] = "1"
@ -21,20 +22,33 @@ def run(dev=False):
domain = opts["host"]
check_cert(confdir / domain, domain)
opts["ssl"] = str(confdir / domain) # type: ignore
app.prepare(**opts, motd=False, dev=dev, auto_reload=dev, reload_dir={confdir, wwwroot}, access_log=True) # type: ignore
app.prepare(
**opts,
motd=False,
dev=dev,
auto_reload=dev,
reload_dir={confdir, wwwroot},
access_log=True,
) # type: ignore
Sanic.serve()
def check_cert(certdir, domain):
if (certdir / "privkey.pem").exist() and (certdir / "fullchain.pem").exists():
return
# TODO: Use certbot to fetch a cert
raise ValueError(f"TLS certificate files privkey.pem and fullchain.pem needed in {certdir}")
raise ValueError(
f"TLS certificate files privkey.pem and fullchain.pem needed in {certdir}"
)
def parse_listen(listen):
if listen.startswith("/"):
unix = Path(listen).resolve()
if not unix.parent.exists():
raise ValueError(f"Directory for unix socket does not exist: {unix.parent}/")
raise ValueError(
f"Directory for unix socket does not exist: {unix.parent}/"
)
return "http://localhost", {"unix": unix}
elif re.fullmatch(r"(\w+(-\w+)*\.)+\w{2,}", listen, re.UNICODE):
return f"https://{listen}", {"host": listen, "port": 443, "ssl": True}

View File

@ -2,6 +2,7 @@ from sanic import Sanic, exceptions, response
app = Sanic("server80")
# Send all HTTP users to HTTPS
@app.exception(exceptions.NotFound, exceptions.MethodNotSupported)
def redirect_everything_else(request, exception):
@ -10,6 +11,7 @@ def redirect_everything_else(request, exception):
return response.redirect(f"https://{server}{path}", status=308)
return response.text("Bad Request. Please use HTTPS!", status=400)
# ACME challenge for LetsEncrypt
@app.get("/.well-known/acme-challenge/<challenge>")
async def letsencrypt(request, challenge):
@ -18,4 +20,5 @@ async def letsencrypt(request, challenge):
except KeyError:
return response.text(f"ACME challenge not found: {challenge}", status=404)
acme_challenges = {}

View File

@ -4,16 +4,21 @@ import jwt
from cista.config import derived_secret
session_secret = lambda: derived_secret("session")
def session_secret():
return derived_secret("session")
max_age = 365 * 86400 # Seconds since last login
def get(request):
try:
return jwt.decode(request.cookies.s, session_secret(), algorithms=["HS256"])
except Exception as e:
s = None
except Exception:
return False if "s" in request.cookies else None
def create(res, username, **kwargs):
data = {
"exp": int(time()) + max_age,
@ -23,12 +28,14 @@ def create(res, username, **kwargs):
s = jwt.encode(data, session_secret())
res.cookies.add_cookie("s", s, httponly=True, max_age=max_age)
def update(res, s, **kwargs):
s.update(kwargs)
s = jwt.encode(s, session_secret())
max_age = max(1, s["exp"] - int(time())) # type: ignore
res.cookies.add_cookie("s", s, httponly=True, max_age=max_age)
def delete(res):
res.cookies.delete_cookie("s")

View File

@ -14,10 +14,12 @@ def asend(ws, msg):
"""Send JSON message or bytes to a websocket"""
return ws.send(msg if isinstance(msg, bytes) else msgspec.json.encode(msg).decode())
def jres(data, **kwargs):
"""JSON Sanic response, using msgspec encoding"""
return raw(msgspec.json.encode(data), content_type="application/json", **kwargs)
async def handle_sanic_exception(request, e):
logger.exception(e)
context, code = {}, 500
@ -30,7 +32,9 @@ async def handle_sanic_exception(request, e):
message = f"⚠️ {message}" if code < 500 else f"🛑 {message}"
# Non-browsers get JSON errors
if "text/html" not in request.headers.accept:
return jres(ErrorMsg({"code": code, "message": message, **context}), status=code)
return jres(
ErrorMsg({"code": code, "message": message, **context}), status=code
)
# Redirections flash the error message via cookies
if "redirect" in context:
res = redirect(context["redirect"])
@ -39,8 +43,10 @@ async def handle_sanic_exception(request, e):
# Otherwise use Sanic's default error page
return errorpages.HTMLRenderer(request, e, debug=request.app.debug).full()
def websocket_wrapper(handler):
"""Decorator for websocket handlers that catches exceptions and sends them back to the client"""
@wraps(handler)
async def wrapper(request, ws, *args, **kwargs):
try:
@ -55,4 +61,5 @@ def websocket_wrapper(handler):
message = f"⚠️ {message}" if code < 500 else f"🛑 {message}"
await asend(ws, ErrorMsg({"code": code, "message": message, **context}))
raise
return wrapper

View File

@ -11,6 +11,7 @@ class LRUCache:
maxage (float): Max age for items in cache in seconds.
cache (list): Internal list storing the cache items.
"""
def __init__(self, open: callable, *, capacity: int, maxage: float):
"""
Initialize LRUCache.

View File

@ -6,6 +6,7 @@ def generate(n=4):
wl = list(words)
return ".".join(wl.pop(secrets.randbelow(len(wl))) for i in range(n))
# A custom list of 1024 common 3-6 letter words, with unique 3-prefixes and no prefix words, entropy 2.1b/letter 10b/word
words: list = """
able about absent abuse access acid across act adapt add adjust admit adult advice affair afraid again age agree ahead

View File

@ -13,11 +13,13 @@ def setup_temp_dir():
config.config = config.Config(path=Path(tmpdirname), listen=":0")
yield Path(tmpdirname)
def test_mkdir(setup_temp_dir):
cmd = MkDir(path="new_folder")
cmd()
assert (setup_temp_dir / "new_folder").is_dir()
def test_rename(setup_temp_dir):
(setup_temp_dir / "old_name").mkdir()
cmd = Rename(path="old_name", to="new_name")
@ -25,12 +27,14 @@ def test_rename(setup_temp_dir):
assert not (setup_temp_dir / "old_name").exists()
assert (setup_temp_dir / "new_name").is_dir()
def test_rm(setup_temp_dir):
(setup_temp_dir / "folder_to_remove").mkdir()
cmd = Rm(sel=["folder_to_remove"])
cmd()
assert not (setup_temp_dir / "folder_to_remove").exists()
def test_mv(setup_temp_dir):
(setup_temp_dir / "folder_to_move").mkdir()
(setup_temp_dir / "destination").mkdir()
@ -39,6 +43,7 @@ def test_mv(setup_temp_dir):
assert not (setup_temp_dir / "folder_to_move").exists()
assert (setup_temp_dir / "destination" / "folder_to_move").is_dir()
def test_cp(setup_temp_dir):
(setup_temp_dir / "folder_to_copy").mkdir()
(setup_temp_dir / "destination").mkdir()

View File

@ -1,8 +1,6 @@
from time import sleep
from unittest.mock import Mock
import pytest
from cista.util.lrucache import LRUCache # Replace with actual import
@ -12,24 +10,28 @@ def mock_open(key):
mock.content = f"content-{key}"
return mock
def test_contains():
cache = LRUCache(open=mock_open, capacity=2, maxage=10)
assert "key1" not in cache
cache["key1"]
assert "key1" in cache
def test_getitem():
cache = LRUCache(open=mock_open, capacity=2, maxage=10)
assert cache["key1"].content == "content-key1"
def test_capacity():
cache = LRUCache(open=mock_open, capacity=2, maxage=10)
item1 = cache["key1"]
item2 = cache["key2"]
cache["key2"]
cache["key3"]
assert "key1" not in cache
item1.close.assert_called()
def test_expiry():
cache = LRUCache(open=mock_open, capacity=2, maxage=0.1)
item = cache["key1"]
@ -38,6 +40,7 @@ def test_expiry():
assert "key1" not in cache
item.close.assert_called()
def test_close():
cache = LRUCache(open=mock_open, capacity=2, maxage=10)
item = cache["key1"]
@ -45,6 +48,7 @@ def test_close():
assert "key1" not in cache
item.close.assert_called()
def test_lru_mechanism():
cache = LRUCache(open=mock_open, capacity=2, maxage=10)
item1 = cache["key1"]