Add custom request support to ASGI mode; fix a couple tests
Undo change to request stream test
This commit is contained in:
parent
966b05b47e
commit
a2666a2b8a
|
@ -88,7 +88,7 @@ class MockTransport:
|
||||||
self._websocket_connection = WebSocketConnection(send, receive)
|
self._websocket_connection = WebSocketConnection(send, receive)
|
||||||
return self._websocket_connection
|
return self._websocket_connection
|
||||||
|
|
||||||
def add_task(self) -> None: # noqa
|
def add_task(self) -> None:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
async def send(self, data) -> None:
|
async def send(self, data) -> None:
|
||||||
|
@ -119,27 +119,20 @@ class Lifespan:
|
||||||
"the ASGI server is stopped."
|
"the ASGI server is stopped."
|
||||||
)
|
)
|
||||||
|
|
||||||
# async def pre_startup(self) -> None:
|
|
||||||
# for handler in self.asgi_app.sanic_app.listeners[
|
|
||||||
# "before_server_start"
|
|
||||||
# ]:
|
|
||||||
# response = handler(
|
|
||||||
# self.asgi_app.sanic_app, self.asgi_app.sanic_app.loop
|
|
||||||
# )
|
|
||||||
# if isawaitable(response):
|
|
||||||
# await response
|
|
||||||
|
|
||||||
async def startup(self) -> None:
|
async def startup(self) -> None:
|
||||||
for handler in self.asgi_app.sanic_app.listeners[
|
"""
|
||||||
"before_server_start"
|
Gather the listeners to fire on server start.
|
||||||
]:
|
Because we are using a third-party server and not Sanic server, we do
|
||||||
response = handler(
|
not have access to fire anything BEFORE the server starts.
|
||||||
self.asgi_app.sanic_app, self.asgi_app.sanic_app.loop
|
Therefore, we fire before_server_start and after_server_start
|
||||||
)
|
in sequence since the ASGI lifespan protocol only supports a single
|
||||||
if isawaitable(response):
|
startup event.
|
||||||
await response
|
"""
|
||||||
|
listeners = self.asgi_app.sanic_app.listeners.get(
|
||||||
|
"before_server_start", []
|
||||||
|
) + self.asgi_app.sanic_app.listeners.get("after_server_start", [])
|
||||||
|
|
||||||
for handler in self.asgi_app.sanic_app.listeners["after_server_start"]:
|
for handler in listeners:
|
||||||
response = handler(
|
response = handler(
|
||||||
self.asgi_app.sanic_app, self.asgi_app.sanic_app.loop
|
self.asgi_app.sanic_app, self.asgi_app.sanic_app.loop
|
||||||
)
|
)
|
||||||
|
@ -147,14 +140,19 @@ class Lifespan:
|
||||||
await response
|
await response
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
for handler in self.asgi_app.sanic_app.listeners["before_server_stop"]:
|
"""
|
||||||
response = handler(
|
Gather the listeners to fire on server stop.
|
||||||
self.asgi_app.sanic_app, self.asgi_app.sanic_app.loop
|
Because we are using a third-party server and not Sanic server, we do
|
||||||
)
|
not have access to fire anything AFTER the server stops.
|
||||||
if isawaitable(response):
|
Therefore, we fire before_server_stop and after_server_stop
|
||||||
await response
|
in sequence since the ASGI lifespan protocol only supports a single
|
||||||
|
shutdown event.
|
||||||
|
"""
|
||||||
|
listeners = self.asgi_app.sanic_app.listeners.get(
|
||||||
|
"before_server_stop", []
|
||||||
|
) + self.asgi_app.sanic_app.listeners.get("after_server_stop", [])
|
||||||
|
|
||||||
for handler in self.asgi_app.sanic_app.listeners["after_server_stop"]:
|
for handler in listeners:
|
||||||
response = handler(
|
response = handler(
|
||||||
self.asgi_app.sanic_app, self.asgi_app.sanic_app.loop
|
self.asgi_app.sanic_app, self.asgi_app.sanic_app.loop
|
||||||
)
|
)
|
||||||
|
@ -223,7 +221,8 @@ class ASGIApp:
|
||||||
# TODO:
|
# TODO:
|
||||||
# - close connection
|
# - close connection
|
||||||
|
|
||||||
instance.request = Request(
|
request_class = sanic_app.request_class or Request
|
||||||
|
instance.request = request_class(
|
||||||
url_bytes,
|
url_bytes,
|
||||||
headers,
|
headers,
|
||||||
version,
|
version,
|
||||||
|
@ -267,6 +266,7 @@ class ASGIApp:
|
||||||
message = await self.transport.receive()
|
message = await self.transport.receive()
|
||||||
chunk = message.get("body", b"")
|
chunk = message.get("body", b"")
|
||||||
await self.request.stream.put(chunk)
|
await self.request.stream.put(chunk)
|
||||||
|
# self.sanic_app.loop.create_task(self.request.stream.put(chunk))
|
||||||
|
|
||||||
more_body = message.get("more_body", False)
|
more_body = message.get("more_body", False)
|
||||||
|
|
||||||
|
@ -294,6 +294,7 @@ class ASGIApp:
|
||||||
headers = [
|
headers = [
|
||||||
(str(name).encode("latin-1"), str(value).encode("latin-1"))
|
(str(name).encode("latin-1"), str(value).encode("latin-1"))
|
||||||
for name, value in response.headers.items()
|
for name, value in response.headers.items()
|
||||||
|
# if name not in ("Set-Cookie",)
|
||||||
]
|
]
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
logger.error(
|
logger.error(
|
||||||
|
|
|
@ -5,8 +5,11 @@ from collections import deque
|
||||||
import pytest
|
import pytest
|
||||||
import uvicorn
|
import uvicorn
|
||||||
|
|
||||||
|
from sanic import Sanic
|
||||||
from sanic.asgi import MockTransport
|
from sanic.asgi import MockTransport
|
||||||
from sanic.exceptions import InvalidUsage
|
from sanic.exceptions import InvalidUsage
|
||||||
|
from sanic.request import Request
|
||||||
|
from sanic.response import text
|
||||||
from sanic.websocket import WebSocketConnection
|
from sanic.websocket import WebSocketConnection
|
||||||
|
|
||||||
|
|
||||||
|
@ -201,3 +204,28 @@ def test_improper_websocket_connection(transport, send, receive):
|
||||||
transport.create_websocket_connection(send, receive)
|
transport.create_websocket_connection(send, receive)
|
||||||
connection = transport.get_websocket_connection()
|
connection = transport.get_websocket_connection()
|
||||||
assert isinstance(connection, WebSocketConnection)
|
assert isinstance(connection, WebSocketConnection)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_request_class_regular(app):
|
||||||
|
@app.get("/regular")
|
||||||
|
def regular_request(request):
|
||||||
|
return text(request.__class__.__name__)
|
||||||
|
|
||||||
|
_, response = await app.asgi_client.get("/regular")
|
||||||
|
assert response.body == b"Request"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_request_class_custom():
|
||||||
|
class MyCustomRequest(Request):
|
||||||
|
pass
|
||||||
|
|
||||||
|
app = Sanic(request_class=MyCustomRequest)
|
||||||
|
|
||||||
|
@app.get("/custom")
|
||||||
|
def custom_request(request):
|
||||||
|
return text(request.__class__.__name__)
|
||||||
|
|
||||||
|
_, response = await app.asgi_client.get("/custom")
|
||||||
|
assert response.body == b"MyCustomRequest"
|
||||||
|
|
|
@ -19,8 +19,8 @@ def temp_path():
|
||||||
|
|
||||||
|
|
||||||
class ConfigTest:
|
class ConfigTest:
|
||||||
not_for_config = 'should not be used'
|
not_for_config = "should not be used"
|
||||||
CONFIG_VALUE = 'should be used'
|
CONFIG_VALUE = "should be used"
|
||||||
|
|
||||||
|
|
||||||
def test_load_from_object(app):
|
def test_load_from_object(app):
|
||||||
|
@ -31,15 +31,15 @@ def test_load_from_object(app):
|
||||||
|
|
||||||
|
|
||||||
def test_load_from_object_string(app):
|
def test_load_from_object_string(app):
|
||||||
app.config.from_object('test_config.ConfigTest')
|
app.config.from_object("test_config.ConfigTest")
|
||||||
assert 'CONFIG_VALUE' in app.config
|
assert "CONFIG_VALUE" in app.config
|
||||||
assert app.config.CONFIG_VALUE == 'should be used'
|
assert app.config.CONFIG_VALUE == "should be used"
|
||||||
assert 'not_for_config' not in app.config
|
assert "not_for_config" not in app.config
|
||||||
|
|
||||||
|
|
||||||
def test_load_from_object_string_exception(app):
|
def test_load_from_object_string_exception(app):
|
||||||
with pytest.raises(ImportError):
|
with pytest.raises(ImportError):
|
||||||
app.config.from_object('test_config.Config.test')
|
app.config.from_object("test_config.Config.test")
|
||||||
|
|
||||||
|
|
||||||
def test_auto_load_env():
|
def test_auto_load_env():
|
||||||
|
|
|
@ -1,8 +1,9 @@
|
||||||
import inspect
|
import inspect
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
from sanic import helpers
|
from sanic import helpers
|
||||||
from sanic.config import Config
|
from sanic.config import Config
|
||||||
import pytest
|
|
||||||
|
|
||||||
|
|
||||||
def test_has_message_body():
|
def test_has_message_body():
|
||||||
|
@ -63,15 +64,15 @@ def test_remove_entity_headers():
|
||||||
|
|
||||||
|
|
||||||
def test_import_string_class():
|
def test_import_string_class():
|
||||||
obj = helpers.import_string('sanic.config.Config')
|
obj = helpers.import_string("sanic.config.Config")
|
||||||
assert isinstance(obj, Config)
|
assert isinstance(obj, Config)
|
||||||
|
|
||||||
|
|
||||||
def test_import_string_module():
|
def test_import_string_module():
|
||||||
module = helpers.import_string('sanic.config')
|
module = helpers.import_string("sanic.config")
|
||||||
assert inspect.ismodule(module)
|
assert inspect.ismodule(module)
|
||||||
|
|
||||||
|
|
||||||
def test_import_string_exception():
|
def test_import_string_exception():
|
||||||
with pytest.raises(ImportError):
|
with pytest.raises(ImportError):
|
||||||
helpers.import_string('test.test.test')
|
helpers.import_string("test.test.test")
|
||||||
|
|
Loading…
Reference in New Issue
Block a user