This commit is contained in:
Adam Hopkins 2021-01-31 16:08:19 +02:00
parent 4358a7eefd
commit a46ea4fc59
9 changed files with 125 additions and 98 deletions

View File

@ -7,21 +7,20 @@ from asyncio import CancelledError, Protocol, ensure_future, get_event_loop
from asyncio.futures import Future from asyncio.futures import Future
from collections import defaultdict, deque from collections import defaultdict, deque
from functools import partial from functools import partial
from inspect import isawaitable, signature from inspect import isawaitable
from socket import socket from socket import socket
from ssl import Purpose, SSLContext, create_default_context from ssl import Purpose, SSLContext, create_default_context
from traceback import format_exc from traceback import format_exc
from typing import ( from typing import (
Any, Any,
Callable, Callable,
Coroutine, Deque,
Dict, Dict,
Iterable, Iterable,
List, List,
Optional, Optional,
Set, Set,
Type, Type,
TypeVar,
Union, Union,
) )
from urllib.parse import urlencode, urlunparse from urllib.parse import urlencode, urlunparse
@ -33,7 +32,6 @@ from sanic.asgi import ASGIApp
from sanic.blueprint_group import BlueprintGroup from sanic.blueprint_group import BlueprintGroup
from sanic.blueprints import Blueprint from sanic.blueprints import Blueprint
from sanic.config import BASE_LOGO, Config from sanic.config import BASE_LOGO, Config
from sanic.constants import HTTP_METHODS
from sanic.exceptions import ( from sanic.exceptions import (
InvalidUsage, InvalidUsage,
NotFound, NotFound,
@ -64,9 +62,9 @@ from sanic.server import (
Signal, Signal,
serve, serve,
serve_multiple, serve_multiple,
trigger_events,
) )
from sanic.static import register as static_register from sanic.static import register as static_register
from sanic.views import CompositionView
from sanic.websocket import ConnectionClosed, WebSocketProtocol from sanic.websocket import ConnectionClosed, WebSocketProtocol
@ -111,8 +109,8 @@ class Sanic(
self.request_class = request_class self.request_class = request_class
self.error_handler = error_handler or ErrorHandler() self.error_handler = error_handler or ErrorHandler()
self.config = Config(load_env=load_env) self.config = Config(load_env=load_env)
self.request_middleware: Iterable[MiddlewareType] = deque() self.request_middleware: Deque[MiddlewareType] = deque()
self.response_middleware: Iterable[MiddlewareType] = deque() self.response_middleware: Deque[MiddlewareType] = deque()
self.blueprints: Dict[str, Blueprint] = {} self.blueprints: Dict[str, Blueprint] = {}
self._blueprint_order: List[Blueprint] = [] self._blueprint_order: List[Blueprint] = []
self.configure_logging = configure_logging self.configure_logging = configure_logging
@ -124,8 +122,8 @@ class Sanic(
self.is_running = False self.is_running = False
self.websocket_enabled = False self.websocket_enabled = False
self.websocket_tasks: Set[Future] = set() self.websocket_tasks: Set[Future] = set()
self.named_request_middleware: Dict[str, MiddlewareType] = {} self.named_request_middleware: Dict[str, Deque[MiddlewareType]] = {}
self.named_response_middleware: Dict[str, MiddlewareType] = {} self.named_response_middleware: Dict[str, Deque[MiddlewareType]] = {}
self._test_client = None self._test_client = None
self._asgi_client = None self._asgi_client = None
# Register alternative method names # Register alternative method names
@ -164,7 +162,8 @@ class Sanic(
also return a future, and the actual ensure_future call also return a future, and the actual ensure_future call
is delayed until before server start. is delayed until before server start.
`See user guide <https://sanicframework.org/guide/basics/tasks.html#background-tasks>`_ `See user guide
<https://sanicframework.org/guide/basics/tasks.html#background-tasks>`_
:param task: future, couroutine or awaitable :param task: future, couroutine or awaitable
""" """
@ -272,7 +271,8 @@ class Sanic(
:param middleware: the middleware to execute :param middleware: the middleware to execute
:param route_names: a list of the names of the endpoints :param route_names: a list of the names of the endpoints
:type route_names: Iterable[str] :type route_names: Iterable[str]
:param attach_to: whether to attach to request or response, defaults to "request" :param attach_to: whether to attach to request or response,
defaults to "request"
:type attach_to: str, optional :type attach_to: str, optional
""" """
if attach_to == "request": if attach_to == "request":
@ -336,10 +336,11 @@ class Sanic(
Keyword arguments that are not request parameters will be included in Keyword arguments that are not request parameters will be included in
the output URL's query string. the output URL's query string.
`See user guide <https://sanicframework.org/guide/basics/routing.html#generating-a-url>`__ `See user guide
<https://sanicframework.org/guide/basics/routing.html#generating-a-url>`_
:param view_name: string referencing the view name :param view_name: string referencing the view name
:param \**kwargs: keys and values that are used to build request :param kwargs: keys and values that are used to build request
parameters and query string arguments. parameters and query string arguments.
:return: the built URL :return: the built URL
@ -509,11 +510,13 @@ class Sanic(
response = await request.respond(response) response = await request.respond(response)
except BaseException: except BaseException:
# Skip response middleware # Skip response middleware
request.stream.respond(response) if request.stream:
request.stream.respond(response)
await response.send(end_stream=True) await response.send(end_stream=True)
raise raise
else: else:
response = request.stream.response if request.stream:
response = request.stream.response
if isinstance(response, BaseHTTPResponse): if isinstance(response, BaseHTTPResponse):
await response.send(end_stream=True) await response.send(end_stream=True)
else: else:
@ -551,7 +554,11 @@ class Sanic(
) = self.router.get(request) ) = self.router.get(request)
request.name = name request.name = name
if request.stream.request_body and not ignore_body: if (
request.stream
and request.stream.request_body
and not ignore_body
):
if self.router.is_stream_handler(request): if self.router.is_stream_handler(request):
# Streaming handler: lift the size limit # Streaming handler: lift the size limit
request.stream.request_max_size = float("inf") request.stream.request_max_size = float("inf")
@ -589,7 +596,8 @@ class Sanic(
if response: if response:
response = await request.respond(response) response = await request.respond(response)
else: else:
response = request.stream.response if request.stream:
response = request.stream.response
# Make sure that response is finished / run StreamingHTTP callback # Make sure that response is finished / run StreamingHTTP callback
if isinstance(response, BaseHTTPResponse): if isinstance(response, BaseHTTPResponse):
@ -597,7 +605,7 @@ class Sanic(
else: else:
try: try:
# Fastest method for checking if the property exists # Fastest method for checking if the property exists
handler.is_websocket handler.is_websocket # type: ignore
except AttributeError: except AttributeError:
raise ServerError( raise ServerError(
f"Invalid response type {response!r} " f"Invalid response type {response!r} "
@ -841,7 +849,7 @@ class Sanic(
) )
# Trigger before_start events # Trigger before_start events
await self.trigger_events( await trigger_events(
server_settings.get("before_start", []), server_settings.get("before_start", []),
server_settings.get("loop"), server_settings.get("loop"),
) )
@ -850,17 +858,6 @@ class Sanic(
asyncio_server_kwargs=asyncio_server_kwargs, **server_settings asyncio_server_kwargs=asyncio_server_kwargs, **server_settings
) )
async def trigger_events(self, events, loop):
"""
Trigger events (functions or async)
:param events: one or more sync or async functions to execute
:param loop: event loop
"""
for event in events:
result = event(loop)
if isawaitable(result):
await result
async def _run_request_middleware(self, request, request_name=None): async def _run_request_middleware(self, request, request_name=None):
# The if improves speed. I don't know why # The if improves speed. I don't know why
named_middleware = self.named_request_middleware.get( named_middleware = self.named_request_middleware.get(
@ -1075,7 +1072,8 @@ class Sanic(
""" """
Update app.config. Full implementation can be found in the user guide. Update app.config. Full implementation can be found in the user guide.
`See user guide <https://sanicframework.org/guide/deployment/configuration.html#basics>`__ `See user guide
<https://sanicframework.org/guide/deployment/configuration.html#basics>`__
""" """
self.config.update_config(config) self.config.update_config(config)

View File

@ -1,8 +1,7 @@
from __future__ import annotations from __future__ import annotations
from typing import (
Optional, from typing import TYPE_CHECKING, Optional
TYPE_CHECKING,
)
if TYPE_CHECKING: if TYPE_CHECKING:
from sanic.request import Request from sanic.request import Request
@ -264,7 +263,7 @@ class Http:
# Compatibility with simple response body # Compatibility with simple response body
if not data and getattr(res, "body", None): if not data and getattr(res, "body", None):
data, end_stream = res.body, True data, end_stream = res.body, True # type: ignore
size = len(data) size = len(data)
headers = res.headers headers = res.headers

View File

@ -24,7 +24,9 @@ class ListenerMixin:
def listener( def listener(
self, self,
listener_or_event: Union[Callable[..., Coroutine[Any, Any, None]]], listener_or_event: Union[
Callable[..., Coroutine[Any, Any, None]], str
],
event_or_none: Optional[str] = None, event_or_none: Optional[str] = None,
apply: bool = True, apply: bool = True,
): ):
@ -39,7 +41,8 @@ class ListenerMixin:
async def before_server_start(app, loop): async def before_server_start(app, loop):
... ...
`See user guide <https://sanicframework.org/guide/basics/listeners.html#listeners>`__ `See user guide
<https://sanicframework.org/guide/basics/listeners.html#listeners>`_
:param event: event to listen to :param event: event to listen to
""" """

View File

@ -19,7 +19,8 @@ class MiddlewareMixin:
Can either be called as *@app.middleware* or Can either be called as *@app.middleware* or
*@app.middleware('request')* *@app.middleware('request')*
`See user guide <https://sanicframework.org/guide/basics/middleware.html>`__ `See user guide
<https://sanicframework.org/guide/basics/middleware.html>`_
:param: middleware_or_request: Optional parameter to use for :param: middleware_or_request: Optional parameter to use for
identifying which type of middleware is being registered. identifying which type of middleware is being registered.

View File

@ -100,7 +100,7 @@ class RouteMixin:
route = FutureRoute( route = FutureRoute(
handler, handler,
uri, uri,
methods, frozenset(methods),
host, host,
strict_slashes, strict_slashes,
stream, stream,

View File

@ -21,7 +21,7 @@ if TYPE_CHECKING:
import email.utils import email.utils
import uuid import uuid
from asyncio.transports import BaseTransport from asyncio.transports import Transport
from collections import defaultdict from collections import defaultdict
from http.cookies import SimpleCookie from http.cookies import SimpleCookie
from types import SimpleNamespace from types import SimpleNamespace
@ -115,7 +115,7 @@ class Request:
headers: Header, headers: Header,
version: str, version: str,
method: str, method: str,
transport: BaseTransport, transport: Transport,
app: Sanic, app: Sanic,
): ):
self.raw_url = url_bytes self.raw_url = url_bytes
@ -144,11 +144,11 @@ class Request:
self.parsed_not_grouped_args: DefaultDict[ self.parsed_not_grouped_args: DefaultDict[
Tuple[bool, bool, str, str], List[Tuple[str, str]] Tuple[bool, bool, str, str], List[Tuple[str, str]]
] = defaultdict(list) ] = defaultdict(list)
self.uri_template = None self.uri_template: Optional[str] = None
self.request_middleware_started = False self.request_middleware_started = False
self._cookies: Dict[str, str] = {} self._cookies: Dict[str, str] = {}
self.stream: Optional[Http] = None self.stream: Optional[Http] = None
self.endpoint = None self.endpoint: Optional[str] = None
def __repr__(self): def __repr__(self):
class_name = self.__class__.__name__ class_name = self.__class__.__name__
@ -182,7 +182,7 @@ class Request:
self, response, request_name=self.name self, response, request_name=self.name
) )
# Redefining this as a tuple here satisfies mypy # Redefining this as a tuple here satisfies mypy
except tuple(CancelledErrors): except tuple(*CancelledErrors):
raise raise
except Exception: except Exception:
error_logger.exception( error_logger.exception(

View File

@ -192,7 +192,7 @@ class StreamingHTTPResponse(BaseHTTPResponse):
self, self,
streaming_fn: StreamingFunction, streaming_fn: StreamingFunction,
status: int = 200, status: int = 200,
headers: Optional[Dict[str, str]] = None, headers: Optional[Union[Header, Dict[str, str]]] = None,
content_type: str = "text/plain; charset=utf-8", content_type: str = "text/plain; charset=utf-8",
chunked="deprecated", chunked="deprecated",
): ):
@ -244,7 +244,7 @@ class HTTPResponse(BaseHTTPResponse):
self, self,
body: Optional[AnyStr] = None, body: Optional[AnyStr] = None,
status: int = 200, status: int = 200,
headers: Optional[Dict[str, str]] = None, headers: Optional[Union[Header, Dict[str, str]]] = None,
content_type: Optional[str] = None, content_type: Optional[str] = None,
): ):
super().__init__() super().__init__()

View File

@ -1,17 +1,19 @@
from __future__ import annotations from __future__ import annotations
from ssl import SSLContext
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
DefaultDict, Any,
Callable,
Dict, Dict,
List, Iterable,
NamedTuple,
Optional, Optional,
Tuple,
Type, Type,
Union, Union,
) )
from sanic.handlers import ListenerType
if TYPE_CHECKING: if TYPE_CHECKING:
from sanic.app import Sanic from sanic.app import Sanic
@ -25,7 +27,7 @@ import stat
import sys import sys
from asyncio import CancelledError from asyncio import CancelledError
from asyncio.transports import BaseTransport from asyncio.transports import Transport
from functools import partial from functools import partial
from inspect import isawaitable from inspect import isawaitable
from ipaddress import ip_address from ipaddress import ip_address
@ -67,8 +69,8 @@ class ConnInfo:
"ssl", "ssl",
) )
def __init__(self, transport: BaseTransport, unix=None): def __init__(self, transport: Transport, unix=None):
self.ssl = bool(transport.get_extra_info("sslcontext")) self.ssl: bool = bool(transport.get_extra_info("sslcontext"))
self.server = self.client = "" self.server = self.client = ""
self.server_port = self.client_port = 0 self.server_port = self.client_port = 0
self.peername = None self.peername = None
@ -134,8 +136,8 @@ class HttpProtocol(asyncio.Protocol):
self, self,
*, *,
loop, loop,
app, app: Sanic,
signal=Signal(), signal=None,
connections=None, connections=None,
state=None, state=None,
unix=None, unix=None,
@ -146,10 +148,10 @@ class HttpProtocol(asyncio.Protocol):
deprecated_loop = self.loop if sys.version_info < (3, 7) else None deprecated_loop = self.loop if sys.version_info < (3, 7) else None
self.app: Sanic = app self.app: Sanic = app
self.url = None self.url = None
self.transport = None self.transport: Optional[Transport] = None
self.conn_info = None self.conn_info: Optional[ConnInfo] = None
self.request = None self.request: Optional[Request] = None
self.signal = signal self.signal = signal or Signal()
self.access_log = self.app.config.ACCESS_LOG self.access_log = self.app.config.ACCESS_LOG
self.connections = connections if connections is not None else set() self.connections = connections if connections is not None else set()
self.request_handler = self.app.handle_request self.request_handler = self.app.handle_request
@ -177,7 +179,8 @@ class HttpProtocol(asyncio.Protocol):
self.check_timeouts() self.check_timeouts()
async def connection_task(self): async def connection_task(self):
"""Run a HTTP connection. """
Run a HTTP connection.
Timeouts and some additional error handling occur here, while most of Timeouts and some additional error handling occur here, while most of
everything else happens in class Http or in code called from there. everything else happens in class Http or in code called from there.
@ -204,13 +207,17 @@ class HttpProtocol(asyncio.Protocol):
logger.exception("Closing failed") logger.exception("Closing failed")
async def receive_more(self): async def receive_more(self):
"""Wait until more data is received into self._buffer.""" """
Wait until more data is received into the Server protocol's buffer
"""
self.transport.resume_reading() self.transport.resume_reading()
self._data_received.clear() self._data_received.clear()
await self._data_received.wait() await self._data_received.wait()
def check_timeouts(self): def check_timeouts(self):
"""Runs itself periodically to enforce any expired timeouts.""" """
Runs itself periodically to enforce any expired timeouts.
"""
try: try:
if not self._task: if not self._task:
return return
@ -241,15 +248,18 @@ class HttpProtocol(asyncio.Protocol):
logger.exception("protocol.check_timeouts") logger.exception("protocol.check_timeouts")
async def send(self, data): async def send(self, data):
"""Writes data with backpressure control.""" """
Writes data with backpressure control.
"""
await self._can_write.wait() await self._can_write.wait()
if self.transport.is_closing(): if self.transport.is_closing():
raise CancelledError raise CancelledError
self.transport.write(data) self.transport.write(data)
self._time = current_time() self._time = current_time()
def close_if_idle(self): def close_if_idle(self) -> bool:
"""Close the connection if a request is not being sent or received """
Close the connection if a request is not being sent or received
:return: boolean - True if closed, false if staying open :return: boolean - True if closed, false if staying open
""" """
@ -298,14 +308,17 @@ class HttpProtocol(asyncio.Protocol):
def resume_writing(self): def resume_writing(self):
self._can_write.set() self._can_write.set()
def data_received(self, data): def data_received(self, data: bytes):
try: try:
self._time = current_time() self._time = current_time()
if not data: if not data:
return self.close() return self.close()
self.recv_buffer += data self.recv_buffer += data
if len(self.recv_buffer) > self.app.config.REQUEST_BUFFER_SIZE: if (
len(self.recv_buffer) > self.app.config.REQUEST_BUFFER_SIZE
and self.transport
):
self.transport.pause_reading() self.transport.pause_reading()
if self._data_received: if self._data_received:
@ -314,16 +327,18 @@ class HttpProtocol(asyncio.Protocol):
logger.exception("protocol.data_received") logger.exception("protocol.data_received")
def trigger_events(events, loop): def trigger_events(events: Optional[Iterable[Callable[..., Any]]], loop):
"""Trigger event callbacks (functions or async) """
Trigger event callbacks (functions or async)
:param events: one or more sync or async functions to execute :param events: one or more sync or async functions to execute
:param loop: event loop :param loop: event loop
""" """
for event in events: if events:
result = event(loop) for event in events:
if isawaitable(result): result = event(loop)
loop.run_until_complete(result) if isawaitable(result):
loop.run_until_complete(result)
class AsyncioServer: class AsyncioServer:
@ -347,9 +362,9 @@ class AsyncioServer:
loop, loop,
serve_coro, serve_coro,
connections, connections,
after_start, after_start: Optional[Iterable[ListenerType]],
before_stop, before_stop: Optional[Iterable[ListenerType]],
after_stop, after_stop: Optional[Iterable[ListenerType]],
): ):
# Note, Sanic already called "before_server_start" events # Note, Sanic already called "before_server_start" events
# before this helper was even created. So we don't need it here. # before this helper was even created. So we don't need it here.
@ -362,18 +377,24 @@ class AsyncioServer:
self.connections = connections self.connections = connections
def after_start(self): def after_start(self):
"""Trigger "after_server_start" events""" """
Trigger "after_server_start" events
"""
trigger_events(self._after_start, self.loop) trigger_events(self._after_start, self.loop)
def before_stop(self): def before_stop(self):
"""Trigger "before_server_stop" events""" """
Trigger "before_server_stop" events
"""
trigger_events(self._before_stop, self.loop) trigger_events(self._before_stop, self.loop)
def after_stop(self): def after_stop(self):
"""Trigger "after_server_stop" events""" """
Trigger "after_server_stop" events
"""
trigger_events(self._after_stop, self.loop) trigger_events(self._after_stop, self.loop)
def is_serving(self): def is_serving(self) -> bool:
if self.server: if self.server:
return self.server.is_serving() return self.server.is_serving()
return False return False
@ -410,7 +431,9 @@ class AsyncioServer:
) )
def __await__(self): def __await__(self):
"""Starts the asyncio server, returns AsyncServerCoro""" """
Starts the asyncio server, returns AsyncServerCoro
"""
task = asyncio.ensure_future(self.serve_coro) task = asyncio.ensure_future(self.serve_coro)
while not task.done(): while not task.done():
yield yield
@ -422,20 +445,20 @@ def serve(
host, host,
port, port,
app, app,
before_start=None, before_start: Optional[Iterable[ListenerType]] = None,
after_start=None, after_start: Optional[Iterable[ListenerType]] = None,
before_stop=None, before_stop: Optional[Iterable[ListenerType]] = None,
after_stop=None, after_stop: Optional[Iterable[ListenerType]] = None,
ssl=None, ssl: Optional[SSLContext] = None,
sock=None, sock: Optional[socket.socket] = None,
unix=None, unix: Optional[str] = None,
reuse_port=False, reuse_port: bool = False,
loop=None, loop=None,
protocol=HttpProtocol, protocol: Type[asyncio.Protocol] = HttpProtocol,
backlog=100, backlog: int = 100,
register_sys_signals=True, register_sys_signals: bool = True,
run_multiple=False, run_multiple: bool = False,
run_async=False, run_async: bool = False,
connections=None, connections=None,
signal=Signal(), signal=Signal(),
state=None, state=None,
@ -560,7 +583,7 @@ def serve(
# instead of letting connection hangs forever. # instead of letting connection hangs forever.
# Let's roughly calcucate time. # Let's roughly calcucate time.
graceful = app.config.GRACEFUL_SHUTDOWN_TIMEOUT graceful = app.config.GRACEFUL_SHUTDOWN_TIMEOUT
start_shutdown = 0 start_shutdown: float = 0
while connections and (start_shutdown < graceful): while connections and (start_shutdown < graceful):
loop.run_until_complete(asyncio.sleep(0.1)) loop.run_until_complete(asyncio.sleep(0.1))
start_shutdown = start_shutdown + 0.1 start_shutdown = start_shutdown + 0.1
@ -584,7 +607,7 @@ def serve(
def _build_protocol_kwargs( def _build_protocol_kwargs(
protocol: Type[HttpProtocol], config: Config protocol: Type[asyncio.Protocol], config: Config
) -> Dict[str, Union[int, float]]: ) -> Dict[str, Union[int, float]]:
if hasattr(protocol, "websocket_handshake"): if hasattr(protocol, "websocket_handshake"):
return { return {
@ -660,7 +683,7 @@ def bind_unix_socket(path: str, *, mode=0o666, backlog=100) -> socket.socket:
return sock return sock
def remove_unix_socket(path: str) -> None: def remove_unix_socket(path: Optional[str]) -> None:
"""Remove dead unix socket during server exit.""" """Remove dead unix socket during server exit."""
if not path: if not path:
return return

View File

@ -92,6 +92,9 @@ class CompositionView:
self.handlers = {} self.handlers = {}
self.name = self.__class__.__name__ self.name = self.__class__.__name__
def __name__(self):
return self.name
def add(self, methods, handler, stream=False): def add(self, methods, handler, stream=False):
if stream: if stream:
handler.is_stream = stream handler.is_stream = stream