ASGI refactoring attempt

This commit is contained in:
Tom Christie 2019-01-18 14:50:42 +00:00
parent 2af229eb1a
commit 95526a82de
2 changed files with 402 additions and 106 deletions

View File

@ -8,6 +8,7 @@ from asyncio import CancelledError, Protocol, ensure_future, get_event_loop
from collections import defaultdict, deque from collections import defaultdict, deque
from functools import partial from functools import partial
from inspect import getmodulename, isawaitable, signature, stack from inspect import getmodulename, isawaitable, signature, stack
from multidict import CIMultiDict
from socket import socket from socket import socket
from ssl import Purpose, SSLContext, create_default_context from ssl import Purpose, SSLContext, create_default_context
from traceback import format_exc from traceback import format_exc
@ -21,6 +22,7 @@ from sanic.exceptions import SanicException, ServerError, URLBuildError
from sanic.handlers import ErrorHandler from sanic.handlers import ErrorHandler
from sanic.log import LOGGING_CONFIG_DEFAULTS, error_logger, logger from sanic.log import LOGGING_CONFIG_DEFAULTS, error_logger, logger
from sanic.response import HTTPResponse, StreamingHTTPResponse from sanic.response import HTTPResponse, StreamingHTTPResponse
from sanic.request import Request
from sanic.router import Router from sanic.router import Router
from sanic.server import HttpProtocol, Signal, serve, serve_multiple from sanic.server import HttpProtocol, Signal, serve, serve_multiple
from sanic.static import register as static_register from sanic.static import register as static_register
@ -967,7 +969,7 @@ class Sanic:
raise CancelledError() raise CancelledError()
# pass the response to the correct callback # pass the response to the correct callback
if isinstance(response, StreamingHTTPResponse): if write_callback is None or isinstance(response, StreamingHTTPResponse):
await stream_callback(response) await stream_callback(response)
else: else:
write_callback(response) write_callback(response)
@ -1106,9 +1108,8 @@ class Sanic:
"""This kills the Sanic""" """This kills the Sanic"""
get_event_loop().stop() get_event_loop().stop()
def __call__(self): def __call__(self, scope):
"""gunicorn compatibility""" return ASGIApp(self, scope)
return self
async def create_server( async def create_server(
self, self,
@ -1339,3 +1340,80 @@ class Sanic:
def _build_endpoint_name(self, *parts): def _build_endpoint_name(self, *parts):
parts = [self.name, *parts] parts = [self.name, *parts]
return ".".join(parts) return ".".join(parts)
class MockTransport:
def __init__(self, scope):
self.scope = scope
def get_extra_info(self, info):
if info == 'peername':
return self.scope.get('server')
elif info == 'sslcontext':
return self.scope.get('scheme') in ["https", "wss"]
class ASGIApp:
def __init__(self, sanic_app, scope):
self.sanic_app = sanic_app
url_bytes = scope.get('root_path', '') + scope['path']
url_bytes = url_bytes.encode('latin-1')
url_bytes += scope['query_string']
headers = CIMultiDict([
(key.decode('latin-1'), value.decode('latin-1'))
for key, value in scope.get('headers', [])
])
version = scope['http_version']
method = scope['method']
self.request = Request(url_bytes, headers, version, method, MockTransport(scope))
self.request.app = sanic_app
async def read_body(self, receive):
"""
Read and return the entire body from an incoming ASGI message.
"""
body = b''
more_body = True
while more_body:
message = await receive()
body += message.get('body', b'')
more_body = message.get('more_body', False)
return body
async def __call__(self, receive, send):
"""
Handle the incoming request.
"""
self.send = send
self.request.body = await self.read_body(receive)
handler = self.sanic_app.handle_request
await handler(self.request, None, self.stream_callback)
async def stream_callback(self, response):
"""
Write the response.
"""
if isinstance(response, StreamingHTTPResponse):
raise NotImplementedError('Not supported')
headers = [
(str(name).encode('latin-1'), str(value).encode('latin-1'))
for name, value in response.headers.items()
]
if 'content-length' not in response.headers:
headers += [(
b'content-length',
str(len(response.body)).encode('latin-1')
)]
await self.send({
'type': 'http.response.start',
'status': response.status,
'headers': headers
})
await self.send({
'type': 'http.response.body',
'body': response.body,
'more_body': False
})

View File

@ -1,137 +1,355 @@
from json import JSONDecodeError
from sanic.exceptions import MethodNotSupported import asyncio
from sanic.log import logger import http
from sanic.response import text import io
import json
import queue
import threading
import types
import typing
from urllib.parse import unquote, urljoin, urlparse, parse_qs
import requests
from starlette.types import ASGIApp, Message, Scope
from starlette.websockets import WebSocketDisconnect
HOST = "127.0.0.1" HOST = "127.0.0.1"
PORT = 42101 PORT = 42101
class SanicTestClient: # Annotations for `Session.request()`
def __init__(self, app, port=PORT): Cookies = typing.Union[
typing.MutableMapping[str, str], requests.cookies.RequestsCookieJar
]
Params = typing.Union[bytes, typing.MutableMapping[str, str]]
DataType = typing.Union[bytes, typing.MutableMapping[str, str], typing.IO]
TimeOut = typing.Union[float, typing.Tuple[float, float]]
FileType = typing.MutableMapping[str, typing.IO]
AuthType = typing.Union[
typing.Tuple[str, str],
requests.auth.AuthBase,
typing.Callable[[requests.Request], requests.Request],
]
class _HeaderDict(requests.packages.urllib3._collections.HTTPHeaderDict):
def get_all(self, key: str, default: str) -> str:
return self.getheaders(key)
class _MockOriginalResponse:
"""
We have to jump through some hoops to present the response as if
it was made using urllib3.
"""
def __init__(self, headers: typing.List[typing.Tuple[bytes, bytes]]) -> None:
self.msg = _HeaderDict(headers)
self.closed = False
def isclosed(self) -> bool:
return self.closed
class _Upgrade(Exception):
def __init__(self, session: "WebSocketTestSession") -> None:
self.session = session
def _get_reason_phrase(status_code: int) -> str:
try:
return http.HTTPStatus(status_code).phrase
except ValueError:
return ""
class _ASGIAdapter(requests.adapters.HTTPAdapter):
def __init__(self, app: ASGIApp, raise_server_exceptions: bool = True) -> None:
self.app = app self.app = app
self.port = port self.raise_server_exceptions = raise_server_exceptions
async def _local_request(self, method, uri, cookies=None, *args, **kwargs): def send( # type: ignore
import aiohttp self, request: requests.PreparedRequest, *args: typing.Any, **kwargs: typing.Any
) -> requests.Response:
if uri.startswith(("http:", "https:", "ftp:", "ftps://" "//")): scheme, netloc, path, params, query, fragement = urlparse( # type: ignore
url = uri request.url
else:
url = "http://{host}:{port}{uri}".format(
host=HOST, port=self.port, uri=uri
) )
logger.info(url) default_port = {"http": 80, "ws": 80, "https": 443, "wss": 443}[scheme]
conn = aiohttp.TCPConnector(ssl=False)
async with aiohttp.ClientSession( if ":" in netloc:
cookies=cookies, connector=conn host, port_string = netloc.split(":", 1)
) as session: port = int(port_string)
async with getattr(session, method.lower())( else:
url, *args, **kwargs host = netloc
) as response: port = default_port
# Include the 'host' header.
if "host" in request.headers:
headers = [] # type: typing.List[typing.Tuple[bytes, bytes]]
elif port == default_port:
headers = [(b"host", host.encode())]
else:
headers = [(b"host", ("%s:%d" % (host, port)).encode())]
# Include other request headers.
headers += [
(key.lower().encode(), value.encode())
for key, value in request.headers.items()
]
if scheme in {"ws", "wss"}:
subprotocol = request.headers.get("sec-websocket-protocol", None)
if subprotocol is None:
subprotocols = [] # type: typing.Sequence[str]
else:
subprotocols = [value.strip() for value in subprotocol.split(",")]
scope = {
"type": "websocket",
"path": unquote(path),
"root_path": "",
"scheme": scheme,
"query_string": query.encode(),
"headers": headers,
"client": ["testclient", 50000],
"server": [host, port],
"subprotocols": subprotocols,
}
session = WebSocketTestSession(self.app, scope)
raise _Upgrade(session)
scope = {
"type": "http",
"http_version": "1.1",
"method": request.method,
"path": unquote(path),
"root_path": "",
"scheme": scheme,
"query_string": query.encode(),
"headers": headers,
"client": ["testclient", 50000],
"server": [host, port],
"extensions": {"http.response.template": {}},
}
async def receive() -> Message:
nonlocal request_complete, response_complete
if request_complete:
while not response_complete:
await asyncio.sleep(0.0001)
return {"type": "http.disconnect"}
body = request.body
if isinstance(body, str):
body_bytes = body.encode("utf-8") # type: bytes
elif body is None:
body_bytes = b""
elif isinstance(body, types.GeneratorType):
try: try:
response.text = await response.text() chunk = body.send(None)
except UnicodeDecodeError: if isinstance(chunk, str):
response.text = None chunk = chunk.encode("utf-8")
return {"type": "http.request", "body": chunk, "more_body": True}
except StopIteration:
request_complete = True
return {"type": "http.request", "body": b""}
else:
body_bytes = body
request_complete = True
return {"type": "http.request", "body": body_bytes}
async def send(message: Message) -> None:
nonlocal raw_kwargs, response_started, response_complete, template, context
if message["type"] == "http.response.start":
assert (
not response_started
), 'Received multiple "http.response.start" messages.'
raw_kwargs["version"] = 11
raw_kwargs["status"] = message["status"]
raw_kwargs["reason"] = _get_reason_phrase(message["status"])
raw_kwargs["headers"] = [
(key.decode(), value.decode()) for key, value in message["headers"]
]
raw_kwargs["preload_content"] = False
raw_kwargs["original_response"] = _MockOriginalResponse(
raw_kwargs["headers"]
)
response_started = True
elif message["type"] == "http.response.body":
assert (
response_started
), 'Received "http.response.body" without "http.response.start".'
assert (
not response_complete
), 'Received "http.response.body" after response completed.'
body = message.get("body", b"")
more_body = message.get("more_body", False)
if request.method != "HEAD":
raw_kwargs["body"].write(body)
if not more_body:
raw_kwargs["body"].seek(0)
response_complete = True
elif message["type"] == "http.response.template":
template = message["template"]
context = message["context"]
request_complete = False
response_started = False
response_complete = False
raw_kwargs = {"body": io.BytesIO()} # type: typing.Dict[str, typing.Any]
template = None
context = None
try: try:
response.json = await response.json() loop = asyncio.get_event_loop()
except ( except RuntimeError:
JSONDecodeError, loop = asyncio.new_event_loop()
UnicodeDecodeError, asyncio.set_event_loop(loop)
aiohttp.ClientResponseError,
):
response.json = None
response.body = await response.read() self.app.is_running = True
try:
connection = self.app(scope)
loop.run_until_complete(connection(receive, send))
except BaseException as exc:
if self.raise_server_exceptions:
raise exc from None
if self.raise_server_exceptions:
assert response_started, "TestClient did not receive any response."
elif not response_started:
raw_kwargs = {
"version": 11,
"status": 500,
"reason": "Internal Server Error",
"headers": [],
"preload_content": False,
"original_response": _MockOriginalResponse([]),
"body": io.BytesIO(),
}
raw = requests.packages.urllib3.HTTPResponse(**raw_kwargs)
response = self.build_response(request, raw)
if template is not None:
response.template = template
response.context = context
return response return response
def _sanic_endpoint_test(
class SanicTestClient(requests.Session):
__test__ = False # For pytest to not discover this up.
def __init__(
self, self,
method="get", app: ASGIApp,
uri="/", base_url: str = "http://%s:%d" % (HOST, PORT),
gather_request=True, raise_server_exceptions: bool = True,
debug=False, ) -> None:
server_kwargs={"auto_reload": False}, super(SanicTestClient, self).__init__()
*request_args, adapter = _ASGIAdapter(app, raise_server_exceptions=raise_server_exceptions)
**request_kwargs self.mount("http://", adapter)
): self.mount("https://", adapter)
results = [None, None] self.mount("ws://", adapter)
exceptions = [] self.mount("wss://", adapter)
self.headers.update({"user-agent": "testclient"})
self.app = app
self.base_url = base_url
def request(
self,
method: str,
url: str = '/',
params: Params = None,
data: DataType = None,
headers: typing.MutableMapping[str, str] = None,
cookies: Cookies = None,
files: FileType = None,
auth: AuthType = None,
timeout: TimeOut = None,
allow_redirects: bool = None,
proxies: typing.MutableMapping[str, str] = None,
hooks: typing.Any = None,
stream: bool = None,
verify: typing.Union[bool, str] = None,
cert: typing.Union[str, typing.Tuple[str, str]] = None,
json: typing.Any = None,
debug = None,
gather_request = True
) -> requests.Response:
if debug is not None:
self.app.debug = debug
url = urljoin(self.base_url, url)
response = super().request(
method,
url,
params=params,
data=data,
headers=headers,
cookies=cookies,
files=files,
auth=auth,
timeout=timeout,
allow_redirects=allow_redirects,
proxies=proxies,
hooks=hooks,
stream=stream,
verify=verify,
cert=cert,
json=json,
)
response.status = response.status_code
response.body = response.content
try:
response.json = response.json()
except:
response.json = None
if gather_request: if gather_request:
request = response.request
def _collect_request(request): parsed = urlparse(request.url)
if results[0] is None: request.scheme = parsed.scheme
results[0] = request request.path = parsed.path
request.args = parse_qs(parsed.query)
self.app.request_middleware.appendleft(_collect_request)
@self.app.exception(MethodNotSupported)
async def error_handler(request, exception):
if request.method in ["HEAD", "PATCH", "PUT", "DELETE"]:
return text(
"", exception.status_code, headers=exception.headers
)
else:
return self.app.error_handler.default(request, exception)
@self.app.listener("after_server_start")
async def _collect_response(sanic, loop):
try:
response = await self._local_request(
method, uri, *request_args, **request_kwargs
)
results[-1] = response
except Exception as e:
logger.exception("Exception")
exceptions.append(e)
self.app.stop()
self.app.run(host=HOST, debug=debug, port=self.port, **server_kwargs)
self.app.listeners["after_server_start"].pop()
if exceptions:
raise ValueError("Exception during request: {}".format(exceptions))
if gather_request:
try:
request, response = results
return request, response return request, response
except BaseException:
raise ValueError( return response
"Request and response object expected, got ({})".format(
results
)
)
else:
try:
return results[-1]
except BaseException:
raise ValueError(
"Request object expected, got ({})".format(results)
)
def get(self, *args, **kwargs): def get(self, *args, **kwargs):
return self._sanic_endpoint_test("get", *args, **kwargs) if 'uri' in kwargs:
kwargs['url'] = kwargs.pop('uri')
return self.request("get", *args, **kwargs)
def post(self, *args, **kwargs): def post(self, *args, **kwargs):
return self._sanic_endpoint_test("post", *args, **kwargs) if 'uri' in kwargs:
kwargs['url'] = kwargs.pop('uri')
return self.request("post", *args, **kwargs)
def put(self, *args, **kwargs): def put(self, *args, **kwargs):
return self._sanic_endpoint_test("put", *args, **kwargs) if 'uri' in kwargs:
kwargs['url'] = kwargs.pop('uri')
return self.request("put", *args, **kwargs)
def delete(self, *args, **kwargs): def delete(self, *args, **kwargs):
return self._sanic_endpoint_test("delete", *args, **kwargs) if 'uri' in kwargs:
kwargs['url'] = kwargs.pop('uri')
return self.request("delete", *args, **kwargs)
def patch(self, *args, **kwargs): def patch(self, *args, **kwargs):
return self._sanic_endpoint_test("patch", *args, **kwargs) if 'uri' in kwargs:
kwargs['url'] = kwargs.pop('uri')
return self.request("patch", *args, **kwargs)
def options(self, *args, **kwargs): def options(self, *args, **kwargs):
return self._sanic_endpoint_test("options", *args, **kwargs) if 'uri' in kwargs:
kwargs['url'] = kwargs.pop('uri')
return self.request("options", *args, **kwargs)
def head(self, *args, **kwargs): def head(self, *args, **kwargs):
return self._sanic_endpoint_test("head", *args, **kwargs) if 'uri' in kwargs:
kwargs['url'] = kwargs.pop('uri')
return self.request("head", *args, **kwargs)