squash
This commit is contained in:
parent
4358a7eefd
commit
a46ea4fc59
60
sanic/app.py
60
sanic/app.py
|
@ -7,21 +7,20 @@ from asyncio import CancelledError, Protocol, ensure_future, get_event_loop
|
|||
from asyncio.futures import Future
|
||||
from collections import defaultdict, deque
|
||||
from functools import partial
|
||||
from inspect import isawaitable, signature
|
||||
from inspect import isawaitable
|
||||
from socket import socket
|
||||
from ssl import Purpose, SSLContext, create_default_context
|
||||
from traceback import format_exc
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Coroutine,
|
||||
Deque,
|
||||
Dict,
|
||||
Iterable,
|
||||
List,
|
||||
Optional,
|
||||
Set,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
from urllib.parse import urlencode, urlunparse
|
||||
|
@ -33,7 +32,6 @@ from sanic.asgi import ASGIApp
|
|||
from sanic.blueprint_group import BlueprintGroup
|
||||
from sanic.blueprints import Blueprint
|
||||
from sanic.config import BASE_LOGO, Config
|
||||
from sanic.constants import HTTP_METHODS
|
||||
from sanic.exceptions import (
|
||||
InvalidUsage,
|
||||
NotFound,
|
||||
|
@ -64,9 +62,9 @@ from sanic.server import (
|
|||
Signal,
|
||||
serve,
|
||||
serve_multiple,
|
||||
trigger_events,
|
||||
)
|
||||
from sanic.static import register as static_register
|
||||
from sanic.views import CompositionView
|
||||
from sanic.websocket import ConnectionClosed, WebSocketProtocol
|
||||
|
||||
|
||||
|
@ -111,8 +109,8 @@ class Sanic(
|
|||
self.request_class = request_class
|
||||
self.error_handler = error_handler or ErrorHandler()
|
||||
self.config = Config(load_env=load_env)
|
||||
self.request_middleware: Iterable[MiddlewareType] = deque()
|
||||
self.response_middleware: Iterable[MiddlewareType] = deque()
|
||||
self.request_middleware: Deque[MiddlewareType] = deque()
|
||||
self.response_middleware: Deque[MiddlewareType] = deque()
|
||||
self.blueprints: Dict[str, Blueprint] = {}
|
||||
self._blueprint_order: List[Blueprint] = []
|
||||
self.configure_logging = configure_logging
|
||||
|
@ -124,8 +122,8 @@ class Sanic(
|
|||
self.is_running = False
|
||||
self.websocket_enabled = False
|
||||
self.websocket_tasks: Set[Future] = set()
|
||||
self.named_request_middleware: Dict[str, MiddlewareType] = {}
|
||||
self.named_response_middleware: Dict[str, MiddlewareType] = {}
|
||||
self.named_request_middleware: Dict[str, Deque[MiddlewareType]] = {}
|
||||
self.named_response_middleware: Dict[str, Deque[MiddlewareType]] = {}
|
||||
self._test_client = None
|
||||
self._asgi_client = None
|
||||
# Register alternative method names
|
||||
|
@ -164,7 +162,8 @@ class Sanic(
|
|||
also return a future, and the actual ensure_future call
|
||||
is delayed until before server start.
|
||||
|
||||
`See user guide <https://sanicframework.org/guide/basics/tasks.html#background-tasks>`_
|
||||
`See user guide
|
||||
<https://sanicframework.org/guide/basics/tasks.html#background-tasks>`_
|
||||
|
||||
:param task: future, couroutine or awaitable
|
||||
"""
|
||||
|
@ -272,7 +271,8 @@ class Sanic(
|
|||
:param middleware: the middleware to execute
|
||||
:param route_names: a list of the names of the endpoints
|
||||
:type route_names: Iterable[str]
|
||||
:param attach_to: whether to attach to request or response, defaults to "request"
|
||||
:param attach_to: whether to attach to request or response,
|
||||
defaults to "request"
|
||||
:type attach_to: str, optional
|
||||
"""
|
||||
if attach_to == "request":
|
||||
|
@ -336,10 +336,11 @@ class Sanic(
|
|||
Keyword arguments that are not request parameters will be included in
|
||||
the output URL's query string.
|
||||
|
||||
`See user guide <https://sanicframework.org/guide/basics/routing.html#generating-a-url>`__
|
||||
`See user guide
|
||||
<https://sanicframework.org/guide/basics/routing.html#generating-a-url>`_
|
||||
|
||||
:param view_name: string referencing the view name
|
||||
:param \**kwargs: keys and values that are used to build request
|
||||
:param kwargs: keys and values that are used to build request
|
||||
parameters and query string arguments.
|
||||
|
||||
:return: the built URL
|
||||
|
@ -509,11 +510,13 @@ class Sanic(
|
|||
response = await request.respond(response)
|
||||
except BaseException:
|
||||
# Skip response middleware
|
||||
request.stream.respond(response)
|
||||
if request.stream:
|
||||
request.stream.respond(response)
|
||||
await response.send(end_stream=True)
|
||||
raise
|
||||
else:
|
||||
response = request.stream.response
|
||||
if request.stream:
|
||||
response = request.stream.response
|
||||
if isinstance(response, BaseHTTPResponse):
|
||||
await response.send(end_stream=True)
|
||||
else:
|
||||
|
@ -551,7 +554,11 @@ class Sanic(
|
|||
) = self.router.get(request)
|
||||
request.name = name
|
||||
|
||||
if request.stream.request_body and not ignore_body:
|
||||
if (
|
||||
request.stream
|
||||
and request.stream.request_body
|
||||
and not ignore_body
|
||||
):
|
||||
if self.router.is_stream_handler(request):
|
||||
# Streaming handler: lift the size limit
|
||||
request.stream.request_max_size = float("inf")
|
||||
|
@ -589,7 +596,8 @@ class Sanic(
|
|||
if response:
|
||||
response = await request.respond(response)
|
||||
else:
|
||||
response = request.stream.response
|
||||
if request.stream:
|
||||
response = request.stream.response
|
||||
# Make sure that response is finished / run StreamingHTTP callback
|
||||
|
||||
if isinstance(response, BaseHTTPResponse):
|
||||
|
@ -597,7 +605,7 @@ class Sanic(
|
|||
else:
|
||||
try:
|
||||
# Fastest method for checking if the property exists
|
||||
handler.is_websocket
|
||||
handler.is_websocket # type: ignore
|
||||
except AttributeError:
|
||||
raise ServerError(
|
||||
f"Invalid response type {response!r} "
|
||||
|
@ -841,7 +849,7 @@ class Sanic(
|
|||
)
|
||||
|
||||
# Trigger before_start events
|
||||
await self.trigger_events(
|
||||
await trigger_events(
|
||||
server_settings.get("before_start", []),
|
||||
server_settings.get("loop"),
|
||||
)
|
||||
|
@ -850,17 +858,6 @@ class Sanic(
|
|||
asyncio_server_kwargs=asyncio_server_kwargs, **server_settings
|
||||
)
|
||||
|
||||
async def trigger_events(self, events, loop):
|
||||
"""
|
||||
Trigger events (functions or async)
|
||||
:param events: one or more sync or async functions to execute
|
||||
:param loop: event loop
|
||||
"""
|
||||
for event in events:
|
||||
result = event(loop)
|
||||
if isawaitable(result):
|
||||
await result
|
||||
|
||||
async def _run_request_middleware(self, request, request_name=None):
|
||||
# The if improves speed. I don't know why
|
||||
named_middleware = self.named_request_middleware.get(
|
||||
|
@ -1075,7 +1072,8 @@ class Sanic(
|
|||
"""
|
||||
Update app.config. Full implementation can be found in the user guide.
|
||||
|
||||
`See user guide <https://sanicframework.org/guide/deployment/configuration.html#basics>`__
|
||||
`See user guide
|
||||
<https://sanicframework.org/guide/deployment/configuration.html#basics>`__
|
||||
"""
|
||||
|
||||
self.config.update_config(config)
|
||||
|
|
|
@ -1,8 +1,7 @@
|
|||
from __future__ import annotations
|
||||
from typing import (
|
||||
Optional,
|
||||
TYPE_CHECKING,
|
||||
)
|
||||
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sanic.request import Request
|
||||
|
@ -264,7 +263,7 @@ class Http:
|
|||
|
||||
# Compatibility with simple response body
|
||||
if not data and getattr(res, "body", None):
|
||||
data, end_stream = res.body, True
|
||||
data, end_stream = res.body, True # type: ignore
|
||||
|
||||
size = len(data)
|
||||
headers = res.headers
|
||||
|
|
|
@ -24,7 +24,9 @@ class ListenerMixin:
|
|||
|
||||
def listener(
|
||||
self,
|
||||
listener_or_event: Union[Callable[..., Coroutine[Any, Any, None]]],
|
||||
listener_or_event: Union[
|
||||
Callable[..., Coroutine[Any, Any, None]], str
|
||||
],
|
||||
event_or_none: Optional[str] = None,
|
||||
apply: bool = True,
|
||||
):
|
||||
|
@ -39,7 +41,8 @@ class ListenerMixin:
|
|||
async def before_server_start(app, loop):
|
||||
...
|
||||
|
||||
`See user guide <https://sanicframework.org/guide/basics/listeners.html#listeners>`__
|
||||
`See user guide
|
||||
<https://sanicframework.org/guide/basics/listeners.html#listeners>`_
|
||||
|
||||
:param event: event to listen to
|
||||
"""
|
||||
|
|
|
@ -19,7 +19,8 @@ class MiddlewareMixin:
|
|||
Can either be called as *@app.middleware* or
|
||||
*@app.middleware('request')*
|
||||
|
||||
`See user guide <https://sanicframework.org/guide/basics/middleware.html>`__
|
||||
`See user guide
|
||||
<https://sanicframework.org/guide/basics/middleware.html>`_
|
||||
|
||||
:param: middleware_or_request: Optional parameter to use for
|
||||
identifying which type of middleware is being registered.
|
||||
|
|
|
@ -100,7 +100,7 @@ class RouteMixin:
|
|||
route = FutureRoute(
|
||||
handler,
|
||||
uri,
|
||||
methods,
|
||||
frozenset(methods),
|
||||
host,
|
||||
strict_slashes,
|
||||
stream,
|
||||
|
|
|
@ -21,7 +21,7 @@ if TYPE_CHECKING:
|
|||
import email.utils
|
||||
import uuid
|
||||
|
||||
from asyncio.transports import BaseTransport
|
||||
from asyncio.transports import Transport
|
||||
from collections import defaultdict
|
||||
from http.cookies import SimpleCookie
|
||||
from types import SimpleNamespace
|
||||
|
@ -115,7 +115,7 @@ class Request:
|
|||
headers: Header,
|
||||
version: str,
|
||||
method: str,
|
||||
transport: BaseTransport,
|
||||
transport: Transport,
|
||||
app: Sanic,
|
||||
):
|
||||
self.raw_url = url_bytes
|
||||
|
@ -144,11 +144,11 @@ class Request:
|
|||
self.parsed_not_grouped_args: DefaultDict[
|
||||
Tuple[bool, bool, str, str], List[Tuple[str, str]]
|
||||
] = defaultdict(list)
|
||||
self.uri_template = None
|
||||
self.uri_template: Optional[str] = None
|
||||
self.request_middleware_started = False
|
||||
self._cookies: Dict[str, str] = {}
|
||||
self.stream: Optional[Http] = None
|
||||
self.endpoint = None
|
||||
self.endpoint: Optional[str] = None
|
||||
|
||||
def __repr__(self):
|
||||
class_name = self.__class__.__name__
|
||||
|
@ -182,7 +182,7 @@ class Request:
|
|||
self, response, request_name=self.name
|
||||
)
|
||||
# Redefining this as a tuple here satisfies mypy
|
||||
except tuple(CancelledErrors):
|
||||
except tuple(*CancelledErrors):
|
||||
raise
|
||||
except Exception:
|
||||
error_logger.exception(
|
||||
|
|
|
@ -192,7 +192,7 @@ class StreamingHTTPResponse(BaseHTTPResponse):
|
|||
self,
|
||||
streaming_fn: StreamingFunction,
|
||||
status: int = 200,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
headers: Optional[Union[Header, Dict[str, str]]] = None,
|
||||
content_type: str = "text/plain; charset=utf-8",
|
||||
chunked="deprecated",
|
||||
):
|
||||
|
@ -244,7 +244,7 @@ class HTTPResponse(BaseHTTPResponse):
|
|||
self,
|
||||
body: Optional[AnyStr] = None,
|
||||
status: int = 200,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
headers: Optional[Union[Header, Dict[str, str]]] = None,
|
||||
content_type: Optional[str] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
|
125
sanic/server.py
125
sanic/server.py
|
@ -1,17 +1,19 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from ssl import SSLContext
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
DefaultDict,
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
NamedTuple,
|
||||
Iterable,
|
||||
Optional,
|
||||
Tuple,
|
||||
Type,
|
||||
Union,
|
||||
)
|
||||
|
||||
from sanic.handlers import ListenerType
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sanic.app import Sanic
|
||||
|
@ -25,7 +27,7 @@ import stat
|
|||
import sys
|
||||
|
||||
from asyncio import CancelledError
|
||||
from asyncio.transports import BaseTransport
|
||||
from asyncio.transports import Transport
|
||||
from functools import partial
|
||||
from inspect import isawaitable
|
||||
from ipaddress import ip_address
|
||||
|
@ -67,8 +69,8 @@ class ConnInfo:
|
|||
"ssl",
|
||||
)
|
||||
|
||||
def __init__(self, transport: BaseTransport, unix=None):
|
||||
self.ssl = bool(transport.get_extra_info("sslcontext"))
|
||||
def __init__(self, transport: Transport, unix=None):
|
||||
self.ssl: bool = bool(transport.get_extra_info("sslcontext"))
|
||||
self.server = self.client = ""
|
||||
self.server_port = self.client_port = 0
|
||||
self.peername = None
|
||||
|
@ -134,8 +136,8 @@ class HttpProtocol(asyncio.Protocol):
|
|||
self,
|
||||
*,
|
||||
loop,
|
||||
app,
|
||||
signal=Signal(),
|
||||
app: Sanic,
|
||||
signal=None,
|
||||
connections=None,
|
||||
state=None,
|
||||
unix=None,
|
||||
|
@ -146,10 +148,10 @@ class HttpProtocol(asyncio.Protocol):
|
|||
deprecated_loop = self.loop if sys.version_info < (3, 7) else None
|
||||
self.app: Sanic = app
|
||||
self.url = None
|
||||
self.transport = None
|
||||
self.conn_info = None
|
||||
self.request = None
|
||||
self.signal = signal
|
||||
self.transport: Optional[Transport] = None
|
||||
self.conn_info: Optional[ConnInfo] = None
|
||||
self.request: Optional[Request] = None
|
||||
self.signal = signal or Signal()
|
||||
self.access_log = self.app.config.ACCESS_LOG
|
||||
self.connections = connections if connections is not None else set()
|
||||
self.request_handler = self.app.handle_request
|
||||
|
@ -177,7 +179,8 @@ class HttpProtocol(asyncio.Protocol):
|
|||
self.check_timeouts()
|
||||
|
||||
async def connection_task(self):
|
||||
"""Run a HTTP connection.
|
||||
"""
|
||||
Run a HTTP connection.
|
||||
|
||||
Timeouts and some additional error handling occur here, while most of
|
||||
everything else happens in class Http or in code called from there.
|
||||
|
@ -204,13 +207,17 @@ class HttpProtocol(asyncio.Protocol):
|
|||
logger.exception("Closing failed")
|
||||
|
||||
async def receive_more(self):
|
||||
"""Wait until more data is received into self._buffer."""
|
||||
"""
|
||||
Wait until more data is received into the Server protocol's buffer
|
||||
"""
|
||||
self.transport.resume_reading()
|
||||
self._data_received.clear()
|
||||
await self._data_received.wait()
|
||||
|
||||
def check_timeouts(self):
|
||||
"""Runs itself periodically to enforce any expired timeouts."""
|
||||
"""
|
||||
Runs itself periodically to enforce any expired timeouts.
|
||||
"""
|
||||
try:
|
||||
if not self._task:
|
||||
return
|
||||
|
@ -241,15 +248,18 @@ class HttpProtocol(asyncio.Protocol):
|
|||
logger.exception("protocol.check_timeouts")
|
||||
|
||||
async def send(self, data):
|
||||
"""Writes data with backpressure control."""
|
||||
"""
|
||||
Writes data with backpressure control.
|
||||
"""
|
||||
await self._can_write.wait()
|
||||
if self.transport.is_closing():
|
||||
raise CancelledError
|
||||
self.transport.write(data)
|
||||
self._time = current_time()
|
||||
|
||||
def close_if_idle(self):
|
||||
"""Close the connection if a request is not being sent or received
|
||||
def close_if_idle(self) -> bool:
|
||||
"""
|
||||
Close the connection if a request is not being sent or received
|
||||
|
||||
:return: boolean - True if closed, false if staying open
|
||||
"""
|
||||
|
@ -298,14 +308,17 @@ class HttpProtocol(asyncio.Protocol):
|
|||
def resume_writing(self):
|
||||
self._can_write.set()
|
||||
|
||||
def data_received(self, data):
|
||||
def data_received(self, data: bytes):
|
||||
try:
|
||||
self._time = current_time()
|
||||
if not data:
|
||||
return self.close()
|
||||
self.recv_buffer += data
|
||||
|
||||
if len(self.recv_buffer) > self.app.config.REQUEST_BUFFER_SIZE:
|
||||
if (
|
||||
len(self.recv_buffer) > self.app.config.REQUEST_BUFFER_SIZE
|
||||
and self.transport
|
||||
):
|
||||
self.transport.pause_reading()
|
||||
|
||||
if self._data_received:
|
||||
|
@ -314,16 +327,18 @@ class HttpProtocol(asyncio.Protocol):
|
|||
logger.exception("protocol.data_received")
|
||||
|
||||
|
||||
def trigger_events(events, loop):
|
||||
"""Trigger event callbacks (functions or async)
|
||||
def trigger_events(events: Optional[Iterable[Callable[..., Any]]], loop):
|
||||
"""
|
||||
Trigger event callbacks (functions or async)
|
||||
|
||||
:param events: one or more sync or async functions to execute
|
||||
:param loop: event loop
|
||||
"""
|
||||
for event in events:
|
||||
result = event(loop)
|
||||
if isawaitable(result):
|
||||
loop.run_until_complete(result)
|
||||
if events:
|
||||
for event in events:
|
||||
result = event(loop)
|
||||
if isawaitable(result):
|
||||
loop.run_until_complete(result)
|
||||
|
||||
|
||||
class AsyncioServer:
|
||||
|
@ -347,9 +362,9 @@ class AsyncioServer:
|
|||
loop,
|
||||
serve_coro,
|
||||
connections,
|
||||
after_start,
|
||||
before_stop,
|
||||
after_stop,
|
||||
after_start: Optional[Iterable[ListenerType]],
|
||||
before_stop: Optional[Iterable[ListenerType]],
|
||||
after_stop: Optional[Iterable[ListenerType]],
|
||||
):
|
||||
# Note, Sanic already called "before_server_start" events
|
||||
# before this helper was even created. So we don't need it here.
|
||||
|
@ -362,18 +377,24 @@ class AsyncioServer:
|
|||
self.connections = connections
|
||||
|
||||
def after_start(self):
|
||||
"""Trigger "after_server_start" events"""
|
||||
"""
|
||||
Trigger "after_server_start" events
|
||||
"""
|
||||
trigger_events(self._after_start, self.loop)
|
||||
|
||||
def before_stop(self):
|
||||
"""Trigger "before_server_stop" events"""
|
||||
"""
|
||||
Trigger "before_server_stop" events
|
||||
"""
|
||||
trigger_events(self._before_stop, self.loop)
|
||||
|
||||
def after_stop(self):
|
||||
"""Trigger "after_server_stop" events"""
|
||||
"""
|
||||
Trigger "after_server_stop" events
|
||||
"""
|
||||
trigger_events(self._after_stop, self.loop)
|
||||
|
||||
def is_serving(self):
|
||||
def is_serving(self) -> bool:
|
||||
if self.server:
|
||||
return self.server.is_serving()
|
||||
return False
|
||||
|
@ -410,7 +431,9 @@ class AsyncioServer:
|
|||
)
|
||||
|
||||
def __await__(self):
|
||||
"""Starts the asyncio server, returns AsyncServerCoro"""
|
||||
"""
|
||||
Starts the asyncio server, returns AsyncServerCoro
|
||||
"""
|
||||
task = asyncio.ensure_future(self.serve_coro)
|
||||
while not task.done():
|
||||
yield
|
||||
|
@ -422,20 +445,20 @@ def serve(
|
|||
host,
|
||||
port,
|
||||
app,
|
||||
before_start=None,
|
||||
after_start=None,
|
||||
before_stop=None,
|
||||
after_stop=None,
|
||||
ssl=None,
|
||||
sock=None,
|
||||
unix=None,
|
||||
reuse_port=False,
|
||||
before_start: Optional[Iterable[ListenerType]] = None,
|
||||
after_start: Optional[Iterable[ListenerType]] = None,
|
||||
before_stop: Optional[Iterable[ListenerType]] = None,
|
||||
after_stop: Optional[Iterable[ListenerType]] = None,
|
||||
ssl: Optional[SSLContext] = None,
|
||||
sock: Optional[socket.socket] = None,
|
||||
unix: Optional[str] = None,
|
||||
reuse_port: bool = False,
|
||||
loop=None,
|
||||
protocol=HttpProtocol,
|
||||
backlog=100,
|
||||
register_sys_signals=True,
|
||||
run_multiple=False,
|
||||
run_async=False,
|
||||
protocol: Type[asyncio.Protocol] = HttpProtocol,
|
||||
backlog: int = 100,
|
||||
register_sys_signals: bool = True,
|
||||
run_multiple: bool = False,
|
||||
run_async: bool = False,
|
||||
connections=None,
|
||||
signal=Signal(),
|
||||
state=None,
|
||||
|
@ -560,7 +583,7 @@ def serve(
|
|||
# instead of letting connection hangs forever.
|
||||
# Let's roughly calcucate time.
|
||||
graceful = app.config.GRACEFUL_SHUTDOWN_TIMEOUT
|
||||
start_shutdown = 0
|
||||
start_shutdown: float = 0
|
||||
while connections and (start_shutdown < graceful):
|
||||
loop.run_until_complete(asyncio.sleep(0.1))
|
||||
start_shutdown = start_shutdown + 0.1
|
||||
|
@ -584,7 +607,7 @@ def serve(
|
|||
|
||||
|
||||
def _build_protocol_kwargs(
|
||||
protocol: Type[HttpProtocol], config: Config
|
||||
protocol: Type[asyncio.Protocol], config: Config
|
||||
) -> Dict[str, Union[int, float]]:
|
||||
if hasattr(protocol, "websocket_handshake"):
|
||||
return {
|
||||
|
@ -660,7 +683,7 @@ def bind_unix_socket(path: str, *, mode=0o666, backlog=100) -> socket.socket:
|
|||
return sock
|
||||
|
||||
|
||||
def remove_unix_socket(path: str) -> None:
|
||||
def remove_unix_socket(path: Optional[str]) -> None:
|
||||
"""Remove dead unix socket during server exit."""
|
||||
if not path:
|
||||
return
|
||||
|
|
|
@ -92,6 +92,9 @@ class CompositionView:
|
|||
self.handlers = {}
|
||||
self.name = self.__class__.__name__
|
||||
|
||||
def __name__(self):
|
||||
return self.name
|
||||
|
||||
def add(self, methods, handler, stream=False):
|
||||
if stream:
|
||||
handler.is_stream = stream
|
||||
|
|
Loading…
Reference in New Issue
Block a user