Add custom request support to ASGI mode; fix a couple tests

Undo change to request stream test
This commit is contained in:
Adam Hopkins 2019-06-24 22:49:11 +03:00
parent 966b05b47e
commit a2666a2b8a
5 changed files with 70 additions and 40 deletions

View File

@ -88,7 +88,7 @@ class MockTransport:
self._websocket_connection = WebSocketConnection(send, receive) self._websocket_connection = WebSocketConnection(send, receive)
return self._websocket_connection return self._websocket_connection
def add_task(self) -> None: # noqa def add_task(self) -> None:
raise NotImplementedError raise NotImplementedError
async def send(self, data) -> None: async def send(self, data) -> None:
@ -119,27 +119,20 @@ class Lifespan:
"the ASGI server is stopped." "the ASGI server is stopped."
) )
# async def pre_startup(self) -> None:
# for handler in self.asgi_app.sanic_app.listeners[
# "before_server_start"
# ]:
# response = handler(
# self.asgi_app.sanic_app, self.asgi_app.sanic_app.loop
# )
# if isawaitable(response):
# await response
async def startup(self) -> None: async def startup(self) -> None:
for handler in self.asgi_app.sanic_app.listeners[ """
"before_server_start" Gather the listeners to fire on server start.
]: Because we are using a third-party server and not Sanic server, we do
response = handler( not have access to fire anything BEFORE the server starts.
self.asgi_app.sanic_app, self.asgi_app.sanic_app.loop Therefore, we fire before_server_start and after_server_start
) in sequence since the ASGI lifespan protocol only supports a single
if isawaitable(response): startup event.
await response """
listeners = self.asgi_app.sanic_app.listeners.get(
"before_server_start", []
) + self.asgi_app.sanic_app.listeners.get("after_server_start", [])
for handler in self.asgi_app.sanic_app.listeners["after_server_start"]: for handler in listeners:
response = handler( response = handler(
self.asgi_app.sanic_app, self.asgi_app.sanic_app.loop self.asgi_app.sanic_app, self.asgi_app.sanic_app.loop
) )
@ -147,14 +140,19 @@ class Lifespan:
await response await response
async def shutdown(self) -> None: async def shutdown(self) -> None:
for handler in self.asgi_app.sanic_app.listeners["before_server_stop"]: """
response = handler( Gather the listeners to fire on server stop.
self.asgi_app.sanic_app, self.asgi_app.sanic_app.loop Because we are using a third-party server and not Sanic server, we do
) not have access to fire anything AFTER the server stops.
if isawaitable(response): Therefore, we fire before_server_stop and after_server_stop
await response in sequence since the ASGI lifespan protocol only supports a single
shutdown event.
"""
listeners = self.asgi_app.sanic_app.listeners.get(
"before_server_stop", []
) + self.asgi_app.sanic_app.listeners.get("after_server_stop", [])
for handler in self.asgi_app.sanic_app.listeners["after_server_stop"]: for handler in listeners:
response = handler( response = handler(
self.asgi_app.sanic_app, self.asgi_app.sanic_app.loop self.asgi_app.sanic_app, self.asgi_app.sanic_app.loop
) )
@ -223,7 +221,8 @@ class ASGIApp:
# TODO: # TODO:
# - close connection # - close connection
instance.request = Request( request_class = sanic_app.request_class or Request
instance.request = request_class(
url_bytes, url_bytes,
headers, headers,
version, version,
@ -267,6 +266,7 @@ class ASGIApp:
message = await self.transport.receive() message = await self.transport.receive()
chunk = message.get("body", b"") chunk = message.get("body", b"")
await self.request.stream.put(chunk) await self.request.stream.put(chunk)
# self.sanic_app.loop.create_task(self.request.stream.put(chunk))
more_body = message.get("more_body", False) more_body = message.get("more_body", False)
@ -294,6 +294,7 @@ class ASGIApp:
headers = [ headers = [
(str(name).encode("latin-1"), str(value).encode("latin-1")) (str(name).encode("latin-1"), str(value).encode("latin-1"))
for name, value in response.headers.items() for name, value in response.headers.items()
# if name not in ("Set-Cookie",)
] ]
except AttributeError: except AttributeError:
logger.error( logger.error(

View File

@ -5,8 +5,11 @@ from collections import deque
import pytest import pytest
import uvicorn import uvicorn
from sanic import Sanic
from sanic.asgi import MockTransport from sanic.asgi import MockTransport
from sanic.exceptions import InvalidUsage from sanic.exceptions import InvalidUsage
from sanic.request import Request
from sanic.response import text
from sanic.websocket import WebSocketConnection from sanic.websocket import WebSocketConnection
@ -201,3 +204,28 @@ def test_improper_websocket_connection(transport, send, receive):
transport.create_websocket_connection(send, receive) transport.create_websocket_connection(send, receive)
connection = transport.get_websocket_connection() connection = transport.get_websocket_connection()
assert isinstance(connection, WebSocketConnection) assert isinstance(connection, WebSocketConnection)
@pytest.mark.asyncio
async def test_request_class_regular(app):
@app.get("/regular")
def regular_request(request):
return text(request.__class__.__name__)
_, response = await app.asgi_client.get("/regular")
assert response.body == b"Request"
@pytest.mark.asyncio
async def test_request_class_custom():
class MyCustomRequest(Request):
pass
app = Sanic(request_class=MyCustomRequest)
@app.get("/custom")
def custom_request(request):
return text(request.__class__.__name__)
_, response = await app.asgi_client.get("/custom")
assert response.body == b"MyCustomRequest"

View File

@ -19,8 +19,8 @@ def temp_path():
class ConfigTest: class ConfigTest:
not_for_config = 'should not be used' not_for_config = "should not be used"
CONFIG_VALUE = 'should be used' CONFIG_VALUE = "should be used"
def test_load_from_object(app): def test_load_from_object(app):
@ -31,15 +31,15 @@ def test_load_from_object(app):
def test_load_from_object_string(app): def test_load_from_object_string(app):
app.config.from_object('test_config.ConfigTest') app.config.from_object("test_config.ConfigTest")
assert 'CONFIG_VALUE' in app.config assert "CONFIG_VALUE" in app.config
assert app.config.CONFIG_VALUE == 'should be used' assert app.config.CONFIG_VALUE == "should be used"
assert 'not_for_config' not in app.config assert "not_for_config" not in app.config
def test_load_from_object_string_exception(app): def test_load_from_object_string_exception(app):
with pytest.raises(ImportError): with pytest.raises(ImportError):
app.config.from_object('test_config.Config.test') app.config.from_object("test_config.Config.test")
def test_auto_load_env(): def test_auto_load_env():

View File

@ -1,8 +1,9 @@
import inspect import inspect
import pytest
from sanic import helpers from sanic import helpers
from sanic.config import Config from sanic.config import Config
import pytest
def test_has_message_body(): def test_has_message_body():
@ -63,15 +64,15 @@ def test_remove_entity_headers():
def test_import_string_class(): def test_import_string_class():
obj = helpers.import_string('sanic.config.Config') obj = helpers.import_string("sanic.config.Config")
assert isinstance(obj, Config) assert isinstance(obj, Config)
def test_import_string_module(): def test_import_string_module():
module = helpers.import_string('sanic.config') module = helpers.import_string("sanic.config")
assert inspect.ismodule(module) assert inspect.ismodule(module)
def test_import_string_exception(): def test_import_string_exception():
with pytest.raises(ImportError): with pytest.raises(ImportError):
helpers.import_string('test.test.test') helpers.import_string("test.test.test")

View File

@ -54,4 +54,4 @@ deps =
bandit bandit
commands = commands =
bandit --recursive sanic --skip B404,B101 --exclude sanic/reloader_helpers.py bandit --recursive sanic --skip B404,B101 --exclude sanic/reloader_helpers.py