Create SanicASGITestClient and refactor ASGI methods
This commit is contained in:
95
sanic/app.py
95
sanic/app.py
@@ -16,6 +16,7 @@ from typing import Any, Optional, Type, Union
|
||||
from urllib.parse import urlencode, urlunparse
|
||||
|
||||
from sanic import reloader_helpers
|
||||
from sanic.asgi import ASGIApp
|
||||
from sanic.blueprint_group import BlueprintGroup
|
||||
from sanic.config import BASE_LOGO, Config
|
||||
from sanic.constants import HTTP_METHODS
|
||||
@@ -27,7 +28,7 @@ from sanic.request import Request
|
||||
from sanic.router import Router
|
||||
from sanic.server import HttpProtocol, Signal, serve, serve_multiple
|
||||
from sanic.static import register as static_register
|
||||
from sanic.testing import SanicTestClient
|
||||
from sanic.testing import SanicTestClient, SanicASGITestClient
|
||||
from sanic.views import CompositionView
|
||||
from sanic.websocket import ConnectionClosed, WebSocketProtocol
|
||||
|
||||
@@ -981,7 +982,9 @@ class Sanic:
|
||||
raise CancelledError()
|
||||
|
||||
# pass the response to the correct callback
|
||||
if write_callback is None or isinstance(response, StreamingHTTPResponse):
|
||||
if write_callback is None or isinstance(
|
||||
response, StreamingHTTPResponse
|
||||
):
|
||||
await stream_callback(response)
|
||||
else:
|
||||
write_callback(response)
|
||||
@@ -994,6 +997,10 @@ class Sanic:
|
||||
def test_client(self):
|
||||
return SanicTestClient(self)
|
||||
|
||||
@property
|
||||
def asgi_client(self):
|
||||
return SanicASGITestClient(self)
|
||||
|
||||
# -------------------------------------------------------------------- #
|
||||
# Execution
|
||||
# -------------------------------------------------------------------- #
|
||||
@@ -1120,9 +1127,6 @@ class Sanic:
|
||||
"""This kills the Sanic"""
|
||||
get_event_loop().stop()
|
||||
|
||||
def __call__(self, scope):
|
||||
return ASGIApp(self, scope)
|
||||
|
||||
async def create_server(
|
||||
self,
|
||||
host: Optional[str] = None,
|
||||
@@ -1365,79 +1369,10 @@ class Sanic:
|
||||
parts = [self.name, *parts]
|
||||
return ".".join(parts)
|
||||
|
||||
# -------------------------------------------------------------------- #
|
||||
# ASGI
|
||||
# -------------------------------------------------------------------- #
|
||||
|
||||
class MockTransport:
|
||||
def __init__(self, scope):
|
||||
self.scope = scope
|
||||
|
||||
def get_extra_info(self, info):
|
||||
if info == 'peername':
|
||||
return self.scope.get('server')
|
||||
elif info == 'sslcontext':
|
||||
return self.scope.get('scheme') in ["https", "wss"]
|
||||
|
||||
class ASGIApp:
|
||||
def __init__(self, sanic_app, scope):
|
||||
self.sanic_app = sanic_app
|
||||
url_bytes = scope.get('root_path', '') + scope['path']
|
||||
url_bytes = url_bytes.encode('latin-1')
|
||||
url_bytes += scope['query_string']
|
||||
headers = CIMultiDict([
|
||||
(key.decode('latin-1'), value.decode('latin-1'))
|
||||
for key, value in scope.get('headers', [])
|
||||
])
|
||||
version = scope['http_version']
|
||||
method = scope['method']
|
||||
self.request = Request(url_bytes, headers, version, method, MockTransport(scope))
|
||||
self.request.app = sanic_app
|
||||
|
||||
async def read_body(self, receive):
|
||||
"""
|
||||
Read and return the entire body from an incoming ASGI message.
|
||||
"""
|
||||
body = b''
|
||||
more_body = True
|
||||
|
||||
while more_body:
|
||||
message = await receive()
|
||||
body += message.get('body', b'')
|
||||
more_body = message.get('more_body', False)
|
||||
|
||||
return body
|
||||
|
||||
async def __call__(self, receive, send):
|
||||
"""
|
||||
Handle the incoming request.
|
||||
"""
|
||||
self.send = send
|
||||
self.request.body = await self.read_body(receive)
|
||||
handler = self.sanic_app.handle_request
|
||||
await handler(self.request, None, self.stream_callback)
|
||||
|
||||
async def stream_callback(self, response):
|
||||
"""
|
||||
Write the response.
|
||||
"""
|
||||
if isinstance(response, StreamingHTTPResponse):
|
||||
raise NotImplementedError('Not supported')
|
||||
|
||||
headers = [
|
||||
(str(name).encode('latin-1'), str(value).encode('latin-1'))
|
||||
for name, value in response.headers.items()
|
||||
]
|
||||
if 'content-length' not in response.headers:
|
||||
headers += [(
|
||||
b'content-length',
|
||||
str(len(response.body)).encode('latin-1')
|
||||
)]
|
||||
|
||||
await self.send({
|
||||
'type': 'http.response.start',
|
||||
'status': response.status,
|
||||
'headers': headers
|
||||
})
|
||||
await self.send({
|
||||
'type': 'http.response.body',
|
||||
'body': response.body,
|
||||
'more_body': False
|
||||
})
|
||||
async def __call__(self, scope, receive, send):
|
||||
asgi_app = ASGIApp(self, scope, receive, send)
|
||||
await asgi_app()
|
||||
|
||||
Reference in New Issue
Block a user