diff --git a/cista/__init__.py b/cista/__init__.py index 4861e11..662ce6a 100755 --- a/cista/__init__.py +++ b/cista/__init__.py @@ -1 +1,3 @@ from cista._version import __version__ + +__version__ # Public API diff --git a/cista/__main__.py b/cista/__main__.py index bfff876..bbddc8d 100755 --- a/cista/__main__.py +++ b/cista/__main__.py @@ -7,7 +7,7 @@ import cista from cista import app, config, droppy, serve, server80 from cista.util import pwgen -del app, server80.app # Only import needed, for Sanic multiprocessing +del app, server80.app # Only import needed, for Sanic multiprocessing doc = f"""Cista {cista.__version__} - A file storage for the web. @@ -34,6 +34,7 @@ User management: --password Reset password """ + def main(): # Dev mode doesn't catch exceptions if "--dev" in sys.argv: @@ -45,6 +46,7 @@ def main(): print("Error:", e) return 1 + def _main(): args = docopt(doc) if args["--user"]: @@ -64,19 +66,25 @@ def _main(): if not necessary_opts: # Maybe run without arguments print(doc) - print("No config file found! Get started with:\n cista -l :8000 /path/to/files, or\n cista -l example.com --import-droppy # Uses Droppy files\n") + print( + "No config file found! Get started with:\n cista -l :8000 /path/to/files, or\n cista -l example.com --import-droppy # Uses Droppy files\n" + ) return 1 settings = {} if import_droppy: if exists: - raise ValueError(f"Importing Droppy: First remove the existing configuration:\n rm {config.conffile}") + raise ValueError( + f"Importing Droppy: First remove the existing configuration:\n rm {config.conffile}" + ) settings = droppy.readconf() - if path: settings["path"] = path - if listen: settings["listen"] = listen + if path: + settings["path"] = path + if listen: + settings["listen"] = listen operation = config.update_config(settings) print(f"Config {operation}: {config.conffile}") # Prepare to serve - domain = unix = port = None + unix = None url, _ = serve.parse_listen(config.config.listen) if not config.config.path.is_dir(): raise ValueError(f"No such directory: {config.config.path}") @@ -88,6 +96,7 @@ def _main(): # Run the server serve.run(dev=dev) + def _confdir(args): if args["-c"]: # Custom config directory @@ -99,6 +108,7 @@ def _confdir(args): confdir = confdir.parent config.conffile = config.conffile.with_parent(confdir) + def _user(args): _confdir(args) config.load_config() @@ -123,5 +133,6 @@ def _user(args): if res == "read": print(" No changes") + if __name__ == "__main__": sys.exit(main()) diff --git a/cista/api.py b/cista/api.py index 66a3264..39a4a18 100644 --- a/cista/api.py +++ b/cista/api.py @@ -12,15 +12,18 @@ from cista.util.apphelpers import asend, websocket_wrapper bp = Blueprint("api", url_prefix="/api") fileserver = FileServer() + @bp.before_server_start async def start_fileserver(app, _): await fileserver.start() + @bp.after_server_stop async def stop_fileserver(app, _): await fileserver.stop() -@bp.websocket('upload') + +@bp.websocket("upload") @websocket_wrapper async def upload(req, ws): alink = fileserver.alink @@ -28,7 +31,9 @@ async def upload(req, ws): req = None text = await ws.recv() if not isinstance(text, str): - raise ValueError(f"Expected JSON control, got binary len(data) = {len(text)}") + raise ValueError( + f"Expected JSON control, got binary len(data) = {len(text)}" + ) req = msgspec.json.decode(text, type=FileRange) pos = req.start data = None @@ -41,9 +46,10 @@ async def upload(req, ws): # Report success res = StatusMsg(status="ack", req=req) await asend(ws, res) - #await ws.drain() + # await ws.drain() -@bp.websocket('download') + +@bp.websocket("download") @websocket_wrapper async def download(req, ws): alink = fileserver.alink @@ -51,18 +57,21 @@ async def download(req, ws): req = None text = await ws.recv() if not isinstance(text, str): - raise ValueError(f"Expected JSON control, got binary len(data) = {len(text)}") + raise ValueError( + f"Expected JSON control, got binary len(data) = {len(text)}" + ) req = msgspec.json.decode(text, type=FileRange) pos = req.start while pos < req.end: - end = min(req.end, pos + (1<<20)) + end = min(req.end, pos + (1 << 20)) data = typing.cast(bytes, await alink(("download", req.name, pos, end))) await asend(ws, data) pos += len(data) # Report success res = StatusMsg(status="ack", req=req) await asend(ws, res) - #await ws.drain() + # await ws.drain() + @bp.websocket("control") @websocket_wrapper @@ -71,6 +80,7 @@ async def control(req, ws): await asyncio.to_thread(cmd) await asend(ws, StatusMsg(status="ack", req=cmd)) + @bp.websocket("watch") @websocket_wrapper async def watch(req, ws): diff --git a/cista/app.py b/cista/app.py index 8c740db..6bc09e6 100755 --- a/cista/app.py +++ b/cista/app.py @@ -2,7 +2,6 @@ import mimetypes from importlib.resources import files from urllib.parse import unquote -from html5tagger import E from sanic import Blueprint, Sanic, raw from sanic.exceptions import Forbidden, NotFound @@ -16,15 +15,18 @@ app.blueprint(auth.bp) app.blueprint(bp) app.exception(Exception)(handle_sanic_exception) + @app.before_server_start async def main_start(app, loop): config.load_config() await watching.start(app, loop) + @app.after_server_stop async def main_stop(app, loop): await watching.stop(app, loop) + @app.on_request async def use_session(req): req.ctx.session = session.get(req) @@ -41,13 +43,21 @@ async def use_session(req): if origin and origin.split("//", 1)[1] != req.host: raise Forbidden("Invalid origin: Cross-Site requests not permitted") + @app.before_server_start def http_fileserver(app, _): bp = Blueprint("fileserver") bp.on_request(auth.verify) - bp.static("/files/", config.config.path, use_content_range=True, stream_large_files=True, directory_view=True) + bp.static( + "/files/", + config.config.path, + use_content_range=True, + stream_large_files=True, + directory_view=True, + ) app.blueprint(bp) + @app.get("/", static=True) async def wwwroot(req, path=""): """Frontend files only""" @@ -55,6 +65,8 @@ async def wwwroot(req, path=""): try: index = files("cista").joinpath("wwwroot", name).read_bytes() except OSError as e: - raise NotFound(f"File not found: /{path}", extra={"name": name, "exception": repr(e)}) + raise NotFound( + f"File not found: /{path}", extra={"name": name, "exception": repr(e)} + ) mime = mimetypes.guess_type(name)[0] or "application/octet-stream" return raw(index, content_type=mime) diff --git a/cista/auth.py b/cista/auth.py index 570f373..e1974e3 100755 --- a/cista/auth.py +++ b/cista/auth.py @@ -12,10 +12,12 @@ from sanic.exceptions import BadRequest, Forbidden, Unauthorized from cista import config, session _argon = argon2.PasswordHasher() -_droppyhash = re.compile(r'^([a-f0-9]{64})\$([a-f0-9]{8})$') +_droppyhash = re.compile(r"^([a-f0-9]{64})\$([a-f0-9]{8})$") + def _pwnorm(password): - return normalize('NFC', password).strip().encode() + return normalize("NFC", password).strip().encode() + def login(username: str, password: str): un = _pwnorm(username) @@ -49,33 +51,51 @@ def login(username: str, password: str): u.lastSeen = now return u + def set_password(user: config.User, password: str): user.hash = _argon.hash(_pwnorm(password)) + class LoginResponse(msgspec.Struct): user: str = "" privileged: bool = False error: str = "" + def verify(request, privileged=False): """Raise Unauthorized or Forbidden if the request is not authorized""" if privileged: if request.ctx.user: - if request.ctx.user.privileged: return + if request.ctx.user.privileged: + return raise Forbidden("Access Forbidden: Only for privileged users") - elif config.config.public or request.ctx.user: return + elif config.config.public or request.ctx.user: + return raise Unauthorized("Login required", "cookie", context={"redirect": "/login"}) + bp = Blueprint("auth") + @bp.get("/login") async def login_page(request): doc = Document("Cista Login") with doc.div(id="login"): with doc.form(method="POST", autocomplete="on"): doc.h1("Login") - doc.input(name="username", placeholder="Username", autocomplete="username", required=True).br - doc.input(type="password", name="password", placeholder="Password", autocomplete="current-password", required=True).br + doc.input( + name="username", + placeholder="Username", + autocomplete="username", + required=True, + ).br + doc.input( + type="password", + name="password", + placeholder="Password", + autocomplete="current-password", + required=True, + ).br doc.input(type="submit", value="Login") s = session.get(request) if s: @@ -84,7 +104,12 @@ async def login_page(request): doc.input(type="submit", value=f"Logout {name}") flash = request.cookies.message if flash: - doc.dialog(flash, id="flash", open=True, style="position: fixed; top: 0; left: 0; width: 100%; opacity: .8") + doc.dialog( + flash, + id="flash", + open=True, + style="position: fixed; top: 0; left: 0; width: 100%; opacity: .8", + ) res = html(doc) if flash: res.cookies.delete_cookie("flash") @@ -92,6 +117,7 @@ async def login_page(request): session.delete(res) return res + @bp.post("/login") async def login_post(request): try: @@ -118,6 +144,7 @@ async def login_post(request): session.create(res, username) return res + @bp.post("/logout") async def logout_post(request): s = request.ctx.session diff --git a/cista/config.py b/cista/config.py index 1b07a8c..fa6fc9f 100755 --- a/cista/config.py +++ b/cista/config.py @@ -17,26 +17,28 @@ class Config(msgspec.Struct): users: dict[str, User] = {} links: dict[str, Link] = {} + class User(msgspec.Struct, omit_defaults=True): privileged: bool = False hash: str = "" lastSeen: int = 0 + class Link(msgspec.Struct, omit_defaults=True): location: str creator: str = "" expires: int = 0 + config = None conffile = Path.home() / ".local/share/cista/db.toml" + def derived_secret(*params, len=8) -> bytes: """Used to derive secret keys from the main secret""" # Each part is made the same length by hashing first combined = b"".join( - sha256( - p if isinstance(p, bytes) else f"{p}".encode() - ).digest() + sha256(p if isinstance(p, bytes) else f"{p}".encode()).digest() for p in [config.secret, *params] ) # Output a bytes of the desired length @@ -48,11 +50,13 @@ def enc_hook(obj): return obj.as_posix() raise TypeError + def dec_hook(typ, obj): if typ is Path: return Path(obj) raise TypeError + def config_update(modify): global config if not conffile.exists(): @@ -93,21 +97,28 @@ def config_update(modify): config = c return "modified" if old else "created" + def modifies_config(modify): """Decorator for functions that modify the config file""" + @wraps(modify) def wrapper(*args, **kwargs): - m = lambda c: modify(c, *args, **kwargs) + def m(c): + return modify(c, *args, **kwargs) + # Retry modification in case of write collision while (c := config_update(m)) == "collision": time.sleep(0.01) return c + return wrapper + def load_config(): global config config = msgspec.toml.decode(conffile.read_bytes(), type=Config, dec_hook=dec_hook) + @modifies_config def update_config(conf: Config, changes: dict) -> Config: """Create/update the config with new values, respecting changes done by others.""" @@ -116,6 +127,7 @@ def update_config(conf: Config, changes: dict) -> Config: settings.update(changes) return msgspec.convert(settings, Config, dec_hook=dec_hook) + @modifies_config def update_user(conf: Config, name: str, changes: dict) -> Config: """Create/update a user with new values, respecting changes done by others.""" @@ -126,6 +138,7 @@ def update_user(conf: Config, name: str, changes: dict) -> Config: u = User() if "password" in changes: from . import auth + auth.set_password(u, changes["password"]) del changes["password"] udict = msgspec.to_builtins(u, enc_hook=enc_hook) @@ -134,6 +147,7 @@ def update_user(conf: Config, name: str, changes: dict) -> Config: settings["users"][name] = msgspec.convert(udict, User, dec_hook=dec_hook) return msgspec.convert(settings, Config, dec_hook=dec_hook) + @modifies_config def del_user(conf: Config, name: str) -> Config: """Delete named user account.""" diff --git a/cista/droppy.py b/cista/droppy.py index 9e2ccf3..5636266 100755 --- a/cista/droppy.py +++ b/cista/droppy.py @@ -11,6 +11,7 @@ def readconf() -> dict: cf["listen"] = _droppy_listeners(cf) return cf | db + def _droppy_listeners(cf): """Convert Droppy listeners to our format, for typical cases but not in full.""" for listener in cf["listeners"]: @@ -20,15 +21,19 @@ def _droppy_listeners(cf): continue socket = listener.get("socket") if socket: - if isinstance(socket, list): socket = socket[0] + if isinstance(socket, list): + socket = socket[0] return f"{socket}" port = listener["port"] - if isinstance(port, list): port = port[0] + if isinstance(port, list): + port = port[0] host = listener["host"] - if isinstance(host, list): host = host[0] - if host in ("127.0.0.1", "::", "localhost"): return f":{port}" + if isinstance(host, list): + host = host[0] + if host in ("127.0.0.1", "::", "localhost"): + return f":{port}" return f"{host}:{port}" except (KeyError, IndexError): continue # If none matched, fallback to Droppy default - return f"0.0.0.0:8989" + return "0.0.0.0:8989" diff --git a/cista/fileio.py b/cista/fileio.py index 9c10c82..566aaea 100755 --- a/cista/fileio.py +++ b/cista/fileio.py @@ -1,7 +1,5 @@ import asyncio import os -import unicodedata -from pathlib import PurePosixPath from cista import config from cista.util import filename @@ -13,6 +11,7 @@ def fuid(stat) -> str: """Unique file ID. Stays the same on renames and modification.""" return config.derived_secret("filekey-inode", stat.st_dev, stat.st_ino).hex() + class File: def __init__(self, filename): self.path = config.config.path / filename @@ -30,23 +29,24 @@ class File: self.writable = True def write(self, pos, buffer, *, file_size=None): - assert self.fd is not None if not self.writable: # Create/open file self.open_rw() + assert self.fd is not None if file_size is not None: os.ftruncate(self.fd, file_size) os.lseek(self.fd, pos, os.SEEK_SET) os.write(self.fd, buffer) def __getitem__(self, slice): - assert self.fd is not None if self.fd is None: self.open_ro() + assert self.fd is not None os.lseek(self.fd, slice.start, os.SEEK_SET) - l = slice.stop - slice.start - data = os.read(self.fd, l) - if len(data) < l: raise EOFError("Error reading requested range") + size = slice.stop - slice.start + data = os.read(self.fd, size) + if len(data) < size: + raise EOFError("Error reading requested range") return data def close(self): @@ -59,10 +59,11 @@ class File: class FileServer: - async def start(self): self.alink = AsyncLink() - self.worker = asyncio.get_event_loop().run_in_executor(None, self.worker_thread, self.alink.to_sync) + self.worker = asyncio.get_event_loop().run_in_executor( + None, self.worker_thread, self.alink.to_sync + ) self.cache = LRUCache(File, capacity=10, maxage=5.0) async def stop(self): @@ -91,4 +92,4 @@ class FileServer: def download(self, name, start, end): name = filename.sanitize(name) f = self.cache[name] - return f[start: end] + return f[start:end] diff --git a/cista/protocol.py b/cista/protocol.py index 1964e06..31f6eff 100755 --- a/cista/protocol.py +++ b/cista/protocol.py @@ -11,19 +11,24 @@ from cista.util import filename ## Control commands + class ControlBase(msgspec.Struct, tag_field="op", tag=str.lower): def __call__(self): raise NotImplementedError + class MkDir(ControlBase): path: str + def __call__(self): path = config.config.path / filename.sanitize(self.path) path.mkdir(parents=False, exist_ok=False) + class Rename(ControlBase): path: str to: str + def __call__(self): to = filename.sanitize(self.to) if "/" in to: @@ -31,17 +36,21 @@ class Rename(ControlBase): path = config.config.path / filename.sanitize(self.path) path.rename(path.with_name(to)) + class Rm(ControlBase): sel: list[str] + def __call__(self): root = config.config.path sel = [root / filename.sanitize(p) for p in self.sel] for p in sel: shutil.rmtree(p, ignore_errors=True) + class Mv(ControlBase): sel: list[str] dst: str + def __call__(self): root = config.config.path sel = [root / filename.sanitize(p) for p in self.sel] @@ -51,9 +60,11 @@ class Mv(ControlBase): for p in sel: shutil.move(p, dst) + class Cp(ControlBase): sel: list[str] dst: str + def __call__(self): root = config.config.path sel = [root / filename.sanitize(p) for p in self.sel] @@ -62,29 +73,38 @@ class Cp(ControlBase): raise BadRequest("The destination must be a directory") for p in sel: # Note: copies as dst rather than in dst unless name is appended. - shutil.copytree(p, dst / p.name, dirs_exist_ok=True, ignore_dangling_symlinks=True) + shutil.copytree( + p, dst / p.name, dirs_exist_ok=True, ignore_dangling_symlinks=True + ) + ## File uploads and downloads + class FileRange(msgspec.Struct): name: str size: int start: int end: int + class StatusMsg(msgspec.Struct): status: str req: FileRange + class ErrorMsg(msgspec.Struct): error: dict[str, Any] + ## Directory listings + class FileEntry(msgspec.Struct): size: int mtime: int + class DirEntry(msgspec.Struct): size: int mtime: int @@ -104,23 +124,22 @@ class DirEntry(msgspec.Struct): @property def props(self): - return { - k: v - for k, v in self.__struct_fields__ - if k != "dir" - } + return {k: v for k, v in self.__struct_fields__ if k != "dir"} + DirList = dict[str, FileEntry | DirEntry] class UpdateEntry(msgspec.Struct, omit_defaults=True): """Updates the named entry in the tree. Fields that are set replace old values. A list of entries recurses directories.""" + name: str = "" deleted: bool = False size: int | None = None mtime: int | None = None dir: DirList | None = None + def make_dir_data(root): if len(root) == 2: return FileEntry(*root) diff --git a/cista/serve.py b/cista/serve.py index 4bf4af1..05c1232 100755 --- a/cista/serve.py +++ b/cista/serve.py @@ -10,6 +10,7 @@ from cista import config, server80 def run(dev=False): """Run Sanic main process that spawns worker processes to serve HTTP requests.""" from .app import app + url, opts = parse_listen(config.config.listen) # Silence Sanic's warning about running in production rather than debug os.environ["SANIC_IGNORE_PRODUCTION_WARNING"] = "1" @@ -21,20 +22,33 @@ def run(dev=False): domain = opts["host"] check_cert(confdir / domain, domain) opts["ssl"] = str(confdir / domain) # type: ignore - app.prepare(**opts, motd=False, dev=dev, auto_reload=dev, reload_dir={confdir, wwwroot}, access_log=True) # type: ignore + app.prepare( + **opts, + motd=False, + dev=dev, + auto_reload=dev, + reload_dir={confdir, wwwroot}, + access_log=True, + ) # type: ignore Sanic.serve() + def check_cert(certdir, domain): if (certdir / "privkey.pem").exist() and (certdir / "fullchain.pem").exists(): return # TODO: Use certbot to fetch a cert - raise ValueError(f"TLS certificate files privkey.pem and fullchain.pem needed in {certdir}") + raise ValueError( + f"TLS certificate files privkey.pem and fullchain.pem needed in {certdir}" + ) + def parse_listen(listen): if listen.startswith("/"): unix = Path(listen).resolve() if not unix.parent.exists(): - raise ValueError(f"Directory for unix socket does not exist: {unix.parent}/") + raise ValueError( + f"Directory for unix socket does not exist: {unix.parent}/" + ) return "http://localhost", {"unix": unix} elif re.fullmatch(r"(\w+(-\w+)*\.)+\w{2,}", listen, re.UNICODE): return f"https://{listen}", {"host": listen, "port": 443, "ssl": True} diff --git a/cista/server80.py b/cista/server80.py index 9119a00..d4c7486 100644 --- a/cista/server80.py +++ b/cista/server80.py @@ -2,6 +2,7 @@ from sanic import Sanic, exceptions, response app = Sanic("server80") + # Send all HTTP users to HTTPS @app.exception(exceptions.NotFound, exceptions.MethodNotSupported) def redirect_everything_else(request, exception): @@ -10,6 +11,7 @@ def redirect_everything_else(request, exception): return response.redirect(f"https://{server}{path}", status=308) return response.text("Bad Request. Please use HTTPS!", status=400) + # ACME challenge for LetsEncrypt @app.get("/.well-known/acme-challenge/") async def letsencrypt(request, challenge): @@ -18,4 +20,5 @@ async def letsencrypt(request, challenge): except KeyError: return response.text(f"ACME challenge not found: {challenge}", status=404) + acme_challenges = {} diff --git a/cista/session.py b/cista/session.py index 3a97eee..6075bcb 100755 --- a/cista/session.py +++ b/cista/session.py @@ -4,16 +4,21 @@ import jwt from cista.config import derived_secret -session_secret = lambda: derived_secret("session") + +def session_secret(): + return derived_secret("session") + + max_age = 365 * 86400 # Seconds since last login + def get(request): try: return jwt.decode(request.cookies.s, session_secret(), algorithms=["HS256"]) - except Exception as e: - s = None + except Exception: return False if "s" in request.cookies else None + def create(res, username, **kwargs): data = { "exp": int(time()) + max_age, @@ -23,12 +28,14 @@ def create(res, username, **kwargs): s = jwt.encode(data, session_secret()) res.cookies.add_cookie("s", s, httponly=True, max_age=max_age) + def update(res, s, **kwargs): s.update(kwargs) s = jwt.encode(s, session_secret()) max_age = max(1, s["exp"] - int(time())) # type: ignore res.cookies.add_cookie("s", s, httponly=True, max_age=max_age) + def delete(res): res.cookies.delete_cookie("s") diff --git a/cista/util/apphelpers.py b/cista/util/apphelpers.py index f5d790d..7043354 100644 --- a/cista/util/apphelpers.py +++ b/cista/util/apphelpers.py @@ -14,10 +14,12 @@ def asend(ws, msg): """Send JSON message or bytes to a websocket""" return ws.send(msg if isinstance(msg, bytes) else msgspec.json.encode(msg).decode()) + def jres(data, **kwargs): """JSON Sanic response, using msgspec encoding""" return raw(msgspec.json.encode(data), content_type="application/json", **kwargs) + async def handle_sanic_exception(request, e): logger.exception(e) context, code = {}, 500 @@ -30,7 +32,9 @@ async def handle_sanic_exception(request, e): message = f"⚠️ {message}" if code < 500 else f"🛑 {message}" # Non-browsers get JSON errors if "text/html" not in request.headers.accept: - return jres(ErrorMsg({"code": code, "message": message, **context}), status=code) + return jres( + ErrorMsg({"code": code, "message": message, **context}), status=code + ) # Redirections flash the error message via cookies if "redirect" in context: res = redirect(context["redirect"]) @@ -39,8 +43,10 @@ async def handle_sanic_exception(request, e): # Otherwise use Sanic's default error page return errorpages.HTMLRenderer(request, e, debug=request.app.debug).full() + def websocket_wrapper(handler): """Decorator for websocket handlers that catches exceptions and sends them back to the client""" + @wraps(handler) async def wrapper(request, ws, *args, **kwargs): try: @@ -55,4 +61,5 @@ def websocket_wrapper(handler): message = f"⚠️ {message}" if code < 500 else f"🛑 {message}" await asend(ws, ErrorMsg({"code": code, "message": message, **context})) raise + return wrapper diff --git a/cista/util/lrucache.py b/cista/util/lrucache.py index 524d757..528403f 100755 --- a/cista/util/lrucache.py +++ b/cista/util/lrucache.py @@ -11,6 +11,7 @@ class LRUCache: maxage (float): Max age for items in cache in seconds. cache (list): Internal list storing the cache items. """ + def __init__(self, open: callable, *, capacity: int, maxage: float): """ Initialize LRUCache. diff --git a/cista/util/pwgen.py b/cista/util/pwgen.py index 71e344a..e82ef9a 100644 --- a/cista/util/pwgen.py +++ b/cista/util/pwgen.py @@ -6,6 +6,7 @@ def generate(n=4): wl = list(words) return ".".join(wl.pop(secrets.randbelow(len(wl))) for i in range(n)) + # A custom list of 1024 common 3-6 letter words, with unique 3-prefixes and no prefix words, entropy 2.1b/letter 10b/word words: list = """ able about absent abuse access acid across act adapt add adjust admit adult advice affair afraid again age agree ahead diff --git a/cista/watching.py b/cista/watching.py index 7e85758..d383ddc 100755 --- a/cista/watching.py +++ b/cista/watching.py @@ -15,9 +15,18 @@ tree = {"": None} tree_lock = threading.Lock() rootpath: Path = None # type: ignore quit = False -modified_flags = "IN_CREATE", "IN_DELETE", "IN_DELETE_SELF", "IN_MODIFY", "IN_MOVE_SELF", "IN_MOVED_FROM", "IN_MOVED_TO" +modified_flags = ( + "IN_CREATE", + "IN_DELETE", + "IN_DELETE_SELF", + "IN_MODIFY", + "IN_MOVE_SELF", + "IN_MOVED_FROM", + "IN_MOVED_TO", +) disk_usage = None + def watcher_thread(loop): global disk_usage @@ -36,7 +45,8 @@ def watcher_thread(loop): refreshdl = time.monotonic() + 60.0 for event in i.event_gen(): - if quit: return + if quit: + return # Disk usage update du = shutil.disk_usage(rootpath) if du != disk_usage: @@ -44,8 +54,10 @@ def watcher_thread(loop): asyncio.run_coroutine_threadsafe(broadcast(format_du()), loop) break # Do a full refresh? - if time.monotonic() > refreshdl: break - if event is None: continue + if time.monotonic() > refreshdl: + break + if event is None: + continue _, flags, path, filename = event if not any(f in modified_flags for f in flags): continue @@ -58,19 +70,26 @@ def watcher_thread(loop): break i = None # Free the inotify object + def format_du(): - return msgspec.json.encode({"space": { - "disk": disk_usage.total, - "used": disk_usage.used, - "free": disk_usage.free, - "storage": tree[""].size, - }}).decode() + return msgspec.json.encode( + { + "space": { + "disk": disk_usage.total, + "used": disk_usage.used, + "free": disk_usage.free, + "storage": tree[""].size, + } + } + ).decode() + def format_tree(): root = tree[""] - return msgspec.json.encode({"update": [ - UpdateEntry(size=root.size, mtime=root.mtime, dir=root.dir) - ]}).decode() + return msgspec.json.encode( + {"update": [UpdateEntry(size=root.size, mtime=root.mtime, dir=root.dir)]} + ).decode() + def walk(path: Path) -> DirEntry | FileEntry | None: try: @@ -79,7 +98,12 @@ def walk(path: Path) -> DirEntry | FileEntry | None: if path.is_file(): return FileEntry(s.st_size, mtime) - tree = {p.name: v for p in path.iterdir() if not p.name.startswith('.') if (v := walk(p)) is not None} + tree = { + p.name: v + for p in path.iterdir() + if not p.name.startswith(".") + if (v := walk(p)) is not None + } if tree: size = sum(v.size for v in tree.values()) mtime = max(mtime, max(v.mtime for v in tree.values())) @@ -92,16 +116,21 @@ def walk(path: Path) -> DirEntry | FileEntry | None: print("OS error walking path", path, e) return None + def update(relpath: PurePosixPath, loop): """Called by inotify updates, check the filesystem and broadcast any changes.""" new = walk(rootpath / relpath) with tree_lock: update = update_internal(relpath, new) - if not update: return # No changes + if not update: + return # No changes msg = msgspec.json.encode({"update": update}).decode() asyncio.run_coroutine_threadsafe(broadcast(msg), loop) -def update_internal(relpath: PurePosixPath, new: DirEntry | FileEntry | None) -> list[UpdateEntry]: + +def update_internal( + relpath: PurePosixPath, new: DirEntry | FileEntry | None +) -> list[UpdateEntry]: path = "", *relpath.parts old = tree elems = [] @@ -142,25 +171,31 @@ def update_internal(relpath: PurePosixPath, new: DirEntry | FileEntry | None) -> u = UpdateEntry(name) if new: parent[name] = new - if u.size != new.size: u.size = new.size - if u.mtime != new.mtime: u.mtime = new.mtime + if u.size != new.size: + u.size = new.size + if u.mtime != new.mtime: + u.mtime = new.mtime if isinstance(new, DirEntry): - if u.dir == new.dir: u.dir = new.dir + if u.dir == new.dir: + u.dir = new.dir else: del parent[name] u.deleted = True update.append(u) return update + async def broadcast(msg): for queue in pubsub.values(): await queue.put_nowait(msg) + async def start(app, loop): config.load_config() app.ctx.watcher = threading.Thread(target=watcher_thread, args=[loop]) app.ctx.watcher.start() + async def stop(app, loop): global quit quit = True diff --git a/tests/test_control.py b/tests/test_control.py index 858f52b..06c3406 100644 --- a/tests/test_control.py +++ b/tests/test_control.py @@ -13,11 +13,13 @@ def setup_temp_dir(): config.config = config.Config(path=Path(tmpdirname), listen=":0") yield Path(tmpdirname) + def test_mkdir(setup_temp_dir): cmd = MkDir(path="new_folder") cmd() assert (setup_temp_dir / "new_folder").is_dir() + def test_rename(setup_temp_dir): (setup_temp_dir / "old_name").mkdir() cmd = Rename(path="old_name", to="new_name") @@ -25,12 +27,14 @@ def test_rename(setup_temp_dir): assert not (setup_temp_dir / "old_name").exists() assert (setup_temp_dir / "new_name").is_dir() + def test_rm(setup_temp_dir): (setup_temp_dir / "folder_to_remove").mkdir() cmd = Rm(sel=["folder_to_remove"]) cmd() assert not (setup_temp_dir / "folder_to_remove").exists() + def test_mv(setup_temp_dir): (setup_temp_dir / "folder_to_move").mkdir() (setup_temp_dir / "destination").mkdir() @@ -39,6 +43,7 @@ def test_mv(setup_temp_dir): assert not (setup_temp_dir / "folder_to_move").exists() assert (setup_temp_dir / "destination" / "folder_to_move").is_dir() + def test_cp(setup_temp_dir): (setup_temp_dir / "folder_to_copy").mkdir() (setup_temp_dir / "destination").mkdir() diff --git a/tests/test_lrucache.py b/tests/test_lrucache.py index 9903e13..20e4632 100644 --- a/tests/test_lrucache.py +++ b/tests/test_lrucache.py @@ -1,8 +1,6 @@ from time import sleep from unittest.mock import Mock -import pytest - from cista.util.lrucache import LRUCache # Replace with actual import @@ -12,24 +10,28 @@ def mock_open(key): mock.content = f"content-{key}" return mock + def test_contains(): cache = LRUCache(open=mock_open, capacity=2, maxage=10) assert "key1" not in cache cache["key1"] assert "key1" in cache + def test_getitem(): cache = LRUCache(open=mock_open, capacity=2, maxage=10) assert cache["key1"].content == "content-key1" + def test_capacity(): cache = LRUCache(open=mock_open, capacity=2, maxage=10) item1 = cache["key1"] - item2 = cache["key2"] + cache["key2"] cache["key3"] assert "key1" not in cache item1.close.assert_called() + def test_expiry(): cache = LRUCache(open=mock_open, capacity=2, maxage=0.1) item = cache["key1"] @@ -38,6 +40,7 @@ def test_expiry(): assert "key1" not in cache item.close.assert_called() + def test_close(): cache = LRUCache(open=mock_open, capacity=2, maxage=10) item = cache["key1"] @@ -45,6 +48,7 @@ def test_close(): assert "key1" not in cache item.close.assert_called() + def test_lru_mechanism(): cache = LRUCache(open=mock_open, capacity=2, maxage=10) item1 = cache["key1"]