cista-storage/cista/app.py

264 lines
8.4 KiB
Python
Raw Normal View History

import asyncio
import datetime
import mimetypes
from collections import deque
from concurrent.futures import ThreadPoolExecutor
2023-10-15 05:31:54 +01:00
from importlib.resources import files
2023-11-11 10:48:33 +00:00
from pathlib import Path, PurePath
from stat import S_IFDIR, S_IFREG
from urllib.parse import unquote
from wsgiref.handlers import format_date_time
2023-10-14 23:29:50 +01:00
import brotli
import sanic.helpers
from blake3 import blake3
from natsort import natsorted, ns
from sanic import Blueprint, Sanic, empty, raw
2023-11-11 10:48:33 +00:00
from sanic.exceptions import Forbidden, NotFound, ServerError, ServiceUnavailable
from sanic.log import logging
from stream_zip import ZIP_AUTO, stream_zip
2023-10-14 23:29:50 +01:00
from cista import auth, config, session, watching
from cista.api import bp
2023-11-11 10:48:33 +00:00
from cista.protocol import DirEntry, DirList
from cista.util.apphelpers import handle_sanic_exception
# Workaround until Sanic PR #2824 is merged
sanic.helpers._ENTITY_HEADERS = frozenset()
app = Sanic("cista", strict_slashes=True)
app.blueprint(auth.bp)
app.blueprint(bp)
app.exception(Exception)(handle_sanic_exception)
2023-10-14 23:29:50 +01:00
@app.before_server_start
async def main_start(app, loop):
config.load_config()
await watching.start(app, loop)
2023-11-11 10:48:33 +00:00
app.ctx.threadexec = ThreadPoolExecutor(
max_workers=8, thread_name_prefix="cista-ioworker"
)
@app.after_server_stop
async def main_stop(app, loop):
await watching.stop(app, loop)
app.ctx.threadexec.shutdown()
2023-10-14 23:29:50 +01:00
@app.on_request
async def use_session(req):
req.ctx.session = session.get(req)
try:
2023-11-11 10:48:33 +00:00
req.ctx.username = req.ctx.session["username"] # type: ignore
req.ctx.user = config.config.users[req.ctx.username]
except (AttributeError, KeyError, TypeError):
req.ctx.username = None
req.ctx.user = None
# CSRF protection
if req.method == "GET" and req.headers.upgrade != "websocket":
return # Ordinary GET requests are fine
# Check that origin matches host, for browsers which should all send Origin.
# Curl doesn't send any Origin header, so we allow it anyway.
origin = req.headers.origin
if origin and origin.split("//", 1)[1] != req.host:
raise Forbidden("Invalid origin: Cross-Site requests not permitted")
2023-10-14 23:29:50 +01:00
@app.before_server_start
def http_fileserver(app, _):
bp = Blueprint("fileserver")
bp.on_request(auth.verify)
bp.static(
"/files/",
config.config.path,
use_content_range=True,
stream_large_files=True,
directory_view=True,
)
app.blueprint(bp)
www = {}
def _load_wwwroot(www):
wwwnew = {}
2023-11-11 10:48:33 +00:00
base = Path(__file__).with_name("wwwroot")
paths = [PurePath()]
while paths:
path = paths.pop(0)
current = base / path
for p in current.iterdir():
if p.is_dir():
2023-11-11 10:48:33 +00:00
paths.append(p.relative_to(base))
continue
name = p.relative_to(base).as_posix()
mime = mimetypes.guess_type(name)[0] or "application/octet-stream"
mtime = p.stat().st_mtime
data = p.read_bytes()
etag = blake3(data).hexdigest(length=8)
if name == "index.html":
name = ""
# Use old data if not changed
if name in www and www[name][2]["etag"] == etag:
wwwnew[name] = www[name]
continue
# Add charset definition
if mime.startswith("text/"):
mime = f"{mime}; charset=UTF-8"
# Asset files names will change whenever the content changes
cached = name.startswith("assets/")
headers = {
"etag": etag,
"last-modified": format_date_time(mtime),
"cache-control": "max-age=31536000, immutable"
if cached
else "no-cache",
"content-type": mime,
}
# Precompress with Brotli
br = brotli.compress(data)
if len(br) >= len(data):
br = False
wwwnew[name] = data, br, headers
2023-11-11 10:48:33 +00:00
if not wwwnew:
raise ServerError(
"Web frontend missing. Did you forget npm run build?",
extra={"wwwroot": str(base)},
quiet=True,
)
return wwwnew
2023-11-11 10:48:33 +00:00
@app.before_server_start
async def start(app):
await load_wwwroot(app)
if app.debug:
app.add_task(refresh_wwwroot())
async def load_wwwroot(app):
global www
www = await asyncio.get_event_loop().run_in_executor(
app.ctx.threadexec, _load_wwwroot, www
)
async def refresh_wwwroot():
while True:
2023-11-11 10:48:33 +00:00
await asyncio.sleep(0.5)
try:
wwwold = www
2023-11-11 10:48:33 +00:00
await load_wwwroot(app)
changes = ""
for name in sorted(www):
attr = www[name]
if wwwold.get(name) == attr:
continue
headers = attr[2]
changes += f"{headers['last-modified']} {headers['etag']} /{name}\n"
for name in sorted(set(wwwold) - set(www)):
changes += f"Deleted /{name}\n"
if changes:
print(f"Updated wwwroot:\n{changes}", end="", flush=True)
except Exception as e:
print("Error loading wwwroot", e)
if not app.debug:
return
@app.route("/<path:path>", methods=["GET", "HEAD"])
async def wwwroot(req, path=""):
"""Frontend files only"""
name = unquote(path)
if name not in www:
raise NotFound(f"File not found: /{path}", extra={"name": name})
data, br, headers = www[name]
if req.headers.if_none_match == headers["etag"]:
# The client has it cached, respond 304 Not Modified
return empty(304, headers=headers)
# Brotli compressed?
if br and "br" in req.headers.accept_encoding.split(", "):
2023-11-11 10:48:33 +00:00
headers = {**headers, "content-encoding": "br"}
data = br
return raw(data, headers=headers)
2023-11-11 10:48:33 +00:00
def get_files(wanted: set) -> list:
if not isinstance(watching.tree[""], DirEntry):
raise ServiceUnavailable(headers={"retry-after": 1})
root = Path(watching.rootpath)
with watching.tree_lock:
2023-11-11 10:48:33 +00:00
q: deque[tuple[list[str], None | list[str], DirList]] = deque(
[([], None, watching.tree[""].dir)]
)
while q:
locpar, relpar, d = q.pop()
for name, attr in d.items():
loc = [*locpar, name]
rel = None
if relpar or attr.key in wanted:
rel = [*relpar, name] if relpar else [name]
wanted.discard(attr.key)
isdir = isinstance(attr, DirEntry)
if isdir:
q.append((loc, rel, attr.dir))
if rel:
2023-11-11 10:48:33 +00:00
files.append(("/".join(rel), root.joinpath(*loc)))
return natsorted(files, key=lambda f: "/".join(f[0]), alg=ns.IGNORECASE)
@app.get("/zip/<keys>/<zipfile:ext=zip>")
async def zip_download(req, keys, zipfile, ext):
"""Download a zip archive of the given keys"""
wanted = set(keys.split("+"))
files = get_files(wanted)
if not files:
raise NotFound(
"No files found",
context={"keys": keys, "zipfile": zipfile, "wanted": wanted},
)
if wanted:
raise NotFound("Files not found", context={"missing": wanted})
def local_files(files):
for rel, p in files:
s = p.stat()
size = s.st_size
modified = datetime.datetime.fromtimestamp(s.st_mtime, datetime.UTC)
if p.is_dir():
yield rel, modified, S_IFDIR | 0o755, ZIP_AUTO(size), b""
else:
yield rel, modified, S_IFREG | 0o644, ZIP_AUTO(size), contents(p)
def contents(name):
with name.open("rb") as f:
while chunk := f.read(65536):
yield chunk
def worker():
try:
for chunk in stream_zip(local_files(files)):
asyncio.run_coroutine_threadsafe(queue.put(chunk), loop)
except Exception:
logging.exception("Error streaming ZIP")
raise
finally:
asyncio.run_coroutine_threadsafe(queue.put(None), loop)
# Don't block the event loop: run in a thread
queue = asyncio.Queue(maxsize=1)
loop = asyncio.get_event_loop()
thread = loop.run_in_executor(app.ctx.threadexec, worker)
# Stream the response
res = await req.respond(content_type="application/zip")
while chunk := await queue.get():
await res.send(chunk)
await thread # If it raises, the response will fail download