squash this

This commit is contained in:
Adam Hopkins
2020-07-13 23:59:45 +03:00
parent 050a563e1d
commit cf234fca15
5 changed files with 61 additions and 81 deletions

View File

@@ -1,6 +1,5 @@
import asyncio
import warnings
from inspect import isawaitable
from typing import (
Any,
@@ -16,7 +15,6 @@ from typing import (
from urllib.parse import quote
import sanic.app # noqa
from sanic.compat import Header
from sanic.exceptions import InvalidUsage, ServerError
from sanic.log import logger
@@ -25,7 +23,6 @@ from sanic.response import HTTPResponse, StreamingHTTPResponse
from sanic.server import ConnInfo, StreamBuffer
from sanic.websocket import WebSocketConnection
ASGIScope = MutableMapping[str, Any]
ASGIMessage = MutableMapping[str, Any]
ASGISend = Callable[[ASGIMessage], Awaitable[None]]
@@ -68,9 +65,7 @@ class MockProtocol:
class MockTransport:
_protocol: Optional[MockProtocol]
def __init__(
self, scope: ASGIScope, receive: ASGIReceive, send: ASGISend
) -> None:
def __init__(self, scope: ASGIScope, receive: ASGIReceive, send: ASGISend) -> None:
self.scope = scope
self._receive = receive
self._send = send
@@ -117,6 +112,8 @@ class Lifespan:
def __init__(self, asgi_app: "ASGIApp") -> None:
self.asgi_app = asgi_app
print(self.asgi_app.sanic_app.listeners)
if "before_server_start" in self.asgi_app.sanic_app.listeners:
warnings.warn(
'You have set a listener for "before_server_start" '
@@ -146,9 +143,7 @@ class Lifespan:
) + self.asgi_app.sanic_app.listeners.get("after_server_start", [])
for handler in listeners:
response = handler(
self.asgi_app.sanic_app, self.asgi_app.sanic_app.loop
)
response = handler(self.asgi_app.sanic_app, self.asgi_app.sanic_app.loop)
if isawaitable(response):
await response
@@ -166,9 +161,7 @@ class Lifespan:
) + self.asgi_app.sanic_app.listeners.get("after_server_stop", [])
for handler in listeners:
response = handler(
self.asgi_app.sanic_app, self.asgi_app.sanic_app.loop
)
response = handler(self.asgi_app.sanic_app, self.asgi_app.sanic_app.loop)
if isawaitable(response):
await response
@@ -213,19 +206,13 @@ class ASGIApp:
for key, value in scope.get("headers", [])
]
)
instance.do_stream = (
True if headers.get("expect") == "100-continue" else False
)
instance.do_stream = True if headers.get("expect") == "100-continue" else False
instance.lifespan = Lifespan(instance)
if scope["type"] == "lifespan":
await instance.lifespan(scope, receive, send)
else:
path = (
scope["path"][1:]
if scope["path"].startswith("/")
else scope["path"]
)
path = scope["path"][1:] if scope["path"].startswith("/") else scope["path"]
url = "/".join([scope.get("root_path", ""), quote(path)])
url_bytes = url.encode("latin-1")
url_bytes += b"?" + scope["query_string"]
@@ -248,19 +235,12 @@ class ASGIApp:
request_class = sanic_app.request_class or Request
instance.request = request_class(
url_bytes,
headers,
version,
method,
instance.transport,
sanic_app,
url_bytes, headers, version, method, instance.transport, sanic_app,
)
instance.request.conn_info = ConnInfo(instance.transport)
if sanic_app.is_request_stream:
is_stream_handler = sanic_app.router.is_stream_handler(
instance.request
)
is_stream_handler = sanic_app.router.is_stream_handler(instance.request)
if is_stream_handler:
instance.request.stream = StreamBuffer(
sanic_app.config.REQUEST_BUFFER_QUEUE_SIZE
@@ -339,9 +319,7 @@ class ASGIApp:
type(response),
)
exception = ServerError("Invalid response type")
response = self.sanic_app.error_handler.response(
self.request, exception
)
response = self.sanic_app.error_handler.response(self.request, exception)
headers = [
(str(name).encode("latin-1"), str(value).encode("latin-1"))
for name, value in response.headers.items()
@@ -351,14 +329,10 @@ class ASGIApp:
if "content-length" not in response.headers and not isinstance(
response, StreamingHTTPResponse
):
headers += [
(b"content-length", str(len(response.body)).encode("latin-1"))
]
headers += [(b"content-length", str(len(response.body)).encode("latin-1"))]
if "content-type" not in response.headers:
headers += [
(b"content-type", str(response.content_type).encode("latin-1"))
]
headers += [(b"content-type", str(response.content_type).encode("latin-1"))]
if response.cookies:
cookies.update(
@@ -370,8 +344,7 @@ class ASGIApp:
)
headers += [
(b"set-cookie", cookie.encode("utf-8"))
for k, cookie in cookies.items()
(b"set-cookie", cookie.encode("utf-8")) for k, cookie in cookies.items()
]
await self.transport.send(

View File

@@ -9,7 +9,6 @@ from sanic.exceptions import MethodNotSupported
from sanic.log import logger
from sanic.response import text
ASGI_HOST = "mockserver"
HOST = "127.0.0.1"
PORT = None
@@ -88,9 +87,7 @@ class SanicTestClient:
@self.app.exception(MethodNotSupported)
async def error_handler(request, exception):
if request.method in ["HEAD", "PATCH", "PUT", "DELETE"]:
return text(
"", exception.status_code, headers=exception.headers
)
return text("", exception.status_code, headers=exception.headers)
else:
return self.app.error_handler.default(request, exception)
@@ -106,9 +103,7 @@ class SanicTestClient:
host, port = sock.getsockname()
self.port = port
if uri.startswith(
("http:", "https:", "ftp:", "ftps://", "//", "ws:", "wss:")
):
if uri.startswith(("http:", "https:", "ftp:", "ftps://", "//", "ws:", "wss:")):
url = uri
else:
uri = uri if uri.startswith("/") else f"/{uri}"
@@ -201,7 +196,6 @@ class SanicASGITestClient(httpx.AsyncClient):
app.asgi = True
self.app = app
self.app.test_mode = True
dispatch = SanicASGIDispatch(app=app, client=(ASGI_HOST, PORT or 0))
super().__init__(dispatch=dispatch, base_url=base_url)
@@ -211,6 +205,17 @@ class SanicASGITestClient(httpx.AsyncClient):
def _collect_request(request):
self.last_request = request
def _start_test_mode(request):
self.app.test_mode = True
@app.listener("after_server_start")
def _end_test_mode(sanic, loop):
sanic.test_mode = True
@app.listener("before_server_end")
def _end_test_mode(sanic, loop):
sanic.test_mode = False
app.request_middleware.appendleft(_collect_request)
async def request(self, method, url, gather_request=True, *args, **kwargs):
@@ -233,9 +238,7 @@ class SanicASGITestClient(httpx.AsyncClient):
headers.setdefault("sec-websocket-key", "testserver==")
headers.setdefault("sec-websocket-version", "13")
if subprotocols is not None:
headers.setdefault(
"sec-websocket-protocol", ", ".join(subprotocols)
)
headers.setdefault("sec-websocket-protocol", ", ".join(subprotocols))
scope = {
"type": "websocket",