Signals Integration (#2160)

* Update some tests

* Resolve #2122 route decorator returning tuple

* Use rc sanic-routing version

* Update unit tests to <:str>

* Minimal working version with some signals implemented

* Add more http signals

* Update ASGI and change listeners to signals

* Allow for dynamic ODE signals

* Allow signals to be stacked

* Begin tests

* Prioritize match_info on keyword argument injection

* WIP on tests

* Compat with signals

* Work through some test coverage

* Passing tests

* Post linting

* Setup proper resets

* coverage reporting

* Fixes from vltr comments

* clear delayed tasks

* Fix bad test

* rm pycache
This commit is contained in:
Adam Hopkins 2021-08-05 22:55:42 +03:00 committed by GitHub
parent 0ba57d4701
commit b1b12e004e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
31 changed files with 823 additions and 513 deletions

View File

@ -1,29 +1,44 @@
from sanic import Sanic
from sanic import response
from signal import signal, SIGINT
import asyncio
from signal import SIGINT, signal
import uvloop
from sanic import Sanic, response
from sanic.server import AsyncioServer
app = Sanic(__name__)
@app.listener('after_server_start')
@app.listener("after_server_start")
async def after_start_test(app, loop):
print("Async Server Started!")
@app.route("/")
async def test(request):
return response.json({"answer": "42"})
asyncio.set_event_loop(uvloop.new_event_loop())
serv_coro = app.create_server(host="0.0.0.0", port=8000, return_asyncio_server=True)
serv_coro = app.create_server(
host="0.0.0.0", port=8000, return_asyncio_server=True
)
loop = asyncio.get_event_loop()
serv_task = asyncio.ensure_future(serv_coro, loop=loop)
signal(SIGINT, lambda s, f: loop.stop())
server = loop.run_until_complete(serv_task)
server: AsyncioServer = loop.run_until_complete(serv_task) # type: ignore
server.startup()
# When using app.run(), this actually triggers before the serv_coro.
# But, in this example, we are using the convenience method, even if it is
# out of order.
server.before_start()
server.after_start()
try:
loop.run_forever()
except KeyboardInterrupt as e:
except KeyboardInterrupt:
loop.stop()
finally:
server.before_stop()

View File

@ -1,9 +1,12 @@
from __future__ import annotations
import logging
import logging.config
import os
import re
from asyncio import (
AbstractEventLoop,
CancelledError,
Protocol,
ensure_future,
@ -72,19 +75,27 @@ from sanic.server import AsyncioServer, HttpProtocol
from sanic.server import Signal as ServerSignal
from sanic.server import serve, serve_multiple, serve_single
from sanic.signals import Signal, SignalRouter
from sanic.touchup import TouchUp, TouchUpMeta
from sanic.websocket import ConnectionClosed, WebSocketProtocol
class Sanic(BaseSanic):
class Sanic(BaseSanic, metaclass=TouchUpMeta):
"""
The main application instance
"""
__touchup__ = (
"handle_request",
"handle_exception",
"_run_response_middleware",
"_run_request_middleware",
)
__fake_slots__ = (
"_asgi_app",
"_app_registry",
"_asgi_client",
"_blueprint_order",
"_delayed_tasks",
"_future_routes",
"_future_statics",
"_future_middleware",
@ -155,6 +166,7 @@ class Sanic(BaseSanic):
self._asgi_client = None
self._blueprint_order: List[Blueprint] = []
self._delayed_tasks: List[str] = []
self._test_client = None
self._test_manager = None
self.asgi = False
@ -192,6 +204,7 @@ class Sanic(BaseSanic):
self.__class__.register_app(self)
self.router.ctx.app = self
self.signal_router.ctx.app = self
if dumps:
BaseHTTPResponse._dumps = dumps # type: ignore
@ -232,9 +245,12 @@ class Sanic(BaseSanic):
loop = self.loop # Will raise SanicError if loop is not started
self._loop_add_task(task, self, loop)
except SanicException:
self.listener("before_server_start")(
partial(self._loop_add_task, task)
)
task_name = f"sanic.delayed_task.{hash(task)}"
if not self._delayed_tasks:
self.after_server_start(partial(self.dispatch_delayed_tasks))
self.signal(task_name)(partial(self.run_delayed_task, task=task))
self._delayed_tasks.append(task_name)
def register_listener(self, listener: Callable, event: str) -> Any:
"""
@ -246,12 +262,20 @@ class Sanic(BaseSanic):
"""
try:
_event = ListenerEvent(event)
except ValueError:
valid = ", ".join(ListenerEvent.__members__.values())
_event = ListenerEvent[event.upper()]
except (ValueError, AttributeError):
valid = ", ".join(
map(lambda x: x.lower(), ListenerEvent.__members__.keys())
)
raise InvalidUsage(f"Invalid event: {event}. Use one of: {valid}")
self.listeners[_event].append(listener)
if "." in _event:
self.signal(_event.value)(
partial(self._listener, listener=listener)
)
else:
self.listeners[_event.value].append(listener)
return listener
def register_middleware(self, middleware, attach_to: str = "request"):
@ -379,11 +403,17 @@ class Sanic(BaseSanic):
*,
condition: Optional[Dict[str, str]] = None,
context: Optional[Dict[str, Any]] = None,
fail_not_found: bool = True,
inline: bool = False,
reverse: bool = False,
) -> Coroutine[Any, Any, Awaitable[Any]]:
return self.signal_router.dispatch(
event,
context=context,
condition=condition,
inline=inline,
reverse=reverse,
fail_not_found=fail_not_found,
)
async def event(
@ -659,7 +689,7 @@ class Sanic(BaseSanic):
async def handle_exception(
self, request: Request, exception: BaseException
):
): # no cov
"""
A handler that catches specific exceptions and outputs a response.
@ -669,6 +699,12 @@ class Sanic(BaseSanic):
:type exception: BaseException
:raises ServerError: response 500
"""
await self.dispatch(
"http.lifecycle.exception",
inline=True,
context={"request": request, "exception": exception},
)
# -------------------------------------------- #
# Request Middleware
# -------------------------------------------- #
@ -715,7 +751,7 @@ class Sanic(BaseSanic):
f"Invalid response type {response!r} (need HTTPResponse)"
)
async def handle_request(self, request: Request):
async def handle_request(self, request: Request): # no cov
"""Take a request from the HTTP Server and return a response object
to be sent back The HTTP Server only expects a response object, so
exception handling must be done here
@ -723,10 +759,22 @@ class Sanic(BaseSanic):
:param request: HTTP Request object
:return: Nothing
"""
await self.dispatch(
"http.lifecycle.handle",
inline=True,
context={"request": request},
)
# Define `response` var here to remove warnings about
# allocation before assignment below.
response = None
try:
await self.dispatch(
"http.routing.before",
inline=True,
context={"request": request},
)
# Fetch handler from router
route, handler, kwargs = self.router.get(
request.path,
@ -734,9 +782,20 @@ class Sanic(BaseSanic):
request.headers.getone("host", None),
)
request._match_info = kwargs
request._match_info = {**kwargs}
request.route = route
await self.dispatch(
"http.routing.after",
inline=True,
context={
"request": request,
"route": route,
"kwargs": kwargs,
"handler": handler,
},
)
if (
request.stream
and request.stream.request_body
@ -772,7 +831,7 @@ class Sanic(BaseSanic):
)
# Run response handler
response = handler(request, **kwargs)
response = handler(request, **request.match_info)
if isawaitable(response):
response = await response
@ -783,6 +842,14 @@ class Sanic(BaseSanic):
# Make sure that response is finished / run StreamingHTTP callback
if isinstance(response, BaseHTTPResponse):
await self.dispatch(
"http.lifecycle.response",
inline=True,
context={
"request": request,
"response": response,
},
)
await response.send(end_stream=True)
else:
if not hasattr(handler, "is_websocket"):
@ -1078,11 +1145,6 @@ class Sanic(BaseSanic):
run_async=return_asyncio_server,
)
# Trigger before_start events
await self.trigger_events(
server_settings.get("before_start", []),
server_settings.get("loop"),
)
main_start = server_settings.pop("main_start", None)
main_stop = server_settings.pop("main_stop", None)
if main_start or main_stop:
@ -1095,17 +1157,9 @@ class Sanic(BaseSanic):
asyncio_server_kwargs=asyncio_server_kwargs, **server_settings
)
async def trigger_events(self, events, loop):
"""Trigger events (functions or async)
:param events: one or more sync or async functions to execute
:param loop: event loop
"""
for event in events:
result = event(loop)
if isawaitable(result):
await result
async def _run_request_middleware(self, request, request_name=None):
async def _run_request_middleware(
self, request, request_name=None
): # no cov
# The if improves speed. I don't know why
named_middleware = self.named_request_middleware.get(
request_name, deque()
@ -1118,25 +1172,67 @@ class Sanic(BaseSanic):
request.request_middleware_started = True
for middleware in applicable_middleware:
await self.dispatch(
"http.middleware.before",
inline=True,
context={
"request": request,
"response": None,
},
condition={"attach_to": "request"},
)
response = middleware(request)
if isawaitable(response):
response = await response
await self.dispatch(
"http.middleware.after",
inline=True,
context={
"request": request,
"response": None,
},
condition={"attach_to": "request"},
)
if response:
return response
return None
async def _run_response_middleware(
self, request, response, request_name=None
):
): # no cov
named_middleware = self.named_response_middleware.get(
request_name, deque()
)
applicable_middleware = self.response_middleware + named_middleware
if applicable_middleware:
for middleware in applicable_middleware:
await self.dispatch(
"http.middleware.before",
inline=True,
context={
"request": request,
"response": response,
},
condition={"attach_to": "response"},
)
_response = middleware(request, response)
if isawaitable(_response):
_response = await _response
await self.dispatch(
"http.middleware.after",
inline=True,
context={
"request": request,
"response": _response if _response else response,
},
condition={"attach_to": "response"},
)
if _response:
response = _response
if isinstance(response, BaseHTTPResponse):
@ -1162,10 +1258,6 @@ class Sanic(BaseSanic):
):
"""Helper function used by `run` and `create_server`."""
self.listeners["before_server_start"] = [
self.finalize
] + self.listeners["before_server_start"]
if isinstance(ssl, dict):
# try common aliaseses
cert = ssl.get("cert") or ssl.get("certificate")
@ -1202,10 +1294,6 @@ class Sanic(BaseSanic):
# Register start/stop events
for event_name, settings_name, reverse in (
("before_server_start", "before_start", False),
("after_server_start", "after_start", False),
("before_server_stop", "before_stop", True),
("after_server_stop", "after_stop", True),
("main_process_start", "main_start", False),
("main_process_stop", "main_stop", True),
):
@ -1253,20 +1341,44 @@ class Sanic(BaseSanic):
return ".".join(parts)
@classmethod
def _loop_add_task(cls, task, app, loop):
def _prep_task(cls, task, app, loop):
if callable(task):
try:
loop.create_task(task(app))
task = task(app)
except TypeError:
loop.create_task(task())
else:
loop.create_task(task)
task = task()
return task
@classmethod
def _loop_add_task(cls, task, app, loop):
prepped = cls._prep_task(task, app, loop)
loop.create_task(prepped)
@classmethod
def _cancel_websocket_tasks(cls, app, loop):
for task in app.websocket_tasks:
task.cancel()
@staticmethod
async def dispatch_delayed_tasks(app, loop):
for name in app._delayed_tasks:
await app.dispatch(name, context={"app": app, "loop": loop})
app._delayed_tasks.clear()
@staticmethod
async def run_delayed_task(app, loop, task):
prepped = app._prep_task(task, app, loop)
await prepped
@staticmethod
async def _listener(
app: Sanic, loop: AbstractEventLoop, listener: ListenerType
):
maybe_coro = listener(app, loop)
if maybe_coro and isawaitable(maybe_coro):
await maybe_coro
# -------------------------------------------------------------------- #
# ASGI
# -------------------------------------------------------------------- #
@ -1340,15 +1452,51 @@ class Sanic(BaseSanic):
raise SanicException(f'Sanic app name "{name}" not found.')
# -------------------------------------------------------------------- #
# Static methods
# Lifecycle
# -------------------------------------------------------------------- #
@staticmethod
async def finalize(app, _):
def finalize(self):
try:
app.router.finalize()
if app.signal_router.routes:
app.signal_router.finalize() # noqa
self.router.finalize()
except FinalizationError as e:
if not Sanic.test_mode:
raise e # noqa
raise e
def signalize(self):
try:
self.signal_router.finalize()
except FinalizationError as e:
if not Sanic.test_mode:
raise e
async def _startup(self):
self.signalize()
self.finalize()
TouchUp.run(self)
async def _server_event(
self,
concern: str,
action: str,
loop: Optional[AbstractEventLoop] = None,
) -> None:
event = f"server.{concern}.{action}"
if action not in ("before", "after") or concern not in (
"init",
"shutdown",
):
raise SanicException(f"Invalid server event: {event}")
logger.debug(f"Triggering server events: {event}")
reverse = concern == "shutdown"
if loop is None:
loop = self.loop
await self.dispatch(
event,
fail_not_found=False,
reverse=reverse,
inline=True,
context={
"app": self,
"loop": loop,
},
)

View File

@ -1,6 +1,5 @@
import warnings
from inspect import isawaitable
from typing import Optional
from urllib.parse import quote
@ -18,14 +17,20 @@ class Lifespan:
def __init__(self, asgi_app: "ASGIApp") -> None:
self.asgi_app = asgi_app
if "before_server_start" in self.asgi_app.sanic_app.listeners:
if (
"server.init.before"
in self.asgi_app.sanic_app.signal_router.name_index
):
warnings.warn(
'You have set a listener for "before_server_start" '
"in ASGI mode. "
"It will be executed as early as possible, but not before "
"the ASGI server is started."
)
if "after_server_stop" in self.asgi_app.sanic_app.listeners:
if (
"server.shutdown.after"
in self.asgi_app.sanic_app.signal_router.name_index
):
warnings.warn(
'You have set a listener for "after_server_stop" '
"in ASGI mode. "
@ -42,19 +47,9 @@ class Lifespan:
in sequence since the ASGI lifespan protocol only supports a single
startup event.
"""
self.asgi_app.sanic_app.router.finalize()
if self.asgi_app.sanic_app.signal_router.routes:
self.asgi_app.sanic_app.signal_router.finalize()
listeners = self.asgi_app.sanic_app.listeners.get(
"before_server_start", []
) + self.asgi_app.sanic_app.listeners.get("after_server_start", [])
for handler in listeners:
response = handler(
self.asgi_app.sanic_app, self.asgi_app.sanic_app.loop
)
if response and isawaitable(response):
await response
await self.asgi_app.sanic_app._startup()
await self.asgi_app.sanic_app._server_event("init", "before")
await self.asgi_app.sanic_app._server_event("init", "after")
async def shutdown(self) -> None:
"""
@ -65,16 +60,8 @@ class Lifespan:
in sequence since the ASGI lifespan protocol only supports a single
shutdown event.
"""
listeners = self.asgi_app.sanic_app.listeners.get(
"before_server_stop", []
) + self.asgi_app.sanic_app.listeners.get("after_server_stop", [])
for handler in listeners:
response = handler(
self.asgi_app.sanic_app, self.asgi_app.sanic_app.loop
)
if response and isawaitable(response):
await response
await self.asgi_app.sanic_app._server_event("shutdown", "before")
await self.asgi_app.sanic_app._server_event("shutdown", "after")
async def __call__(
self, scope: ASGIScope, receive: ASGIReceive, send: ASGISend

View File

@ -21,6 +21,7 @@ from sanic.exceptions import (
from sanic.headers import format_http1_response
from sanic.helpers import has_message_body
from sanic.log import access_logger, error_logger, logger
from sanic.touchup import TouchUpMeta
class Stage(Enum):
@ -45,7 +46,7 @@ class Stage(Enum):
HTTP_CONTINUE = b"HTTP/1.1 100 Continue\r\n\r\n"
class Http:
class Http(metaclass=TouchUpMeta):
"""
Internal helper for managing the HTTP request/response cycle
@ -67,9 +68,15 @@ class Http:
HEADER_CEILING = 16_384
HEADER_MAX_SIZE = 0
__touchup__ = (
"http1_request_header",
"http1_response_header",
"read",
)
__slots__ = [
"_send",
"_receive_more",
"dispatch",
"recv_buffer",
"protocol",
"expecting_continue",
@ -97,6 +104,7 @@ class Http:
self.protocol = protocol
self.keep_alive = True
self.stage: Stage = Stage.IDLE
self.dispatch = self.protocol.app.dispatch
self.init_for_request()
def init_for_request(self):
@ -183,7 +191,7 @@ class Http:
if not self.recv_buffer:
await self._receive_more()
async def http1_request_header(self):
async def http1_request_header(self): # no cov
"""
Receive and parse request header into self.request.
"""
@ -212,6 +220,12 @@ class Http:
reqline, *split_headers = raw_headers.split("\r\n")
method, self.url, protocol = reqline.split(" ")
await self.dispatch(
"http.lifecycle.read_head",
inline=True,
context={"head": bytes(head)},
)
if protocol == "HTTP/1.1":
self.keep_alive = True
elif protocol == "HTTP/1.0":
@ -250,6 +264,11 @@ class Http:
transport=self.protocol.transport,
app=self.protocol.app,
)
await self.dispatch(
"http.lifecycle.request",
inline=True,
context={"request": request},
)
# Prepare for request body
self.request_bytes_left = self.request_bytes = 0
@ -280,7 +299,7 @@ class Http:
async def http1_response_header(
self, data: bytes, end_stream: bool
) -> None:
) -> None: # no cov
res = self.response
# Compatibility with simple response body
@ -469,7 +488,7 @@ class Http:
if data:
yield data
async def read(self) -> Optional[bytes]:
async def read(self) -> Optional[bytes]: # no cov
"""
Read some bytes of request body.
"""
@ -543,6 +562,12 @@ class Http:
self.request_bytes_left -= size
await self.dispatch(
"http.lifecycle.read_body",
inline=True,
context={"body": data},
)
return data
# Response methods

View File

@ -9,10 +9,10 @@ class ListenerEvent(str, Enum):
def _generate_next_value_(name: str, *args) -> str: # type: ignore
return name.lower()
BEFORE_SERVER_START = auto()
AFTER_SERVER_START = auto()
BEFORE_SERVER_STOP = auto()
AFTER_SERVER_STOP = auto()
BEFORE_SERVER_START = "server.init.before"
AFTER_SERVER_START = "server.init.after"
BEFORE_SERVER_STOP = "server.shutdown.before"
AFTER_SERVER_STOP = "server.shutdown.after"
MAIN_PROCESS_START = auto()
MAIN_PROCESS_STOP = auto()

View File

@ -23,7 +23,7 @@ class SignalMixin:
*,
apply: bool = True,
condition: Dict[str, Any] = None,
) -> Callable[[SignalHandler], FutureSignal]:
) -> Callable[[SignalHandler], SignalHandler]:
"""
For creating a signal handler, used similar to a route handler:
@ -54,7 +54,7 @@ class SignalMixin:
if apply:
self._apply_signal(future_signal)
return future_signal
return handler
return decorator

View File

@ -497,6 +497,10 @@ class Request:
"""
return self._match_info
@match_info.setter
def match_info(self, value):
self._match_info = value
# Transport properties (obtained from local interface only)
@property

View File

@ -13,7 +13,7 @@ from typing import (
Union,
)
from sanic.models.handler_types import ListenerType
from sanic.touchup.meta import TouchUpMeta
if TYPE_CHECKING:
@ -37,7 +37,7 @@ from time import monotonic as current_time
from sanic.compat import OS_IS_WINDOWS, ctrlc_workaround_for_windows
from sanic.config import Config
from sanic.exceptions import RequestTimeout, ServiceUnavailable
from sanic.exceptions import RequestTimeout, SanicException, ServiceUnavailable
from sanic.http import Http, Stage
from sanic.log import error_logger, logger
from sanic.models.protocol_types import TransportProtocol
@ -102,11 +102,15 @@ class ConnInfo:
self.client_port = addr[1]
class HttpProtocol(asyncio.Protocol):
class HttpProtocol(asyncio.Protocol, metaclass=TouchUpMeta):
"""
This class provides a basic HTTP implementation of the sanic framework.
"""
__touchup__ = (
"send",
"connection_task",
)
__slots__ = (
# app
"app",
@ -185,7 +189,7 @@ class HttpProtocol(asyncio.Protocol):
self._time = current_time()
self.check_timeouts()
async def connection_task(self):
async def connection_task(self): # no cov
"""
Run a HTTP connection.
@ -194,6 +198,11 @@ class HttpProtocol(asyncio.Protocol):
"""
try:
self._setup_connection()
await self.app.dispatch(
"http.lifecycle.begin",
inline=True,
context={"conn_info": self.conn_info},
)
await self._http.http1()
except CancelledError:
pass
@ -212,6 +221,13 @@ class HttpProtocol(asyncio.Protocol):
self.close()
except BaseException:
error_logger.exception("Closing failed")
finally:
await self.app.dispatch(
"http.lifecycle.complete",
inline=True,
context={"conn_info": self.conn_info},
)
...
async def receive_more(self):
"""
@ -259,13 +275,18 @@ class HttpProtocol(asyncio.Protocol):
except Exception:
error_logger.exception("protocol.check_timeouts")
async def send(self, data):
async def send(self, data): # no cov
"""
Writes data with backpressure control.
"""
await self._can_write.wait()
if self.transport.is_closing():
raise CancelledError
await self.app.dispatch(
"http.lifecycle.send",
inline=True,
context={"data": data},
)
self.transport.write(data)
self._time = current_time()
@ -359,52 +380,54 @@ class AsyncioServer:
a user who needs to manage the server lifecycle manually.
"""
__slots__ = (
"loop",
"serve_coro",
"_after_start",
"_before_stop",
"_after_stop",
"server",
"connections",
)
__slots__ = ("app", "connections", "loop", "serve_coro", "server", "init")
def __init__(
self,
app,
loop,
serve_coro,
connections,
after_start: Optional[Iterable[ListenerType]],
before_stop: Optional[Iterable[ListenerType]],
after_stop: Optional[Iterable[ListenerType]],
):
# Note, Sanic already called "before_server_start" events
# before this helper was even created. So we don't need it here.
self.app = app
self.connections = connections
self.loop = loop
self.serve_coro = serve_coro
self._after_start = after_start
self._before_stop = before_stop
self._after_stop = after_stop
self.server = None
self.connections = connections
self.init = False
def startup(self):
"""
Trigger "before_server_start" events
"""
self.init = True
return self.app._startup()
def before_start(self):
"""
Trigger "before_server_start" events
"""
return self._server_event("init", "before")
def after_start(self):
"""
Trigger "after_server_start" events
"""
trigger_events(self._after_start, self.loop)
return self._server_event("init", "after")
def before_stop(self):
"""
Trigger "before_server_stop" events
"""
trigger_events(self._before_stop, self.loop)
return self._server_event("shutdown", "before")
def after_stop(self):
"""
Trigger "after_server_stop" events
"""
trigger_events(self._after_stop, self.loop)
return self._server_event("shutdown", "after")
def is_serving(self) -> bool:
if self.server:
@ -442,6 +465,14 @@ class AsyncioServer:
"of asyncio or uvloop."
)
def _server_event(self, concern: str, action: str):
if not self.init:
raise SanicException(
"Cannot dispatch server event without "
"first running server.startup()"
)
return self.app._server_event(concern, action, loop=self.loop)
def __await__(self):
"""
Starts the asyncio server, returns AsyncServerCoro
@ -456,11 +487,7 @@ class AsyncioServer:
def serve(
host,
port,
app,
before_start: Optional[Iterable[ListenerType]] = None,
after_start: Optional[Iterable[ListenerType]] = None,
before_stop: Optional[Iterable[ListenerType]] = None,
after_stop: Optional[Iterable[ListenerType]] = None,
app: Sanic,
ssl: Optional[SSLContext] = None,
sock: Optional[socket.socket] = None,
unix: Optional[str] = None,
@ -542,15 +569,14 @@ def serve(
if run_async:
return AsyncioServer(
app=app,
loop=loop,
serve_coro=server_coroutine,
connections=connections,
after_start=after_start,
before_stop=before_stop,
after_stop=after_stop,
)
trigger_events(before_start, loop)
loop.run_until_complete(app._startup())
loop.run_until_complete(app._server_event("init", "before"))
try:
http_server = loop.run_until_complete(server_coroutine)
@ -558,8 +584,6 @@ def serve(
error_logger.exception("Unable to start server")
return
trigger_events(after_start, loop)
# Ignore SIGINT when run_multiple
if run_multiple:
signal_func(SIGINT, SIG_IGN)
@ -571,6 +595,8 @@ def serve(
else:
for _signal in [SIGTERM] if run_multiple else [SIGINT, SIGTERM]:
loop.add_signal_handler(_signal, app.stop)
loop.run_until_complete(app._server_event("init", "after"))
pid = os.getpid()
try:
logger.info("Starting worker [%s]", pid)
@ -579,7 +605,7 @@ def serve(
logger.info("Stopping worker [%s]", pid)
# Run the on_stop function if provided
trigger_events(before_stop, loop)
loop.run_until_complete(app._server_event("shutdown", "before"))
# Wait for event loop to finish and all connections to drain
http_server.close()
@ -611,8 +637,7 @@ def serve(
_shutdown = asyncio.gather(*coros)
loop.run_until_complete(_shutdown)
trigger_events(after_stop, loop)
loop.run_until_complete(app._server_event("shutdown", "after"))
remove_unix_socket(unix)

View File

@ -10,13 +10,39 @@ from sanic_routing.exceptions import NotFound # type: ignore
from sanic_routing.utils import path_to_parts # type: ignore
from sanic.exceptions import InvalidSignal
from sanic.log import error_logger, logger
from sanic.models.handler_types import SignalHandler
RESERVED_NAMESPACES = (
"server",
"http",
)
RESERVED_NAMESPACES = {
"server": (
# "server.main.start",
# "server.main.stop",
"server.init.before",
"server.init.after",
"server.shutdown.before",
"server.shutdown.after",
),
"http": (
"http.lifecycle.begin",
"http.lifecycle.complete",
"http.lifecycle.exception",
"http.lifecycle.handle",
"http.lifecycle.read_body",
"http.lifecycle.read_head",
"http.lifecycle.request",
"http.lifecycle.response",
"http.routing.after",
"http.routing.before",
"http.lifecycle.send",
"http.middleware.after",
"http.middleware.before",
),
}
def _blank():
...
class Signal(Route):
@ -59,8 +85,13 @@ class SignalRouter(BaseRouter):
terms.append(extra)
raise NotFound(message % tuple(terms))
# Regex routes evaluate and can extract params directly. They are set
# on param_basket["__params__"]
params = param_basket["__params__"]
if not params:
# If param_basket["__params__"] does not exist, we might have
# param_basket["__matches__"], which are indexed based matches
# on path segments. They should already be cast types.
params = {
param.name: param_basket["__matches__"][idx]
for idx, param in group.params.items()
@ -73,8 +104,18 @@ class SignalRouter(BaseRouter):
event: str,
context: Optional[Dict[str, Any]] = None,
condition: Optional[Dict[str, str]] = None,
) -> None:
group, handlers, params = self.get(event, condition=condition)
fail_not_found: bool = True,
reverse: bool = False,
) -> Any:
try:
group, handlers, params = self.get(event, condition=condition)
except NotFound as e:
if fail_not_found:
raise e
else:
if self.ctx.app.debug:
error_logger.warning(str(e))
return None
events = [signal.ctx.event for signal in group]
for signal_event in events:
@ -82,12 +123,19 @@ class SignalRouter(BaseRouter):
if context:
params.update(context)
if not reverse:
handlers = handlers[::-1]
try:
for handler in handlers:
if condition is None or condition == handler.__requirements__:
maybe_coroutine = handler(**params)
if isawaitable(maybe_coroutine):
await maybe_coroutine
retval = await maybe_coroutine
if retval:
return retval
elif maybe_coroutine:
return maybe_coroutine
return None
finally:
for signal_event in events:
signal_event.clear()
@ -98,14 +146,23 @@ class SignalRouter(BaseRouter):
*,
context: Optional[Dict[str, Any]] = None,
condition: Optional[Dict[str, str]] = None,
) -> asyncio.Task:
task = self.ctx.loop.create_task(
self._dispatch(
event,
context=context,
condition=condition,
)
fail_not_found: bool = True,
inline: bool = False,
reverse: bool = False,
) -> Union[asyncio.Task, Any]:
dispatch = self._dispatch(
event,
context=context,
condition=condition,
fail_not_found=fail_not_found and inline,
reverse=reverse,
)
logger.debug(f"Dispatching signal: {event}")
if inline:
return await dispatch
task = asyncio.get_running_loop().create_task(dispatch)
await asyncio.sleep(0)
return task
@ -131,7 +188,9 @@ class SignalRouter(BaseRouter):
append=True,
) # type: ignore
def finalize(self, do_compile: bool = True):
def finalize(self, do_compile: bool = True, do_optimize: bool = False):
self.add(_blank, "sanic.__signal__.__init__")
try:
self.ctx.loop = asyncio.get_running_loop()
except RuntimeError:
@ -140,7 +199,7 @@ class SignalRouter(BaseRouter):
for signal in self.routes:
signal.ctx.event = asyncio.Event()
return super().finalize(do_compile=do_compile)
return super().finalize(do_compile=do_compile, do_optimize=do_optimize)
def _build_event_parts(self, event: str) -> Tuple[str, str, str]:
parts = path_to_parts(event, self.delimiter)
@ -151,7 +210,11 @@ class SignalRouter(BaseRouter):
):
raise InvalidSignal("Invalid signal event: %s" % event)
if parts[0] in RESERVED_NAMESPACES:
if (
parts[0] in RESERVED_NAMESPACES
and event not in RESERVED_NAMESPACES[parts[0]]
and not (parts[2].startswith("<") and parts[2].endswith(">"))
):
raise InvalidSignal(
"Cannot declare reserved signal event: %s" % event
)

View File

@ -0,0 +1,8 @@
from .meta import TouchUpMeta
from .service import TouchUp
__all__ = (
"TouchUp",
"TouchUpMeta",
)

22
sanic/touchup/meta.py Normal file
View File

@ -0,0 +1,22 @@
from sanic.exceptions import SanicException
from .service import TouchUp
class TouchUpMeta(type):
def __new__(cls, name, bases, attrs, **kwargs):
gen_class = super().__new__(cls, name, bases, attrs, **kwargs)
methods = attrs.get("__touchup__")
attrs["__touched__"] = False
if methods:
for method in methods:
if method not in attrs:
raise SanicException(
"Cannot perform touchup on non-existent method: "
f"{name}.{method}"
)
TouchUp.register(gen_class, method)
return gen_class

View File

@ -0,0 +1,5 @@
from .base import BaseScheme
from .ode import OptionalDispatchEvent # noqa
__all__ = ("BaseScheme",)

View File

@ -0,0 +1,20 @@
from abc import ABC, abstractmethod
from typing import Set, Type
class BaseScheme(ABC):
ident: str
_registry: Set[Type] = set()
def __init__(self, app) -> None:
self.app = app
@abstractmethod
def run(self, method, module_globals) -> None:
...
def __init_subclass__(cls):
BaseScheme._registry.add(cls)
def __call__(self, method, module_globals):
return self.run(method, module_globals)

View File

@ -0,0 +1,67 @@
from ast import Attribute, Await, Dict, Expr, NodeTransformer, parse
from inspect import getsource
from textwrap import dedent
from typing import Any
from sanic.log import logger
from .base import BaseScheme
class OptionalDispatchEvent(BaseScheme):
ident = "ODE"
def __init__(self, app) -> None:
super().__init__(app)
self._registered_events = [
signal.path for signal in app.signal_router.routes
]
def run(self, method, module_globals):
raw_source = getsource(method)
src = dedent(raw_source)
tree = parse(src)
node = RemoveDispatch(self._registered_events).visit(tree)
compiled_src = compile(node, method.__name__, "exec")
exec_locals: Dict[str, Any] = {}
exec(compiled_src, module_globals, exec_locals) # nosec
return exec_locals[method.__name__]
class RemoveDispatch(NodeTransformer):
def __init__(self, registered_events) -> None:
self._registered_events = registered_events
def visit_Expr(self, node: Expr) -> Any:
call = node.value
if isinstance(call, Await):
call = call.value
func = getattr(call, "func", None)
args = getattr(call, "args", None)
if not func or not args:
return node
if isinstance(func, Attribute) and func.attr == "dispatch":
event = args[0]
if hasattr(event, "s"):
event_name = getattr(event, "value", event.s)
if self._not_registered(event_name):
logger.debug(f"Disabling event: {event_name}")
return None
return node
def _not_registered(self, event_name):
dynamic = []
for event in self._registered_events:
if event.endswith(">"):
namespace_concern, _ = event.rsplit(".", 1)
dynamic.append(namespace_concern)
namespace_concern, _ = event_name.rsplit(".", 1)
return (
event_name not in self._registered_events
and namespace_concern not in dynamic
)

33
sanic/touchup/service.py Normal file
View File

@ -0,0 +1,33 @@
from inspect import getmembers, getmodule
from typing import Set, Tuple, Type
from .schemes import BaseScheme
class TouchUp:
_registry: Set[Tuple[Type, str]] = set()
@classmethod
def run(cls, app):
for target, method_name in cls._registry:
method = getattr(target, method_name)
if app.test_mode:
placeholder = f"_{method_name}"
if hasattr(target, placeholder):
method = getattr(target, placeholder)
else:
setattr(target, placeholder, method)
module = getmodule(target)
module_globals = dict(getmembers(module))
for scheme in BaseScheme._registry:
modified = scheme(app)(method, module_globals)
setattr(target, method_name, modified)
target.__touched__ = True
@classmethod
def register(cls, target, method_name):
cls._registry.add((target, method_name))

View File

@ -8,7 +8,7 @@ import traceback
from gunicorn.workers import base # type: ignore
from sanic.log import logger
from sanic.server import HttpProtocol, Signal, serve, trigger_events
from sanic.server import HttpProtocol, Signal, serve
from sanic.websocket import WebSocketProtocol
@ -68,10 +68,10 @@ class GunicornWorker(base.Worker):
)
self._server_settings["signal"] = self.signal
self._server_settings.pop("sock")
trigger_events(
self._server_settings.get("before_start", []), self.loop
self._await(self.app.callable._startup())
self._await(
self.app.callable._server_event("init", "before", loop=self.loop)
)
self._server_settings["before_start"] = ()
main_start = self._server_settings.pop("main_start", None)
main_stop = self._server_settings.pop("main_stop", None)
@ -82,24 +82,29 @@ class GunicornWorker(base.Worker):
"with GunicornWorker"
)
self._runner = asyncio.ensure_future(self._run(), loop=self.loop)
try:
self.loop.run_until_complete(self._runner)
self._await(self._run())
self.app.callable.is_running = True
trigger_events(
self._server_settings.get("after_start", []), self.loop
self._await(
self.app.callable._server_event(
"init", "after", loop=self.loop
)
)
self.loop.run_until_complete(self._check_alive())
trigger_events(
self._server_settings.get("before_stop", []), self.loop
self._await(
self.app.callable._server_event(
"shutdown", "before", loop=self.loop
)
)
self.loop.run_until_complete(self.close())
except BaseException:
traceback.print_exc()
finally:
try:
trigger_events(
self._server_settings.get("after_stop", []), self.loop
self._await(
self.app.callable._server_event(
"shutdown", "after", loop=self.loop
)
)
except BaseException:
traceback.print_exc()
@ -238,3 +243,7 @@ class GunicornWorker(base.Worker):
self.exit_code = 1
self.cfg.worker_abort(self)
sys.exit(1)
def _await(self, coro):
fut = asyncio.ensure_future(coro, loop=self.loop)
self.loop.run_until_complete(fut)

View File

@ -93,7 +93,7 @@ requirements = [
]
tests_require = [
"sanic-testing==0.7.0b1",
"sanic-testing>=0.7.0b2",
"pytest==5.2.1",
"coverage==5.3",
"gunicorn==20.0.4",

View File

@ -1,3 +1,5 @@
import asyncio
import logging
import random
import re
import string
@ -9,10 +11,12 @@ from typing import Tuple
import pytest
from sanic_routing.exceptions import RouteExists
from sanic_testing.testing import PORT
from sanic import Sanic
from sanic.constants import HTTP_METHODS
from sanic.router import Router
from sanic.touchup.service import TouchUp
slugify = re.compile(r"[^a-zA-Z0-9_\-]")
@ -23,11 +27,6 @@ if sys.platform in ["win32", "cygwin"]:
collect_ignore = ["test_worker.py"]
@pytest.fixture
def caplog(caplog):
yield caplog
async def _handler(request):
"""
Dummy placeholder method used for route resolver when creating a new
@ -41,33 +40,32 @@ async def _handler(request):
TYPE_TO_GENERATOR_MAP = {
"string": lambda: "".join(
"str": lambda: "".join(
[random.choice(string.ascii_lowercase) for _ in range(4)]
),
"int": lambda: random.choice(range(1000000)),
"number": lambda: random.random(),
"float": lambda: random.random(),
"alpha": lambda: "".join(
[random.choice(string.ascii_lowercase) for _ in range(4)]
),
"uuid": lambda: str(uuid.uuid1()),
}
CACHE = {}
class RouteStringGenerator:
ROUTE_COUNT_PER_DEPTH = 100
HTTP_METHODS = HTTP_METHODS
ROUTE_PARAM_TYPES = ["string", "int", "number", "alpha", "uuid"]
ROUTE_PARAM_TYPES = ["str", "int", "float", "alpha", "uuid"]
def generate_random_direct_route(self, max_route_depth=4):
routes = []
for depth in range(1, max_route_depth + 1):
for _ in range(self.ROUTE_COUNT_PER_DEPTH):
route = "/".join(
[
TYPE_TO_GENERATOR_MAP.get("string")()
for _ in range(depth)
]
[TYPE_TO_GENERATOR_MAP.get("str")() for _ in range(depth)]
)
route = route.replace(".", "", -1)
route_detail = (random.choice(self.HTTP_METHODS), route)
@ -83,7 +81,7 @@ class RouteStringGenerator:
new_route_part = "/".join(
[
"<{}:{}>".format(
TYPE_TO_GENERATOR_MAP.get("string")(),
TYPE_TO_GENERATOR_MAP.get("str")(),
random.choice(self.ROUTE_PARAM_TYPES),
)
for _ in range(max_route_depth - current_length)
@ -98,7 +96,7 @@ class RouteStringGenerator:
def generate_url_for_template(template):
url = template
for pattern, param_type in re.findall(
re.compile(r"((?:<\w+:(string|int|number|alpha|uuid)>)+)"),
re.compile(r"((?:<\w+:(str|int|float|alpha|uuid)>)+)"),
template,
):
value = TYPE_TO_GENERATOR_MAP.get(param_type)()
@ -141,5 +139,33 @@ def url_param_generator():
@pytest.fixture(scope="function")
def app(request):
if not CACHE:
for target, method_name in TouchUp._registry:
CACHE[method_name] = getattr(target, method_name)
app = Sanic(slugify.sub("-", request.node.name))
return app
yield app
for target, method_name in TouchUp._registry:
setattr(target, method_name, CACHE[method_name])
@pytest.fixture(scope="function")
def run_startup(caplog):
def run(app):
nonlocal caplog
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
with caplog.at_level(logging.DEBUG):
server = app.create_server(
debug=True, return_asyncio_server=True, port=PORT
)
loop._stopping = False
_server = loop.run_until_complete(server)
_server.close()
loop.run_until_complete(_server.wait_closed())
app.stop()
return caplog.record_tuples
return run

View File

@ -89,7 +89,7 @@ def test_debug(cmd):
out, err, exitcode = capture(command)
lines = out.split(b"\n")
app_info = lines[9]
app_info = lines[26]
info = json.loads(app_info)
assert (b"\n".join(lines[:6])).decode("utf-8") == BASE_LOGO
@ -103,7 +103,7 @@ def test_auto_reload(cmd):
out, err, exitcode = capture(command)
lines = out.split(b"\n")
app_info = lines[9]
app_info = lines[26]
info = json.loads(app_info)
assert info["debug"] is False
@ -118,7 +118,7 @@ def test_access_logs(cmd, expected):
out, err, exitcode = capture(command)
lines = out.split(b"\n")
app_info = lines[9]
app_info = lines[26]
info = json.loads(app_info)
assert info["access_log"] is expected

View File

@ -1,6 +1,5 @@
import asyncio
from queue import Queue
from threading import Event
from sanic.response import text
@ -13,8 +12,6 @@ def test_create_task(app):
await asyncio.sleep(0.05)
e.set()
app.add_task(coro)
@app.route("/early")
def not_set(request):
return text(str(e.is_set()))
@ -24,24 +21,30 @@ def test_create_task(app):
await asyncio.sleep(0.1)
return text(str(e.is_set()))
app.add_task(coro)
request, response = app.test_client.get("/early")
assert response.body == b"False"
app.signal_router.reset()
app.add_task(coro)
request, response = app.test_client.get("/late")
assert response.body == b"True"
def test_create_task_with_app_arg(app):
q = Queue()
@app.after_server_start
async def setup_q(app, _):
app.ctx.q = asyncio.Queue()
@app.route("/")
def not_set(request):
return "hello"
async def not_set(request):
return text(await request.app.ctx.q.get())
async def coro(app):
q.put(app.name)
await app.ctx.q.put(app.name)
app.add_task(coro)
request, response = app.test_client.get("/")
assert q.get() == "test_create_task_with_app_arg"
_, response = app.test_client.get("/")
assert response.text == "test_create_task_with_app_arg"

View File

@ -127,7 +127,6 @@ def test_html_traceback_output_in_debug_mode():
soup = BeautifulSoup(response.body, "html.parser")
html = str(soup)
assert "response = handler(request, **kwargs)" in html
assert "handler_4" in html
assert "foo = bar" in html
@ -151,7 +150,6 @@ def test_chained_exception_handler():
soup = BeautifulSoup(response.body, "html.parser")
html = str(soup)
assert "response = handler(request, **kwargs)" in html
assert "handler_6" in html
assert "foo = 1 / arg" in html
assert "ValueError" in html

View File

@ -2,16 +2,13 @@ import asyncio
import platform
from asyncio import sleep as aio_sleep
from json import JSONDecodeError
from os import environ
import httpcore
import httpx
import pytest
from sanic_testing.testing import HOST, SanicTestClient
from sanic_testing.reusable import ReusableClient
from sanic import Sanic, server
from sanic import Sanic
from sanic.compat import OS_IS_WINDOWS
from sanic.response import text
@ -21,164 +18,6 @@ CONFIG_FOR_TESTS = {"KEEP_ALIVE_TIMEOUT": 2, "KEEP_ALIVE": True}
PORT = 42101 # test_keep_alive_timeout_reuse doesn't work with random port
class ReusableSanicConnectionPool(httpcore.AsyncConnectionPool):
last_reused_connection = None
async def _get_connection_from_pool(self, *args, **kwargs):
conn = await super()._get_connection_from_pool(*args, **kwargs)
self.__class__.last_reused_connection = conn
return conn
class ResusableSanicSession(httpx.AsyncClient):
def __init__(self, *args, **kwargs) -> None:
transport = ReusableSanicConnectionPool()
super().__init__(transport=transport, *args, **kwargs)
class ReuseableSanicTestClient(SanicTestClient):
def __init__(self, app, loop=None):
super().__init__(app)
if loop is None:
loop = asyncio.get_event_loop()
self._loop = loop
self._server = None
self._tcp_connector = None
self._session = None
def get_new_session(self):
return ResusableSanicSession()
# Copied from SanicTestClient, but with some changes to reuse the
# same loop for the same app.
def _sanic_endpoint_test(
self,
method="get",
uri="/",
gather_request=True,
debug=False,
server_kwargs=None,
*request_args,
**request_kwargs,
):
loop = self._loop
results = [None, None]
exceptions = []
server_kwargs = server_kwargs or {"return_asyncio_server": True}
if gather_request:
def _collect_request(request):
if results[0] is None:
results[0] = request
self.app.request_middleware.appendleft(_collect_request)
if uri.startswith(
("http:", "https:", "ftp:", "ftps://", "//", "ws:", "wss:")
):
url = uri
else:
uri = uri if uri.startswith("/") else f"/{uri}"
scheme = "http"
url = f"{scheme}://{HOST}:{PORT}{uri}"
@self.app.listener("after_server_start")
async def _collect_response(loop):
try:
response = await self._local_request(
method, url, *request_args, **request_kwargs
)
results[-1] = response
except Exception as e2:
exceptions.append(e2)
if self._server is not None:
_server = self._server
else:
_server_co = self.app.create_server(
host=HOST, debug=debug, port=PORT, **server_kwargs
)
server.trigger_events(
self.app.listeners["before_server_start"], loop
)
try:
loop._stopping = False
_server = loop.run_until_complete(_server_co)
except Exception as e1:
raise e1
self._server = _server
server.trigger_events(self.app.listeners["after_server_start"], loop)
self.app.listeners["after_server_start"].pop()
if exceptions:
raise ValueError(f"Exception during request: {exceptions}")
if gather_request:
self.app.request_middleware.pop()
try:
request, response = results
return request, response
except Exception:
raise ValueError(
f"Request and response object expected, got ({results})"
)
else:
try:
return results[-1]
except Exception:
raise ValueError(f"Request object expected, got ({results})")
def kill_server(self):
try:
if self._server:
self._server.close()
self._loop.run_until_complete(self._server.wait_closed())
self._server = None
if self._session:
self._loop.run_until_complete(self._session.aclose())
self._session = None
except Exception as e3:
raise e3
# Copied from SanicTestClient, but with some changes to reuse the
# same TCPConnection and the sane ClientSession more than once.
# Note, you cannot use the same session if you are in a _different_
# loop, so the changes above are required too.
async def _local_request(self, method, url, *args, **kwargs):
raw_cookies = kwargs.pop("raw_cookies", None)
request_keepalive = kwargs.pop(
"request_keepalive", CONFIG_FOR_TESTS["KEEP_ALIVE_TIMEOUT"]
)
if not self._session:
self._session = self.get_new_session()
try:
response = await getattr(self._session, method.lower())(
url, timeout=request_keepalive, *args, **kwargs
)
except NameError:
raise Exception(response.status_code)
try:
response.json = response.json()
except (JSONDecodeError, UnicodeDecodeError):
response.json = None
response.body = await response.aread()
response.status = response.status_code
response.content_type = response.headers.get("content-type")
if raw_cookies:
response.raw_cookies = {}
for cookie in response.cookies:
response.raw_cookies[cookie.name] = cookie
return response
keep_alive_timeout_app_reuse = Sanic("test_ka_timeout_reuse")
keep_alive_app_client_timeout = Sanic("test_ka_client_timeout")
keep_alive_app_server_timeout = Sanic("test_ka_server_timeout")
@ -224,21 +63,22 @@ def test_keep_alive_timeout_reuse():
"""If the server keep-alive timeout and client keep-alive timeout are
both longer than the delay, the client _and_ server will successfully
reuse the existing connection."""
try:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
client = ReuseableSanicTestClient(keep_alive_timeout_app_reuse, loop)
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
client = ReusableClient(keep_alive_timeout_app_reuse, loop=loop, port=PORT)
with client:
headers = {"Connection": "keep-alive"}
request, response = client.get("/1", headers=headers)
assert response.status == 200
assert response.text == "OK"
assert request.protocol.state["requests_count"] == 1
loop.run_until_complete(aio_sleep(1))
request, response = client.get("/1")
assert response.status == 200
assert response.text == "OK"
assert ReusableSanicConnectionPool.last_reused_connection
finally:
client.kill_server()
assert request.protocol.state["requests_count"] == 2
@pytest.mark.skipif(
@ -250,22 +90,22 @@ def test_keep_alive_timeout_reuse():
def test_keep_alive_client_timeout():
"""If the server keep-alive timeout is longer than the client
keep-alive timeout, client will try to create a new connection here."""
try:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
client = ReuseableSanicTestClient(keep_alive_app_client_timeout, loop)
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
client = ReusableClient(
keep_alive_app_client_timeout, loop=loop, port=PORT
)
with client:
headers = {"Connection": "keep-alive"}
_, response = client.get("/1", headers=headers, request_keepalive=1)
request, response = client.get("/1", headers=headers, timeout=1)
assert response.status == 200
assert response.text == "OK"
assert request.protocol.state["requests_count"] == 1
loop.run_until_complete(aio_sleep(2))
_, response = client.get("/1", request_keepalive=1)
assert ReusableSanicConnectionPool.last_reused_connection is None
finally:
client.kill_server()
request, response = client.get("/1", timeout=1)
assert request.protocol.state["requests_count"] == 1
@pytest.mark.skipif(
@ -277,22 +117,23 @@ def test_keep_alive_server_timeout():
keep-alive timeout, the client will either a 'Connection reset' error
_or_ a new connection. Depending on how the event-loop handles the
broken server connection."""
try:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
client = ReuseableSanicTestClient(keep_alive_app_server_timeout, loop)
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
client = ReusableClient(
keep_alive_app_server_timeout, loop=loop, port=PORT
)
with client:
headers = {"Connection": "keep-alive"}
_, response = client.get("/1", headers=headers, request_keepalive=60)
request, response = client.get("/1", headers=headers, timeout=60)
assert response.status == 200
assert response.text == "OK"
assert request.protocol.state["requests_count"] == 1
loop.run_until_complete(aio_sleep(3))
_, response = client.get("/1", request_keepalive=60)
request, response = client.get("/1", timeout=60)
assert ReusableSanicConnectionPool.last_reused_connection is None
finally:
client.kill_server()
assert request.protocol.state["requests_count"] == 1
@pytest.mark.skipif(
@ -300,10 +141,10 @@ def test_keep_alive_server_timeout():
reason="Not testable with current client",
)
def test_keep_alive_connection_context():
try:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
client = ReuseableSanicTestClient(keep_alive_app_context, loop)
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
client = ReusableClient(keep_alive_app_context, loop=loop, port=PORT)
with client:
headers = {"Connection": "keep-alive"}
request1, _ = client.post("/ctx", headers=headers)
@ -315,5 +156,4 @@ def test_keep_alive_connection_context():
assert (
request1.conn_info.ctx.foo == request2.conn_info.ctx.foo == "hello"
)
finally:
client.kill_server()
assert request2.protocol.state["requests_count"] == 2

View File

@ -6,85 +6,37 @@ from sanic_testing.testing import PORT
from sanic.config import BASE_LOGO
def test_logo_base(app, caplog):
server = app.create_server(
debug=True, return_asyncio_server=True, port=PORT
)
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop._stopping = False
def test_logo_base(app, run_startup):
logs = run_startup(app)
with caplog.at_level(logging.DEBUG):
_server = loop.run_until_complete(server)
_server.close()
loop.run_until_complete(_server.wait_closed())
app.stop()
assert caplog.record_tuples[0][1] == logging.DEBUG
assert caplog.record_tuples[0][2] == BASE_LOGO
assert logs[0][1] == logging.DEBUG
assert logs[0][2] == BASE_LOGO
def test_logo_false(app, caplog):
def test_logo_false(app, caplog, run_startup):
app.config.LOGO = False
server = app.create_server(
debug=True, return_asyncio_server=True, port=PORT
)
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop._stopping = False
logs = run_startup(app)
with caplog.at_level(logging.DEBUG):
_server = loop.run_until_complete(server)
_server.close()
loop.run_until_complete(_server.wait_closed())
app.stop()
banner, port = caplog.record_tuples[0][2].rsplit(":", 1)
assert caplog.record_tuples[0][1] == logging.INFO
banner, port = logs[0][2].rsplit(":", 1)
assert logs[0][1] == logging.INFO
assert banner == "Goin' Fast @ http://127.0.0.1"
assert int(port) > 0
def test_logo_true(app, caplog):
def test_logo_true(app, run_startup):
app.config.LOGO = True
server = app.create_server(
debug=True, return_asyncio_server=True, port=PORT
)
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop._stopping = False
logs = run_startup(app)
with caplog.at_level(logging.DEBUG):
_server = loop.run_until_complete(server)
_server.close()
loop.run_until_complete(_server.wait_closed())
app.stop()
assert caplog.record_tuples[0][1] == logging.DEBUG
assert caplog.record_tuples[0][2] == BASE_LOGO
assert logs[0][1] == logging.DEBUG
assert logs[0][2] == BASE_LOGO
def test_logo_custom(app, caplog):
def test_logo_custom(app, run_startup):
app.config.LOGO = "My Custom Logo"
server = app.create_server(
debug=True, return_asyncio_server=True, port=PORT
)
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop._stopping = False
logs = run_startup(app)
with caplog.at_level(logging.DEBUG):
_server = loop.run_until_complete(server)
_server.close()
loop.run_until_complete(_server.wait_closed())
app.stop()
assert caplog.record_tuples[0][1] == logging.DEBUG
assert caplog.record_tuples[0][2] == "My Custom Logo"
assert logs[0][1] == logging.DEBUG
assert logs[0][2] == "My Custom Logo"

View File

@ -2,6 +2,7 @@ import asyncio
import httpcore
import httpx
import pytest
from sanic_testing.testing import SanicTestClient
@ -48,42 +49,51 @@ class DelayableSanicTestClient(SanicTestClient):
return DelayableSanicSession(request_delay=self._request_delay)
request_timeout_default_app = Sanic("test_request_timeout_default")
request_no_timeout_app = Sanic("test_request_no_timeout")
request_timeout_default_app.config.REQUEST_TIMEOUT = 0.6
request_no_timeout_app.config.REQUEST_TIMEOUT = 0.6
@pytest.fixture
def request_no_timeout_app():
app = Sanic("test_request_no_timeout")
app.config.REQUEST_TIMEOUT = 0.6
@app.route("/1")
async def handler2(request):
return text("OK")
return app
@request_timeout_default_app.route("/1")
async def handler1(request):
return text("OK")
@pytest.fixture
def request_timeout_default_app():
app = Sanic("test_request_timeout_default")
app.config.REQUEST_TIMEOUT = 0.6
@app.route("/1")
async def handler1(request):
return text("OK")
@app.websocket("/ws1")
async def ws_handler1(request, ws):
await ws.send("OK")
return app
@request_no_timeout_app.route("/1")
async def handler2(request):
return text("OK")
@request_timeout_default_app.websocket("/ws1")
async def ws_handler1(request, ws):
await ws.send("OK")
def test_default_server_error_request_timeout():
def test_default_server_error_request_timeout(request_timeout_default_app):
client = DelayableSanicTestClient(request_timeout_default_app, 2)
request, response = client.get("/1")
_, response = client.get("/1")
assert response.status == 408
assert "Request Timeout" in response.text
def test_default_server_error_request_dont_timeout():
def test_default_server_error_request_dont_timeout(request_no_timeout_app):
client = DelayableSanicTestClient(request_no_timeout_app, 0.2)
request, response = client.get("/1")
_, response = client.get("/1")
assert response.status == 200
assert response.text == "OK"
def test_default_server_error_websocket_request_timeout():
def test_default_server_error_websocket_request_timeout(
request_timeout_default_app,
):
headers = {
"Upgrade": "websocket",
@ -93,7 +103,7 @@ def test_default_server_error_websocket_request_timeout():
}
client = DelayableSanicTestClient(request_timeout_default_app, 2)
request, response = client.get("/ws1", headers=headers)
_, response = client.get("/ws1", headers=headers)
assert response.status == 408
assert "Request Timeout" in response.text

View File

@ -654,41 +654,46 @@ def test_websocket_route_invalid_handler(app):
@pytest.mark.asyncio
@pytest.mark.parametrize("url", ["/ws", "ws"])
async def test_websocket_route_asgi(app, url):
ev = asyncio.Event()
@app.after_server_start
async def setup_ev(app, _):
app.ctx.ev = asyncio.Event()
@app.websocket(url)
async def handler(request, ws):
ev.set()
request.app.ctx.ev.set()
request, response = await app.asgi_client.websocket(url)
assert ev.is_set()
@app.get("/ev")
async def check(request):
return json({"set": request.app.ctx.ev.is_set()})
_, response = await app.asgi_client.websocket(url)
_, response = await app.asgi_client.get("/")
assert response.json["set"]
def test_websocket_route_with_subprotocols(app):
results = []
@pytest.mark.parametrize(
"subprotocols,expected",
(
(["bar"], "bar"),
(["bar", "foo"], "bar"),
(["baz"], None),
(None, None),
),
)
def test_websocket_route_with_subprotocols(app, subprotocols, expected):
results = "unset"
@app.websocket("/ws", subprotocols=["foo", "bar"])
async def handler(request, ws):
results.append(ws.subprotocol)
nonlocal results
results = ws.subprotocol
assert ws.subprotocol is not None
_, response = SanicTestClient(app).websocket("/ws", subprotocols=["bar"])
assert response.opened is True
assert results == ["bar"]
_, response = SanicTestClient(app).websocket(
"/ws", subprotocols=["bar", "foo"]
"/ws", subprotocols=subprotocols
)
assert response.opened is True
assert results == ["bar", "bar"]
_, response = SanicTestClient(app).websocket("/ws", subprotocols=["baz"])
assert response.opened is True
assert results == ["bar", "bar", None]
_, response = SanicTestClient(app).websocket("/ws")
assert response.opened is True
assert results == ["bar", "bar", None, None]
assert results == expected
@pytest.mark.parametrize("strict_slashes", [True, False, None])

View File

@ -8,7 +8,7 @@ import pytest
from sanic_testing.testing import HOST, PORT
from sanic.exceptions import InvalidUsage
from sanic.exceptions import InvalidUsage, SanicException
AVAILABLE_LISTENERS = [
@ -103,7 +103,11 @@ async def test_trigger_before_events_create_server(app):
async def init_db(app, loop):
app.db = MySanicDb()
await app.create_server(debug=True, return_asyncio_server=True, port=PORT)
srv = await app.create_server(
debug=True, return_asyncio_server=True, port=PORT
)
await srv.startup()
await srv.before_start()
assert hasattr(app, "db")
assert isinstance(app.db, MySanicDb)
@ -157,14 +161,15 @@ def test_create_server_trigger_events(app):
serv_coro = app.create_server(return_asyncio_server=True, sock=sock)
serv_task = asyncio.ensure_future(serv_coro, loop=loop)
server = loop.run_until_complete(serv_task)
server.after_start()
loop.run_until_complete(server.startup())
loop.run_until_complete(server.after_start())
try:
loop.run_forever()
except KeyboardInterrupt as e:
except KeyboardInterrupt:
loop.stop()
finally:
# Run the on_stop function if provided
server.before_stop()
loop.run_until_complete(server.before_stop())
# Wait for server to close
close_task = server.close()
@ -174,5 +179,19 @@ def test_create_server_trigger_events(app):
signal.stopped = True
for connection in server.connections:
connection.close_if_idle()
server.after_stop()
loop.run_until_complete(server.after_stop())
assert flag1 and flag2 and flag3
@pytest.mark.asyncio
async def test_missing_startup_raises_exception(app):
@app.listener("before_server_start")
async def init_db(app, loop):
...
srv = await app.create_server(
debug=True, return_asyncio_server=True, port=PORT
)
with pytest.raises(SanicException):
await srv.before_start()

View File

@ -7,6 +7,7 @@ from unittest.mock import MagicMock
import pytest
from sanic_testing.reusable import ReusableClient
from sanic_testing.testing import HOST, PORT
from sanic.compat import ctrlc_workaround_for_windows
@ -28,9 +29,13 @@ def set_loop(app, loop):
signal.signal = mock
else:
loop.add_signal_handler = mock
print(">>>>>>>>>>>>>>>1", id(loop))
print(">>>>>>>>>>>>>>>1", loop.add_signal_handler)
def after(app, loop):
print(">>>>>>>>>>>>>>>2", id(loop))
print(">>>>>>>>>>>>>>>2", loop.add_signal_handler)
calledq.put(mock.called)

View File

@ -68,6 +68,7 @@ async def test_dispatch_signal_triggers_multiple_handlers(app):
app.signal_router.finalize()
assert len(app.signal_router.routes) == 3
await app.dispatch("foo.bar.baz")
assert counter == 2
@ -331,7 +332,8 @@ def test_event_on_bp_not_registered():
"event,expected",
(
("foo.bar.baz", True),
("server.init.before", False),
("server.init.before", True),
("server.init.somethingelse", False),
("http.request.start", False),
("sanic.notice.anything", True),
),

21
tests/test_touchup.py Normal file
View File

@ -0,0 +1,21 @@
import logging
from sanic.signals import RESERVED_NAMESPACES
from sanic.touchup import TouchUp
def test_touchup_methods(app):
assert len(TouchUp._registry) == 9
async def test_ode_removes_dispatch_events(app, caplog):
with caplog.at_level(logging.DEBUG, logger="sanic.root"):
await app._startup()
logs = caplog.record_tuples
for signal in RESERVED_NAMESPACES["http"]:
assert (
"sanic.root",
logging.DEBUG,
f"Disabling event: {signal}",
) in logs

View File

@ -43,7 +43,15 @@ def test_routes_with_multiple_hosts(app):
)
def test_websocket_bp_route_name(app):
@pytest.mark.parametrize(
"name,expected",
(
("test_route", "/bp/route"),
("test_route2", "/bp/route2"),
("foobar_3", "/bp/route3"),
),
)
def test_websocket_bp_route_name(app, name, expected):
"""Tests that blueprint websocket route is named."""
event = asyncio.Event()
bp = Blueprint("test_bp", url_prefix="/bp")
@ -69,22 +77,12 @@ def test_websocket_bp_route_name(app):
uri = app.url_for("test_bp.main")
assert uri == "/bp/main"
uri = app.url_for("test_bp.test_route")
assert uri == "/bp/route"
uri = app.url_for(f"test_bp.{name}")
assert uri == expected
request, response = SanicTestClient(app).websocket(uri)
assert response.opened is True
assert event.is_set()
event.clear()
uri = app.url_for("test_bp.test_route2")
assert uri == "/bp/route2"
request, response = SanicTestClient(app).websocket(uri)
assert response.opened is True
assert event.is_set()
uri = app.url_for("test_bp.foobar_3")
assert uri == "/bp/route3"
# TODO: add test with a route with multiple hosts
# TODO: add test with a route with _host in url_for

View File

@ -10,7 +10,7 @@ extras = test
commands =
pytest {posargs:tests --cov sanic}
- coverage combine --append
coverage report -m
coverage report -m -i
coverage html -i
[testenv:lint]