Frontend created and rewritten a few times, with some backend fixes #1

Merged
leo merged 110 commits from plaintable into main 2023-11-08 20:38:40 +00:00
16 changed files with 176 additions and 52 deletions
Showing only changes of commit acdd776b92 - Show all commits

View File

@ -1 +1,3 @@
from cista._version import __version__ from cista._version import __version__
__version__ # Public API

View File

@ -34,6 +34,7 @@ User management:
--password Reset password --password Reset password
""" """
def main(): def main():
# Dev mode doesn't catch exceptions # Dev mode doesn't catch exceptions
if "--dev" in sys.argv: if "--dev" in sys.argv:
@ -45,6 +46,7 @@ def main():
print("Error:", e) print("Error:", e)
return 1 return 1
def _main(): def _main():
args = docopt(doc) args = docopt(doc)
if args["--user"]: if args["--user"]:
@ -64,19 +66,25 @@ def _main():
if not necessary_opts: if not necessary_opts:
# Maybe run without arguments # Maybe run without arguments
print(doc) print(doc)
print("No config file found! Get started with:\n cista -l :8000 /path/to/files, or\n cista -l example.com --import-droppy # Uses Droppy files\n") print(
"No config file found! Get started with:\n cista -l :8000 /path/to/files, or\n cista -l example.com --import-droppy # Uses Droppy files\n"
)
return 1 return 1
settings = {} settings = {}
if import_droppy: if import_droppy:
if exists: if exists:
raise ValueError(f"Importing Droppy: First remove the existing configuration:\n rm {config.conffile}") raise ValueError(
f"Importing Droppy: First remove the existing configuration:\n rm {config.conffile}"
)
settings = droppy.readconf() settings = droppy.readconf()
if path: settings["path"] = path if path:
if listen: settings["listen"] = listen settings["path"] = path
if listen:
settings["listen"] = listen
operation = config.update_config(settings) operation = config.update_config(settings)
print(f"Config {operation}: {config.conffile}") print(f"Config {operation}: {config.conffile}")
# Prepare to serve # Prepare to serve
domain = unix = port = None unix = None
url, _ = serve.parse_listen(config.config.listen) url, _ = serve.parse_listen(config.config.listen)
if not config.config.path.is_dir(): if not config.config.path.is_dir():
raise ValueError(f"No such directory: {config.config.path}") raise ValueError(f"No such directory: {config.config.path}")
@ -88,6 +96,7 @@ def _main():
# Run the server # Run the server
serve.run(dev=dev) serve.run(dev=dev)
def _confdir(args): def _confdir(args):
if args["-c"]: if args["-c"]:
# Custom config directory # Custom config directory
@ -99,6 +108,7 @@ def _confdir(args):
confdir = confdir.parent confdir = confdir.parent
config.conffile = config.conffile.with_parent(confdir) config.conffile = config.conffile.with_parent(confdir)
def _user(args): def _user(args):
_confdir(args) _confdir(args)
config.load_config() config.load_config()
@ -123,5 +133,6 @@ def _user(args):
if res == "read": if res == "read":
print(" No changes") print(" No changes")
if __name__ == "__main__": if __name__ == "__main__":
sys.exit(main()) sys.exit(main())

View File

@ -12,15 +12,18 @@ from cista.util.apphelpers import asend, websocket_wrapper
bp = Blueprint("api", url_prefix="/api") bp = Blueprint("api", url_prefix="/api")
fileserver = FileServer() fileserver = FileServer()
@bp.before_server_start @bp.before_server_start
async def start_fileserver(app, _): async def start_fileserver(app, _):
await fileserver.start() await fileserver.start()
@bp.after_server_stop @bp.after_server_stop
async def stop_fileserver(app, _): async def stop_fileserver(app, _):
await fileserver.stop() await fileserver.stop()
@bp.websocket('upload')
@bp.websocket("upload")
@websocket_wrapper @websocket_wrapper
async def upload(req, ws): async def upload(req, ws):
alink = fileserver.alink alink = fileserver.alink
@ -28,7 +31,9 @@ async def upload(req, ws):
req = None req = None
text = await ws.recv() text = await ws.recv()
if not isinstance(text, str): if not isinstance(text, str):
raise ValueError(f"Expected JSON control, got binary len(data) = {len(text)}") raise ValueError(
f"Expected JSON control, got binary len(data) = {len(text)}"
)
req = msgspec.json.decode(text, type=FileRange) req = msgspec.json.decode(text, type=FileRange)
pos = req.start pos = req.start
data = None data = None
@ -41,9 +46,10 @@ async def upload(req, ws):
# Report success # Report success
res = StatusMsg(status="ack", req=req) res = StatusMsg(status="ack", req=req)
await asend(ws, res) await asend(ws, res)
#await ws.drain() # await ws.drain()
@bp.websocket('download')
@bp.websocket("download")
@websocket_wrapper @websocket_wrapper
async def download(req, ws): async def download(req, ws):
alink = fileserver.alink alink = fileserver.alink
@ -51,18 +57,21 @@ async def download(req, ws):
req = None req = None
text = await ws.recv() text = await ws.recv()
if not isinstance(text, str): if not isinstance(text, str):
raise ValueError(f"Expected JSON control, got binary len(data) = {len(text)}") raise ValueError(
f"Expected JSON control, got binary len(data) = {len(text)}"
)
req = msgspec.json.decode(text, type=FileRange) req = msgspec.json.decode(text, type=FileRange)
pos = req.start pos = req.start
while pos < req.end: while pos < req.end:
end = min(req.end, pos + (1<<20)) end = min(req.end, pos + (1 << 20))
data = typing.cast(bytes, await alink(("download", req.name, pos, end))) data = typing.cast(bytes, await alink(("download", req.name, pos, end)))
await asend(ws, data) await asend(ws, data)
pos += len(data) pos += len(data)
# Report success # Report success
res = StatusMsg(status="ack", req=req) res = StatusMsg(status="ack", req=req)
await asend(ws, res) await asend(ws, res)
#await ws.drain() # await ws.drain()
@bp.websocket("control") @bp.websocket("control")
@websocket_wrapper @websocket_wrapper
@ -71,6 +80,7 @@ async def control(req, ws):
await asyncio.to_thread(cmd) await asyncio.to_thread(cmd)
await asend(ws, StatusMsg(status="ack", req=cmd)) await asend(ws, StatusMsg(status="ack", req=cmd))
@bp.websocket("watch") @bp.websocket("watch")
@websocket_wrapper @websocket_wrapper
async def watch(req, ws): async def watch(req, ws):

View File

@ -2,7 +2,6 @@ import mimetypes
from importlib.resources import files from importlib.resources import files
from urllib.parse import unquote from urllib.parse import unquote
from html5tagger import E
from sanic import Blueprint, Sanic, raw from sanic import Blueprint, Sanic, raw
from sanic.exceptions import Forbidden, NotFound from sanic.exceptions import Forbidden, NotFound
@ -16,15 +15,18 @@ app.blueprint(auth.bp)
app.blueprint(bp) app.blueprint(bp)
app.exception(Exception)(handle_sanic_exception) app.exception(Exception)(handle_sanic_exception)
@app.before_server_start @app.before_server_start
async def main_start(app, loop): async def main_start(app, loop):
config.load_config() config.load_config()
await watching.start(app, loop) await watching.start(app, loop)
@app.after_server_stop @app.after_server_stop
async def main_stop(app, loop): async def main_stop(app, loop):
await watching.stop(app, loop) await watching.stop(app, loop)
@app.on_request @app.on_request
async def use_session(req): async def use_session(req):
req.ctx.session = session.get(req) req.ctx.session = session.get(req)
@ -41,13 +43,21 @@ async def use_session(req):
if origin and origin.split("//", 1)[1] != req.host: if origin and origin.split("//", 1)[1] != req.host:
raise Forbidden("Invalid origin: Cross-Site requests not permitted") raise Forbidden("Invalid origin: Cross-Site requests not permitted")
@app.before_server_start @app.before_server_start
def http_fileserver(app, _): def http_fileserver(app, _):
bp = Blueprint("fileserver") bp = Blueprint("fileserver")
bp.on_request(auth.verify) bp.on_request(auth.verify)
bp.static("/files/", config.config.path, use_content_range=True, stream_large_files=True, directory_view=True) bp.static(
"/files/",
config.config.path,
use_content_range=True,
stream_large_files=True,
directory_view=True,
)
app.blueprint(bp) app.blueprint(bp)
@app.get("/<path:path>", static=True) @app.get("/<path:path>", static=True)
async def wwwroot(req, path=""): async def wwwroot(req, path=""):
"""Frontend files only""" """Frontend files only"""
@ -55,6 +65,8 @@ async def wwwroot(req, path=""):
try: try:
index = files("cista").joinpath("wwwroot", name).read_bytes() index = files("cista").joinpath("wwwroot", name).read_bytes()
except OSError as e: except OSError as e:
raise NotFound(f"File not found: /{path}", extra={"name": name, "exception": repr(e)}) raise NotFound(
f"File not found: /{path}", extra={"name": name, "exception": repr(e)}
)
mime = mimetypes.guess_type(name)[0] or "application/octet-stream" mime = mimetypes.guess_type(name)[0] or "application/octet-stream"
return raw(index, content_type=mime) return raw(index, content_type=mime)

View File

@ -12,10 +12,12 @@ from sanic.exceptions import BadRequest, Forbidden, Unauthorized
from cista import config, session from cista import config, session
_argon = argon2.PasswordHasher() _argon = argon2.PasswordHasher()
_droppyhash = re.compile(r'^([a-f0-9]{64})\$([a-f0-9]{8})$') _droppyhash = re.compile(r"^([a-f0-9]{64})\$([a-f0-9]{8})$")
def _pwnorm(password): def _pwnorm(password):
return normalize('NFC', password).strip().encode() return normalize("NFC", password).strip().encode()
def login(username: str, password: str): def login(username: str, password: str):
un = _pwnorm(username) un = _pwnorm(username)
@ -49,33 +51,51 @@ def login(username: str, password: str):
u.lastSeen = now u.lastSeen = now
return u return u
def set_password(user: config.User, password: str): def set_password(user: config.User, password: str):
user.hash = _argon.hash(_pwnorm(password)) user.hash = _argon.hash(_pwnorm(password))
class LoginResponse(msgspec.Struct): class LoginResponse(msgspec.Struct):
user: str = "" user: str = ""
privileged: bool = False privileged: bool = False
error: str = "" error: str = ""
def verify(request, privileged=False): def verify(request, privileged=False):
"""Raise Unauthorized or Forbidden if the request is not authorized""" """Raise Unauthorized or Forbidden if the request is not authorized"""
if privileged: if privileged:
if request.ctx.user: if request.ctx.user:
if request.ctx.user.privileged: return if request.ctx.user.privileged:
return
raise Forbidden("Access Forbidden: Only for privileged users") raise Forbidden("Access Forbidden: Only for privileged users")
elif config.config.public or request.ctx.user: return elif config.config.public or request.ctx.user:
return
raise Unauthorized("Login required", "cookie", context={"redirect": "/login"}) raise Unauthorized("Login required", "cookie", context={"redirect": "/login"})
bp = Blueprint("auth") bp = Blueprint("auth")
@bp.get("/login") @bp.get("/login")
async def login_page(request): async def login_page(request):
doc = Document("Cista Login") doc = Document("Cista Login")
with doc.div(id="login"): with doc.div(id="login"):
with doc.form(method="POST", autocomplete="on"): with doc.form(method="POST", autocomplete="on"):
doc.h1("Login") doc.h1("Login")
doc.input(name="username", placeholder="Username", autocomplete="username", required=True).br doc.input(
doc.input(type="password", name="password", placeholder="Password", autocomplete="current-password", required=True).br name="username",
placeholder="Username",
autocomplete="username",
required=True,
).br
doc.input(
type="password",
name="password",
placeholder="Password",
autocomplete="current-password",
required=True,
).br
doc.input(type="submit", value="Login") doc.input(type="submit", value="Login")
s = session.get(request) s = session.get(request)
if s: if s:
@ -84,7 +104,12 @@ async def login_page(request):
doc.input(type="submit", value=f"Logout {name}") doc.input(type="submit", value=f"Logout {name}")
flash = request.cookies.message flash = request.cookies.message
if flash: if flash:
doc.dialog(flash, id="flash", open=True, style="position: fixed; top: 0; left: 0; width: 100%; opacity: .8") doc.dialog(
flash,
id="flash",
open=True,
style="position: fixed; top: 0; left: 0; width: 100%; opacity: .8",
)
res = html(doc) res = html(doc)
if flash: if flash:
res.cookies.delete_cookie("flash") res.cookies.delete_cookie("flash")
@ -92,6 +117,7 @@ async def login_page(request):
session.delete(res) session.delete(res)
return res return res
@bp.post("/login") @bp.post("/login")
async def login_post(request): async def login_post(request):
try: try:
@ -118,6 +144,7 @@ async def login_post(request):
session.create(res, username) session.create(res, username)
return res return res
@bp.post("/logout") @bp.post("/logout")
async def logout_post(request): async def logout_post(request):
s = request.ctx.session s = request.ctx.session

View File

@ -17,26 +17,28 @@ class Config(msgspec.Struct):
users: dict[str, User] = {} users: dict[str, User] = {}
links: dict[str, Link] = {} links: dict[str, Link] = {}
class User(msgspec.Struct, omit_defaults=True): class User(msgspec.Struct, omit_defaults=True):
privileged: bool = False privileged: bool = False
hash: str = "" hash: str = ""
lastSeen: int = 0 lastSeen: int = 0
class Link(msgspec.Struct, omit_defaults=True): class Link(msgspec.Struct, omit_defaults=True):
location: str location: str
creator: str = "" creator: str = ""
expires: int = 0 expires: int = 0
config = None config = None
conffile = Path.home() / ".local/share/cista/db.toml" conffile = Path.home() / ".local/share/cista/db.toml"
def derived_secret(*params, len=8) -> bytes: def derived_secret(*params, len=8) -> bytes:
"""Used to derive secret keys from the main secret""" """Used to derive secret keys from the main secret"""
# Each part is made the same length by hashing first # Each part is made the same length by hashing first
combined = b"".join( combined = b"".join(
sha256( sha256(p if isinstance(p, bytes) else f"{p}".encode()).digest()
p if isinstance(p, bytes) else f"{p}".encode()
).digest()
for p in [config.secret, *params] for p in [config.secret, *params]
) )
# Output a bytes of the desired length # Output a bytes of the desired length
@ -48,11 +50,13 @@ def enc_hook(obj):
return obj.as_posix() return obj.as_posix()
raise TypeError raise TypeError
def dec_hook(typ, obj): def dec_hook(typ, obj):
if typ is Path: if typ is Path:
return Path(obj) return Path(obj)
raise TypeError raise TypeError
def config_update(modify): def config_update(modify):
global config global config
if not conffile.exists(): if not conffile.exists():
@ -93,21 +97,28 @@ def config_update(modify):
config = c config = c
return "modified" if old else "created" return "modified" if old else "created"
def modifies_config(modify): def modifies_config(modify):
"""Decorator for functions that modify the config file""" """Decorator for functions that modify the config file"""
@wraps(modify) @wraps(modify)
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
m = lambda c: modify(c, *args, **kwargs) def m(c):
return modify(c, *args, **kwargs)
# Retry modification in case of write collision # Retry modification in case of write collision
while (c := config_update(m)) == "collision": while (c := config_update(m)) == "collision":
time.sleep(0.01) time.sleep(0.01)
return c return c
return wrapper return wrapper
def load_config(): def load_config():
global config global config
config = msgspec.toml.decode(conffile.read_bytes(), type=Config, dec_hook=dec_hook) config = msgspec.toml.decode(conffile.read_bytes(), type=Config, dec_hook=dec_hook)
@modifies_config @modifies_config
def update_config(conf: Config, changes: dict) -> Config: def update_config(conf: Config, changes: dict) -> Config:
"""Create/update the config with new values, respecting changes done by others.""" """Create/update the config with new values, respecting changes done by others."""
@ -116,6 +127,7 @@ def update_config(conf: Config, changes: dict) -> Config:
settings.update(changes) settings.update(changes)
return msgspec.convert(settings, Config, dec_hook=dec_hook) return msgspec.convert(settings, Config, dec_hook=dec_hook)
@modifies_config @modifies_config
def update_user(conf: Config, name: str, changes: dict) -> Config: def update_user(conf: Config, name: str, changes: dict) -> Config:
"""Create/update a user with new values, respecting changes done by others.""" """Create/update a user with new values, respecting changes done by others."""
@ -126,6 +138,7 @@ def update_user(conf: Config, name: str, changes: dict) -> Config:
u = User() u = User()
if "password" in changes: if "password" in changes:
from . import auth from . import auth
auth.set_password(u, changes["password"]) auth.set_password(u, changes["password"])
del changes["password"] del changes["password"]
udict = msgspec.to_builtins(u, enc_hook=enc_hook) udict = msgspec.to_builtins(u, enc_hook=enc_hook)
@ -134,6 +147,7 @@ def update_user(conf: Config, name: str, changes: dict) -> Config:
settings["users"][name] = msgspec.convert(udict, User, dec_hook=dec_hook) settings["users"][name] = msgspec.convert(udict, User, dec_hook=dec_hook)
return msgspec.convert(settings, Config, dec_hook=dec_hook) return msgspec.convert(settings, Config, dec_hook=dec_hook)
@modifies_config @modifies_config
def del_user(conf: Config, name: str) -> Config: def del_user(conf: Config, name: str) -> Config:
"""Delete named user account.""" """Delete named user account."""

View File

@ -11,6 +11,7 @@ def readconf() -> dict:
cf["listen"] = _droppy_listeners(cf) cf["listen"] = _droppy_listeners(cf)
return cf | db return cf | db
def _droppy_listeners(cf): def _droppy_listeners(cf):
"""Convert Droppy listeners to our format, for typical cases but not in full.""" """Convert Droppy listeners to our format, for typical cases but not in full."""
for listener in cf["listeners"]: for listener in cf["listeners"]:
@ -20,15 +21,19 @@ def _droppy_listeners(cf):
continue continue
socket = listener.get("socket") socket = listener.get("socket")
if socket: if socket:
if isinstance(socket, list): socket = socket[0] if isinstance(socket, list):
socket = socket[0]
return f"{socket}" return f"{socket}"
port = listener["port"] port = listener["port"]
if isinstance(port, list): port = port[0] if isinstance(port, list):
port = port[0]
host = listener["host"] host = listener["host"]
if isinstance(host, list): host = host[0] if isinstance(host, list):
if host in ("127.0.0.1", "::", "localhost"): return f":{port}" host = host[0]
if host in ("127.0.0.1", "::", "localhost"):
return f":{port}"
return f"{host}:{port}" return f"{host}:{port}"
except (KeyError, IndexError): except (KeyError, IndexError):
continue continue
# If none matched, fallback to Droppy default # If none matched, fallback to Droppy default
return f"0.0.0.0:8989" return "0.0.0.0:8989"

View File

@ -1,7 +1,5 @@
import asyncio import asyncio
import os import os
import unicodedata
from pathlib import PurePosixPath
from cista import config from cista import config
from cista.util import filename from cista.util import filename
@ -13,6 +11,7 @@ def fuid(stat) -> str:
"""Unique file ID. Stays the same on renames and modification.""" """Unique file ID. Stays the same on renames and modification."""
return config.derived_secret("filekey-inode", stat.st_dev, stat.st_ino).hex() return config.derived_secret("filekey-inode", stat.st_dev, stat.st_ino).hex()
class File: class File:
def __init__(self, filename): def __init__(self, filename):
self.path = config.config.path / filename self.path = config.config.path / filename
@ -30,23 +29,24 @@ class File:
self.writable = True self.writable = True
def write(self, pos, buffer, *, file_size=None): def write(self, pos, buffer, *, file_size=None):
assert self.fd is not None
if not self.writable: if not self.writable:
# Create/open file # Create/open file
self.open_rw() self.open_rw()
assert self.fd is not None
if file_size is not None: if file_size is not None:
os.ftruncate(self.fd, file_size) os.ftruncate(self.fd, file_size)
os.lseek(self.fd, pos, os.SEEK_SET) os.lseek(self.fd, pos, os.SEEK_SET)
os.write(self.fd, buffer) os.write(self.fd, buffer)
def __getitem__(self, slice): def __getitem__(self, slice):
assert self.fd is not None
if self.fd is None: if self.fd is None:
self.open_ro() self.open_ro()
assert self.fd is not None
os.lseek(self.fd, slice.start, os.SEEK_SET) os.lseek(self.fd, slice.start, os.SEEK_SET)
l = slice.stop - slice.start size = slice.stop - slice.start
data = os.read(self.fd, l) data = os.read(self.fd, size)
if len(data) < l: raise EOFError("Error reading requested range") if len(data) < size:
raise EOFError("Error reading requested range")
return data return data
def close(self): def close(self):
@ -59,10 +59,11 @@ class File:
class FileServer: class FileServer:
async def start(self): async def start(self):
self.alink = AsyncLink() self.alink = AsyncLink()
self.worker = asyncio.get_event_loop().run_in_executor(None, self.worker_thread, self.alink.to_sync) self.worker = asyncio.get_event_loop().run_in_executor(
None, self.worker_thread, self.alink.to_sync
)
self.cache = LRUCache(File, capacity=10, maxage=5.0) self.cache = LRUCache(File, capacity=10, maxage=5.0)
async def stop(self): async def stop(self):
@ -91,4 +92,4 @@ class FileServer:
def download(self, name, start, end): def download(self, name, start, end):
name = filename.sanitize(name) name = filename.sanitize(name)
f = self.cache[name] f = self.cache[name]
return f[start: end] return f[start:end]

View File

@ -10,6 +10,7 @@ from cista import config, server80
def run(dev=False): def run(dev=False):
"""Run Sanic main process that spawns worker processes to serve HTTP requests.""" """Run Sanic main process that spawns worker processes to serve HTTP requests."""
from .app import app from .app import app
url, opts = parse_listen(config.config.listen) url, opts = parse_listen(config.config.listen)
# Silence Sanic's warning about running in production rather than debug # Silence Sanic's warning about running in production rather than debug
os.environ["SANIC_IGNORE_PRODUCTION_WARNING"] = "1" os.environ["SANIC_IGNORE_PRODUCTION_WARNING"] = "1"
@ -21,20 +22,33 @@ def run(dev=False):
domain = opts["host"] domain = opts["host"]
check_cert(confdir / domain, domain) check_cert(confdir / domain, domain)
opts["ssl"] = str(confdir / domain) # type: ignore opts["ssl"] = str(confdir / domain) # type: ignore
app.prepare(**opts, motd=False, dev=dev, auto_reload=dev, reload_dir={confdir, wwwroot}, access_log=True) # type: ignore app.prepare(
**opts,
motd=False,
dev=dev,
auto_reload=dev,
reload_dir={confdir, wwwroot},
access_log=True,
) # type: ignore
Sanic.serve() Sanic.serve()
def check_cert(certdir, domain): def check_cert(certdir, domain):
if (certdir / "privkey.pem").exist() and (certdir / "fullchain.pem").exists(): if (certdir / "privkey.pem").exist() and (certdir / "fullchain.pem").exists():
return return
# TODO: Use certbot to fetch a cert # TODO: Use certbot to fetch a cert
raise ValueError(f"TLS certificate files privkey.pem and fullchain.pem needed in {certdir}") raise ValueError(
f"TLS certificate files privkey.pem and fullchain.pem needed in {certdir}"
)
def parse_listen(listen): def parse_listen(listen):
if listen.startswith("/"): if listen.startswith("/"):
unix = Path(listen).resolve() unix = Path(listen).resolve()
if not unix.parent.exists(): if not unix.parent.exists():
raise ValueError(f"Directory for unix socket does not exist: {unix.parent}/") raise ValueError(
f"Directory for unix socket does not exist: {unix.parent}/"
)
return "http://localhost", {"unix": unix} return "http://localhost", {"unix": unix}
elif re.fullmatch(r"(\w+(-\w+)*\.)+\w{2,}", listen, re.UNICODE): elif re.fullmatch(r"(\w+(-\w+)*\.)+\w{2,}", listen, re.UNICODE):
return f"https://{listen}", {"host": listen, "port": 443, "ssl": True} return f"https://{listen}", {"host": listen, "port": 443, "ssl": True}

View File

@ -2,6 +2,7 @@ from sanic import Sanic, exceptions, response
app = Sanic("server80") app = Sanic("server80")
# Send all HTTP users to HTTPS # Send all HTTP users to HTTPS
@app.exception(exceptions.NotFound, exceptions.MethodNotSupported) @app.exception(exceptions.NotFound, exceptions.MethodNotSupported)
def redirect_everything_else(request, exception): def redirect_everything_else(request, exception):
@ -10,6 +11,7 @@ def redirect_everything_else(request, exception):
return response.redirect(f"https://{server}{path}", status=308) return response.redirect(f"https://{server}{path}", status=308)
return response.text("Bad Request. Please use HTTPS!", status=400) return response.text("Bad Request. Please use HTTPS!", status=400)
# ACME challenge for LetsEncrypt # ACME challenge for LetsEncrypt
@app.get("/.well-known/acme-challenge/<challenge>") @app.get("/.well-known/acme-challenge/<challenge>")
async def letsencrypt(request, challenge): async def letsencrypt(request, challenge):
@ -18,4 +20,5 @@ async def letsencrypt(request, challenge):
except KeyError: except KeyError:
return response.text(f"ACME challenge not found: {challenge}", status=404) return response.text(f"ACME challenge not found: {challenge}", status=404)
acme_challenges = {} acme_challenges = {}

View File

@ -4,16 +4,21 @@ import jwt
from cista.config import derived_secret from cista.config import derived_secret
session_secret = lambda: derived_secret("session")
def session_secret():
return derived_secret("session")
max_age = 365 * 86400 # Seconds since last login max_age = 365 * 86400 # Seconds since last login
def get(request): def get(request):
try: try:
return jwt.decode(request.cookies.s, session_secret(), algorithms=["HS256"]) return jwt.decode(request.cookies.s, session_secret(), algorithms=["HS256"])
except Exception as e: except Exception:
s = None
return False if "s" in request.cookies else None return False if "s" in request.cookies else None
def create(res, username, **kwargs): def create(res, username, **kwargs):
data = { data = {
"exp": int(time()) + max_age, "exp": int(time()) + max_age,
@ -23,12 +28,14 @@ def create(res, username, **kwargs):
s = jwt.encode(data, session_secret()) s = jwt.encode(data, session_secret())
res.cookies.add_cookie("s", s, httponly=True, max_age=max_age) res.cookies.add_cookie("s", s, httponly=True, max_age=max_age)
def update(res, s, **kwargs): def update(res, s, **kwargs):
s.update(kwargs) s.update(kwargs)
s = jwt.encode(s, session_secret()) s = jwt.encode(s, session_secret())
max_age = max(1, s["exp"] - int(time())) # type: ignore max_age = max(1, s["exp"] - int(time())) # type: ignore
res.cookies.add_cookie("s", s, httponly=True, max_age=max_age) res.cookies.add_cookie("s", s, httponly=True, max_age=max_age)
def delete(res): def delete(res):
res.cookies.delete_cookie("s") res.cookies.delete_cookie("s")

View File

@ -14,10 +14,12 @@ def asend(ws, msg):
"""Send JSON message or bytes to a websocket""" """Send JSON message or bytes to a websocket"""
return ws.send(msg if isinstance(msg, bytes) else msgspec.json.encode(msg).decode()) return ws.send(msg if isinstance(msg, bytes) else msgspec.json.encode(msg).decode())
def jres(data, **kwargs): def jres(data, **kwargs):
"""JSON Sanic response, using msgspec encoding""" """JSON Sanic response, using msgspec encoding"""
return raw(msgspec.json.encode(data), content_type="application/json", **kwargs) return raw(msgspec.json.encode(data), content_type="application/json", **kwargs)
async def handle_sanic_exception(request, e): async def handle_sanic_exception(request, e):
logger.exception(e) logger.exception(e)
context, code = {}, 500 context, code = {}, 500
@ -30,7 +32,9 @@ async def handle_sanic_exception(request, e):
message = f"⚠️ {message}" if code < 500 else f"🛑 {message}" message = f"⚠️ {message}" if code < 500 else f"🛑 {message}"
# Non-browsers get JSON errors # Non-browsers get JSON errors
if "text/html" not in request.headers.accept: if "text/html" not in request.headers.accept:
return jres(ErrorMsg({"code": code, "message": message, **context}), status=code) return jres(
ErrorMsg({"code": code, "message": message, **context}), status=code
)
# Redirections flash the error message via cookies # Redirections flash the error message via cookies
if "redirect" in context: if "redirect" in context:
res = redirect(context["redirect"]) res = redirect(context["redirect"])
@ -39,8 +43,10 @@ async def handle_sanic_exception(request, e):
# Otherwise use Sanic's default error page # Otherwise use Sanic's default error page
return errorpages.HTMLRenderer(request, e, debug=request.app.debug).full() return errorpages.HTMLRenderer(request, e, debug=request.app.debug).full()
def websocket_wrapper(handler): def websocket_wrapper(handler):
"""Decorator for websocket handlers that catches exceptions and sends them back to the client""" """Decorator for websocket handlers that catches exceptions and sends them back to the client"""
@wraps(handler) @wraps(handler)
async def wrapper(request, ws, *args, **kwargs): async def wrapper(request, ws, *args, **kwargs):
try: try:
@ -55,4 +61,5 @@ def websocket_wrapper(handler):
message = f"⚠️ {message}" if code < 500 else f"🛑 {message}" message = f"⚠️ {message}" if code < 500 else f"🛑 {message}"
await asend(ws, ErrorMsg({"code": code, "message": message, **context})) await asend(ws, ErrorMsg({"code": code, "message": message, **context}))
raise raise
return wrapper return wrapper

View File

@ -11,6 +11,7 @@ class LRUCache:
maxage (float): Max age for items in cache in seconds. maxage (float): Max age for items in cache in seconds.
cache (list): Internal list storing the cache items. cache (list): Internal list storing the cache items.
""" """
def __init__(self, open: callable, *, capacity: int, maxage: float): def __init__(self, open: callable, *, capacity: int, maxage: float):
""" """
Initialize LRUCache. Initialize LRUCache.

View File

@ -6,6 +6,7 @@ def generate(n=4):
wl = list(words) wl = list(words)
return ".".join(wl.pop(secrets.randbelow(len(wl))) for i in range(n)) return ".".join(wl.pop(secrets.randbelow(len(wl))) for i in range(n))
# A custom list of 1024 common 3-6 letter words, with unique 3-prefixes and no prefix words, entropy 2.1b/letter 10b/word # A custom list of 1024 common 3-6 letter words, with unique 3-prefixes and no prefix words, entropy 2.1b/letter 10b/word
words: list = """ words: list = """
able about absent abuse access acid across act adapt add adjust admit adult advice affair afraid again age agree ahead able about absent abuse access acid across act adapt add adjust admit adult advice affair afraid again age agree ahead

View File

@ -13,11 +13,13 @@ def setup_temp_dir():
config.config = config.Config(path=Path(tmpdirname), listen=":0") config.config = config.Config(path=Path(tmpdirname), listen=":0")
yield Path(tmpdirname) yield Path(tmpdirname)
def test_mkdir(setup_temp_dir): def test_mkdir(setup_temp_dir):
cmd = MkDir(path="new_folder") cmd = MkDir(path="new_folder")
cmd() cmd()
assert (setup_temp_dir / "new_folder").is_dir() assert (setup_temp_dir / "new_folder").is_dir()
def test_rename(setup_temp_dir): def test_rename(setup_temp_dir):
(setup_temp_dir / "old_name").mkdir() (setup_temp_dir / "old_name").mkdir()
cmd = Rename(path="old_name", to="new_name") cmd = Rename(path="old_name", to="new_name")
@ -25,12 +27,14 @@ def test_rename(setup_temp_dir):
assert not (setup_temp_dir / "old_name").exists() assert not (setup_temp_dir / "old_name").exists()
assert (setup_temp_dir / "new_name").is_dir() assert (setup_temp_dir / "new_name").is_dir()
def test_rm(setup_temp_dir): def test_rm(setup_temp_dir):
(setup_temp_dir / "folder_to_remove").mkdir() (setup_temp_dir / "folder_to_remove").mkdir()
cmd = Rm(sel=["folder_to_remove"]) cmd = Rm(sel=["folder_to_remove"])
cmd() cmd()
assert not (setup_temp_dir / "folder_to_remove").exists() assert not (setup_temp_dir / "folder_to_remove").exists()
def test_mv(setup_temp_dir): def test_mv(setup_temp_dir):
(setup_temp_dir / "folder_to_move").mkdir() (setup_temp_dir / "folder_to_move").mkdir()
(setup_temp_dir / "destination").mkdir() (setup_temp_dir / "destination").mkdir()
@ -39,6 +43,7 @@ def test_mv(setup_temp_dir):
assert not (setup_temp_dir / "folder_to_move").exists() assert not (setup_temp_dir / "folder_to_move").exists()
assert (setup_temp_dir / "destination" / "folder_to_move").is_dir() assert (setup_temp_dir / "destination" / "folder_to_move").is_dir()
def test_cp(setup_temp_dir): def test_cp(setup_temp_dir):
(setup_temp_dir / "folder_to_copy").mkdir() (setup_temp_dir / "folder_to_copy").mkdir()
(setup_temp_dir / "destination").mkdir() (setup_temp_dir / "destination").mkdir()

View File

@ -1,8 +1,6 @@
from time import sleep from time import sleep
from unittest.mock import Mock from unittest.mock import Mock
import pytest
from cista.util.lrucache import LRUCache # Replace with actual import from cista.util.lrucache import LRUCache # Replace with actual import
@ -12,24 +10,28 @@ def mock_open(key):
mock.content = f"content-{key}" mock.content = f"content-{key}"
return mock return mock
def test_contains(): def test_contains():
cache = LRUCache(open=mock_open, capacity=2, maxage=10) cache = LRUCache(open=mock_open, capacity=2, maxage=10)
assert "key1" not in cache assert "key1" not in cache
cache["key1"] cache["key1"]
assert "key1" in cache assert "key1" in cache
def test_getitem(): def test_getitem():
cache = LRUCache(open=mock_open, capacity=2, maxage=10) cache = LRUCache(open=mock_open, capacity=2, maxage=10)
assert cache["key1"].content == "content-key1" assert cache["key1"].content == "content-key1"
def test_capacity(): def test_capacity():
cache = LRUCache(open=mock_open, capacity=2, maxage=10) cache = LRUCache(open=mock_open, capacity=2, maxage=10)
item1 = cache["key1"] item1 = cache["key1"]
item2 = cache["key2"] cache["key2"]
cache["key3"] cache["key3"]
assert "key1" not in cache assert "key1" not in cache
item1.close.assert_called() item1.close.assert_called()
def test_expiry(): def test_expiry():
cache = LRUCache(open=mock_open, capacity=2, maxage=0.1) cache = LRUCache(open=mock_open, capacity=2, maxage=0.1)
item = cache["key1"] item = cache["key1"]
@ -38,6 +40,7 @@ def test_expiry():
assert "key1" not in cache assert "key1" not in cache
item.close.assert_called() item.close.assert_called()
def test_close(): def test_close():
cache = LRUCache(open=mock_open, capacity=2, maxage=10) cache = LRUCache(open=mock_open, capacity=2, maxage=10)
item = cache["key1"] item = cache["key1"]
@ -45,6 +48,7 @@ def test_close():
assert "key1" not in cache assert "key1" not in cache
item.close.assert_called() item.close.assert_called()
def test_lru_mechanism(): def test_lru_mechanism():
cache = LRUCache(open=mock_open, capacity=2, maxage=10) cache = LRUCache(open=mock_open, capacity=2, maxage=10)
item1 = cache["key1"] item1 = cache["key1"]