[Trio] Quick fixes to make Sanic usable on hypercorn -k trio myweb.app (#1767)

* Quick fixes to make Sanic usable on hypercorn -k trio myweb.app

* Quick'n dirty compatibility and autodetection of hypercorn trio mode.

* mypy ignore for aiofiles/trio.

* lint
This commit is contained in:
L. Kärkkäinen 2020-01-20 18:29:06 +02:00 committed by Stephen Sadowski
parent 801595e24a
commit e908ca8cef
3 changed files with 28 additions and 11 deletions

View File

@ -1,6 +1,25 @@
from sys import argv
from multidict import CIMultiDict # type: ignore from multidict import CIMultiDict # type: ignore
class Header(CIMultiDict): class Header(CIMultiDict):
def get_all(self, key): def get_all(self, key):
return self.getall(key, default=[]) return self.getall(key, default=[])
use_trio = argv[0].endswith("hypercorn") and "trio" in argv
if use_trio:
from trio import open_file as open_async, Path # type: ignore
def stat_async(path):
return Path(path).stat()
else:
from aiofiles import open as aio_open # type: ignore
from aiofiles.os import stat as stat_async # type: ignore # noqa: F401
async def open_async(file, mode="r", **kwargs):
return aio_open(file, mode, **kwargs)

View File

@ -3,9 +3,7 @@ from mimetypes import guess_type
from os import path from os import path
from urllib.parse import quote_plus from urllib.parse import quote_plus
from aiofiles import open as open_async # type: ignore from sanic.compat import Header, open_async
from sanic.compat import Header
from sanic.cookies import CookieJar from sanic.cookies import CookieJar
from sanic.headers import format_http1 from sanic.headers import format_http1
from sanic.helpers import STATUS_CODES, has_message_body, remove_entity_headers from sanic.helpers import STATUS_CODES, has_message_body, remove_entity_headers
@ -300,7 +298,7 @@ async def file(
) )
filename = filename or path.split(location)[-1] filename = filename or path.split(location)[-1]
async with open_async(location, mode="rb") as _file: async with await open_async(location, mode="rb") as _file:
if _range: if _range:
await _file.seek(_range.start) await _file.seek(_range.start)
out_stream = await _file.read(_range.size) out_stream = await _file.read(_range.size)
@ -349,7 +347,8 @@ async def file_stream(
) )
filename = filename or path.split(location)[-1] filename = filename or path.split(location)[-1]
_file = await open_async(location, mode="rb") _filectx = await open_async(location, mode="rb")
_file = await _filectx.__aenter__() # Will be exited by _streaming_fn
async def _streaming_fn(response): async def _streaming_fn(response):
nonlocal _file, chunk_size nonlocal _file, chunk_size
@ -371,7 +370,7 @@ async def file_stream(
break break
await response.write(content) await response.write(content)
finally: finally:
await _file.close() await _filectx.__aexit__(None, None, None)
return # Returning from this fn closes the stream return # Returning from this fn closes the stream
mime_type = mime_type or guess_type(filename)[0] or "text/plain" mime_type = mime_type or guess_type(filename)[0] or "text/plain"

View File

@ -4,8 +4,7 @@ from re import sub
from time import gmtime, strftime from time import gmtime, strftime
from urllib.parse import unquote from urllib.parse import unquote
from aiofiles.os import stat # type: ignore from sanic.compat import stat_async
from sanic.exceptions import ( from sanic.exceptions import (
ContentRangeError, ContentRangeError,
FileNotFound, FileNotFound,
@ -84,7 +83,7 @@ def register(
# and it has not been modified since # and it has not been modified since
stats = None stats = None
if use_modified_since: if use_modified_since:
stats = await stat(file_path) stats = await stat_async(file_path)
modified_since = strftime( modified_since = strftime(
"%a, %d %b %Y %H:%M:%S GMT", gmtime(stats.st_mtime) "%a, %d %b %Y %H:%M:%S GMT", gmtime(stats.st_mtime)
) )
@ -95,7 +94,7 @@ def register(
if use_content_range: if use_content_range:
_range = None _range = None
if not stats: if not stats:
stats = await stat(file_path) stats = await stat_async(file_path)
headers["Accept-Ranges"] = "bytes" headers["Accept-Ranges"] = "bytes"
headers["Content-Length"] = str(stats.st_size) headers["Content-Length"] = str(stats.st_size)
if request.method != "HEAD": if request.method != "HEAD":
@ -120,7 +119,7 @@ def register(
threshold = 1024 * 1024 threshold = 1024 * 1024
if not stats: if not stats:
stats = await stat(file_path) stats = await stat_async(file_path)
if stats.st_size >= threshold: if stats.st_size >= threshold:
return await file_stream( return await file_stream(
file_path, headers=headers, _range=_range file_path, headers=headers, _range=_range