Cleanup, bugfixes. Added access control on files and API.

This commit is contained in:
Leo Vasanko 2023-10-23 04:51:39 +03:00 committed by Leo Vasanko
parent bd61d7451e
commit 4852212347
14 changed files with 239 additions and 412 deletions

View File

@ -7,7 +7,7 @@ import cista
from cista import app, config, droppy, serve, server80
from cista.util import pwgen
app, server80.app # Needed for Sanic multiprocessing
del app, server80.app # Only import needed, for Sanic multiprocessing
doc = f"""Cista {cista.__version__} - A file storage for the web.

86
cista/api.py Normal file
View File

@ -0,0 +1,86 @@
import asyncio
import typing
import msgspec
from sanic import Blueprint, SanicException
from cista import watching
from cista.fileio import FileServer
from cista.protocol import ControlBase, ErrorMsg, FileRange, StatusMsg
from cista.util.apphelpers import asend, websocket_wrapper
bp = Blueprint("api", url_prefix="/api")
fileserver = FileServer()
@bp.before_server_start
async def start_fileserver(app, _):
await fileserver.start()
@bp.after_server_stop
async def stop_fileserver(app, _):
await fileserver.stop()
@bp.websocket('upload')
@websocket_wrapper
async def upload(req, ws):
alink = fileserver.alink
while True:
req = None
text = await ws.recv()
if not isinstance(text, str):
raise ValueError(f"Expected JSON control, got binary len(data) = {len(text)}")
req = msgspec.json.decode(text, type=FileRange)
pos = req.start
data = None
while pos < req.end and (data := await ws.recv()) and isinstance(data, bytes):
sentsize = await alink(("upload", req.name, pos, data, req.size))
pos += typing.cast(int, sentsize)
if pos != req.end:
d = f"{len(data)} bytes" if isinstance(data, bytes) else data
raise ValueError(f"Expected {req.end - pos} more bytes, got {d}")
# Report success
res = StatusMsg(status="ack", req=req)
await asend(ws, res)
#await ws.drain()
@bp.websocket('download')
@websocket_wrapper
async def download(req, ws):
alink = fileserver.alink
while True:
req = None
text = await ws.recv()
if not isinstance(text, str):
raise ValueError(f"Expected JSON control, got binary len(data) = {len(text)}")
req = msgspec.json.decode(text, type=FileRange)
pos = req.start
while pos < req.end:
end = min(req.end, pos + (1<<20))
data = typing.cast(bytes, await alink(("download", req.name, pos, end)))
await asend(ws, data)
pos += len(data)
# Report success
res = StatusMsg(status="ack", req=req)
await asend(ws, res)
#await ws.drain()
@bp.websocket("control")
@websocket_wrapper
async def control(req, ws):
cmd = msgspec.json.decode(await ws.recv(), type=ControlBase)
await asyncio.to_thread(cmd)
await asend(ws, StatusMsg(status="ack", req=cmd))
@bp.websocket("watch")
@websocket_wrapper
async def watch(req, ws):
try:
with watching.tree_lock:
q = watching.pubsub[ws] = asyncio.Queue()
# Init with full tree
await ws.send(watching.refresh())
# Send updates
while True:
await ws.send(await q.get())
finally:
del watching.pubsub[ws]

View File

@ -1,72 +1,56 @@
import asyncio
from importlib.resources import files
import msgspec
from html5tagger import E
from sanic import Forbidden, Sanic, SanicException, errorpages, raw
from sanic.log import logger
from sanic.response import html, json, redirect
import mimetypes
from cista import config, session, watching
from cista.util import filename
from cista.auth import authbp
from cista.fileio import FileServer
from cista.protocol import ControlBase, ErrorMsg, FileRange, StatusMsg
from importlib.resources import files
from urllib.parse import unquote
app = Sanic("cista")
fileserver = FileServer()
watching.register(app, "/api/watch")
app.blueprint(authbp)
from html5tagger import E
from sanic import Blueprint, Sanic, raw
from sanic.exceptions import Forbidden, NotFound
@app.on_request
async def use_session(request):
request.ctx.session = session.get(request)
# CSRF protection
if request.method == "GET" and request.headers.upgrade != "websocket":
return # Ordinary GET requests are fine
if request.headers.origin.split("//", 1)[1] != request.host:
raise Forbidden("Invalid origin: Cross-Site requests not permitted")
from cista import auth, config, session, watching
from cista.api import bp
from cista.util import filename
from cista.util.apphelpers import handle_sanic_exception
@app.exception(Exception)
async def handle_sanic_exception(request, e):
context, code, message = {}, 500, str(e) or "Internal Server Error"
if isinstance(e, SanicException):
context = e.context or {}
code = e.status_code
message = f"⚠️ {message}"
# Non-browsers get JSON errors
if "text/html" not in request.headers.accept:
return json({"error": {"code": code, "message": message, **context}}, status=code)
# Redirections flash the error message via cookies
if "redirect" in context:
res = redirect(context["redirect"])
res.cookies.add_cookie("message", message, max_age=5)
return res
# Otherwise use Sanic's default error page
return errorpages.HTMLRenderer(request, e, debug=app.debug).full()
@app.on_response
async def resphandler(request, response):
print("In response handler", request.path, response)
def asend(ws, msg):
return ws.send(msg if isinstance(msg, bytes) else msgspec.json.encode(msg).decode())
app = Sanic("cista", strict_slashes=True)
app.blueprint(auth.bp)
app.blueprint(bp)
app.exception(Exception)(handle_sanic_exception)
@app.before_server_start
async def start_fileserver(app, _):
config.load_config() # Main process may have loaded it but we haven't
app.static("/files", config.config.path, use_content_range=True, stream_large_files=True, directory_view=True)
await fileserver.start()
async def main_start(app, loop):
config.load_config()
await watching.start(app, loop)
@app.after_server_stop
async def stop_fileserver(app, _):
await fileserver.stop()
async def main_stop(app, loop):
await watching.stop(app, loop)
@app.get("/<path:path>")
async def wwwroot(request, path=""):
@app.on_request
async def use_session(req):
req.ctx.session = session.get(req)
try:
req.ctx.user = config.config.users[req.ctx.session["user"]] # type: ignore
except (AttributeError, KeyError, TypeError):
req.ctx.user = None
# CSRF protection
if req.method == "GET" and req.headers.upgrade != "websocket":
return # Ordinary GET requests are fine
# Check that origin matches host, for browsers which should all send Origin.
# Curl doesn't send any Origin header, so we allow it anyway.
origin = req.headers.origin
if origin and origin.split("//", 1)[1] != req.host:
raise Forbidden("Invalid origin: Cross-Site requests not permitted")
@app.before_server_start
def http_fileserver(app, _):
bp = Blueprint("fileserver")
bp.on_request(auth.verify)
bp.static("/files/", config.config.path, use_content_range=True, stream_large_files=True, directory_view=True)
app.blueprint(bp)
@app.get("/<path:path>", static=True)
async def wwwroot(req, path=""):
"""Frontend files only"""
name = filename.sanitize(unquote(path)) if path else "index.html"
try:
index = files("cista").joinpath("wwwroot", name).read_bytes()
@ -74,67 +58,3 @@ async def wwwroot(request, path=""):
raise NotFound(f"File not found: /{path}", extra={"name": name, "exception": repr(e)})
mime = mimetypes.guess_type(name)[0] or "application/octet-stream"
return raw(index, content_type=mime)
@app.websocket('/api/upload')
async def upload(request, ws):
alink = fileserver.alink
url = request.url_for("upload")
while True:
req = None
try:
text = await ws.recv()
if not isinstance(text, str):
raise ValueError(f"Expected JSON control, got binary len(data) = {len(text)}")
req = msgspec.json.decode(text, type=FileRange)
pos = req.start
while pos < req.end and (data := await ws.recv()) and isinstance(data, bytes):
pos += await alink(("upload", req.name, pos, data, req.size))
if pos != req.end:
d = f"{len(data)} bytes" if isinstance(data, bytes) else data
raise ValueError(f"Expected {req.end - pos} more bytes, got {d}")
# Report success
res = StatusMsg(status="ack", req=req)
await asend(ws, res)
await ws.drain()
except Exception as e:
res = ErrorMsg(error=str(e), req=req)
await asend(ws, res)
logger.exception(repr(res), e)
return
@app.websocket('/api/download')
async def download(request, ws):
alink = fileserver.alink
while True:
req = None
try:
print("Waiting for download command")
text = await ws.recv()
if not isinstance(text, str):
raise ValueError(f"Expected JSON control, got binary len(data) = {len(text)}")
req = msgspec.json.decode(text, type=FileRange)
print("download", req)
pos = req.start
while pos < req.end:
end = min(req.end, pos + (1<<20))
data = await alink(("download", req.name, pos, end))
await asend(ws, data)
pos += len(data)
# Report success
res = StatusMsg(status="ack", req=req)
await asend(ws, res)
print(ws, dir(ws))
await ws.drain()
print(res)
except Exception as e:
res = ErrorMsg(error=str(e), req=req)
await asend(ws, res)
logger.exception(repr(res), e)
return
@app.websocket("/api/control")
async def control(request, websocket):
cmd = msgspec.json.decode(await websocket.recv(), type=ControlBase)
await asyncio.to_thread(cmd)
await asend(websocket, StatusMsg(status="ack", req=cmd))

View File

@ -6,7 +6,8 @@ from unicodedata import normalize
import argon2
import msgspec
from html5tagger import Document
from sanic import BadRequest, Blueprint, Forbidden, html, json, redirect
from sanic import Blueprint, html, json, redirect
from sanic.exceptions import BadRequest, Forbidden, Unauthorized
from cista import config, session
@ -56,9 +57,18 @@ class LoginResponse(msgspec.Struct):
privileged: bool = False
error: str = ""
authbp = Blueprint("auth")
def verify(request, privileged=False):
"""Raise Unauthorized or Forbidden if the request is not authorized"""
if privileged:
if request.ctx.user:
if request.ctx.user.privileged: return
raise Forbidden("Access Forbidden: Only for privileged users")
elif config.config.public or request.ctx.user: return
raise Unauthorized("Login required", "cookie", context={"redirect": "/login"})
@authbp.get("/login")
bp = Blueprint("auth")
@bp.get("/login")
async def login_page(request):
doc = Document("Cista Login")
with doc.div(id="login"):
@ -82,7 +92,7 @@ async def login_page(request):
session.delete(res)
return res
@authbp.post("/login")
@bp.post("/login")
async def login_post(request):
try:
if request.headers.content_type == "application/json":
@ -108,7 +118,7 @@ async def login_post(request):
session.create(res, username)
return res
@authbp.post("/logout")
@bp.post("/logout")
async def logout_post(request):
s = request.ctx.session
msg = "Logged out" if s else "Not logged in"

View File

@ -30,6 +30,7 @@ class File:
self.writable = True
def write(self, pos, buffer, *, file_size=None):
assert self.fd is not None
if not self.writable:
# Create/open file
self.open_rw()
@ -39,6 +40,7 @@ class File:
os.write(self.fd, buffer)
def __getitem__(self, slice):
assert self.fd is not None
if self.fd is None:
self.open_ro()
os.lseek(self.fd, slice.start, os.SEEK_SET)
@ -71,8 +73,6 @@ class FileServer:
try:
for req in slink:
with req as (command, *args):
if command == "control":
self.control(*args)
if command == "upload":
req.set_result(self.upload(*args))
elif command == "download":

View File

@ -1,241 +0,0 @@
<!DOCTYPE html>
<title>Storage</title>
<style>
body {
font-family: sans-serif;
max-width: 100ch;
margin: 0 auto;
padding: 1em;
background-color: #333;
color: #eee;
}
td {
text-align: right;
padding: .5em;
}
td:first-child {
text-align: left;
}
a {
color: inherit;
text-decoration: none;
}
</style>
<div>
<h2>Quick file upload</h2>
<p>Uses parallel WebSocket connections for increased bandwidth /api/upload</p>
<input type=file id=fileInput>
<progress id=progressBar value=0 max=1></progress>
</div>
<div>
<h2>Files</h2>
<ul id=file_list></ul>
</div>
<script>
let files = {}
let flatfiles = {}
function createWatchSocket() {
const wsurl = new URL("/api/watch", location.href.replace(/^http/, 'ws'))
const ws = new WebSocket(wsurl)
ws.onmessage = event => {
msg = JSON.parse(event.data)
if (msg.update) {
tree_update(msg.update)
file_list(files)
} else {
console.log("Unkonwn message from watch socket", msg)
}
}
}
createWatchSocket()
function tree_update(msg) {
console.log("Tree update", msg)
let node = files
for (const elem of msg) {
if (elem.deleted) {
const p = node.dir[elem.name].path
delete node.dir[elem.name]
delete flatfiles[p]
break
}
if (elem.name !== undefined) node = node.dir[elem.name] ||= {}
if (elem.size !== undefined) node.size = elem.size
if (elem.mtime !== undefined) node.mtime = elem.mtime
if (elem.dir !== undefined) node.dir = elem.dir
}
// Update paths and flatfiles
files.path = "/"
const nodes = [files]
flatfiles = {}
while (node = nodes.pop()) {
flatfiles[node.path] = node
if (node.dir === undefined) continue
for (const name of Object.keys(node.dir)) {
const child = node.dir[name]
child.path = node.path + name + (child.dir === undefined ? "" : "/")
nodes.push(child)
}
}
}
var collator = new Intl.Collator(undefined, {numeric: true, sensitivity: 'base'});
const compare_path = (a, b) => collator.compare(a.path, b.path)
const compare_time = (a, b) => a.mtime > b.mtime
function file_list(files) {
const table = document.getElementById("file_list")
const sorted = Object.values(flatfiles).sort(compare_time)
table.innerHTML = ""
for (const f of sorted) {
const {path, size, mtime} = f
const tr = document.createElement("tr")
const name_td = document.createElement("td")
const size_td = document.createElement("td")
const mtime_td = document.createElement("td")
const a = document.createElement("a")
table.appendChild(tr)
tr.appendChild(name_td)
tr.appendChild(size_td)
tr.appendChild(mtime_td)
name_td.appendChild(a)
size_td.textContent = size
mtime_td.textContent = formatUnixDate(mtime)
a.textContent = path
a.href = `/files${path}`
/*a.onclick = event => {
if (window.showSaveFilePicker) {
event.preventDefault()
download_ws(name, size)
}
}
a.download = ""*/
}
}
function formatUnixDate(t) {
const date = new Date(t * 1000)
const now = new Date()
const diff = date - now
const formatter = new Intl.RelativeTimeFormat('en', { numeric: 'auto' })
if (Math.abs(diff) <= 60000) {
return formatter.format(Math.round(diff / 1000), 'second')
}
if (Math.abs(diff) <= 3600000) {
return formatter.format(Math.round(diff / 60000), 'minute')
}
if (Math.abs(diff) <= 86400000) {
return formatter.format(Math.round(diff / 3600000), 'hour')
}
if (Math.abs(diff) <= 604800000) {
return formatter.format(Math.round(diff / 86400000), 'day')
}
return date.toLocaleDateString()
}
async function download_ws(name, size) {
const fh = await window.showSaveFilePicker({
suggestedName: name,
})
const writer = await fh.createWritable()
writer.truncate(size)
const wsurl = new URL("/api/download", location.href.replace(/^http/, 'ws'))
const ws = new WebSocket(wsurl)
let pos = 0
ws.onopen = () => {
console.log("Downloading over WebSocket", name, size)
ws.send(JSON.stringify({name, start: 0, end: size, size}))
}
ws.onmessage = event => {
if (typeof event.data === 'string') {
const msg = JSON.parse(event.data)
console.log("Download finished", msg)
ws.close()
return
}
console.log("Received chunk", name, pos, pos + event.data.size)
pos += event.data.size
writer.write(event.data)
}
ws.onclose = () => {
if (pos < size) {
console.log("Download aborted", name, pos)
writer.truncate(pos)
}
writer.close()
}
}
const fileInput = document.getElementById("fileInput")
const progress = document.getElementById("progressBar")
const numConnections = 2
const chunkSize = 1<<20
const wsConnections = new Set()
//for (let i = 0; i < numConnections; i++) createUploadWS()
function createUploadWS() {
const wsurl = new URL("/api/upload", location.href.replace(/^http/, 'ws'))
const ws = new WebSocket(wsurl)
ws.binaryType = 'arraybuffer'
ws.onopen = () => {
wsConnections.add(ws)
console.log("Upload socket connected")
}
ws.onmessage = event => {
msg = JSON.parse(event.data)
if (msg.written) progress.value += +msg.written
else console.log(`Error: ${msg.error}`)
}
ws.onclose = () => {
wsConnections.delete(ws)
console.log("Upload socket disconnected, reconnecting...")
setTimeout(createUploadWS, 1000)
}
}
async function load(file, start, end) {
const reader = new FileReader()
const load = new Promise(resolve => reader.onload = resolve)
reader.readAsArrayBuffer(file.slice(start, end))
const event = await load
return event.target.result
}
async function sendChunk(file, start, end, ws) {
const chunk = await load(file, start, end)
ws.send(JSON.stringify({
name: file.name,
size: file.size,
start: start,
end: end
}))
ws.send(chunk)
}
fileInput.addEventListener("change", async function() {
const file = this.files[0]
const numChunks = Math.ceil(file.size / chunkSize)
progress.value = 0
progress.max = file.size
console.log(wsConnections)
for (let i = 0; i < numChunks; i++) {
const ws = Array.from(wsConnections)[i % wsConnections.size]
const start = i * chunkSize
const end = Math.min(file.size, start + chunkSize)
const res = await sendChunk(file, start, end, ws)
}
})
</script>

View File

@ -1,6 +1,7 @@
from __future__ import annotations
import shutil
from typing import Any
import msgspec
from sanic import BadRequest
@ -71,14 +72,12 @@ class FileRange(msgspec.Struct):
start: int
end: int
class ErrorMsg(msgspec.Struct):
error: str
req: FileRange
class StatusMsg(msgspec.Struct):
status: str
req: FileRange
class ErrorMsg(msgspec.Struct):
error: dict[str, Any]
## Directory listings

View File

@ -4,7 +4,7 @@ from pathlib import Path
from sanic import Sanic
from cista import config
from cista import config, server80
def run(dev=False):
@ -15,11 +15,10 @@ def run(dev=False):
os.environ["SANIC_IGNORE_PRODUCTION_WARNING"] = "1"
if opts.get("ssl"):
# Run plain HTTP redirect/acme server on port 80
from . import server80
server80.app.prepare(port=80, motd=False)
domain = opts["host"]
opts["ssl"] = str(config.conffile.parent / domain)
app.prepare(**opts, motd=False, dev=dev, auto_reload=dev, access_log=True)
opts["ssl"] = str(config.conffile.parent / domain) # type: ignore
app.prepare(**opts, motd=False, dev=dev, auto_reload=dev, access_log=True) # type: ignore
Sanic.serve()
def parse_listen(listen):

View File

@ -26,7 +26,8 @@ def create(res, username, **kwargs):
def update(res, s, **kwargs):
s.update(kwargs)
s = jwt.encode(s, session_secret())
res.cookies.add_cookie("s", s, httponly=True, max_age=max(1, s["exp"] - int(time())))
max_age = max(1, s["exp"] - int(time())) # type: ignore
res.cookies.add_cookie("s", s, httponly=True, max_age=max_age)
def delete(res):
res.cookies.delete_cookie("s")

58
cista/util/apphelpers.py Normal file
View File

@ -0,0 +1,58 @@
from functools import wraps
import msgspec
from sanic import errorpages
from sanic.exceptions import SanicException
from sanic.log import logger
from sanic.response import raw, redirect
from cista import auth
from cista.protocol import ErrorMsg
def asend(ws, msg):
"""Send JSON message or bytes to a websocket"""
return ws.send(msg if isinstance(msg, bytes) else msgspec.json.encode(msg).decode())
def jres(data, **kwargs):
"""JSON Sanic response, using msgspec encoding"""
return raw(msgspec.json.encode(data), content_type="application/json", **kwargs)
async def handle_sanic_exception(request, e):
logger.exception(e)
context, code = {}, 500
message = str(e)
if isinstance(e, SanicException):
context = e.context or {}
code = e.status_code
if not message or not request.app.debug and code == 500:
message = "Internal Server Error"
message = f"⚠️ {message}" if code < 500 else f"🛑 {message}"
# Non-browsers get JSON errors
if "text/html" not in request.headers.accept:
return jres(ErrorMsg({"code": code, "message": message, **context}), status=code)
# Redirections flash the error message via cookies
if "redirect" in context:
res = redirect(context["redirect"])
res.cookies.add_cookie("message", message, max_age=5)
return res
# Otherwise use Sanic's default error page
return errorpages.HTMLRenderer(request, e, debug=request.app.debug).full()
def websocket_wrapper(handler):
"""Decorator for websocket handlers that catches exceptions and sends them back to the client"""
@wraps(handler)
async def wrapper(request, ws, *args, **kwargs):
try:
auth.verify(request)
await handler(request, ws, *args, **kwargs)
except Exception as e:
logger.exception(e)
context, code, message = {}, 500, str(e) or "Internal Server Error"
if isinstance(e, SanicException):
context = e.context or {}
code = e.status_code
message = f"⚠️ {message}" if code < 500 else f"🛑 {message}"
await asend(ws, ErrorMsg({"code": code, "message": message, **context}))
raise
return wrapper

View File

@ -1,7 +1,9 @@
from pathvalidate import sanitize_filepath
import unicodedata
from pathlib import PurePosixPath
from pathvalidate import sanitize_filepath
def sanitize(filename: str) -> str:
filename = unicodedata.normalize("NFC", filename)
# UNIX filenames can contain backslashes but for compatibility we replace them with dashes

View File

@ -8,6 +8,7 @@ import msgspec
from cista import config
from cista.protocol import DirEntry, FileEntry, UpdateEntry
from cista.util.apphelpers import websocket_wrapper
pubsub = {}
tree = {"": None}
@ -132,30 +133,14 @@ async def broadcast(msg):
for queue in pubsub.values():
await queue.put_nowait(msg)
def register(app, url):
@app.before_server_start
async def start_watcher(app, loop):
global rootpath
config.load_config()
rootpath = config.config.path
app.ctx.watcher = threading.Thread(target=watcher_thread, args=[loop])
app.ctx.watcher.start()
async def start(app, loop):
global rootpath
config.load_config()
rootpath = config.config.path
app.ctx.watcher = threading.Thread(target=watcher_thread, args=[loop])
app.ctx.watcher.start()
@app.after_server_stop
async def stop_watcher(app, _):
global quit
quit = True
app.ctx.watcher.join()
@app.websocket(url)
async def watch(request, ws):
try:
with tree_lock:
q = pubsub[ws] = asyncio.Queue()
# Init with full tree
await ws.send(refresh())
# Send updates
while True:
await ws.send(await q.get())
finally:
del pubsub[ws]
async def stop(app, loop):
global quit
quit = True
app.ctx.watcher.join()

View File

@ -62,3 +62,8 @@ addopts = [
testpaths = [
"tests",
]
[tool.isort]
#src_paths = ["cista", "tests"]
line_length = 120
multi_line_output = 5

View File

@ -1,8 +1,11 @@
import pytest
from cista import config
from cista.protocol import MkDir, Rename, Rm, Mv, Cp
from pathlib import Path
import tempfile
from pathlib import Path
import pytest
from cista import config
from cista.protocol import Cp, MkDir, Mv, Rename, Rm
@pytest.fixture
def setup_temp_dir():