Create SanicASGITestClient and refactor ASGI methods

This commit is contained in:
Adam Hopkins
2019-05-21 19:30:55 +03:00
parent 4767a67acd
commit 8a56da84e6
4 changed files with 1390 additions and 282 deletions

View File

@@ -16,6 +16,7 @@ from typing import Any, Optional, Type, Union
from urllib.parse import urlencode, urlunparse
from sanic import reloader_helpers
from sanic.asgi import ASGIApp
from sanic.blueprint_group import BlueprintGroup
from sanic.config import BASE_LOGO, Config
from sanic.constants import HTTP_METHODS
@@ -27,7 +28,7 @@ from sanic.request import Request
from sanic.router import Router
from sanic.server import HttpProtocol, Signal, serve, serve_multiple
from sanic.static import register as static_register
from sanic.testing import SanicTestClient
from sanic.testing import SanicTestClient, SanicASGITestClient
from sanic.views import CompositionView
from sanic.websocket import ConnectionClosed, WebSocketProtocol
@@ -981,7 +982,9 @@ class Sanic:
raise CancelledError()
# pass the response to the correct callback
if write_callback is None or isinstance(response, StreamingHTTPResponse):
if write_callback is None or isinstance(
response, StreamingHTTPResponse
):
await stream_callback(response)
else:
write_callback(response)
@@ -994,6 +997,10 @@ class Sanic:
def test_client(self):
return SanicTestClient(self)
@property
def asgi_client(self):
return SanicASGITestClient(self)
# -------------------------------------------------------------------- #
# Execution
# -------------------------------------------------------------------- #
@@ -1120,9 +1127,6 @@ class Sanic:
"""This kills the Sanic"""
get_event_loop().stop()
def __call__(self, scope):
return ASGIApp(self, scope)
async def create_server(
self,
host: Optional[str] = None,
@@ -1365,79 +1369,10 @@ class Sanic:
parts = [self.name, *parts]
return ".".join(parts)
# -------------------------------------------------------------------- #
# ASGI
# -------------------------------------------------------------------- #
class MockTransport:
def __init__(self, scope):
self.scope = scope
def get_extra_info(self, info):
if info == 'peername':
return self.scope.get('server')
elif info == 'sslcontext':
return self.scope.get('scheme') in ["https", "wss"]
class ASGIApp:
def __init__(self, sanic_app, scope):
self.sanic_app = sanic_app
url_bytes = scope.get('root_path', '') + scope['path']
url_bytes = url_bytes.encode('latin-1')
url_bytes += scope['query_string']
headers = CIMultiDict([
(key.decode('latin-1'), value.decode('latin-1'))
for key, value in scope.get('headers', [])
])
version = scope['http_version']
method = scope['method']
self.request = Request(url_bytes, headers, version, method, MockTransport(scope))
self.request.app = sanic_app
async def read_body(self, receive):
"""
Read and return the entire body from an incoming ASGI message.
"""
body = b''
more_body = True
while more_body:
message = await receive()
body += message.get('body', b'')
more_body = message.get('more_body', False)
return body
async def __call__(self, receive, send):
"""
Handle the incoming request.
"""
self.send = send
self.request.body = await self.read_body(receive)
handler = self.sanic_app.handle_request
await handler(self.request, None, self.stream_callback)
async def stream_callback(self, response):
"""
Write the response.
"""
if isinstance(response, StreamingHTTPResponse):
raise NotImplementedError('Not supported')
headers = [
(str(name).encode('latin-1'), str(value).encode('latin-1'))
for name, value in response.headers.items()
]
if 'content-length' not in response.headers:
headers += [(
b'content-length',
str(len(response.body)).encode('latin-1')
)]
await self.send({
'type': 'http.response.start',
'status': response.status,
'headers': headers
})
await self.send({
'type': 'http.response.body',
'body': response.body,
'more_body': False
})
async def __call__(self, scope, receive, send):
asgi_app = ASGIApp(self, scope, receive, send)
await asgi_app()