This commit is contained in:
Adam Hopkins 2020-07-14 10:25:56 +03:00
parent eddb5bad91
commit 16d36fc17f
4 changed files with 70 additions and 21 deletions

View File

@ -1,5 +1,6 @@
import asyncio import asyncio
import warnings import warnings
from inspect import isawaitable from inspect import isawaitable
from typing import ( from typing import (
Any, Any,
@ -15,6 +16,7 @@ from typing import (
from urllib.parse import quote from urllib.parse import quote
import sanic.app # noqa import sanic.app # noqa
from sanic.compat import Header from sanic.compat import Header
from sanic.exceptions import InvalidUsage, ServerError from sanic.exceptions import InvalidUsage, ServerError
from sanic.log import logger from sanic.log import logger
@ -23,6 +25,7 @@ from sanic.response import HTTPResponse, StreamingHTTPResponse
from sanic.server import ConnInfo, StreamBuffer from sanic.server import ConnInfo, StreamBuffer
from sanic.websocket import WebSocketConnection from sanic.websocket import WebSocketConnection
ASGIScope = MutableMapping[str, Any] ASGIScope = MutableMapping[str, Any]
ASGIMessage = MutableMapping[str, Any] ASGIMessage = MutableMapping[str, Any]
ASGISend = Callable[[ASGIMessage], Awaitable[None]] ASGISend = Callable[[ASGIMessage], Awaitable[None]]
@ -65,7 +68,9 @@ class MockProtocol:
class MockTransport: class MockTransport:
_protocol: Optional[MockProtocol] _protocol: Optional[MockProtocol]
def __init__(self, scope: ASGIScope, receive: ASGIReceive, send: ASGISend) -> None: def __init__(
self, scope: ASGIScope, receive: ASGIReceive, send: ASGISend
) -> None:
self.scope = scope self.scope = scope
self._receive = receive self._receive = receive
self._send = send self._send = send
@ -141,7 +146,9 @@ class Lifespan:
) + self.asgi_app.sanic_app.listeners.get("after_server_start", []) ) + self.asgi_app.sanic_app.listeners.get("after_server_start", [])
for handler in listeners: for handler in listeners:
response = handler(self.asgi_app.sanic_app, self.asgi_app.sanic_app.loop) response = handler(
self.asgi_app.sanic_app, self.asgi_app.sanic_app.loop
)
if isawaitable(response): if isawaitable(response):
await response await response
@ -159,7 +166,9 @@ class Lifespan:
) + self.asgi_app.sanic_app.listeners.get("after_server_stop", []) ) + self.asgi_app.sanic_app.listeners.get("after_server_stop", [])
for handler in listeners: for handler in listeners:
response = handler(self.asgi_app.sanic_app, self.asgi_app.sanic_app.loop) response = handler(
self.asgi_app.sanic_app, self.asgi_app.sanic_app.loop
)
if isawaitable(response): if isawaitable(response):
await response await response
@ -204,13 +213,19 @@ class ASGIApp:
for key, value in scope.get("headers", []) for key, value in scope.get("headers", [])
] ]
) )
instance.do_stream = True if headers.get("expect") == "100-continue" else False instance.do_stream = (
True if headers.get("expect") == "100-continue" else False
)
instance.lifespan = Lifespan(instance) instance.lifespan = Lifespan(instance)
if scope["type"] == "lifespan": if scope["type"] == "lifespan":
await instance.lifespan(scope, receive, send) await instance.lifespan(scope, receive, send)
else: else:
path = scope["path"][1:] if scope["path"].startswith("/") else scope["path"] path = (
scope["path"][1:]
if scope["path"].startswith("/")
else scope["path"]
)
url = "/".join([scope.get("root_path", ""), quote(path)]) url = "/".join([scope.get("root_path", ""), quote(path)])
url_bytes = url.encode("latin-1") url_bytes = url.encode("latin-1")
url_bytes += b"?" + scope["query_string"] url_bytes += b"?" + scope["query_string"]
@ -233,12 +248,19 @@ class ASGIApp:
request_class = sanic_app.request_class or Request request_class = sanic_app.request_class or Request
instance.request = request_class( instance.request = request_class(
url_bytes, headers, version, method, instance.transport, sanic_app, url_bytes,
headers,
version,
method,
instance.transport,
sanic_app,
) )
instance.request.conn_info = ConnInfo(instance.transport) instance.request.conn_info = ConnInfo(instance.transport)
if sanic_app.is_request_stream: if sanic_app.is_request_stream:
is_stream_handler = sanic_app.router.is_stream_handler(instance.request) is_stream_handler = sanic_app.router.is_stream_handler(
instance.request
)
if is_stream_handler: if is_stream_handler:
instance.request.stream = StreamBuffer( instance.request.stream = StreamBuffer(
sanic_app.config.REQUEST_BUFFER_QUEUE_SIZE sanic_app.config.REQUEST_BUFFER_QUEUE_SIZE
@ -317,7 +339,9 @@ class ASGIApp:
type(response), type(response),
) )
exception = ServerError("Invalid response type") exception = ServerError("Invalid response type")
response = self.sanic_app.error_handler.response(self.request, exception) response = self.sanic_app.error_handler.response(
self.request, exception
)
headers = [ headers = [
(str(name).encode("latin-1"), str(value).encode("latin-1")) (str(name).encode("latin-1"), str(value).encode("latin-1"))
for name, value in response.headers.items() for name, value in response.headers.items()
@ -327,10 +351,14 @@ class ASGIApp:
if "content-length" not in response.headers and not isinstance( if "content-length" not in response.headers and not isinstance(
response, StreamingHTTPResponse response, StreamingHTTPResponse
): ):
headers += [(b"content-length", str(len(response.body)).encode("latin-1"))] headers += [
(b"content-length", str(len(response.body)).encode("latin-1"))
]
if "content-type" not in response.headers: if "content-type" not in response.headers:
headers += [(b"content-type", str(response.content_type).encode("latin-1"))] headers += [
(b"content-type", str(response.content_type).encode("latin-1"))
]
if response.cookies: if response.cookies:
cookies.update( cookies.update(
@ -342,7 +370,8 @@ class ASGIApp:
) )
headers += [ headers += [
(b"set-cookie", cookie.encode("utf-8")) for k, cookie in cookies.items() (b"set-cookie", cookie.encode("utf-8"))
for k, cookie in cookies.items()
] ]
await self.transport.send( await self.transport.send(

View File

@ -95,7 +95,9 @@ class SanicTestClient:
@self.app.exception(MethodNotSupported) @self.app.exception(MethodNotSupported)
async def error_handler(request, exception): async def error_handler(request, exception):
if request.method in ["HEAD", "PATCH", "PUT", "DELETE"]: if request.method in ["HEAD", "PATCH", "PUT", "DELETE"]:
return text("", exception.status_code, headers=exception.headers) return text(
"", exception.status_code, headers=exception.headers
)
else: else:
return self.app.error_handler.default(request, exception) return self.app.error_handler.default(request, exception)
@ -111,7 +113,9 @@ class SanicTestClient:
host, port = sock.getsockname() host, port = sock.getsockname()
self.port = port self.port = port
if uri.startswith(("http:", "https:", "ftp:", "ftps://", "//", "ws:", "wss:")): if uri.startswith(
("http:", "https:", "ftp:", "ftps://", "//", "ws:", "wss:")
):
url = uri url = uri
else: else:
uri = uri if uri.startswith("/") else f"/{uri}" uri = uri if uri.startswith("/") else f"/{uri}"
@ -243,7 +247,9 @@ class SanicASGITestClient(httpx.AsyncClient):
headers.setdefault("sec-websocket-key", "testserver==") headers.setdefault("sec-websocket-key", "testserver==")
headers.setdefault("sec-websocket-version", "13") headers.setdefault("sec-websocket-version", "13")
if subprotocols is not None: if subprotocols is not None:
headers.setdefault("sec-websocket-protocol", ", ".join(subprotocols)) headers.setdefault(
"sec-websocket-protocol", ", ".join(subprotocols)
)
scope = { scope = {
"type": "websocket", "type": "websocket",

View File

@ -1,6 +1,7 @@
import asyncio import asyncio
import logging import logging
import sys import sys
from inspect import isawaitable from inspect import isawaitable
import pytest import pytest
@ -29,7 +30,9 @@ def test_app_loop_running(app):
assert response.text == "pass" assert response.text == "pass"
@pytest.mark.skipif(sys.version_info < (3, 7), reason="requires python3.7 or higher") @pytest.mark.skipif(
sys.version_info < (3, 7), reason="requires python3.7 or higher"
)
def test_create_asyncio_server(app): def test_create_asyncio_server(app):
if not uvloop_installed(): if not uvloop_installed():
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
@ -39,7 +42,9 @@ def test_create_asyncio_server(app):
assert srv.is_serving() is True assert srv.is_serving() is True
@pytest.mark.skipif(sys.version_info < (3, 7), reason="requires python3.7 or higher") @pytest.mark.skipif(
sys.version_info < (3, 7), reason="requires python3.7 or higher"
)
def test_asyncio_server_no_start_serving(app): def test_asyncio_server_no_start_serving(app):
if not uvloop_installed(): if not uvloop_installed():
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
@ -52,7 +57,9 @@ def test_asyncio_server_no_start_serving(app):
assert srv.is_serving() is False assert srv.is_serving() is False
@pytest.mark.skipif(sys.version_info < (3, 7), reason="requires python3.7 or higher") @pytest.mark.skipif(
sys.version_info < (3, 7), reason="requires python3.7 or higher"
)
def test_asyncio_server_start_serving(app): def test_asyncio_server_start_serving(app):
if not uvloop_installed(): if not uvloop_installed():
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
@ -149,7 +156,9 @@ def test_handle_request_with_nested_exception(app, monkeypatch):
def mock_error_handler_response(*args, **kwargs): def mock_error_handler_response(*args, **kwargs):
raise Exception(err_msg) raise Exception(err_msg)
monkeypatch.setattr(app.error_handler, "response", mock_error_handler_response) monkeypatch.setattr(
app.error_handler, "response", mock_error_handler_response
)
@app.get("/") @app.get("/")
def handler(request): def handler(request):
@ -168,7 +177,9 @@ def test_handle_request_with_nested_exception_debug(app, monkeypatch):
def mock_error_handler_response(*args, **kwargs): def mock_error_handler_response(*args, **kwargs):
raise Exception(err_msg) raise Exception(err_msg)
monkeypatch.setattr(app.error_handler, "response", mock_error_handler_response) monkeypatch.setattr(
app.error_handler, "response", mock_error_handler_response
)
@app.get("/") @app.get("/")
def handler(request): def handler(request):
@ -187,7 +198,9 @@ def test_handle_request_with_nested_sanic_exception(app, monkeypatch, caplog):
def mock_error_handler_response(*args, **kwargs): def mock_error_handler_response(*args, **kwargs):
raise SanicException("Mock SanicException") raise SanicException("Mock SanicException")
monkeypatch.setattr(app.error_handler, "response", mock_error_handler_response) monkeypatch.setattr(
app.error_handler, "response", mock_error_handler_response
)
@app.get("/") @app.get("/")
def handler(request): def handler(request):

View File

@ -1,10 +1,11 @@
import asyncio import asyncio
import sys import sys
from collections import deque, namedtuple from collections import deque, namedtuple
import pytest import pytest
import uvicorn import uvicorn
from sanic import Sanic from sanic import Sanic
from sanic.asgi import MockTransport from sanic.asgi import MockTransport
from sanic.exceptions import InvalidUsage from sanic.exceptions import InvalidUsage