Formatting and fix Internal Server Error on upload
This commit is contained in:
parent
b3fd9637eb
commit
acdd776b92
|
@ -1 +1,3 @@
|
|||
from cista._version import __version__
|
||||
|
||||
__version__ # Public API
|
||||
|
|
|
@ -34,6 +34,7 @@ User management:
|
|||
--password Reset password
|
||||
"""
|
||||
|
||||
|
||||
def main():
|
||||
# Dev mode doesn't catch exceptions
|
||||
if "--dev" in sys.argv:
|
||||
|
@ -45,6 +46,7 @@ def main():
|
|||
print("Error:", e)
|
||||
return 1
|
||||
|
||||
|
||||
def _main():
|
||||
args = docopt(doc)
|
||||
if args["--user"]:
|
||||
|
@ -64,19 +66,25 @@ def _main():
|
|||
if not necessary_opts:
|
||||
# Maybe run without arguments
|
||||
print(doc)
|
||||
print("No config file found! Get started with:\n cista -l :8000 /path/to/files, or\n cista -l example.com --import-droppy # Uses Droppy files\n")
|
||||
print(
|
||||
"No config file found! Get started with:\n cista -l :8000 /path/to/files, or\n cista -l example.com --import-droppy # Uses Droppy files\n"
|
||||
)
|
||||
return 1
|
||||
settings = {}
|
||||
if import_droppy:
|
||||
if exists:
|
||||
raise ValueError(f"Importing Droppy: First remove the existing configuration:\n rm {config.conffile}")
|
||||
raise ValueError(
|
||||
f"Importing Droppy: First remove the existing configuration:\n rm {config.conffile}"
|
||||
)
|
||||
settings = droppy.readconf()
|
||||
if path: settings["path"] = path
|
||||
if listen: settings["listen"] = listen
|
||||
if path:
|
||||
settings["path"] = path
|
||||
if listen:
|
||||
settings["listen"] = listen
|
||||
operation = config.update_config(settings)
|
||||
print(f"Config {operation}: {config.conffile}")
|
||||
# Prepare to serve
|
||||
domain = unix = port = None
|
||||
unix = None
|
||||
url, _ = serve.parse_listen(config.config.listen)
|
||||
if not config.config.path.is_dir():
|
||||
raise ValueError(f"No such directory: {config.config.path}")
|
||||
|
@ -88,6 +96,7 @@ def _main():
|
|||
# Run the server
|
||||
serve.run(dev=dev)
|
||||
|
||||
|
||||
def _confdir(args):
|
||||
if args["-c"]:
|
||||
# Custom config directory
|
||||
|
@ -99,6 +108,7 @@ def _confdir(args):
|
|||
confdir = confdir.parent
|
||||
config.conffile = config.conffile.with_parent(confdir)
|
||||
|
||||
|
||||
def _user(args):
|
||||
_confdir(args)
|
||||
config.load_config()
|
||||
|
@ -123,5 +133,6 @@ def _user(args):
|
|||
if res == "read":
|
||||
print(" No changes")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
|
|
24
cista/api.py
24
cista/api.py
|
@ -12,15 +12,18 @@ from cista.util.apphelpers import asend, websocket_wrapper
|
|||
bp = Blueprint("api", url_prefix="/api")
|
||||
fileserver = FileServer()
|
||||
|
||||
|
||||
@bp.before_server_start
|
||||
async def start_fileserver(app, _):
|
||||
await fileserver.start()
|
||||
|
||||
|
||||
@bp.after_server_stop
|
||||
async def stop_fileserver(app, _):
|
||||
await fileserver.stop()
|
||||
|
||||
@bp.websocket('upload')
|
||||
|
||||
@bp.websocket("upload")
|
||||
@websocket_wrapper
|
||||
async def upload(req, ws):
|
||||
alink = fileserver.alink
|
||||
|
@ -28,7 +31,9 @@ async def upload(req, ws):
|
|||
req = None
|
||||
text = await ws.recv()
|
||||
if not isinstance(text, str):
|
||||
raise ValueError(f"Expected JSON control, got binary len(data) = {len(text)}")
|
||||
raise ValueError(
|
||||
f"Expected JSON control, got binary len(data) = {len(text)}"
|
||||
)
|
||||
req = msgspec.json.decode(text, type=FileRange)
|
||||
pos = req.start
|
||||
data = None
|
||||
|
@ -41,9 +46,10 @@ async def upload(req, ws):
|
|||
# Report success
|
||||
res = StatusMsg(status="ack", req=req)
|
||||
await asend(ws, res)
|
||||
#await ws.drain()
|
||||
# await ws.drain()
|
||||
|
||||
@bp.websocket('download')
|
||||
|
||||
@bp.websocket("download")
|
||||
@websocket_wrapper
|
||||
async def download(req, ws):
|
||||
alink = fileserver.alink
|
||||
|
@ -51,18 +57,21 @@ async def download(req, ws):
|
|||
req = None
|
||||
text = await ws.recv()
|
||||
if not isinstance(text, str):
|
||||
raise ValueError(f"Expected JSON control, got binary len(data) = {len(text)}")
|
||||
raise ValueError(
|
||||
f"Expected JSON control, got binary len(data) = {len(text)}"
|
||||
)
|
||||
req = msgspec.json.decode(text, type=FileRange)
|
||||
pos = req.start
|
||||
while pos < req.end:
|
||||
end = min(req.end, pos + (1<<20))
|
||||
end = min(req.end, pos + (1 << 20))
|
||||
data = typing.cast(bytes, await alink(("download", req.name, pos, end)))
|
||||
await asend(ws, data)
|
||||
pos += len(data)
|
||||
# Report success
|
||||
res = StatusMsg(status="ack", req=req)
|
||||
await asend(ws, res)
|
||||
#await ws.drain()
|
||||
# await ws.drain()
|
||||
|
||||
|
||||
@bp.websocket("control")
|
||||
@websocket_wrapper
|
||||
|
@ -71,6 +80,7 @@ async def control(req, ws):
|
|||
await asyncio.to_thread(cmd)
|
||||
await asend(ws, StatusMsg(status="ack", req=cmd))
|
||||
|
||||
|
||||
@bp.websocket("watch")
|
||||
@websocket_wrapper
|
||||
async def watch(req, ws):
|
||||
|
|
18
cista/app.py
18
cista/app.py
|
@ -2,7 +2,6 @@ import mimetypes
|
|||
from importlib.resources import files
|
||||
from urllib.parse import unquote
|
||||
|
||||
from html5tagger import E
|
||||
from sanic import Blueprint, Sanic, raw
|
||||
from sanic.exceptions import Forbidden, NotFound
|
||||
|
||||
|
@ -16,15 +15,18 @@ app.blueprint(auth.bp)
|
|||
app.blueprint(bp)
|
||||
app.exception(Exception)(handle_sanic_exception)
|
||||
|
||||
|
||||
@app.before_server_start
|
||||
async def main_start(app, loop):
|
||||
config.load_config()
|
||||
await watching.start(app, loop)
|
||||
|
||||
|
||||
@app.after_server_stop
|
||||
async def main_stop(app, loop):
|
||||
await watching.stop(app, loop)
|
||||
|
||||
|
||||
@app.on_request
|
||||
async def use_session(req):
|
||||
req.ctx.session = session.get(req)
|
||||
|
@ -41,13 +43,21 @@ async def use_session(req):
|
|||
if origin and origin.split("//", 1)[1] != req.host:
|
||||
raise Forbidden("Invalid origin: Cross-Site requests not permitted")
|
||||
|
||||
|
||||
@app.before_server_start
|
||||
def http_fileserver(app, _):
|
||||
bp = Blueprint("fileserver")
|
||||
bp.on_request(auth.verify)
|
||||
bp.static("/files/", config.config.path, use_content_range=True, stream_large_files=True, directory_view=True)
|
||||
bp.static(
|
||||
"/files/",
|
||||
config.config.path,
|
||||
use_content_range=True,
|
||||
stream_large_files=True,
|
||||
directory_view=True,
|
||||
)
|
||||
app.blueprint(bp)
|
||||
|
||||
|
||||
@app.get("/<path:path>", static=True)
|
||||
async def wwwroot(req, path=""):
|
||||
"""Frontend files only"""
|
||||
|
@ -55,6 +65,8 @@ async def wwwroot(req, path=""):
|
|||
try:
|
||||
index = files("cista").joinpath("wwwroot", name).read_bytes()
|
||||
except OSError as e:
|
||||
raise NotFound(f"File not found: /{path}", extra={"name": name, "exception": repr(e)})
|
||||
raise NotFound(
|
||||
f"File not found: /{path}", extra={"name": name, "exception": repr(e)}
|
||||
)
|
||||
mime = mimetypes.guess_type(name)[0] or "application/octet-stream"
|
||||
return raw(index, content_type=mime)
|
||||
|
|
|
@ -12,10 +12,12 @@ from sanic.exceptions import BadRequest, Forbidden, Unauthorized
|
|||
from cista import config, session
|
||||
|
||||
_argon = argon2.PasswordHasher()
|
||||
_droppyhash = re.compile(r'^([a-f0-9]{64})\$([a-f0-9]{8})$')
|
||||
_droppyhash = re.compile(r"^([a-f0-9]{64})\$([a-f0-9]{8})$")
|
||||
|
||||
|
||||
def _pwnorm(password):
|
||||
return normalize('NFC', password).strip().encode()
|
||||
return normalize("NFC", password).strip().encode()
|
||||
|
||||
|
||||
def login(username: str, password: str):
|
||||
un = _pwnorm(username)
|
||||
|
@ -49,33 +51,51 @@ def login(username: str, password: str):
|
|||
u.lastSeen = now
|
||||
return u
|
||||
|
||||
|
||||
def set_password(user: config.User, password: str):
|
||||
user.hash = _argon.hash(_pwnorm(password))
|
||||
|
||||
|
||||
class LoginResponse(msgspec.Struct):
|
||||
user: str = ""
|
||||
privileged: bool = False
|
||||
error: str = ""
|
||||
|
||||
|
||||
def verify(request, privileged=False):
|
||||
"""Raise Unauthorized or Forbidden if the request is not authorized"""
|
||||
if privileged:
|
||||
if request.ctx.user:
|
||||
if request.ctx.user.privileged: return
|
||||
if request.ctx.user.privileged:
|
||||
return
|
||||
raise Forbidden("Access Forbidden: Only for privileged users")
|
||||
elif config.config.public or request.ctx.user: return
|
||||
elif config.config.public or request.ctx.user:
|
||||
return
|
||||
raise Unauthorized("Login required", "cookie", context={"redirect": "/login"})
|
||||
|
||||
|
||||
bp = Blueprint("auth")
|
||||
|
||||
|
||||
@bp.get("/login")
|
||||
async def login_page(request):
|
||||
doc = Document("Cista Login")
|
||||
with doc.div(id="login"):
|
||||
with doc.form(method="POST", autocomplete="on"):
|
||||
doc.h1("Login")
|
||||
doc.input(name="username", placeholder="Username", autocomplete="username", required=True).br
|
||||
doc.input(type="password", name="password", placeholder="Password", autocomplete="current-password", required=True).br
|
||||
doc.input(
|
||||
name="username",
|
||||
placeholder="Username",
|
||||
autocomplete="username",
|
||||
required=True,
|
||||
).br
|
||||
doc.input(
|
||||
type="password",
|
||||
name="password",
|
||||
placeholder="Password",
|
||||
autocomplete="current-password",
|
||||
required=True,
|
||||
).br
|
||||
doc.input(type="submit", value="Login")
|
||||
s = session.get(request)
|
||||
if s:
|
||||
|
@ -84,7 +104,12 @@ async def login_page(request):
|
|||
doc.input(type="submit", value=f"Logout {name}")
|
||||
flash = request.cookies.message
|
||||
if flash:
|
||||
doc.dialog(flash, id="flash", open=True, style="position: fixed; top: 0; left: 0; width: 100%; opacity: .8")
|
||||
doc.dialog(
|
||||
flash,
|
||||
id="flash",
|
||||
open=True,
|
||||
style="position: fixed; top: 0; left: 0; width: 100%; opacity: .8",
|
||||
)
|
||||
res = html(doc)
|
||||
if flash:
|
||||
res.cookies.delete_cookie("flash")
|
||||
|
@ -92,6 +117,7 @@ async def login_page(request):
|
|||
session.delete(res)
|
||||
return res
|
||||
|
||||
|
||||
@bp.post("/login")
|
||||
async def login_post(request):
|
||||
try:
|
||||
|
@ -118,6 +144,7 @@ async def login_post(request):
|
|||
session.create(res, username)
|
||||
return res
|
||||
|
||||
|
||||
@bp.post("/logout")
|
||||
async def logout_post(request):
|
||||
s = request.ctx.session
|
||||
|
|
|
@ -17,26 +17,28 @@ class Config(msgspec.Struct):
|
|||
users: dict[str, User] = {}
|
||||
links: dict[str, Link] = {}
|
||||
|
||||
|
||||
class User(msgspec.Struct, omit_defaults=True):
|
||||
privileged: bool = False
|
||||
hash: str = ""
|
||||
lastSeen: int = 0
|
||||
|
||||
|
||||
class Link(msgspec.Struct, omit_defaults=True):
|
||||
location: str
|
||||
creator: str = ""
|
||||
expires: int = 0
|
||||
|
||||
|
||||
config = None
|
||||
conffile = Path.home() / ".local/share/cista/db.toml"
|
||||
|
||||
|
||||
def derived_secret(*params, len=8) -> bytes:
|
||||
"""Used to derive secret keys from the main secret"""
|
||||
# Each part is made the same length by hashing first
|
||||
combined = b"".join(
|
||||
sha256(
|
||||
p if isinstance(p, bytes) else f"{p}".encode()
|
||||
).digest()
|
||||
sha256(p if isinstance(p, bytes) else f"{p}".encode()).digest()
|
||||
for p in [config.secret, *params]
|
||||
)
|
||||
# Output a bytes of the desired length
|
||||
|
@ -48,11 +50,13 @@ def enc_hook(obj):
|
|||
return obj.as_posix()
|
||||
raise TypeError
|
||||
|
||||
|
||||
def dec_hook(typ, obj):
|
||||
if typ is Path:
|
||||
return Path(obj)
|
||||
raise TypeError
|
||||
|
||||
|
||||
def config_update(modify):
|
||||
global config
|
||||
if not conffile.exists():
|
||||
|
@ -93,21 +97,28 @@ def config_update(modify):
|
|||
config = c
|
||||
return "modified" if old else "created"
|
||||
|
||||
|
||||
def modifies_config(modify):
|
||||
"""Decorator for functions that modify the config file"""
|
||||
|
||||
@wraps(modify)
|
||||
def wrapper(*args, **kwargs):
|
||||
m = lambda c: modify(c, *args, **kwargs)
|
||||
def m(c):
|
||||
return modify(c, *args, **kwargs)
|
||||
|
||||
# Retry modification in case of write collision
|
||||
while (c := config_update(m)) == "collision":
|
||||
time.sleep(0.01)
|
||||
return c
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def load_config():
|
||||
global config
|
||||
config = msgspec.toml.decode(conffile.read_bytes(), type=Config, dec_hook=dec_hook)
|
||||
|
||||
|
||||
@modifies_config
|
||||
def update_config(conf: Config, changes: dict) -> Config:
|
||||
"""Create/update the config with new values, respecting changes done by others."""
|
||||
|
@ -116,6 +127,7 @@ def update_config(conf: Config, changes: dict) -> Config:
|
|||
settings.update(changes)
|
||||
return msgspec.convert(settings, Config, dec_hook=dec_hook)
|
||||
|
||||
|
||||
@modifies_config
|
||||
def update_user(conf: Config, name: str, changes: dict) -> Config:
|
||||
"""Create/update a user with new values, respecting changes done by others."""
|
||||
|
@ -126,6 +138,7 @@ def update_user(conf: Config, name: str, changes: dict) -> Config:
|
|||
u = User()
|
||||
if "password" in changes:
|
||||
from . import auth
|
||||
|
||||
auth.set_password(u, changes["password"])
|
||||
del changes["password"]
|
||||
udict = msgspec.to_builtins(u, enc_hook=enc_hook)
|
||||
|
@ -134,6 +147,7 @@ def update_user(conf: Config, name: str, changes: dict) -> Config:
|
|||
settings["users"][name] = msgspec.convert(udict, User, dec_hook=dec_hook)
|
||||
return msgspec.convert(settings, Config, dec_hook=dec_hook)
|
||||
|
||||
|
||||
@modifies_config
|
||||
def del_user(conf: Config, name: str) -> Config:
|
||||
"""Delete named user account."""
|
||||
|
|
|
@ -11,6 +11,7 @@ def readconf() -> dict:
|
|||
cf["listen"] = _droppy_listeners(cf)
|
||||
return cf | db
|
||||
|
||||
|
||||
def _droppy_listeners(cf):
|
||||
"""Convert Droppy listeners to our format, for typical cases but not in full."""
|
||||
for listener in cf["listeners"]:
|
||||
|
@ -20,15 +21,19 @@ def _droppy_listeners(cf):
|
|||
continue
|
||||
socket = listener.get("socket")
|
||||
if socket:
|
||||
if isinstance(socket, list): socket = socket[0]
|
||||
if isinstance(socket, list):
|
||||
socket = socket[0]
|
||||
return f"{socket}"
|
||||
port = listener["port"]
|
||||
if isinstance(port, list): port = port[0]
|
||||
if isinstance(port, list):
|
||||
port = port[0]
|
||||
host = listener["host"]
|
||||
if isinstance(host, list): host = host[0]
|
||||
if host in ("127.0.0.1", "::", "localhost"): return f":{port}"
|
||||
if isinstance(host, list):
|
||||
host = host[0]
|
||||
if host in ("127.0.0.1", "::", "localhost"):
|
||||
return f":{port}"
|
||||
return f"{host}:{port}"
|
||||
except (KeyError, IndexError):
|
||||
continue
|
||||
# If none matched, fallback to Droppy default
|
||||
return f"0.0.0.0:8989"
|
||||
return "0.0.0.0:8989"
|
||||
|
|
|
@ -1,7 +1,5 @@
|
|||
import asyncio
|
||||
import os
|
||||
import unicodedata
|
||||
from pathlib import PurePosixPath
|
||||
|
||||
from cista import config
|
||||
from cista.util import filename
|
||||
|
@ -13,6 +11,7 @@ def fuid(stat) -> str:
|
|||
"""Unique file ID. Stays the same on renames and modification."""
|
||||
return config.derived_secret("filekey-inode", stat.st_dev, stat.st_ino).hex()
|
||||
|
||||
|
||||
class File:
|
||||
def __init__(self, filename):
|
||||
self.path = config.config.path / filename
|
||||
|
@ -30,23 +29,24 @@ class File:
|
|||
self.writable = True
|
||||
|
||||
def write(self, pos, buffer, *, file_size=None):
|
||||
assert self.fd is not None
|
||||
if not self.writable:
|
||||
# Create/open file
|
||||
self.open_rw()
|
||||
assert self.fd is not None
|
||||
if file_size is not None:
|
||||
os.ftruncate(self.fd, file_size)
|
||||
os.lseek(self.fd, pos, os.SEEK_SET)
|
||||
os.write(self.fd, buffer)
|
||||
|
||||
def __getitem__(self, slice):
|
||||
assert self.fd is not None
|
||||
if self.fd is None:
|
||||
self.open_ro()
|
||||
assert self.fd is not None
|
||||
os.lseek(self.fd, slice.start, os.SEEK_SET)
|
||||
l = slice.stop - slice.start
|
||||
data = os.read(self.fd, l)
|
||||
if len(data) < l: raise EOFError("Error reading requested range")
|
||||
size = slice.stop - slice.start
|
||||
data = os.read(self.fd, size)
|
||||
if len(data) < size:
|
||||
raise EOFError("Error reading requested range")
|
||||
return data
|
||||
|
||||
def close(self):
|
||||
|
@ -59,10 +59,11 @@ class File:
|
|||
|
||||
|
||||
class FileServer:
|
||||
|
||||
async def start(self):
|
||||
self.alink = AsyncLink()
|
||||
self.worker = asyncio.get_event_loop().run_in_executor(None, self.worker_thread, self.alink.to_sync)
|
||||
self.worker = asyncio.get_event_loop().run_in_executor(
|
||||
None, self.worker_thread, self.alink.to_sync
|
||||
)
|
||||
self.cache = LRUCache(File, capacity=10, maxage=5.0)
|
||||
|
||||
async def stop(self):
|
||||
|
@ -91,4 +92,4 @@ class FileServer:
|
|||
def download(self, name, start, end):
|
||||
name = filename.sanitize(name)
|
||||
f = self.cache[name]
|
||||
return f[start: end]
|
||||
return f[start:end]
|
||||
|
|
|
@ -10,6 +10,7 @@ from cista import config, server80
|
|||
def run(dev=False):
|
||||
"""Run Sanic main process that spawns worker processes to serve HTTP requests."""
|
||||
from .app import app
|
||||
|
||||
url, opts = parse_listen(config.config.listen)
|
||||
# Silence Sanic's warning about running in production rather than debug
|
||||
os.environ["SANIC_IGNORE_PRODUCTION_WARNING"] = "1"
|
||||
|
@ -21,20 +22,33 @@ def run(dev=False):
|
|||
domain = opts["host"]
|
||||
check_cert(confdir / domain, domain)
|
||||
opts["ssl"] = str(confdir / domain) # type: ignore
|
||||
app.prepare(**opts, motd=False, dev=dev, auto_reload=dev, reload_dir={confdir, wwwroot}, access_log=True) # type: ignore
|
||||
app.prepare(
|
||||
**opts,
|
||||
motd=False,
|
||||
dev=dev,
|
||||
auto_reload=dev,
|
||||
reload_dir={confdir, wwwroot},
|
||||
access_log=True,
|
||||
) # type: ignore
|
||||
Sanic.serve()
|
||||
|
||||
|
||||
def check_cert(certdir, domain):
|
||||
if (certdir / "privkey.pem").exist() and (certdir / "fullchain.pem").exists():
|
||||
return
|
||||
# TODO: Use certbot to fetch a cert
|
||||
raise ValueError(f"TLS certificate files privkey.pem and fullchain.pem needed in {certdir}")
|
||||
raise ValueError(
|
||||
f"TLS certificate files privkey.pem and fullchain.pem needed in {certdir}"
|
||||
)
|
||||
|
||||
|
||||
def parse_listen(listen):
|
||||
if listen.startswith("/"):
|
||||
unix = Path(listen).resolve()
|
||||
if not unix.parent.exists():
|
||||
raise ValueError(f"Directory for unix socket does not exist: {unix.parent}/")
|
||||
raise ValueError(
|
||||
f"Directory for unix socket does not exist: {unix.parent}/"
|
||||
)
|
||||
return "http://localhost", {"unix": unix}
|
||||
elif re.fullmatch(r"(\w+(-\w+)*\.)+\w{2,}", listen, re.UNICODE):
|
||||
return f"https://{listen}", {"host": listen, "port": 443, "ssl": True}
|
||||
|
|
|
@ -2,6 +2,7 @@ from sanic import Sanic, exceptions, response
|
|||
|
||||
app = Sanic("server80")
|
||||
|
||||
|
||||
# Send all HTTP users to HTTPS
|
||||
@app.exception(exceptions.NotFound, exceptions.MethodNotSupported)
|
||||
def redirect_everything_else(request, exception):
|
||||
|
@ -10,6 +11,7 @@ def redirect_everything_else(request, exception):
|
|||
return response.redirect(f"https://{server}{path}", status=308)
|
||||
return response.text("Bad Request. Please use HTTPS!", status=400)
|
||||
|
||||
|
||||
# ACME challenge for LetsEncrypt
|
||||
@app.get("/.well-known/acme-challenge/<challenge>")
|
||||
async def letsencrypt(request, challenge):
|
||||
|
@ -18,4 +20,5 @@ async def letsencrypt(request, challenge):
|
|||
except KeyError:
|
||||
return response.text(f"ACME challenge not found: {challenge}", status=404)
|
||||
|
||||
|
||||
acme_challenges = {}
|
||||
|
|
|
@ -4,16 +4,21 @@ import jwt
|
|||
|
||||
from cista.config import derived_secret
|
||||
|
||||
session_secret = lambda: derived_secret("session")
|
||||
|
||||
def session_secret():
|
||||
return derived_secret("session")
|
||||
|
||||
|
||||
max_age = 365 * 86400 # Seconds since last login
|
||||
|
||||
|
||||
def get(request):
|
||||
try:
|
||||
return jwt.decode(request.cookies.s, session_secret(), algorithms=["HS256"])
|
||||
except Exception as e:
|
||||
s = None
|
||||
except Exception:
|
||||
return False if "s" in request.cookies else None
|
||||
|
||||
|
||||
def create(res, username, **kwargs):
|
||||
data = {
|
||||
"exp": int(time()) + max_age,
|
||||
|
@ -23,12 +28,14 @@ def create(res, username, **kwargs):
|
|||
s = jwt.encode(data, session_secret())
|
||||
res.cookies.add_cookie("s", s, httponly=True, max_age=max_age)
|
||||
|
||||
|
||||
def update(res, s, **kwargs):
|
||||
s.update(kwargs)
|
||||
s = jwt.encode(s, session_secret())
|
||||
max_age = max(1, s["exp"] - int(time())) # type: ignore
|
||||
res.cookies.add_cookie("s", s, httponly=True, max_age=max_age)
|
||||
|
||||
|
||||
def delete(res):
|
||||
res.cookies.delete_cookie("s")
|
||||
|
||||
|
|
|
@ -14,10 +14,12 @@ def asend(ws, msg):
|
|||
"""Send JSON message or bytes to a websocket"""
|
||||
return ws.send(msg if isinstance(msg, bytes) else msgspec.json.encode(msg).decode())
|
||||
|
||||
|
||||
def jres(data, **kwargs):
|
||||
"""JSON Sanic response, using msgspec encoding"""
|
||||
return raw(msgspec.json.encode(data), content_type="application/json", **kwargs)
|
||||
|
||||
|
||||
async def handle_sanic_exception(request, e):
|
||||
logger.exception(e)
|
||||
context, code = {}, 500
|
||||
|
@ -30,7 +32,9 @@ async def handle_sanic_exception(request, e):
|
|||
message = f"⚠️ {message}" if code < 500 else f"🛑 {message}"
|
||||
# Non-browsers get JSON errors
|
||||
if "text/html" not in request.headers.accept:
|
||||
return jres(ErrorMsg({"code": code, "message": message, **context}), status=code)
|
||||
return jres(
|
||||
ErrorMsg({"code": code, "message": message, **context}), status=code
|
||||
)
|
||||
# Redirections flash the error message via cookies
|
||||
if "redirect" in context:
|
||||
res = redirect(context["redirect"])
|
||||
|
@ -39,8 +43,10 @@ async def handle_sanic_exception(request, e):
|
|||
# Otherwise use Sanic's default error page
|
||||
return errorpages.HTMLRenderer(request, e, debug=request.app.debug).full()
|
||||
|
||||
|
||||
def websocket_wrapper(handler):
|
||||
"""Decorator for websocket handlers that catches exceptions and sends them back to the client"""
|
||||
|
||||
@wraps(handler)
|
||||
async def wrapper(request, ws, *args, **kwargs):
|
||||
try:
|
||||
|
@ -55,4 +61,5 @@ def websocket_wrapper(handler):
|
|||
message = f"⚠️ {message}" if code < 500 else f"🛑 {message}"
|
||||
await asend(ws, ErrorMsg({"code": code, "message": message, **context}))
|
||||
raise
|
||||
|
||||
return wrapper
|
||||
|
|
|
@ -11,6 +11,7 @@ class LRUCache:
|
|||
maxage (float): Max age for items in cache in seconds.
|
||||
cache (list): Internal list storing the cache items.
|
||||
"""
|
||||
|
||||
def __init__(self, open: callable, *, capacity: int, maxage: float):
|
||||
"""
|
||||
Initialize LRUCache.
|
||||
|
|
|
@ -6,6 +6,7 @@ def generate(n=4):
|
|||
wl = list(words)
|
||||
return ".".join(wl.pop(secrets.randbelow(len(wl))) for i in range(n))
|
||||
|
||||
|
||||
# A custom list of 1024 common 3-6 letter words, with unique 3-prefixes and no prefix words, entropy 2.1b/letter 10b/word
|
||||
words: list = """
|
||||
able about absent abuse access acid across act adapt add adjust admit adult advice affair afraid again age agree ahead
|
||||
|
|
|
@ -13,11 +13,13 @@ def setup_temp_dir():
|
|||
config.config = config.Config(path=Path(tmpdirname), listen=":0")
|
||||
yield Path(tmpdirname)
|
||||
|
||||
|
||||
def test_mkdir(setup_temp_dir):
|
||||
cmd = MkDir(path="new_folder")
|
||||
cmd()
|
||||
assert (setup_temp_dir / "new_folder").is_dir()
|
||||
|
||||
|
||||
def test_rename(setup_temp_dir):
|
||||
(setup_temp_dir / "old_name").mkdir()
|
||||
cmd = Rename(path="old_name", to="new_name")
|
||||
|
@ -25,12 +27,14 @@ def test_rename(setup_temp_dir):
|
|||
assert not (setup_temp_dir / "old_name").exists()
|
||||
assert (setup_temp_dir / "new_name").is_dir()
|
||||
|
||||
|
||||
def test_rm(setup_temp_dir):
|
||||
(setup_temp_dir / "folder_to_remove").mkdir()
|
||||
cmd = Rm(sel=["folder_to_remove"])
|
||||
cmd()
|
||||
assert not (setup_temp_dir / "folder_to_remove").exists()
|
||||
|
||||
|
||||
def test_mv(setup_temp_dir):
|
||||
(setup_temp_dir / "folder_to_move").mkdir()
|
||||
(setup_temp_dir / "destination").mkdir()
|
||||
|
@ -39,6 +43,7 @@ def test_mv(setup_temp_dir):
|
|||
assert not (setup_temp_dir / "folder_to_move").exists()
|
||||
assert (setup_temp_dir / "destination" / "folder_to_move").is_dir()
|
||||
|
||||
|
||||
def test_cp(setup_temp_dir):
|
||||
(setup_temp_dir / "folder_to_copy").mkdir()
|
||||
(setup_temp_dir / "destination").mkdir()
|
||||
|
|
|
@ -1,8 +1,6 @@
|
|||
from time import sleep
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from cista.util.lrucache import LRUCache # Replace with actual import
|
||||
|
||||
|
||||
|
@ -12,24 +10,28 @@ def mock_open(key):
|
|||
mock.content = f"content-{key}"
|
||||
return mock
|
||||
|
||||
|
||||
def test_contains():
|
||||
cache = LRUCache(open=mock_open, capacity=2, maxage=10)
|
||||
assert "key1" not in cache
|
||||
cache["key1"]
|
||||
assert "key1" in cache
|
||||
|
||||
|
||||
def test_getitem():
|
||||
cache = LRUCache(open=mock_open, capacity=2, maxage=10)
|
||||
assert cache["key1"].content == "content-key1"
|
||||
|
||||
|
||||
def test_capacity():
|
||||
cache = LRUCache(open=mock_open, capacity=2, maxage=10)
|
||||
item1 = cache["key1"]
|
||||
item2 = cache["key2"]
|
||||
cache["key2"]
|
||||
cache["key3"]
|
||||
assert "key1" not in cache
|
||||
item1.close.assert_called()
|
||||
|
||||
|
||||
def test_expiry():
|
||||
cache = LRUCache(open=mock_open, capacity=2, maxage=0.1)
|
||||
item = cache["key1"]
|
||||
|
@ -38,6 +40,7 @@ def test_expiry():
|
|||
assert "key1" not in cache
|
||||
item.close.assert_called()
|
||||
|
||||
|
||||
def test_close():
|
||||
cache = LRUCache(open=mock_open, capacity=2, maxage=10)
|
||||
item = cache["key1"]
|
||||
|
@ -45,6 +48,7 @@ def test_close():
|
|||
assert "key1" not in cache
|
||||
item.close.assert_called()
|
||||
|
||||
|
||||
def test_lru_mechanism():
|
||||
cache = LRUCache(open=mock_open, capacity=2, maxage=10)
|
||||
item1 = cache["key1"]
|
||||
|
|
Loading…
Reference in New Issue
Block a user