Cleanup, bugfixes. Added access control on files and API.
This commit is contained in:
parent
bd61d7451e
commit
4852212347
|
@ -7,7 +7,7 @@ import cista
|
||||||
from cista import app, config, droppy, serve, server80
|
from cista import app, config, droppy, serve, server80
|
||||||
from cista.util import pwgen
|
from cista.util import pwgen
|
||||||
|
|
||||||
app, server80.app # Needed for Sanic multiprocessing
|
del app, server80.app # Only import needed, for Sanic multiprocessing
|
||||||
|
|
||||||
doc = f"""Cista {cista.__version__} - A file storage for the web.
|
doc = f"""Cista {cista.__version__} - A file storage for the web.
|
||||||
|
|
||||||
|
|
86
cista/api.py
Normal file
86
cista/api.py
Normal file
|
@ -0,0 +1,86 @@
|
||||||
|
import asyncio
|
||||||
|
import typing
|
||||||
|
|
||||||
|
import msgspec
|
||||||
|
from sanic import Blueprint, SanicException
|
||||||
|
|
||||||
|
from cista import watching
|
||||||
|
from cista.fileio import FileServer
|
||||||
|
from cista.protocol import ControlBase, ErrorMsg, FileRange, StatusMsg
|
||||||
|
from cista.util.apphelpers import asend, websocket_wrapper
|
||||||
|
|
||||||
|
bp = Blueprint("api", url_prefix="/api")
|
||||||
|
fileserver = FileServer()
|
||||||
|
|
||||||
|
@bp.before_server_start
|
||||||
|
async def start_fileserver(app, _):
|
||||||
|
await fileserver.start()
|
||||||
|
|
||||||
|
@bp.after_server_stop
|
||||||
|
async def stop_fileserver(app, _):
|
||||||
|
await fileserver.stop()
|
||||||
|
|
||||||
|
@bp.websocket('upload')
|
||||||
|
@websocket_wrapper
|
||||||
|
async def upload(req, ws):
|
||||||
|
alink = fileserver.alink
|
||||||
|
while True:
|
||||||
|
req = None
|
||||||
|
text = await ws.recv()
|
||||||
|
if not isinstance(text, str):
|
||||||
|
raise ValueError(f"Expected JSON control, got binary len(data) = {len(text)}")
|
||||||
|
req = msgspec.json.decode(text, type=FileRange)
|
||||||
|
pos = req.start
|
||||||
|
data = None
|
||||||
|
while pos < req.end and (data := await ws.recv()) and isinstance(data, bytes):
|
||||||
|
sentsize = await alink(("upload", req.name, pos, data, req.size))
|
||||||
|
pos += typing.cast(int, sentsize)
|
||||||
|
if pos != req.end:
|
||||||
|
d = f"{len(data)} bytes" if isinstance(data, bytes) else data
|
||||||
|
raise ValueError(f"Expected {req.end - pos} more bytes, got {d}")
|
||||||
|
# Report success
|
||||||
|
res = StatusMsg(status="ack", req=req)
|
||||||
|
await asend(ws, res)
|
||||||
|
#await ws.drain()
|
||||||
|
|
||||||
|
@bp.websocket('download')
|
||||||
|
@websocket_wrapper
|
||||||
|
async def download(req, ws):
|
||||||
|
alink = fileserver.alink
|
||||||
|
while True:
|
||||||
|
req = None
|
||||||
|
text = await ws.recv()
|
||||||
|
if not isinstance(text, str):
|
||||||
|
raise ValueError(f"Expected JSON control, got binary len(data) = {len(text)}")
|
||||||
|
req = msgspec.json.decode(text, type=FileRange)
|
||||||
|
pos = req.start
|
||||||
|
while pos < req.end:
|
||||||
|
end = min(req.end, pos + (1<<20))
|
||||||
|
data = typing.cast(bytes, await alink(("download", req.name, pos, end)))
|
||||||
|
await asend(ws, data)
|
||||||
|
pos += len(data)
|
||||||
|
# Report success
|
||||||
|
res = StatusMsg(status="ack", req=req)
|
||||||
|
await asend(ws, res)
|
||||||
|
#await ws.drain()
|
||||||
|
|
||||||
|
@bp.websocket("control")
|
||||||
|
@websocket_wrapper
|
||||||
|
async def control(req, ws):
|
||||||
|
cmd = msgspec.json.decode(await ws.recv(), type=ControlBase)
|
||||||
|
await asyncio.to_thread(cmd)
|
||||||
|
await asend(ws, StatusMsg(status="ack", req=cmd))
|
||||||
|
|
||||||
|
@bp.websocket("watch")
|
||||||
|
@websocket_wrapper
|
||||||
|
async def watch(req, ws):
|
||||||
|
try:
|
||||||
|
with watching.tree_lock:
|
||||||
|
q = watching.pubsub[ws] = asyncio.Queue()
|
||||||
|
# Init with full tree
|
||||||
|
await ws.send(watching.refresh())
|
||||||
|
# Send updates
|
||||||
|
while True:
|
||||||
|
await ws.send(await q.get())
|
||||||
|
finally:
|
||||||
|
del watching.pubsub[ws]
|
166
cista/app.py
166
cista/app.py
|
@ -1,72 +1,56 @@
|
||||||
import asyncio
|
|
||||||
from importlib.resources import files
|
|
||||||
|
|
||||||
import msgspec
|
|
||||||
from html5tagger import E
|
|
||||||
from sanic import Forbidden, Sanic, SanicException, errorpages, raw
|
|
||||||
from sanic.log import logger
|
|
||||||
from sanic.response import html, json, redirect
|
|
||||||
import mimetypes
|
import mimetypes
|
||||||
|
from importlib.resources import files
|
||||||
from cista import config, session, watching
|
|
||||||
from cista.util import filename
|
|
||||||
from cista.auth import authbp
|
|
||||||
from cista.fileio import FileServer
|
|
||||||
from cista.protocol import ControlBase, ErrorMsg, FileRange, StatusMsg
|
|
||||||
from urllib.parse import unquote
|
from urllib.parse import unquote
|
||||||
|
|
||||||
app = Sanic("cista")
|
from html5tagger import E
|
||||||
fileserver = FileServer()
|
from sanic import Blueprint, Sanic, raw
|
||||||
watching.register(app, "/api/watch")
|
from sanic.exceptions import Forbidden, NotFound
|
||||||
app.blueprint(authbp)
|
|
||||||
|
|
||||||
@app.on_request
|
from cista import auth, config, session, watching
|
||||||
async def use_session(request):
|
from cista.api import bp
|
||||||
request.ctx.session = session.get(request)
|
from cista.util import filename
|
||||||
# CSRF protection
|
from cista.util.apphelpers import handle_sanic_exception
|
||||||
if request.method == "GET" and request.headers.upgrade != "websocket":
|
|
||||||
return # Ordinary GET requests are fine
|
|
||||||
if request.headers.origin.split("//", 1)[1] != request.host:
|
|
||||||
raise Forbidden("Invalid origin: Cross-Site requests not permitted")
|
|
||||||
|
|
||||||
@app.exception(Exception)
|
app = Sanic("cista", strict_slashes=True)
|
||||||
async def handle_sanic_exception(request, e):
|
app.blueprint(auth.bp)
|
||||||
context, code, message = {}, 500, str(e) or "Internal Server Error"
|
app.blueprint(bp)
|
||||||
if isinstance(e, SanicException):
|
app.exception(Exception)(handle_sanic_exception)
|
||||||
context = e.context or {}
|
|
||||||
code = e.status_code
|
|
||||||
message = f"⚠️ {message}"
|
|
||||||
# Non-browsers get JSON errors
|
|
||||||
if "text/html" not in request.headers.accept:
|
|
||||||
return json({"error": {"code": code, "message": message, **context}}, status=code)
|
|
||||||
# Redirections flash the error message via cookies
|
|
||||||
if "redirect" in context:
|
|
||||||
res = redirect(context["redirect"])
|
|
||||||
res.cookies.add_cookie("message", message, max_age=5)
|
|
||||||
return res
|
|
||||||
# Otherwise use Sanic's default error page
|
|
||||||
return errorpages.HTMLRenderer(request, e, debug=app.debug).full()
|
|
||||||
|
|
||||||
@app.on_response
|
|
||||||
async def resphandler(request, response):
|
|
||||||
print("In response handler", request.path, response)
|
|
||||||
|
|
||||||
def asend(ws, msg):
|
|
||||||
return ws.send(msg if isinstance(msg, bytes) else msgspec.json.encode(msg).decode())
|
|
||||||
|
|
||||||
@app.before_server_start
|
@app.before_server_start
|
||||||
async def start_fileserver(app, _):
|
async def main_start(app, loop):
|
||||||
config.load_config() # Main process may have loaded it but we haven't
|
config.load_config()
|
||||||
app.static("/files", config.config.path, use_content_range=True, stream_large_files=True, directory_view=True)
|
await watching.start(app, loop)
|
||||||
|
|
||||||
await fileserver.start()
|
|
||||||
|
|
||||||
@app.after_server_stop
|
@app.after_server_stop
|
||||||
async def stop_fileserver(app, _):
|
async def main_stop(app, loop):
|
||||||
await fileserver.stop()
|
await watching.stop(app, loop)
|
||||||
|
|
||||||
@app.get("/<path:path>")
|
@app.on_request
|
||||||
async def wwwroot(request, path=""):
|
async def use_session(req):
|
||||||
|
req.ctx.session = session.get(req)
|
||||||
|
try:
|
||||||
|
req.ctx.user = config.config.users[req.ctx.session["user"]] # type: ignore
|
||||||
|
except (AttributeError, KeyError, TypeError):
|
||||||
|
req.ctx.user = None
|
||||||
|
# CSRF protection
|
||||||
|
if req.method == "GET" and req.headers.upgrade != "websocket":
|
||||||
|
return # Ordinary GET requests are fine
|
||||||
|
# Check that origin matches host, for browsers which should all send Origin.
|
||||||
|
# Curl doesn't send any Origin header, so we allow it anyway.
|
||||||
|
origin = req.headers.origin
|
||||||
|
if origin and origin.split("//", 1)[1] != req.host:
|
||||||
|
raise Forbidden("Invalid origin: Cross-Site requests not permitted")
|
||||||
|
|
||||||
|
@app.before_server_start
|
||||||
|
def http_fileserver(app, _):
|
||||||
|
bp = Blueprint("fileserver")
|
||||||
|
bp.on_request(auth.verify)
|
||||||
|
bp.static("/files/", config.config.path, use_content_range=True, stream_large_files=True, directory_view=True)
|
||||||
|
app.blueprint(bp)
|
||||||
|
|
||||||
|
@app.get("/<path:path>", static=True)
|
||||||
|
async def wwwroot(req, path=""):
|
||||||
|
"""Frontend files only"""
|
||||||
name = filename.sanitize(unquote(path)) if path else "index.html"
|
name = filename.sanitize(unquote(path)) if path else "index.html"
|
||||||
try:
|
try:
|
||||||
index = files("cista").joinpath("wwwroot", name).read_bytes()
|
index = files("cista").joinpath("wwwroot", name).read_bytes()
|
||||||
|
@ -74,67 +58,3 @@ async def wwwroot(request, path=""):
|
||||||
raise NotFound(f"File not found: /{path}", extra={"name": name, "exception": repr(e)})
|
raise NotFound(f"File not found: /{path}", extra={"name": name, "exception": repr(e)})
|
||||||
mime = mimetypes.guess_type(name)[0] or "application/octet-stream"
|
mime = mimetypes.guess_type(name)[0] or "application/octet-stream"
|
||||||
return raw(index, content_type=mime)
|
return raw(index, content_type=mime)
|
||||||
|
|
||||||
@app.websocket('/api/upload')
|
|
||||||
async def upload(request, ws):
|
|
||||||
alink = fileserver.alink
|
|
||||||
url = request.url_for("upload")
|
|
||||||
while True:
|
|
||||||
req = None
|
|
||||||
try:
|
|
||||||
text = await ws.recv()
|
|
||||||
if not isinstance(text, str):
|
|
||||||
raise ValueError(f"Expected JSON control, got binary len(data) = {len(text)}")
|
|
||||||
req = msgspec.json.decode(text, type=FileRange)
|
|
||||||
pos = req.start
|
|
||||||
while pos < req.end and (data := await ws.recv()) and isinstance(data, bytes):
|
|
||||||
pos += await alink(("upload", req.name, pos, data, req.size))
|
|
||||||
if pos != req.end:
|
|
||||||
d = f"{len(data)} bytes" if isinstance(data, bytes) else data
|
|
||||||
raise ValueError(f"Expected {req.end - pos} more bytes, got {d}")
|
|
||||||
# Report success
|
|
||||||
res = StatusMsg(status="ack", req=req)
|
|
||||||
await asend(ws, res)
|
|
||||||
await ws.drain()
|
|
||||||
except Exception as e:
|
|
||||||
res = ErrorMsg(error=str(e), req=req)
|
|
||||||
await asend(ws, res)
|
|
||||||
logger.exception(repr(res), e)
|
|
||||||
return
|
|
||||||
|
|
||||||
@app.websocket('/api/download')
|
|
||||||
async def download(request, ws):
|
|
||||||
alink = fileserver.alink
|
|
||||||
while True:
|
|
||||||
req = None
|
|
||||||
try:
|
|
||||||
print("Waiting for download command")
|
|
||||||
text = await ws.recv()
|
|
||||||
if not isinstance(text, str):
|
|
||||||
raise ValueError(f"Expected JSON control, got binary len(data) = {len(text)}")
|
|
||||||
req = msgspec.json.decode(text, type=FileRange)
|
|
||||||
print("download", req)
|
|
||||||
pos = req.start
|
|
||||||
while pos < req.end:
|
|
||||||
end = min(req.end, pos + (1<<20))
|
|
||||||
data = await alink(("download", req.name, pos, end))
|
|
||||||
await asend(ws, data)
|
|
||||||
pos += len(data)
|
|
||||||
# Report success
|
|
||||||
res = StatusMsg(status="ack", req=req)
|
|
||||||
await asend(ws, res)
|
|
||||||
print(ws, dir(ws))
|
|
||||||
await ws.drain()
|
|
||||||
print(res)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
res = ErrorMsg(error=str(e), req=req)
|
|
||||||
await asend(ws, res)
|
|
||||||
logger.exception(repr(res), e)
|
|
||||||
return
|
|
||||||
|
|
||||||
@app.websocket("/api/control")
|
|
||||||
async def control(request, websocket):
|
|
||||||
cmd = msgspec.json.decode(await websocket.recv(), type=ControlBase)
|
|
||||||
await asyncio.to_thread(cmd)
|
|
||||||
await asend(websocket, StatusMsg(status="ack", req=cmd))
|
|
||||||
|
|
|
@ -6,7 +6,8 @@ from unicodedata import normalize
|
||||||
import argon2
|
import argon2
|
||||||
import msgspec
|
import msgspec
|
||||||
from html5tagger import Document
|
from html5tagger import Document
|
||||||
from sanic import BadRequest, Blueprint, Forbidden, html, json, redirect
|
from sanic import Blueprint, html, json, redirect
|
||||||
|
from sanic.exceptions import BadRequest, Forbidden, Unauthorized
|
||||||
|
|
||||||
from cista import config, session
|
from cista import config, session
|
||||||
|
|
||||||
|
@ -56,9 +57,18 @@ class LoginResponse(msgspec.Struct):
|
||||||
privileged: bool = False
|
privileged: bool = False
|
||||||
error: str = ""
|
error: str = ""
|
||||||
|
|
||||||
authbp = Blueprint("auth")
|
def verify(request, privileged=False):
|
||||||
|
"""Raise Unauthorized or Forbidden if the request is not authorized"""
|
||||||
|
if privileged:
|
||||||
|
if request.ctx.user:
|
||||||
|
if request.ctx.user.privileged: return
|
||||||
|
raise Forbidden("Access Forbidden: Only for privileged users")
|
||||||
|
elif config.config.public or request.ctx.user: return
|
||||||
|
raise Unauthorized("Login required", "cookie", context={"redirect": "/login"})
|
||||||
|
|
||||||
@authbp.get("/login")
|
bp = Blueprint("auth")
|
||||||
|
|
||||||
|
@bp.get("/login")
|
||||||
async def login_page(request):
|
async def login_page(request):
|
||||||
doc = Document("Cista Login")
|
doc = Document("Cista Login")
|
||||||
with doc.div(id="login"):
|
with doc.div(id="login"):
|
||||||
|
@ -82,7 +92,7 @@ async def login_page(request):
|
||||||
session.delete(res)
|
session.delete(res)
|
||||||
return res
|
return res
|
||||||
|
|
||||||
@authbp.post("/login")
|
@bp.post("/login")
|
||||||
async def login_post(request):
|
async def login_post(request):
|
||||||
try:
|
try:
|
||||||
if request.headers.content_type == "application/json":
|
if request.headers.content_type == "application/json":
|
||||||
|
@ -108,7 +118,7 @@ async def login_post(request):
|
||||||
session.create(res, username)
|
session.create(res, username)
|
||||||
return res
|
return res
|
||||||
|
|
||||||
@authbp.post("/logout")
|
@bp.post("/logout")
|
||||||
async def logout_post(request):
|
async def logout_post(request):
|
||||||
s = request.ctx.session
|
s = request.ctx.session
|
||||||
msg = "Logged out" if s else "Not logged in"
|
msg = "Logged out" if s else "Not logged in"
|
||||||
|
|
|
@ -30,6 +30,7 @@ class File:
|
||||||
self.writable = True
|
self.writable = True
|
||||||
|
|
||||||
def write(self, pos, buffer, *, file_size=None):
|
def write(self, pos, buffer, *, file_size=None):
|
||||||
|
assert self.fd is not None
|
||||||
if not self.writable:
|
if not self.writable:
|
||||||
# Create/open file
|
# Create/open file
|
||||||
self.open_rw()
|
self.open_rw()
|
||||||
|
@ -39,6 +40,7 @@ class File:
|
||||||
os.write(self.fd, buffer)
|
os.write(self.fd, buffer)
|
||||||
|
|
||||||
def __getitem__(self, slice):
|
def __getitem__(self, slice):
|
||||||
|
assert self.fd is not None
|
||||||
if self.fd is None:
|
if self.fd is None:
|
||||||
self.open_ro()
|
self.open_ro()
|
||||||
os.lseek(self.fd, slice.start, os.SEEK_SET)
|
os.lseek(self.fd, slice.start, os.SEEK_SET)
|
||||||
|
@ -71,8 +73,6 @@ class FileServer:
|
||||||
try:
|
try:
|
||||||
for req in slink:
|
for req in slink:
|
||||||
with req as (command, *args):
|
with req as (command, *args):
|
||||||
if command == "control":
|
|
||||||
self.control(*args)
|
|
||||||
if command == "upload":
|
if command == "upload":
|
||||||
req.set_result(self.upload(*args))
|
req.set_result(self.upload(*args))
|
||||||
elif command == "download":
|
elif command == "download":
|
||||||
|
|
241
cista/index.html
241
cista/index.html
|
@ -1,241 +0,0 @@
|
||||||
<!DOCTYPE html>
|
|
||||||
<title>Storage</title>
|
|
||||||
<style>
|
|
||||||
body {
|
|
||||||
font-family: sans-serif;
|
|
||||||
max-width: 100ch;
|
|
||||||
margin: 0 auto;
|
|
||||||
padding: 1em;
|
|
||||||
background-color: #333;
|
|
||||||
color: #eee;
|
|
||||||
}
|
|
||||||
td {
|
|
||||||
text-align: right;
|
|
||||||
padding: .5em;
|
|
||||||
}
|
|
||||||
td:first-child {
|
|
||||||
text-align: left;
|
|
||||||
}
|
|
||||||
a {
|
|
||||||
color: inherit;
|
|
||||||
text-decoration: none;
|
|
||||||
}
|
|
||||||
</style>
|
|
||||||
<div>
|
|
||||||
<h2>Quick file upload</h2>
|
|
||||||
<p>Uses parallel WebSocket connections for increased bandwidth /api/upload</p>
|
|
||||||
<input type=file id=fileInput>
|
|
||||||
<progress id=progressBar value=0 max=1></progress>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
<div>
|
|
||||||
<h2>Files</h2>
|
|
||||||
<ul id=file_list></ul>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
<script>
|
|
||||||
let files = {}
|
|
||||||
let flatfiles = {}
|
|
||||||
|
|
||||||
function createWatchSocket() {
|
|
||||||
const wsurl = new URL("/api/watch", location.href.replace(/^http/, 'ws'))
|
|
||||||
const ws = new WebSocket(wsurl)
|
|
||||||
ws.onmessage = event => {
|
|
||||||
msg = JSON.parse(event.data)
|
|
||||||
if (msg.update) {
|
|
||||||
tree_update(msg.update)
|
|
||||||
file_list(files)
|
|
||||||
} else {
|
|
||||||
console.log("Unkonwn message from watch socket", msg)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
createWatchSocket()
|
|
||||||
|
|
||||||
function tree_update(msg) {
|
|
||||||
console.log("Tree update", msg)
|
|
||||||
let node = files
|
|
||||||
for (const elem of msg) {
|
|
||||||
if (elem.deleted) {
|
|
||||||
const p = node.dir[elem.name].path
|
|
||||||
delete node.dir[elem.name]
|
|
||||||
delete flatfiles[p]
|
|
||||||
break
|
|
||||||
}
|
|
||||||
if (elem.name !== undefined) node = node.dir[elem.name] ||= {}
|
|
||||||
if (elem.size !== undefined) node.size = elem.size
|
|
||||||
if (elem.mtime !== undefined) node.mtime = elem.mtime
|
|
||||||
if (elem.dir !== undefined) node.dir = elem.dir
|
|
||||||
}
|
|
||||||
// Update paths and flatfiles
|
|
||||||
files.path = "/"
|
|
||||||
const nodes = [files]
|
|
||||||
flatfiles = {}
|
|
||||||
while (node = nodes.pop()) {
|
|
||||||
flatfiles[node.path] = node
|
|
||||||
if (node.dir === undefined) continue
|
|
||||||
for (const name of Object.keys(node.dir)) {
|
|
||||||
const child = node.dir[name]
|
|
||||||
child.path = node.path + name + (child.dir === undefined ? "" : "/")
|
|
||||||
nodes.push(child)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
var collator = new Intl.Collator(undefined, {numeric: true, sensitivity: 'base'});
|
|
||||||
|
|
||||||
const compare_path = (a, b) => collator.compare(a.path, b.path)
|
|
||||||
const compare_time = (a, b) => a.mtime > b.mtime
|
|
||||||
|
|
||||||
function file_list(files) {
|
|
||||||
const table = document.getElementById("file_list")
|
|
||||||
const sorted = Object.values(flatfiles).sort(compare_time)
|
|
||||||
table.innerHTML = ""
|
|
||||||
for (const f of sorted) {
|
|
||||||
const {path, size, mtime} = f
|
|
||||||
const tr = document.createElement("tr")
|
|
||||||
const name_td = document.createElement("td")
|
|
||||||
const size_td = document.createElement("td")
|
|
||||||
const mtime_td = document.createElement("td")
|
|
||||||
const a = document.createElement("a")
|
|
||||||
table.appendChild(tr)
|
|
||||||
tr.appendChild(name_td)
|
|
||||||
tr.appendChild(size_td)
|
|
||||||
tr.appendChild(mtime_td)
|
|
||||||
name_td.appendChild(a)
|
|
||||||
size_td.textContent = size
|
|
||||||
mtime_td.textContent = formatUnixDate(mtime)
|
|
||||||
a.textContent = path
|
|
||||||
a.href = `/files${path}`
|
|
||||||
/*a.onclick = event => {
|
|
||||||
if (window.showSaveFilePicker) {
|
|
||||||
event.preventDefault()
|
|
||||||
download_ws(name, size)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
a.download = ""*/
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
function formatUnixDate(t) {
|
|
||||||
const date = new Date(t * 1000)
|
|
||||||
const now = new Date()
|
|
||||||
const diff = date - now
|
|
||||||
const formatter = new Intl.RelativeTimeFormat('en', { numeric: 'auto' })
|
|
||||||
|
|
||||||
if (Math.abs(diff) <= 60000) {
|
|
||||||
return formatter.format(Math.round(diff / 1000), 'second')
|
|
||||||
}
|
|
||||||
|
|
||||||
if (Math.abs(diff) <= 3600000) {
|
|
||||||
return formatter.format(Math.round(diff / 60000), 'minute')
|
|
||||||
}
|
|
||||||
|
|
||||||
if (Math.abs(diff) <= 86400000) {
|
|
||||||
return formatter.format(Math.round(diff / 3600000), 'hour')
|
|
||||||
}
|
|
||||||
|
|
||||||
if (Math.abs(diff) <= 604800000) {
|
|
||||||
return formatter.format(Math.round(diff / 86400000), 'day')
|
|
||||||
}
|
|
||||||
|
|
||||||
return date.toLocaleDateString()
|
|
||||||
}
|
|
||||||
|
|
||||||
async function download_ws(name, size) {
|
|
||||||
const fh = await window.showSaveFilePicker({
|
|
||||||
suggestedName: name,
|
|
||||||
})
|
|
||||||
const writer = await fh.createWritable()
|
|
||||||
writer.truncate(size)
|
|
||||||
const wsurl = new URL("/api/download", location.href.replace(/^http/, 'ws'))
|
|
||||||
const ws = new WebSocket(wsurl)
|
|
||||||
let pos = 0
|
|
||||||
ws.onopen = () => {
|
|
||||||
console.log("Downloading over WebSocket", name, size)
|
|
||||||
ws.send(JSON.stringify({name, start: 0, end: size, size}))
|
|
||||||
}
|
|
||||||
ws.onmessage = event => {
|
|
||||||
if (typeof event.data === 'string') {
|
|
||||||
const msg = JSON.parse(event.data)
|
|
||||||
console.log("Download finished", msg)
|
|
||||||
ws.close()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
console.log("Received chunk", name, pos, pos + event.data.size)
|
|
||||||
pos += event.data.size
|
|
||||||
writer.write(event.data)
|
|
||||||
}
|
|
||||||
ws.onclose = () => {
|
|
||||||
if (pos < size) {
|
|
||||||
console.log("Download aborted", name, pos)
|
|
||||||
writer.truncate(pos)
|
|
||||||
}
|
|
||||||
writer.close()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
const fileInput = document.getElementById("fileInput")
|
|
||||||
const progress = document.getElementById("progressBar")
|
|
||||||
const numConnections = 2
|
|
||||||
const chunkSize = 1<<20
|
|
||||||
const wsConnections = new Set()
|
|
||||||
|
|
||||||
//for (let i = 0; i < numConnections; i++) createUploadWS()
|
|
||||||
|
|
||||||
function createUploadWS() {
|
|
||||||
const wsurl = new URL("/api/upload", location.href.replace(/^http/, 'ws'))
|
|
||||||
const ws = new WebSocket(wsurl)
|
|
||||||
ws.binaryType = 'arraybuffer'
|
|
||||||
ws.onopen = () => {
|
|
||||||
wsConnections.add(ws)
|
|
||||||
console.log("Upload socket connected")
|
|
||||||
}
|
|
||||||
ws.onmessage = event => {
|
|
||||||
msg = JSON.parse(event.data)
|
|
||||||
if (msg.written) progress.value += +msg.written
|
|
||||||
else console.log(`Error: ${msg.error}`)
|
|
||||||
}
|
|
||||||
ws.onclose = () => {
|
|
||||||
wsConnections.delete(ws)
|
|
||||||
console.log("Upload socket disconnected, reconnecting...")
|
|
||||||
setTimeout(createUploadWS, 1000)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
async function load(file, start, end) {
|
|
||||||
const reader = new FileReader()
|
|
||||||
const load = new Promise(resolve => reader.onload = resolve)
|
|
||||||
reader.readAsArrayBuffer(file.slice(start, end))
|
|
||||||
const event = await load
|
|
||||||
return event.target.result
|
|
||||||
}
|
|
||||||
|
|
||||||
async function sendChunk(file, start, end, ws) {
|
|
||||||
const chunk = await load(file, start, end)
|
|
||||||
ws.send(JSON.stringify({
|
|
||||||
name: file.name,
|
|
||||||
size: file.size,
|
|
||||||
start: start,
|
|
||||||
end: end
|
|
||||||
}))
|
|
||||||
ws.send(chunk)
|
|
||||||
}
|
|
||||||
|
|
||||||
fileInput.addEventListener("change", async function() {
|
|
||||||
const file = this.files[0]
|
|
||||||
const numChunks = Math.ceil(file.size / chunkSize)
|
|
||||||
progress.value = 0
|
|
||||||
progress.max = file.size
|
|
||||||
|
|
||||||
console.log(wsConnections)
|
|
||||||
for (let i = 0; i < numChunks; i++) {
|
|
||||||
const ws = Array.from(wsConnections)[i % wsConnections.size]
|
|
||||||
const start = i * chunkSize
|
|
||||||
const end = Math.min(file.size, start + chunkSize)
|
|
||||||
const res = await sendChunk(file, start, end, ws)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
</script>
|
|
|
@ -1,6 +1,7 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import shutil
|
import shutil
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
import msgspec
|
import msgspec
|
||||||
from sanic import BadRequest
|
from sanic import BadRequest
|
||||||
|
@ -71,14 +72,12 @@ class FileRange(msgspec.Struct):
|
||||||
start: int
|
start: int
|
||||||
end: int
|
end: int
|
||||||
|
|
||||||
class ErrorMsg(msgspec.Struct):
|
|
||||||
error: str
|
|
||||||
req: FileRange
|
|
||||||
|
|
||||||
class StatusMsg(msgspec.Struct):
|
class StatusMsg(msgspec.Struct):
|
||||||
status: str
|
status: str
|
||||||
req: FileRange
|
req: FileRange
|
||||||
|
|
||||||
|
class ErrorMsg(msgspec.Struct):
|
||||||
|
error: dict[str, Any]
|
||||||
|
|
||||||
## Directory listings
|
## Directory listings
|
||||||
|
|
||||||
|
|
|
@ -4,7 +4,7 @@ from pathlib import Path
|
||||||
|
|
||||||
from sanic import Sanic
|
from sanic import Sanic
|
||||||
|
|
||||||
from cista import config
|
from cista import config, server80
|
||||||
|
|
||||||
|
|
||||||
def run(dev=False):
|
def run(dev=False):
|
||||||
|
@ -15,11 +15,10 @@ def run(dev=False):
|
||||||
os.environ["SANIC_IGNORE_PRODUCTION_WARNING"] = "1"
|
os.environ["SANIC_IGNORE_PRODUCTION_WARNING"] = "1"
|
||||||
if opts.get("ssl"):
|
if opts.get("ssl"):
|
||||||
# Run plain HTTP redirect/acme server on port 80
|
# Run plain HTTP redirect/acme server on port 80
|
||||||
from . import server80
|
|
||||||
server80.app.prepare(port=80, motd=False)
|
server80.app.prepare(port=80, motd=False)
|
||||||
domain = opts["host"]
|
domain = opts["host"]
|
||||||
opts["ssl"] = str(config.conffile.parent / domain)
|
opts["ssl"] = str(config.conffile.parent / domain) # type: ignore
|
||||||
app.prepare(**opts, motd=False, dev=dev, auto_reload=dev, access_log=True)
|
app.prepare(**opts, motd=False, dev=dev, auto_reload=dev, access_log=True) # type: ignore
|
||||||
Sanic.serve()
|
Sanic.serve()
|
||||||
|
|
||||||
def parse_listen(listen):
|
def parse_listen(listen):
|
||||||
|
|
|
@ -26,7 +26,8 @@ def create(res, username, **kwargs):
|
||||||
def update(res, s, **kwargs):
|
def update(res, s, **kwargs):
|
||||||
s.update(kwargs)
|
s.update(kwargs)
|
||||||
s = jwt.encode(s, session_secret())
|
s = jwt.encode(s, session_secret())
|
||||||
res.cookies.add_cookie("s", s, httponly=True, max_age=max(1, s["exp"] - int(time())))
|
max_age = max(1, s["exp"] - int(time())) # type: ignore
|
||||||
|
res.cookies.add_cookie("s", s, httponly=True, max_age=max_age)
|
||||||
|
|
||||||
def delete(res):
|
def delete(res):
|
||||||
res.cookies.delete_cookie("s")
|
res.cookies.delete_cookie("s")
|
||||||
|
|
58
cista/util/apphelpers.py
Normal file
58
cista/util/apphelpers.py
Normal file
|
@ -0,0 +1,58 @@
|
||||||
|
from functools import wraps
|
||||||
|
|
||||||
|
import msgspec
|
||||||
|
from sanic import errorpages
|
||||||
|
from sanic.exceptions import SanicException
|
||||||
|
from sanic.log import logger
|
||||||
|
from sanic.response import raw, redirect
|
||||||
|
|
||||||
|
from cista import auth
|
||||||
|
from cista.protocol import ErrorMsg
|
||||||
|
|
||||||
|
|
||||||
|
def asend(ws, msg):
|
||||||
|
"""Send JSON message or bytes to a websocket"""
|
||||||
|
return ws.send(msg if isinstance(msg, bytes) else msgspec.json.encode(msg).decode())
|
||||||
|
|
||||||
|
def jres(data, **kwargs):
|
||||||
|
"""JSON Sanic response, using msgspec encoding"""
|
||||||
|
return raw(msgspec.json.encode(data), content_type="application/json", **kwargs)
|
||||||
|
|
||||||
|
async def handle_sanic_exception(request, e):
|
||||||
|
logger.exception(e)
|
||||||
|
context, code = {}, 500
|
||||||
|
message = str(e)
|
||||||
|
if isinstance(e, SanicException):
|
||||||
|
context = e.context or {}
|
||||||
|
code = e.status_code
|
||||||
|
if not message or not request.app.debug and code == 500:
|
||||||
|
message = "Internal Server Error"
|
||||||
|
message = f"⚠️ {message}" if code < 500 else f"🛑 {message}"
|
||||||
|
# Non-browsers get JSON errors
|
||||||
|
if "text/html" not in request.headers.accept:
|
||||||
|
return jres(ErrorMsg({"code": code, "message": message, **context}), status=code)
|
||||||
|
# Redirections flash the error message via cookies
|
||||||
|
if "redirect" in context:
|
||||||
|
res = redirect(context["redirect"])
|
||||||
|
res.cookies.add_cookie("message", message, max_age=5)
|
||||||
|
return res
|
||||||
|
# Otherwise use Sanic's default error page
|
||||||
|
return errorpages.HTMLRenderer(request, e, debug=request.app.debug).full()
|
||||||
|
|
||||||
|
def websocket_wrapper(handler):
|
||||||
|
"""Decorator for websocket handlers that catches exceptions and sends them back to the client"""
|
||||||
|
@wraps(handler)
|
||||||
|
async def wrapper(request, ws, *args, **kwargs):
|
||||||
|
try:
|
||||||
|
auth.verify(request)
|
||||||
|
await handler(request, ws, *args, **kwargs)
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(e)
|
||||||
|
context, code, message = {}, 500, str(e) or "Internal Server Error"
|
||||||
|
if isinstance(e, SanicException):
|
||||||
|
context = e.context or {}
|
||||||
|
code = e.status_code
|
||||||
|
message = f"⚠️ {message}" if code < 500 else f"🛑 {message}"
|
||||||
|
await asend(ws, ErrorMsg({"code": code, "message": message, **context}))
|
||||||
|
raise
|
||||||
|
return wrapper
|
|
@ -1,7 +1,9 @@
|
||||||
from pathvalidate import sanitize_filepath
|
|
||||||
import unicodedata
|
import unicodedata
|
||||||
from pathlib import PurePosixPath
|
from pathlib import PurePosixPath
|
||||||
|
|
||||||
|
from pathvalidate import sanitize_filepath
|
||||||
|
|
||||||
|
|
||||||
def sanitize(filename: str) -> str:
|
def sanitize(filename: str) -> str:
|
||||||
filename = unicodedata.normalize("NFC", filename)
|
filename = unicodedata.normalize("NFC", filename)
|
||||||
# UNIX filenames can contain backslashes but for compatibility we replace them with dashes
|
# UNIX filenames can contain backslashes but for compatibility we replace them with dashes
|
||||||
|
|
|
@ -8,6 +8,7 @@ import msgspec
|
||||||
|
|
||||||
from cista import config
|
from cista import config
|
||||||
from cista.protocol import DirEntry, FileEntry, UpdateEntry
|
from cista.protocol import DirEntry, FileEntry, UpdateEntry
|
||||||
|
from cista.util.apphelpers import websocket_wrapper
|
||||||
|
|
||||||
pubsub = {}
|
pubsub = {}
|
||||||
tree = {"": None}
|
tree = {"": None}
|
||||||
|
@ -132,30 +133,14 @@ async def broadcast(msg):
|
||||||
for queue in pubsub.values():
|
for queue in pubsub.values():
|
||||||
await queue.put_nowait(msg)
|
await queue.put_nowait(msg)
|
||||||
|
|
||||||
def register(app, url):
|
async def start(app, loop):
|
||||||
@app.before_server_start
|
|
||||||
async def start_watcher(app, loop):
|
|
||||||
global rootpath
|
global rootpath
|
||||||
config.load_config()
|
config.load_config()
|
||||||
rootpath = config.config.path
|
rootpath = config.config.path
|
||||||
app.ctx.watcher = threading.Thread(target=watcher_thread, args=[loop])
|
app.ctx.watcher = threading.Thread(target=watcher_thread, args=[loop])
|
||||||
app.ctx.watcher.start()
|
app.ctx.watcher.start()
|
||||||
|
|
||||||
@app.after_server_stop
|
async def stop(app, loop):
|
||||||
async def stop_watcher(app, _):
|
|
||||||
global quit
|
global quit
|
||||||
quit = True
|
quit = True
|
||||||
app.ctx.watcher.join()
|
app.ctx.watcher.join()
|
||||||
|
|
||||||
@app.websocket(url)
|
|
||||||
async def watch(request, ws):
|
|
||||||
try:
|
|
||||||
with tree_lock:
|
|
||||||
q = pubsub[ws] = asyncio.Queue()
|
|
||||||
# Init with full tree
|
|
||||||
await ws.send(refresh())
|
|
||||||
# Send updates
|
|
||||||
while True:
|
|
||||||
await ws.send(await q.get())
|
|
||||||
finally:
|
|
||||||
del pubsub[ws]
|
|
||||||
|
|
|
@ -62,3 +62,8 @@ addopts = [
|
||||||
testpaths = [
|
testpaths = [
|
||||||
"tests",
|
"tests",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[tool.isort]
|
||||||
|
#src_paths = ["cista", "tests"]
|
||||||
|
line_length = 120
|
||||||
|
multi_line_output = 5
|
||||||
|
|
|
@ -1,8 +1,11 @@
|
||||||
import pytest
|
|
||||||
from cista import config
|
|
||||||
from cista.protocol import MkDir, Rename, Rm, Mv, Cp
|
|
||||||
from pathlib import Path
|
|
||||||
import tempfile
|
import tempfile
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from cista import config
|
||||||
|
from cista.protocol import Cp, MkDir, Mv, Rename, Rm
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def setup_temp_dir():
|
def setup_temp_dir():
|
||||||
|
|
Loading…
Reference in New Issue
Block a user