This commit is contained in:
Adam Hopkins 2021-01-31 16:08:19 +02:00
parent 4358a7eefd
commit a46ea4fc59
9 changed files with 125 additions and 98 deletions

View File

@ -7,21 +7,20 @@ from asyncio import CancelledError, Protocol, ensure_future, get_event_loop
from asyncio.futures import Future
from collections import defaultdict, deque
from functools import partial
from inspect import isawaitable, signature
from inspect import isawaitable
from socket import socket
from ssl import Purpose, SSLContext, create_default_context
from traceback import format_exc
from typing import (
Any,
Callable,
Coroutine,
Deque,
Dict,
Iterable,
List,
Optional,
Set,
Type,
TypeVar,
Union,
)
from urllib.parse import urlencode, urlunparse
@ -33,7 +32,6 @@ from sanic.asgi import ASGIApp
from sanic.blueprint_group import BlueprintGroup
from sanic.blueprints import Blueprint
from sanic.config import BASE_LOGO, Config
from sanic.constants import HTTP_METHODS
from sanic.exceptions import (
InvalidUsage,
NotFound,
@ -64,9 +62,9 @@ from sanic.server import (
Signal,
serve,
serve_multiple,
trigger_events,
)
from sanic.static import register as static_register
from sanic.views import CompositionView
from sanic.websocket import ConnectionClosed, WebSocketProtocol
@ -111,8 +109,8 @@ class Sanic(
self.request_class = request_class
self.error_handler = error_handler or ErrorHandler()
self.config = Config(load_env=load_env)
self.request_middleware: Iterable[MiddlewareType] = deque()
self.response_middleware: Iterable[MiddlewareType] = deque()
self.request_middleware: Deque[MiddlewareType] = deque()
self.response_middleware: Deque[MiddlewareType] = deque()
self.blueprints: Dict[str, Blueprint] = {}
self._blueprint_order: List[Blueprint] = []
self.configure_logging = configure_logging
@ -124,8 +122,8 @@ class Sanic(
self.is_running = False
self.websocket_enabled = False
self.websocket_tasks: Set[Future] = set()
self.named_request_middleware: Dict[str, MiddlewareType] = {}
self.named_response_middleware: Dict[str, MiddlewareType] = {}
self.named_request_middleware: Dict[str, Deque[MiddlewareType]] = {}
self.named_response_middleware: Dict[str, Deque[MiddlewareType]] = {}
self._test_client = None
self._asgi_client = None
# Register alternative method names
@ -164,7 +162,8 @@ class Sanic(
also return a future, and the actual ensure_future call
is delayed until before server start.
`See user guide <https://sanicframework.org/guide/basics/tasks.html#background-tasks>`_
`See user guide
<https://sanicframework.org/guide/basics/tasks.html#background-tasks>`_
:param task: future, couroutine or awaitable
"""
@ -272,7 +271,8 @@ class Sanic(
:param middleware: the middleware to execute
:param route_names: a list of the names of the endpoints
:type route_names: Iterable[str]
:param attach_to: whether to attach to request or response, defaults to "request"
:param attach_to: whether to attach to request or response,
defaults to "request"
:type attach_to: str, optional
"""
if attach_to == "request":
@ -336,10 +336,11 @@ class Sanic(
Keyword arguments that are not request parameters will be included in
the output URL's query string.
`See user guide <https://sanicframework.org/guide/basics/routing.html#generating-a-url>`__
`See user guide
<https://sanicframework.org/guide/basics/routing.html#generating-a-url>`_
:param view_name: string referencing the view name
:param \**kwargs: keys and values that are used to build request
:param kwargs: keys and values that are used to build request
parameters and query string arguments.
:return: the built URL
@ -509,10 +510,12 @@ class Sanic(
response = await request.respond(response)
except BaseException:
# Skip response middleware
if request.stream:
request.stream.respond(response)
await response.send(end_stream=True)
raise
else:
if request.stream:
response = request.stream.response
if isinstance(response, BaseHTTPResponse):
await response.send(end_stream=True)
@ -551,7 +554,11 @@ class Sanic(
) = self.router.get(request)
request.name = name
if request.stream.request_body and not ignore_body:
if (
request.stream
and request.stream.request_body
and not ignore_body
):
if self.router.is_stream_handler(request):
# Streaming handler: lift the size limit
request.stream.request_max_size = float("inf")
@ -589,6 +596,7 @@ class Sanic(
if response:
response = await request.respond(response)
else:
if request.stream:
response = request.stream.response
# Make sure that response is finished / run StreamingHTTP callback
@ -597,7 +605,7 @@ class Sanic(
else:
try:
# Fastest method for checking if the property exists
handler.is_websocket
handler.is_websocket # type: ignore
except AttributeError:
raise ServerError(
f"Invalid response type {response!r} "
@ -841,7 +849,7 @@ class Sanic(
)
# Trigger before_start events
await self.trigger_events(
await trigger_events(
server_settings.get("before_start", []),
server_settings.get("loop"),
)
@ -850,17 +858,6 @@ class Sanic(
asyncio_server_kwargs=asyncio_server_kwargs, **server_settings
)
async def trigger_events(self, events, loop):
"""
Trigger events (functions or async)
:param events: one or more sync or async functions to execute
:param loop: event loop
"""
for event in events:
result = event(loop)
if isawaitable(result):
await result
async def _run_request_middleware(self, request, request_name=None):
# The if improves speed. I don't know why
named_middleware = self.named_request_middleware.get(
@ -1075,7 +1072,8 @@ class Sanic(
"""
Update app.config. Full implementation can be found in the user guide.
`See user guide <https://sanicframework.org/guide/deployment/configuration.html#basics>`__
`See user guide
<https://sanicframework.org/guide/deployment/configuration.html#basics>`__
"""
self.config.update_config(config)

View File

@ -1,8 +1,7 @@
from __future__ import annotations
from typing import (
Optional,
TYPE_CHECKING,
)
from typing import TYPE_CHECKING, Optional
if TYPE_CHECKING:
from sanic.request import Request
@ -264,7 +263,7 @@ class Http:
# Compatibility with simple response body
if not data and getattr(res, "body", None):
data, end_stream = res.body, True
data, end_stream = res.body, True # type: ignore
size = len(data)
headers = res.headers

View File

@ -24,7 +24,9 @@ class ListenerMixin:
def listener(
self,
listener_or_event: Union[Callable[..., Coroutine[Any, Any, None]]],
listener_or_event: Union[
Callable[..., Coroutine[Any, Any, None]], str
],
event_or_none: Optional[str] = None,
apply: bool = True,
):
@ -39,7 +41,8 @@ class ListenerMixin:
async def before_server_start(app, loop):
...
`See user guide <https://sanicframework.org/guide/basics/listeners.html#listeners>`__
`See user guide
<https://sanicframework.org/guide/basics/listeners.html#listeners>`_
:param event: event to listen to
"""

View File

@ -19,7 +19,8 @@ class MiddlewareMixin:
Can either be called as *@app.middleware* or
*@app.middleware('request')*
`See user guide <https://sanicframework.org/guide/basics/middleware.html>`__
`See user guide
<https://sanicframework.org/guide/basics/middleware.html>`_
:param: middleware_or_request: Optional parameter to use for
identifying which type of middleware is being registered.

View File

@ -100,7 +100,7 @@ class RouteMixin:
route = FutureRoute(
handler,
uri,
methods,
frozenset(methods),
host,
strict_slashes,
stream,

View File

@ -21,7 +21,7 @@ if TYPE_CHECKING:
import email.utils
import uuid
from asyncio.transports import BaseTransport
from asyncio.transports import Transport
from collections import defaultdict
from http.cookies import SimpleCookie
from types import SimpleNamespace
@ -115,7 +115,7 @@ class Request:
headers: Header,
version: str,
method: str,
transport: BaseTransport,
transport: Transport,
app: Sanic,
):
self.raw_url = url_bytes
@ -144,11 +144,11 @@ class Request:
self.parsed_not_grouped_args: DefaultDict[
Tuple[bool, bool, str, str], List[Tuple[str, str]]
] = defaultdict(list)
self.uri_template = None
self.uri_template: Optional[str] = None
self.request_middleware_started = False
self._cookies: Dict[str, str] = {}
self.stream: Optional[Http] = None
self.endpoint = None
self.endpoint: Optional[str] = None
def __repr__(self):
class_name = self.__class__.__name__
@ -182,7 +182,7 @@ class Request:
self, response, request_name=self.name
)
# Redefining this as a tuple here satisfies mypy
except tuple(CancelledErrors):
except tuple(*CancelledErrors):
raise
except Exception:
error_logger.exception(

View File

@ -192,7 +192,7 @@ class StreamingHTTPResponse(BaseHTTPResponse):
self,
streaming_fn: StreamingFunction,
status: int = 200,
headers: Optional[Dict[str, str]] = None,
headers: Optional[Union[Header, Dict[str, str]]] = None,
content_type: str = "text/plain; charset=utf-8",
chunked="deprecated",
):
@ -244,7 +244,7 @@ class HTTPResponse(BaseHTTPResponse):
self,
body: Optional[AnyStr] = None,
status: int = 200,
headers: Optional[Dict[str, str]] = None,
headers: Optional[Union[Header, Dict[str, str]]] = None,
content_type: Optional[str] = None,
):
super().__init__()

View File

@ -1,17 +1,19 @@
from __future__ import annotations
from ssl import SSLContext
from typing import (
TYPE_CHECKING,
DefaultDict,
Any,
Callable,
Dict,
List,
NamedTuple,
Iterable,
Optional,
Tuple,
Type,
Union,
)
from sanic.handlers import ListenerType
if TYPE_CHECKING:
from sanic.app import Sanic
@ -25,7 +27,7 @@ import stat
import sys
from asyncio import CancelledError
from asyncio.transports import BaseTransport
from asyncio.transports import Transport
from functools import partial
from inspect import isawaitable
from ipaddress import ip_address
@ -67,8 +69,8 @@ class ConnInfo:
"ssl",
)
def __init__(self, transport: BaseTransport, unix=None):
self.ssl = bool(transport.get_extra_info("sslcontext"))
def __init__(self, transport: Transport, unix=None):
self.ssl: bool = bool(transport.get_extra_info("sslcontext"))
self.server = self.client = ""
self.server_port = self.client_port = 0
self.peername = None
@ -134,8 +136,8 @@ class HttpProtocol(asyncio.Protocol):
self,
*,
loop,
app,
signal=Signal(),
app: Sanic,
signal=None,
connections=None,
state=None,
unix=None,
@ -146,10 +148,10 @@ class HttpProtocol(asyncio.Protocol):
deprecated_loop = self.loop if sys.version_info < (3, 7) else None
self.app: Sanic = app
self.url = None
self.transport = None
self.conn_info = None
self.request = None
self.signal = signal
self.transport: Optional[Transport] = None
self.conn_info: Optional[ConnInfo] = None
self.request: Optional[Request] = None
self.signal = signal or Signal()
self.access_log = self.app.config.ACCESS_LOG
self.connections = connections if connections is not None else set()
self.request_handler = self.app.handle_request
@ -177,7 +179,8 @@ class HttpProtocol(asyncio.Protocol):
self.check_timeouts()
async def connection_task(self):
"""Run a HTTP connection.
"""
Run a HTTP connection.
Timeouts and some additional error handling occur here, while most of
everything else happens in class Http or in code called from there.
@ -204,13 +207,17 @@ class HttpProtocol(asyncio.Protocol):
logger.exception("Closing failed")
async def receive_more(self):
"""Wait until more data is received into self._buffer."""
"""
Wait until more data is received into the Server protocol's buffer
"""
self.transport.resume_reading()
self._data_received.clear()
await self._data_received.wait()
def check_timeouts(self):
"""Runs itself periodically to enforce any expired timeouts."""
"""
Runs itself periodically to enforce any expired timeouts.
"""
try:
if not self._task:
return
@ -241,15 +248,18 @@ class HttpProtocol(asyncio.Protocol):
logger.exception("protocol.check_timeouts")
async def send(self, data):
"""Writes data with backpressure control."""
"""
Writes data with backpressure control.
"""
await self._can_write.wait()
if self.transport.is_closing():
raise CancelledError
self.transport.write(data)
self._time = current_time()
def close_if_idle(self):
"""Close the connection if a request is not being sent or received
def close_if_idle(self) -> bool:
"""
Close the connection if a request is not being sent or received
:return: boolean - True if closed, false if staying open
"""
@ -298,14 +308,17 @@ class HttpProtocol(asyncio.Protocol):
def resume_writing(self):
self._can_write.set()
def data_received(self, data):
def data_received(self, data: bytes):
try:
self._time = current_time()
if not data:
return self.close()
self.recv_buffer += data
if len(self.recv_buffer) > self.app.config.REQUEST_BUFFER_SIZE:
if (
len(self.recv_buffer) > self.app.config.REQUEST_BUFFER_SIZE
and self.transport
):
self.transport.pause_reading()
if self._data_received:
@ -314,12 +327,14 @@ class HttpProtocol(asyncio.Protocol):
logger.exception("protocol.data_received")
def trigger_events(events, loop):
"""Trigger event callbacks (functions or async)
def trigger_events(events: Optional[Iterable[Callable[..., Any]]], loop):
"""
Trigger event callbacks (functions or async)
:param events: one or more sync or async functions to execute
:param loop: event loop
"""
if events:
for event in events:
result = event(loop)
if isawaitable(result):
@ -347,9 +362,9 @@ class AsyncioServer:
loop,
serve_coro,
connections,
after_start,
before_stop,
after_stop,
after_start: Optional[Iterable[ListenerType]],
before_stop: Optional[Iterable[ListenerType]],
after_stop: Optional[Iterable[ListenerType]],
):
# Note, Sanic already called "before_server_start" events
# before this helper was even created. So we don't need it here.
@ -362,18 +377,24 @@ class AsyncioServer:
self.connections = connections
def after_start(self):
"""Trigger "after_server_start" events"""
"""
Trigger "after_server_start" events
"""
trigger_events(self._after_start, self.loop)
def before_stop(self):
"""Trigger "before_server_stop" events"""
"""
Trigger "before_server_stop" events
"""
trigger_events(self._before_stop, self.loop)
def after_stop(self):
"""Trigger "after_server_stop" events"""
"""
Trigger "after_server_stop" events
"""
trigger_events(self._after_stop, self.loop)
def is_serving(self):
def is_serving(self) -> bool:
if self.server:
return self.server.is_serving()
return False
@ -410,7 +431,9 @@ class AsyncioServer:
)
def __await__(self):
"""Starts the asyncio server, returns AsyncServerCoro"""
"""
Starts the asyncio server, returns AsyncServerCoro
"""
task = asyncio.ensure_future(self.serve_coro)
while not task.done():
yield
@ -422,20 +445,20 @@ def serve(
host,
port,
app,
before_start=None,
after_start=None,
before_stop=None,
after_stop=None,
ssl=None,
sock=None,
unix=None,
reuse_port=False,
before_start: Optional[Iterable[ListenerType]] = None,
after_start: Optional[Iterable[ListenerType]] = None,
before_stop: Optional[Iterable[ListenerType]] = None,
after_stop: Optional[Iterable[ListenerType]] = None,
ssl: Optional[SSLContext] = None,
sock: Optional[socket.socket] = None,
unix: Optional[str] = None,
reuse_port: bool = False,
loop=None,
protocol=HttpProtocol,
backlog=100,
register_sys_signals=True,
run_multiple=False,
run_async=False,
protocol: Type[asyncio.Protocol] = HttpProtocol,
backlog: int = 100,
register_sys_signals: bool = True,
run_multiple: bool = False,
run_async: bool = False,
connections=None,
signal=Signal(),
state=None,
@ -560,7 +583,7 @@ def serve(
# instead of letting connection hangs forever.
# Let's roughly calcucate time.
graceful = app.config.GRACEFUL_SHUTDOWN_TIMEOUT
start_shutdown = 0
start_shutdown: float = 0
while connections and (start_shutdown < graceful):
loop.run_until_complete(asyncio.sleep(0.1))
start_shutdown = start_shutdown + 0.1
@ -584,7 +607,7 @@ def serve(
def _build_protocol_kwargs(
protocol: Type[HttpProtocol], config: Config
protocol: Type[asyncio.Protocol], config: Config
) -> Dict[str, Union[int, float]]:
if hasattr(protocol, "websocket_handshake"):
return {
@ -660,7 +683,7 @@ def bind_unix_socket(path: str, *, mode=0o666, backlog=100) -> socket.socket:
return sock
def remove_unix_socket(path: str) -> None:
def remove_unix_socket(path: Optional[str]) -> None:
"""Remove dead unix socket during server exit."""
if not path:
return

View File

@ -92,6 +92,9 @@ class CompositionView:
self.handlers = {}
self.name = self.__class__.__name__
def __name__(self):
return self.name
def add(self, methods, handler, stream=False):
if stream:
handler.is_stream = stream