Signals Integration (#2160)

* Update some tests

* Resolve #2122 route decorator returning tuple

* Use rc sanic-routing version

* Update unit tests to <:str>

* Minimal working version with some signals implemented

* Add more http signals

* Update ASGI and change listeners to signals

* Allow for dynamic ODE signals

* Allow signals to be stacked

* Begin tests

* Prioritize match_info on keyword argument injection

* WIP on tests

* Compat with signals

* Work through some test coverage

* Passing tests

* Post linting

* Setup proper resets

* coverage reporting

* Fixes from vltr comments

* clear delayed tasks

* Fix bad test

* rm pycache
This commit is contained in:
Adam Hopkins 2021-08-05 22:55:42 +03:00 committed by GitHub
parent 0ba57d4701
commit b1b12e004e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
31 changed files with 823 additions and 513 deletions

View File

@ -1,29 +1,44 @@
from sanic import Sanic
from sanic import response
from signal import signal, SIGINT
import asyncio import asyncio
from signal import SIGINT, signal
import uvloop import uvloop
from sanic import Sanic, response
from sanic.server import AsyncioServer
app = Sanic(__name__) app = Sanic(__name__)
@app.listener('after_server_start')
@app.listener("after_server_start")
async def after_start_test(app, loop): async def after_start_test(app, loop):
print("Async Server Started!") print("Async Server Started!")
@app.route("/") @app.route("/")
async def test(request): async def test(request):
return response.json({"answer": "42"}) return response.json({"answer": "42"})
asyncio.set_event_loop(uvloop.new_event_loop()) asyncio.set_event_loop(uvloop.new_event_loop())
serv_coro = app.create_server(host="0.0.0.0", port=8000, return_asyncio_server=True) serv_coro = app.create_server(
host="0.0.0.0", port=8000, return_asyncio_server=True
)
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
serv_task = asyncio.ensure_future(serv_coro, loop=loop) serv_task = asyncio.ensure_future(serv_coro, loop=loop)
signal(SIGINT, lambda s, f: loop.stop()) signal(SIGINT, lambda s, f: loop.stop())
server = loop.run_until_complete(serv_task) server: AsyncioServer = loop.run_until_complete(serv_task) # type: ignore
server.startup()
# When using app.run(), this actually triggers before the serv_coro.
# But, in this example, we are using the convenience method, even if it is
# out of order.
server.before_start()
server.after_start() server.after_start()
try: try:
loop.run_forever() loop.run_forever()
except KeyboardInterrupt as e: except KeyboardInterrupt:
loop.stop() loop.stop()
finally: finally:
server.before_stop() server.before_stop()

View File

@ -1,9 +1,12 @@
from __future__ import annotations
import logging import logging
import logging.config import logging.config
import os import os
import re import re
from asyncio import ( from asyncio import (
AbstractEventLoop,
CancelledError, CancelledError,
Protocol, Protocol,
ensure_future, ensure_future,
@ -72,19 +75,27 @@ from sanic.server import AsyncioServer, HttpProtocol
from sanic.server import Signal as ServerSignal from sanic.server import Signal as ServerSignal
from sanic.server import serve, serve_multiple, serve_single from sanic.server import serve, serve_multiple, serve_single
from sanic.signals import Signal, SignalRouter from sanic.signals import Signal, SignalRouter
from sanic.touchup import TouchUp, TouchUpMeta
from sanic.websocket import ConnectionClosed, WebSocketProtocol from sanic.websocket import ConnectionClosed, WebSocketProtocol
class Sanic(BaseSanic): class Sanic(BaseSanic, metaclass=TouchUpMeta):
""" """
The main application instance The main application instance
""" """
__touchup__ = (
"handle_request",
"handle_exception",
"_run_response_middleware",
"_run_request_middleware",
)
__fake_slots__ = ( __fake_slots__ = (
"_asgi_app", "_asgi_app",
"_app_registry", "_app_registry",
"_asgi_client", "_asgi_client",
"_blueprint_order", "_blueprint_order",
"_delayed_tasks",
"_future_routes", "_future_routes",
"_future_statics", "_future_statics",
"_future_middleware", "_future_middleware",
@ -155,6 +166,7 @@ class Sanic(BaseSanic):
self._asgi_client = None self._asgi_client = None
self._blueprint_order: List[Blueprint] = [] self._blueprint_order: List[Blueprint] = []
self._delayed_tasks: List[str] = []
self._test_client = None self._test_client = None
self._test_manager = None self._test_manager = None
self.asgi = False self.asgi = False
@ -192,6 +204,7 @@ class Sanic(BaseSanic):
self.__class__.register_app(self) self.__class__.register_app(self)
self.router.ctx.app = self self.router.ctx.app = self
self.signal_router.ctx.app = self
if dumps: if dumps:
BaseHTTPResponse._dumps = dumps # type: ignore BaseHTTPResponse._dumps = dumps # type: ignore
@ -232,9 +245,12 @@ class Sanic(BaseSanic):
loop = self.loop # Will raise SanicError if loop is not started loop = self.loop # Will raise SanicError if loop is not started
self._loop_add_task(task, self, loop) self._loop_add_task(task, self, loop)
except SanicException: except SanicException:
self.listener("before_server_start")( task_name = f"sanic.delayed_task.{hash(task)}"
partial(self._loop_add_task, task) if not self._delayed_tasks:
) self.after_server_start(partial(self.dispatch_delayed_tasks))
self.signal(task_name)(partial(self.run_delayed_task, task=task))
self._delayed_tasks.append(task_name)
def register_listener(self, listener: Callable, event: str) -> Any: def register_listener(self, listener: Callable, event: str) -> Any:
""" """
@ -246,12 +262,20 @@ class Sanic(BaseSanic):
""" """
try: try:
_event = ListenerEvent(event) _event = ListenerEvent[event.upper()]
except ValueError: except (ValueError, AttributeError):
valid = ", ".join(ListenerEvent.__members__.values()) valid = ", ".join(
map(lambda x: x.lower(), ListenerEvent.__members__.keys())
)
raise InvalidUsage(f"Invalid event: {event}. Use one of: {valid}") raise InvalidUsage(f"Invalid event: {event}. Use one of: {valid}")
self.listeners[_event].append(listener) if "." in _event:
self.signal(_event.value)(
partial(self._listener, listener=listener)
)
else:
self.listeners[_event.value].append(listener)
return listener return listener
def register_middleware(self, middleware, attach_to: str = "request"): def register_middleware(self, middleware, attach_to: str = "request"):
@ -379,11 +403,17 @@ class Sanic(BaseSanic):
*, *,
condition: Optional[Dict[str, str]] = None, condition: Optional[Dict[str, str]] = None,
context: Optional[Dict[str, Any]] = None, context: Optional[Dict[str, Any]] = None,
fail_not_found: bool = True,
inline: bool = False,
reverse: bool = False,
) -> Coroutine[Any, Any, Awaitable[Any]]: ) -> Coroutine[Any, Any, Awaitable[Any]]:
return self.signal_router.dispatch( return self.signal_router.dispatch(
event, event,
context=context, context=context,
condition=condition, condition=condition,
inline=inline,
reverse=reverse,
fail_not_found=fail_not_found,
) )
async def event( async def event(
@ -659,7 +689,7 @@ class Sanic(BaseSanic):
async def handle_exception( async def handle_exception(
self, request: Request, exception: BaseException self, request: Request, exception: BaseException
): ): # no cov
""" """
A handler that catches specific exceptions and outputs a response. A handler that catches specific exceptions and outputs a response.
@ -669,6 +699,12 @@ class Sanic(BaseSanic):
:type exception: BaseException :type exception: BaseException
:raises ServerError: response 500 :raises ServerError: response 500
""" """
await self.dispatch(
"http.lifecycle.exception",
inline=True,
context={"request": request, "exception": exception},
)
# -------------------------------------------- # # -------------------------------------------- #
# Request Middleware # Request Middleware
# -------------------------------------------- # # -------------------------------------------- #
@ -715,7 +751,7 @@ class Sanic(BaseSanic):
f"Invalid response type {response!r} (need HTTPResponse)" f"Invalid response type {response!r} (need HTTPResponse)"
) )
async def handle_request(self, request: Request): async def handle_request(self, request: Request): # no cov
"""Take a request from the HTTP Server and return a response object """Take a request from the HTTP Server and return a response object
to be sent back The HTTP Server only expects a response object, so to be sent back The HTTP Server only expects a response object, so
exception handling must be done here exception handling must be done here
@ -723,10 +759,22 @@ class Sanic(BaseSanic):
:param request: HTTP Request object :param request: HTTP Request object
:return: Nothing :return: Nothing
""" """
await self.dispatch(
"http.lifecycle.handle",
inline=True,
context={"request": request},
)
# Define `response` var here to remove warnings about # Define `response` var here to remove warnings about
# allocation before assignment below. # allocation before assignment below.
response = None response = None
try: try:
await self.dispatch(
"http.routing.before",
inline=True,
context={"request": request},
)
# Fetch handler from router # Fetch handler from router
route, handler, kwargs = self.router.get( route, handler, kwargs = self.router.get(
request.path, request.path,
@ -734,9 +782,20 @@ class Sanic(BaseSanic):
request.headers.getone("host", None), request.headers.getone("host", None),
) )
request._match_info = kwargs request._match_info = {**kwargs}
request.route = route request.route = route
await self.dispatch(
"http.routing.after",
inline=True,
context={
"request": request,
"route": route,
"kwargs": kwargs,
"handler": handler,
},
)
if ( if (
request.stream request.stream
and request.stream.request_body and request.stream.request_body
@ -772,7 +831,7 @@ class Sanic(BaseSanic):
) )
# Run response handler # Run response handler
response = handler(request, **kwargs) response = handler(request, **request.match_info)
if isawaitable(response): if isawaitable(response):
response = await response response = await response
@ -783,6 +842,14 @@ class Sanic(BaseSanic):
# Make sure that response is finished / run StreamingHTTP callback # Make sure that response is finished / run StreamingHTTP callback
if isinstance(response, BaseHTTPResponse): if isinstance(response, BaseHTTPResponse):
await self.dispatch(
"http.lifecycle.response",
inline=True,
context={
"request": request,
"response": response,
},
)
await response.send(end_stream=True) await response.send(end_stream=True)
else: else:
if not hasattr(handler, "is_websocket"): if not hasattr(handler, "is_websocket"):
@ -1078,11 +1145,6 @@ class Sanic(BaseSanic):
run_async=return_asyncio_server, run_async=return_asyncio_server,
) )
# Trigger before_start events
await self.trigger_events(
server_settings.get("before_start", []),
server_settings.get("loop"),
)
main_start = server_settings.pop("main_start", None) main_start = server_settings.pop("main_start", None)
main_stop = server_settings.pop("main_stop", None) main_stop = server_settings.pop("main_stop", None)
if main_start or main_stop: if main_start or main_stop:
@ -1095,17 +1157,9 @@ class Sanic(BaseSanic):
asyncio_server_kwargs=asyncio_server_kwargs, **server_settings asyncio_server_kwargs=asyncio_server_kwargs, **server_settings
) )
async def trigger_events(self, events, loop): async def _run_request_middleware(
"""Trigger events (functions or async) self, request, request_name=None
:param events: one or more sync or async functions to execute ): # no cov
:param loop: event loop
"""
for event in events:
result = event(loop)
if isawaitable(result):
await result
async def _run_request_middleware(self, request, request_name=None):
# The if improves speed. I don't know why # The if improves speed. I don't know why
named_middleware = self.named_request_middleware.get( named_middleware = self.named_request_middleware.get(
request_name, deque() request_name, deque()
@ -1118,25 +1172,67 @@ class Sanic(BaseSanic):
request.request_middleware_started = True request.request_middleware_started = True
for middleware in applicable_middleware: for middleware in applicable_middleware:
await self.dispatch(
"http.middleware.before",
inline=True,
context={
"request": request,
"response": None,
},
condition={"attach_to": "request"},
)
response = middleware(request) response = middleware(request)
if isawaitable(response): if isawaitable(response):
response = await response response = await response
await self.dispatch(
"http.middleware.after",
inline=True,
context={
"request": request,
"response": None,
},
condition={"attach_to": "request"},
)
if response: if response:
return response return response
return None return None
async def _run_response_middleware( async def _run_response_middleware(
self, request, response, request_name=None self, request, response, request_name=None
): ): # no cov
named_middleware = self.named_response_middleware.get( named_middleware = self.named_response_middleware.get(
request_name, deque() request_name, deque()
) )
applicable_middleware = self.response_middleware + named_middleware applicable_middleware = self.response_middleware + named_middleware
if applicable_middleware: if applicable_middleware:
for middleware in applicable_middleware: for middleware in applicable_middleware:
await self.dispatch(
"http.middleware.before",
inline=True,
context={
"request": request,
"response": response,
},
condition={"attach_to": "response"},
)
_response = middleware(request, response) _response = middleware(request, response)
if isawaitable(_response): if isawaitable(_response):
_response = await _response _response = await _response
await self.dispatch(
"http.middleware.after",
inline=True,
context={
"request": request,
"response": _response if _response else response,
},
condition={"attach_to": "response"},
)
if _response: if _response:
response = _response response = _response
if isinstance(response, BaseHTTPResponse): if isinstance(response, BaseHTTPResponse):
@ -1162,10 +1258,6 @@ class Sanic(BaseSanic):
): ):
"""Helper function used by `run` and `create_server`.""" """Helper function used by `run` and `create_server`."""
self.listeners["before_server_start"] = [
self.finalize
] + self.listeners["before_server_start"]
if isinstance(ssl, dict): if isinstance(ssl, dict):
# try common aliaseses # try common aliaseses
cert = ssl.get("cert") or ssl.get("certificate") cert = ssl.get("cert") or ssl.get("certificate")
@ -1202,10 +1294,6 @@ class Sanic(BaseSanic):
# Register start/stop events # Register start/stop events
for event_name, settings_name, reverse in ( for event_name, settings_name, reverse in (
("before_server_start", "before_start", False),
("after_server_start", "after_start", False),
("before_server_stop", "before_stop", True),
("after_server_stop", "after_stop", True),
("main_process_start", "main_start", False), ("main_process_start", "main_start", False),
("main_process_stop", "main_stop", True), ("main_process_stop", "main_stop", True),
): ):
@ -1253,20 +1341,44 @@ class Sanic(BaseSanic):
return ".".join(parts) return ".".join(parts)
@classmethod @classmethod
def _loop_add_task(cls, task, app, loop): def _prep_task(cls, task, app, loop):
if callable(task): if callable(task):
try: try:
loop.create_task(task(app)) task = task(app)
except TypeError: except TypeError:
loop.create_task(task()) task = task()
else:
loop.create_task(task) return task
@classmethod
def _loop_add_task(cls, task, app, loop):
prepped = cls._prep_task(task, app, loop)
loop.create_task(prepped)
@classmethod @classmethod
def _cancel_websocket_tasks(cls, app, loop): def _cancel_websocket_tasks(cls, app, loop):
for task in app.websocket_tasks: for task in app.websocket_tasks:
task.cancel() task.cancel()
@staticmethod
async def dispatch_delayed_tasks(app, loop):
for name in app._delayed_tasks:
await app.dispatch(name, context={"app": app, "loop": loop})
app._delayed_tasks.clear()
@staticmethod
async def run_delayed_task(app, loop, task):
prepped = app._prep_task(task, app, loop)
await prepped
@staticmethod
async def _listener(
app: Sanic, loop: AbstractEventLoop, listener: ListenerType
):
maybe_coro = listener(app, loop)
if maybe_coro and isawaitable(maybe_coro):
await maybe_coro
# -------------------------------------------------------------------- # # -------------------------------------------------------------------- #
# ASGI # ASGI
# -------------------------------------------------------------------- # # -------------------------------------------------------------------- #
@ -1340,15 +1452,51 @@ class Sanic(BaseSanic):
raise SanicException(f'Sanic app name "{name}" not found.') raise SanicException(f'Sanic app name "{name}" not found.')
# -------------------------------------------------------------------- # # -------------------------------------------------------------------- #
# Static methods # Lifecycle
# -------------------------------------------------------------------- # # -------------------------------------------------------------------- #
@staticmethod def finalize(self):
async def finalize(app, _):
try: try:
app.router.finalize() self.router.finalize()
if app.signal_router.routes:
app.signal_router.finalize() # noqa
except FinalizationError as e: except FinalizationError as e:
if not Sanic.test_mode: if not Sanic.test_mode:
raise e # noqa raise e
def signalize(self):
try:
self.signal_router.finalize()
except FinalizationError as e:
if not Sanic.test_mode:
raise e
async def _startup(self):
self.signalize()
self.finalize()
TouchUp.run(self)
async def _server_event(
self,
concern: str,
action: str,
loop: Optional[AbstractEventLoop] = None,
) -> None:
event = f"server.{concern}.{action}"
if action not in ("before", "after") or concern not in (
"init",
"shutdown",
):
raise SanicException(f"Invalid server event: {event}")
logger.debug(f"Triggering server events: {event}")
reverse = concern == "shutdown"
if loop is None:
loop = self.loop
await self.dispatch(
event,
fail_not_found=False,
reverse=reverse,
inline=True,
context={
"app": self,
"loop": loop,
},
)

View File

@ -1,6 +1,5 @@
import warnings import warnings
from inspect import isawaitable
from typing import Optional from typing import Optional
from urllib.parse import quote from urllib.parse import quote
@ -18,14 +17,20 @@ class Lifespan:
def __init__(self, asgi_app: "ASGIApp") -> None: def __init__(self, asgi_app: "ASGIApp") -> None:
self.asgi_app = asgi_app self.asgi_app = asgi_app
if "before_server_start" in self.asgi_app.sanic_app.listeners: if (
"server.init.before"
in self.asgi_app.sanic_app.signal_router.name_index
):
warnings.warn( warnings.warn(
'You have set a listener for "before_server_start" ' 'You have set a listener for "before_server_start" '
"in ASGI mode. " "in ASGI mode. "
"It will be executed as early as possible, but not before " "It will be executed as early as possible, but not before "
"the ASGI server is started." "the ASGI server is started."
) )
if "after_server_stop" in self.asgi_app.sanic_app.listeners: if (
"server.shutdown.after"
in self.asgi_app.sanic_app.signal_router.name_index
):
warnings.warn( warnings.warn(
'You have set a listener for "after_server_stop" ' 'You have set a listener for "after_server_stop" '
"in ASGI mode. " "in ASGI mode. "
@ -42,19 +47,9 @@ class Lifespan:
in sequence since the ASGI lifespan protocol only supports a single in sequence since the ASGI lifespan protocol only supports a single
startup event. startup event.
""" """
self.asgi_app.sanic_app.router.finalize() await self.asgi_app.sanic_app._startup()
if self.asgi_app.sanic_app.signal_router.routes: await self.asgi_app.sanic_app._server_event("init", "before")
self.asgi_app.sanic_app.signal_router.finalize() await self.asgi_app.sanic_app._server_event("init", "after")
listeners = self.asgi_app.sanic_app.listeners.get(
"before_server_start", []
) + self.asgi_app.sanic_app.listeners.get("after_server_start", [])
for handler in listeners:
response = handler(
self.asgi_app.sanic_app, self.asgi_app.sanic_app.loop
)
if response and isawaitable(response):
await response
async def shutdown(self) -> None: async def shutdown(self) -> None:
""" """
@ -65,16 +60,8 @@ class Lifespan:
in sequence since the ASGI lifespan protocol only supports a single in sequence since the ASGI lifespan protocol only supports a single
shutdown event. shutdown event.
""" """
listeners = self.asgi_app.sanic_app.listeners.get( await self.asgi_app.sanic_app._server_event("shutdown", "before")
"before_server_stop", [] await self.asgi_app.sanic_app._server_event("shutdown", "after")
) + self.asgi_app.sanic_app.listeners.get("after_server_stop", [])
for handler in listeners:
response = handler(
self.asgi_app.sanic_app, self.asgi_app.sanic_app.loop
)
if response and isawaitable(response):
await response
async def __call__( async def __call__(
self, scope: ASGIScope, receive: ASGIReceive, send: ASGISend self, scope: ASGIScope, receive: ASGIReceive, send: ASGISend

View File

@ -21,6 +21,7 @@ from sanic.exceptions import (
from sanic.headers import format_http1_response from sanic.headers import format_http1_response
from sanic.helpers import has_message_body from sanic.helpers import has_message_body
from sanic.log import access_logger, error_logger, logger from sanic.log import access_logger, error_logger, logger
from sanic.touchup import TouchUpMeta
class Stage(Enum): class Stage(Enum):
@ -45,7 +46,7 @@ class Stage(Enum):
HTTP_CONTINUE = b"HTTP/1.1 100 Continue\r\n\r\n" HTTP_CONTINUE = b"HTTP/1.1 100 Continue\r\n\r\n"
class Http: class Http(metaclass=TouchUpMeta):
""" """
Internal helper for managing the HTTP request/response cycle Internal helper for managing the HTTP request/response cycle
@ -67,9 +68,15 @@ class Http:
HEADER_CEILING = 16_384 HEADER_CEILING = 16_384
HEADER_MAX_SIZE = 0 HEADER_MAX_SIZE = 0
__touchup__ = (
"http1_request_header",
"http1_response_header",
"read",
)
__slots__ = [ __slots__ = [
"_send", "_send",
"_receive_more", "_receive_more",
"dispatch",
"recv_buffer", "recv_buffer",
"protocol", "protocol",
"expecting_continue", "expecting_continue",
@ -97,6 +104,7 @@ class Http:
self.protocol = protocol self.protocol = protocol
self.keep_alive = True self.keep_alive = True
self.stage: Stage = Stage.IDLE self.stage: Stage = Stage.IDLE
self.dispatch = self.protocol.app.dispatch
self.init_for_request() self.init_for_request()
def init_for_request(self): def init_for_request(self):
@ -183,7 +191,7 @@ class Http:
if not self.recv_buffer: if not self.recv_buffer:
await self._receive_more() await self._receive_more()
async def http1_request_header(self): async def http1_request_header(self): # no cov
""" """
Receive and parse request header into self.request. Receive and parse request header into self.request.
""" """
@ -212,6 +220,12 @@ class Http:
reqline, *split_headers = raw_headers.split("\r\n") reqline, *split_headers = raw_headers.split("\r\n")
method, self.url, protocol = reqline.split(" ") method, self.url, protocol = reqline.split(" ")
await self.dispatch(
"http.lifecycle.read_head",
inline=True,
context={"head": bytes(head)},
)
if protocol == "HTTP/1.1": if protocol == "HTTP/1.1":
self.keep_alive = True self.keep_alive = True
elif protocol == "HTTP/1.0": elif protocol == "HTTP/1.0":
@ -250,6 +264,11 @@ class Http:
transport=self.protocol.transport, transport=self.protocol.transport,
app=self.protocol.app, app=self.protocol.app,
) )
await self.dispatch(
"http.lifecycle.request",
inline=True,
context={"request": request},
)
# Prepare for request body # Prepare for request body
self.request_bytes_left = self.request_bytes = 0 self.request_bytes_left = self.request_bytes = 0
@ -280,7 +299,7 @@ class Http:
async def http1_response_header( async def http1_response_header(
self, data: bytes, end_stream: bool self, data: bytes, end_stream: bool
) -> None: ) -> None: # no cov
res = self.response res = self.response
# Compatibility with simple response body # Compatibility with simple response body
@ -469,7 +488,7 @@ class Http:
if data: if data:
yield data yield data
async def read(self) -> Optional[bytes]: async def read(self) -> Optional[bytes]: # no cov
""" """
Read some bytes of request body. Read some bytes of request body.
""" """
@ -543,6 +562,12 @@ class Http:
self.request_bytes_left -= size self.request_bytes_left -= size
await self.dispatch(
"http.lifecycle.read_body",
inline=True,
context={"body": data},
)
return data return data
# Response methods # Response methods

View File

@ -9,10 +9,10 @@ class ListenerEvent(str, Enum):
def _generate_next_value_(name: str, *args) -> str: # type: ignore def _generate_next_value_(name: str, *args) -> str: # type: ignore
return name.lower() return name.lower()
BEFORE_SERVER_START = auto() BEFORE_SERVER_START = "server.init.before"
AFTER_SERVER_START = auto() AFTER_SERVER_START = "server.init.after"
BEFORE_SERVER_STOP = auto() BEFORE_SERVER_STOP = "server.shutdown.before"
AFTER_SERVER_STOP = auto() AFTER_SERVER_STOP = "server.shutdown.after"
MAIN_PROCESS_START = auto() MAIN_PROCESS_START = auto()
MAIN_PROCESS_STOP = auto() MAIN_PROCESS_STOP = auto()

View File

@ -23,7 +23,7 @@ class SignalMixin:
*, *,
apply: bool = True, apply: bool = True,
condition: Dict[str, Any] = None, condition: Dict[str, Any] = None,
) -> Callable[[SignalHandler], FutureSignal]: ) -> Callable[[SignalHandler], SignalHandler]:
""" """
For creating a signal handler, used similar to a route handler: For creating a signal handler, used similar to a route handler:
@ -54,7 +54,7 @@ class SignalMixin:
if apply: if apply:
self._apply_signal(future_signal) self._apply_signal(future_signal)
return future_signal return handler
return decorator return decorator

View File

@ -497,6 +497,10 @@ class Request:
""" """
return self._match_info return self._match_info
@match_info.setter
def match_info(self, value):
self._match_info = value
# Transport properties (obtained from local interface only) # Transport properties (obtained from local interface only)
@property @property

View File

@ -13,7 +13,7 @@ from typing import (
Union, Union,
) )
from sanic.models.handler_types import ListenerType from sanic.touchup.meta import TouchUpMeta
if TYPE_CHECKING: if TYPE_CHECKING:
@ -37,7 +37,7 @@ from time import monotonic as current_time
from sanic.compat import OS_IS_WINDOWS, ctrlc_workaround_for_windows from sanic.compat import OS_IS_WINDOWS, ctrlc_workaround_for_windows
from sanic.config import Config from sanic.config import Config
from sanic.exceptions import RequestTimeout, ServiceUnavailable from sanic.exceptions import RequestTimeout, SanicException, ServiceUnavailable
from sanic.http import Http, Stage from sanic.http import Http, Stage
from sanic.log import error_logger, logger from sanic.log import error_logger, logger
from sanic.models.protocol_types import TransportProtocol from sanic.models.protocol_types import TransportProtocol
@ -102,11 +102,15 @@ class ConnInfo:
self.client_port = addr[1] self.client_port = addr[1]
class HttpProtocol(asyncio.Protocol): class HttpProtocol(asyncio.Protocol, metaclass=TouchUpMeta):
""" """
This class provides a basic HTTP implementation of the sanic framework. This class provides a basic HTTP implementation of the sanic framework.
""" """
__touchup__ = (
"send",
"connection_task",
)
__slots__ = ( __slots__ = (
# app # app
"app", "app",
@ -185,7 +189,7 @@ class HttpProtocol(asyncio.Protocol):
self._time = current_time() self._time = current_time()
self.check_timeouts() self.check_timeouts()
async def connection_task(self): async def connection_task(self): # no cov
""" """
Run a HTTP connection. Run a HTTP connection.
@ -194,6 +198,11 @@ class HttpProtocol(asyncio.Protocol):
""" """
try: try:
self._setup_connection() self._setup_connection()
await self.app.dispatch(
"http.lifecycle.begin",
inline=True,
context={"conn_info": self.conn_info},
)
await self._http.http1() await self._http.http1()
except CancelledError: except CancelledError:
pass pass
@ -212,6 +221,13 @@ class HttpProtocol(asyncio.Protocol):
self.close() self.close()
except BaseException: except BaseException:
error_logger.exception("Closing failed") error_logger.exception("Closing failed")
finally:
await self.app.dispatch(
"http.lifecycle.complete",
inline=True,
context={"conn_info": self.conn_info},
)
...
async def receive_more(self): async def receive_more(self):
""" """
@ -259,13 +275,18 @@ class HttpProtocol(asyncio.Protocol):
except Exception: except Exception:
error_logger.exception("protocol.check_timeouts") error_logger.exception("protocol.check_timeouts")
async def send(self, data): async def send(self, data): # no cov
""" """
Writes data with backpressure control. Writes data with backpressure control.
""" """
await self._can_write.wait() await self._can_write.wait()
if self.transport.is_closing(): if self.transport.is_closing():
raise CancelledError raise CancelledError
await self.app.dispatch(
"http.lifecycle.send",
inline=True,
context={"data": data},
)
self.transport.write(data) self.transport.write(data)
self._time = current_time() self._time = current_time()
@ -359,52 +380,54 @@ class AsyncioServer:
a user who needs to manage the server lifecycle manually. a user who needs to manage the server lifecycle manually.
""" """
__slots__ = ( __slots__ = ("app", "connections", "loop", "serve_coro", "server", "init")
"loop",
"serve_coro",
"_after_start",
"_before_stop",
"_after_stop",
"server",
"connections",
)
def __init__( def __init__(
self, self,
app,
loop, loop,
serve_coro, serve_coro,
connections, connections,
after_start: Optional[Iterable[ListenerType]],
before_stop: Optional[Iterable[ListenerType]],
after_stop: Optional[Iterable[ListenerType]],
): ):
# Note, Sanic already called "before_server_start" events # Note, Sanic already called "before_server_start" events
# before this helper was even created. So we don't need it here. # before this helper was even created. So we don't need it here.
self.app = app
self.connections = connections
self.loop = loop self.loop = loop
self.serve_coro = serve_coro self.serve_coro = serve_coro
self._after_start = after_start
self._before_stop = before_stop
self._after_stop = after_stop
self.server = None self.server = None
self.connections = connections self.init = False
def startup(self):
"""
Trigger "before_server_start" events
"""
self.init = True
return self.app._startup()
def before_start(self):
"""
Trigger "before_server_start" events
"""
return self._server_event("init", "before")
def after_start(self): def after_start(self):
""" """
Trigger "after_server_start" events Trigger "after_server_start" events
""" """
trigger_events(self._after_start, self.loop) return self._server_event("init", "after")
def before_stop(self): def before_stop(self):
""" """
Trigger "before_server_stop" events Trigger "before_server_stop" events
""" """
trigger_events(self._before_stop, self.loop) return self._server_event("shutdown", "before")
def after_stop(self): def after_stop(self):
""" """
Trigger "after_server_stop" events Trigger "after_server_stop" events
""" """
trigger_events(self._after_stop, self.loop) return self._server_event("shutdown", "after")
def is_serving(self) -> bool: def is_serving(self) -> bool:
if self.server: if self.server:
@ -442,6 +465,14 @@ class AsyncioServer:
"of asyncio or uvloop." "of asyncio or uvloop."
) )
def _server_event(self, concern: str, action: str):
if not self.init:
raise SanicException(
"Cannot dispatch server event without "
"first running server.startup()"
)
return self.app._server_event(concern, action, loop=self.loop)
def __await__(self): def __await__(self):
""" """
Starts the asyncio server, returns AsyncServerCoro Starts the asyncio server, returns AsyncServerCoro
@ -456,11 +487,7 @@ class AsyncioServer:
def serve( def serve(
host, host,
port, port,
app, app: Sanic,
before_start: Optional[Iterable[ListenerType]] = None,
after_start: Optional[Iterable[ListenerType]] = None,
before_stop: Optional[Iterable[ListenerType]] = None,
after_stop: Optional[Iterable[ListenerType]] = None,
ssl: Optional[SSLContext] = None, ssl: Optional[SSLContext] = None,
sock: Optional[socket.socket] = None, sock: Optional[socket.socket] = None,
unix: Optional[str] = None, unix: Optional[str] = None,
@ -542,15 +569,14 @@ def serve(
if run_async: if run_async:
return AsyncioServer( return AsyncioServer(
app=app,
loop=loop, loop=loop,
serve_coro=server_coroutine, serve_coro=server_coroutine,
connections=connections, connections=connections,
after_start=after_start,
before_stop=before_stop,
after_stop=after_stop,
) )
trigger_events(before_start, loop) loop.run_until_complete(app._startup())
loop.run_until_complete(app._server_event("init", "before"))
try: try:
http_server = loop.run_until_complete(server_coroutine) http_server = loop.run_until_complete(server_coroutine)
@ -558,8 +584,6 @@ def serve(
error_logger.exception("Unable to start server") error_logger.exception("Unable to start server")
return return
trigger_events(after_start, loop)
# Ignore SIGINT when run_multiple # Ignore SIGINT when run_multiple
if run_multiple: if run_multiple:
signal_func(SIGINT, SIG_IGN) signal_func(SIGINT, SIG_IGN)
@ -571,6 +595,8 @@ def serve(
else: else:
for _signal in [SIGTERM] if run_multiple else [SIGINT, SIGTERM]: for _signal in [SIGTERM] if run_multiple else [SIGINT, SIGTERM]:
loop.add_signal_handler(_signal, app.stop) loop.add_signal_handler(_signal, app.stop)
loop.run_until_complete(app._server_event("init", "after"))
pid = os.getpid() pid = os.getpid()
try: try:
logger.info("Starting worker [%s]", pid) logger.info("Starting worker [%s]", pid)
@ -579,7 +605,7 @@ def serve(
logger.info("Stopping worker [%s]", pid) logger.info("Stopping worker [%s]", pid)
# Run the on_stop function if provided # Run the on_stop function if provided
trigger_events(before_stop, loop) loop.run_until_complete(app._server_event("shutdown", "before"))
# Wait for event loop to finish and all connections to drain # Wait for event loop to finish and all connections to drain
http_server.close() http_server.close()
@ -611,8 +637,7 @@ def serve(
_shutdown = asyncio.gather(*coros) _shutdown = asyncio.gather(*coros)
loop.run_until_complete(_shutdown) loop.run_until_complete(_shutdown)
loop.run_until_complete(app._server_event("shutdown", "after"))
trigger_events(after_stop, loop)
remove_unix_socket(unix) remove_unix_socket(unix)

View File

@ -10,13 +10,39 @@ from sanic_routing.exceptions import NotFound # type: ignore
from sanic_routing.utils import path_to_parts # type: ignore from sanic_routing.utils import path_to_parts # type: ignore
from sanic.exceptions import InvalidSignal from sanic.exceptions import InvalidSignal
from sanic.log import error_logger, logger
from sanic.models.handler_types import SignalHandler from sanic.models.handler_types import SignalHandler
RESERVED_NAMESPACES = ( RESERVED_NAMESPACES = {
"server", "server": (
"http", # "server.main.start",
) # "server.main.stop",
"server.init.before",
"server.init.after",
"server.shutdown.before",
"server.shutdown.after",
),
"http": (
"http.lifecycle.begin",
"http.lifecycle.complete",
"http.lifecycle.exception",
"http.lifecycle.handle",
"http.lifecycle.read_body",
"http.lifecycle.read_head",
"http.lifecycle.request",
"http.lifecycle.response",
"http.routing.after",
"http.routing.before",
"http.lifecycle.send",
"http.middleware.after",
"http.middleware.before",
),
}
def _blank():
...
class Signal(Route): class Signal(Route):
@ -59,8 +85,13 @@ class SignalRouter(BaseRouter):
terms.append(extra) terms.append(extra)
raise NotFound(message % tuple(terms)) raise NotFound(message % tuple(terms))
# Regex routes evaluate and can extract params directly. They are set
# on param_basket["__params__"]
params = param_basket["__params__"] params = param_basket["__params__"]
if not params: if not params:
# If param_basket["__params__"] does not exist, we might have
# param_basket["__matches__"], which are indexed based matches
# on path segments. They should already be cast types.
params = { params = {
param.name: param_basket["__matches__"][idx] param.name: param_basket["__matches__"][idx]
for idx, param in group.params.items() for idx, param in group.params.items()
@ -73,8 +104,18 @@ class SignalRouter(BaseRouter):
event: str, event: str,
context: Optional[Dict[str, Any]] = None, context: Optional[Dict[str, Any]] = None,
condition: Optional[Dict[str, str]] = None, condition: Optional[Dict[str, str]] = None,
) -> None: fail_not_found: bool = True,
group, handlers, params = self.get(event, condition=condition) reverse: bool = False,
) -> Any:
try:
group, handlers, params = self.get(event, condition=condition)
except NotFound as e:
if fail_not_found:
raise e
else:
if self.ctx.app.debug:
error_logger.warning(str(e))
return None
events = [signal.ctx.event for signal in group] events = [signal.ctx.event for signal in group]
for signal_event in events: for signal_event in events:
@ -82,12 +123,19 @@ class SignalRouter(BaseRouter):
if context: if context:
params.update(context) params.update(context)
if not reverse:
handlers = handlers[::-1]
try: try:
for handler in handlers: for handler in handlers:
if condition is None or condition == handler.__requirements__: if condition is None or condition == handler.__requirements__:
maybe_coroutine = handler(**params) maybe_coroutine = handler(**params)
if isawaitable(maybe_coroutine): if isawaitable(maybe_coroutine):
await maybe_coroutine retval = await maybe_coroutine
if retval:
return retval
elif maybe_coroutine:
return maybe_coroutine
return None
finally: finally:
for signal_event in events: for signal_event in events:
signal_event.clear() signal_event.clear()
@ -98,14 +146,23 @@ class SignalRouter(BaseRouter):
*, *,
context: Optional[Dict[str, Any]] = None, context: Optional[Dict[str, Any]] = None,
condition: Optional[Dict[str, str]] = None, condition: Optional[Dict[str, str]] = None,
) -> asyncio.Task: fail_not_found: bool = True,
task = self.ctx.loop.create_task( inline: bool = False,
self._dispatch( reverse: bool = False,
event, ) -> Union[asyncio.Task, Any]:
context=context, dispatch = self._dispatch(
condition=condition, event,
) context=context,
condition=condition,
fail_not_found=fail_not_found and inline,
reverse=reverse,
) )
logger.debug(f"Dispatching signal: {event}")
if inline:
return await dispatch
task = asyncio.get_running_loop().create_task(dispatch)
await asyncio.sleep(0) await asyncio.sleep(0)
return task return task
@ -131,7 +188,9 @@ class SignalRouter(BaseRouter):
append=True, append=True,
) # type: ignore ) # type: ignore
def finalize(self, do_compile: bool = True): def finalize(self, do_compile: bool = True, do_optimize: bool = False):
self.add(_blank, "sanic.__signal__.__init__")
try: try:
self.ctx.loop = asyncio.get_running_loop() self.ctx.loop = asyncio.get_running_loop()
except RuntimeError: except RuntimeError:
@ -140,7 +199,7 @@ class SignalRouter(BaseRouter):
for signal in self.routes: for signal in self.routes:
signal.ctx.event = asyncio.Event() signal.ctx.event = asyncio.Event()
return super().finalize(do_compile=do_compile) return super().finalize(do_compile=do_compile, do_optimize=do_optimize)
def _build_event_parts(self, event: str) -> Tuple[str, str, str]: def _build_event_parts(self, event: str) -> Tuple[str, str, str]:
parts = path_to_parts(event, self.delimiter) parts = path_to_parts(event, self.delimiter)
@ -151,7 +210,11 @@ class SignalRouter(BaseRouter):
): ):
raise InvalidSignal("Invalid signal event: %s" % event) raise InvalidSignal("Invalid signal event: %s" % event)
if parts[0] in RESERVED_NAMESPACES: if (
parts[0] in RESERVED_NAMESPACES
and event not in RESERVED_NAMESPACES[parts[0]]
and not (parts[2].startswith("<") and parts[2].endswith(">"))
):
raise InvalidSignal( raise InvalidSignal(
"Cannot declare reserved signal event: %s" % event "Cannot declare reserved signal event: %s" % event
) )

View File

@ -0,0 +1,8 @@
from .meta import TouchUpMeta
from .service import TouchUp
__all__ = (
"TouchUp",
"TouchUpMeta",
)

22
sanic/touchup/meta.py Normal file
View File

@ -0,0 +1,22 @@
from sanic.exceptions import SanicException
from .service import TouchUp
class TouchUpMeta(type):
def __new__(cls, name, bases, attrs, **kwargs):
gen_class = super().__new__(cls, name, bases, attrs, **kwargs)
methods = attrs.get("__touchup__")
attrs["__touched__"] = False
if methods:
for method in methods:
if method not in attrs:
raise SanicException(
"Cannot perform touchup on non-existent method: "
f"{name}.{method}"
)
TouchUp.register(gen_class, method)
return gen_class

View File

@ -0,0 +1,5 @@
from .base import BaseScheme
from .ode import OptionalDispatchEvent # noqa
__all__ = ("BaseScheme",)

View File

@ -0,0 +1,20 @@
from abc import ABC, abstractmethod
from typing import Set, Type
class BaseScheme(ABC):
ident: str
_registry: Set[Type] = set()
def __init__(self, app) -> None:
self.app = app
@abstractmethod
def run(self, method, module_globals) -> None:
...
def __init_subclass__(cls):
BaseScheme._registry.add(cls)
def __call__(self, method, module_globals):
return self.run(method, module_globals)

View File

@ -0,0 +1,67 @@
from ast import Attribute, Await, Dict, Expr, NodeTransformer, parse
from inspect import getsource
from textwrap import dedent
from typing import Any
from sanic.log import logger
from .base import BaseScheme
class OptionalDispatchEvent(BaseScheme):
ident = "ODE"
def __init__(self, app) -> None:
super().__init__(app)
self._registered_events = [
signal.path for signal in app.signal_router.routes
]
def run(self, method, module_globals):
raw_source = getsource(method)
src = dedent(raw_source)
tree = parse(src)
node = RemoveDispatch(self._registered_events).visit(tree)
compiled_src = compile(node, method.__name__, "exec")
exec_locals: Dict[str, Any] = {}
exec(compiled_src, module_globals, exec_locals) # nosec
return exec_locals[method.__name__]
class RemoveDispatch(NodeTransformer):
def __init__(self, registered_events) -> None:
self._registered_events = registered_events
def visit_Expr(self, node: Expr) -> Any:
call = node.value
if isinstance(call, Await):
call = call.value
func = getattr(call, "func", None)
args = getattr(call, "args", None)
if not func or not args:
return node
if isinstance(func, Attribute) and func.attr == "dispatch":
event = args[0]
if hasattr(event, "s"):
event_name = getattr(event, "value", event.s)
if self._not_registered(event_name):
logger.debug(f"Disabling event: {event_name}")
return None
return node
def _not_registered(self, event_name):
dynamic = []
for event in self._registered_events:
if event.endswith(">"):
namespace_concern, _ = event.rsplit(".", 1)
dynamic.append(namespace_concern)
namespace_concern, _ = event_name.rsplit(".", 1)
return (
event_name not in self._registered_events
and namespace_concern not in dynamic
)

33
sanic/touchup/service.py Normal file
View File

@ -0,0 +1,33 @@
from inspect import getmembers, getmodule
from typing import Set, Tuple, Type
from .schemes import BaseScheme
class TouchUp:
_registry: Set[Tuple[Type, str]] = set()
@classmethod
def run(cls, app):
for target, method_name in cls._registry:
method = getattr(target, method_name)
if app.test_mode:
placeholder = f"_{method_name}"
if hasattr(target, placeholder):
method = getattr(target, placeholder)
else:
setattr(target, placeholder, method)
module = getmodule(target)
module_globals = dict(getmembers(module))
for scheme in BaseScheme._registry:
modified = scheme(app)(method, module_globals)
setattr(target, method_name, modified)
target.__touched__ = True
@classmethod
def register(cls, target, method_name):
cls._registry.add((target, method_name))

View File

@ -8,7 +8,7 @@ import traceback
from gunicorn.workers import base # type: ignore from gunicorn.workers import base # type: ignore
from sanic.log import logger from sanic.log import logger
from sanic.server import HttpProtocol, Signal, serve, trigger_events from sanic.server import HttpProtocol, Signal, serve
from sanic.websocket import WebSocketProtocol from sanic.websocket import WebSocketProtocol
@ -68,10 +68,10 @@ class GunicornWorker(base.Worker):
) )
self._server_settings["signal"] = self.signal self._server_settings["signal"] = self.signal
self._server_settings.pop("sock") self._server_settings.pop("sock")
trigger_events( self._await(self.app.callable._startup())
self._server_settings.get("before_start", []), self.loop self._await(
self.app.callable._server_event("init", "before", loop=self.loop)
) )
self._server_settings["before_start"] = ()
main_start = self._server_settings.pop("main_start", None) main_start = self._server_settings.pop("main_start", None)
main_stop = self._server_settings.pop("main_stop", None) main_stop = self._server_settings.pop("main_stop", None)
@ -82,24 +82,29 @@ class GunicornWorker(base.Worker):
"with GunicornWorker" "with GunicornWorker"
) )
self._runner = asyncio.ensure_future(self._run(), loop=self.loop)
try: try:
self.loop.run_until_complete(self._runner) self._await(self._run())
self.app.callable.is_running = True self.app.callable.is_running = True
trigger_events( self._await(
self._server_settings.get("after_start", []), self.loop self.app.callable._server_event(
"init", "after", loop=self.loop
)
) )
self.loop.run_until_complete(self._check_alive()) self.loop.run_until_complete(self._check_alive())
trigger_events( self._await(
self._server_settings.get("before_stop", []), self.loop self.app.callable._server_event(
"shutdown", "before", loop=self.loop
)
) )
self.loop.run_until_complete(self.close()) self.loop.run_until_complete(self.close())
except BaseException: except BaseException:
traceback.print_exc() traceback.print_exc()
finally: finally:
try: try:
trigger_events( self._await(
self._server_settings.get("after_stop", []), self.loop self.app.callable._server_event(
"shutdown", "after", loop=self.loop
)
) )
except BaseException: except BaseException:
traceback.print_exc() traceback.print_exc()
@ -238,3 +243,7 @@ class GunicornWorker(base.Worker):
self.exit_code = 1 self.exit_code = 1
self.cfg.worker_abort(self) self.cfg.worker_abort(self)
sys.exit(1) sys.exit(1)
def _await(self, coro):
fut = asyncio.ensure_future(coro, loop=self.loop)
self.loop.run_until_complete(fut)

View File

@ -93,7 +93,7 @@ requirements = [
] ]
tests_require = [ tests_require = [
"sanic-testing==0.7.0b1", "sanic-testing>=0.7.0b2",
"pytest==5.2.1", "pytest==5.2.1",
"coverage==5.3", "coverage==5.3",
"gunicorn==20.0.4", "gunicorn==20.0.4",

View File

@ -1,3 +1,5 @@
import asyncio
import logging
import random import random
import re import re
import string import string
@ -9,10 +11,12 @@ from typing import Tuple
import pytest import pytest
from sanic_routing.exceptions import RouteExists from sanic_routing.exceptions import RouteExists
from sanic_testing.testing import PORT
from sanic import Sanic from sanic import Sanic
from sanic.constants import HTTP_METHODS from sanic.constants import HTTP_METHODS
from sanic.router import Router from sanic.router import Router
from sanic.touchup.service import TouchUp
slugify = re.compile(r"[^a-zA-Z0-9_\-]") slugify = re.compile(r"[^a-zA-Z0-9_\-]")
@ -23,11 +27,6 @@ if sys.platform in ["win32", "cygwin"]:
collect_ignore = ["test_worker.py"] collect_ignore = ["test_worker.py"]
@pytest.fixture
def caplog(caplog):
yield caplog
async def _handler(request): async def _handler(request):
""" """
Dummy placeholder method used for route resolver when creating a new Dummy placeholder method used for route resolver when creating a new
@ -41,33 +40,32 @@ async def _handler(request):
TYPE_TO_GENERATOR_MAP = { TYPE_TO_GENERATOR_MAP = {
"string": lambda: "".join( "str": lambda: "".join(
[random.choice(string.ascii_lowercase) for _ in range(4)] [random.choice(string.ascii_lowercase) for _ in range(4)]
), ),
"int": lambda: random.choice(range(1000000)), "int": lambda: random.choice(range(1000000)),
"number": lambda: random.random(), "float": lambda: random.random(),
"alpha": lambda: "".join( "alpha": lambda: "".join(
[random.choice(string.ascii_lowercase) for _ in range(4)] [random.choice(string.ascii_lowercase) for _ in range(4)]
), ),
"uuid": lambda: str(uuid.uuid1()), "uuid": lambda: str(uuid.uuid1()),
} }
CACHE = {}
class RouteStringGenerator: class RouteStringGenerator:
ROUTE_COUNT_PER_DEPTH = 100 ROUTE_COUNT_PER_DEPTH = 100
HTTP_METHODS = HTTP_METHODS HTTP_METHODS = HTTP_METHODS
ROUTE_PARAM_TYPES = ["string", "int", "number", "alpha", "uuid"] ROUTE_PARAM_TYPES = ["str", "int", "float", "alpha", "uuid"]
def generate_random_direct_route(self, max_route_depth=4): def generate_random_direct_route(self, max_route_depth=4):
routes = [] routes = []
for depth in range(1, max_route_depth + 1): for depth in range(1, max_route_depth + 1):
for _ in range(self.ROUTE_COUNT_PER_DEPTH): for _ in range(self.ROUTE_COUNT_PER_DEPTH):
route = "/".join( route = "/".join(
[ [TYPE_TO_GENERATOR_MAP.get("str")() for _ in range(depth)]
TYPE_TO_GENERATOR_MAP.get("string")()
for _ in range(depth)
]
) )
route = route.replace(".", "", -1) route = route.replace(".", "", -1)
route_detail = (random.choice(self.HTTP_METHODS), route) route_detail = (random.choice(self.HTTP_METHODS), route)
@ -83,7 +81,7 @@ class RouteStringGenerator:
new_route_part = "/".join( new_route_part = "/".join(
[ [
"<{}:{}>".format( "<{}:{}>".format(
TYPE_TO_GENERATOR_MAP.get("string")(), TYPE_TO_GENERATOR_MAP.get("str")(),
random.choice(self.ROUTE_PARAM_TYPES), random.choice(self.ROUTE_PARAM_TYPES),
) )
for _ in range(max_route_depth - current_length) for _ in range(max_route_depth - current_length)
@ -98,7 +96,7 @@ class RouteStringGenerator:
def generate_url_for_template(template): def generate_url_for_template(template):
url = template url = template
for pattern, param_type in re.findall( for pattern, param_type in re.findall(
re.compile(r"((?:<\w+:(string|int|number|alpha|uuid)>)+)"), re.compile(r"((?:<\w+:(str|int|float|alpha|uuid)>)+)"),
template, template,
): ):
value = TYPE_TO_GENERATOR_MAP.get(param_type)() value = TYPE_TO_GENERATOR_MAP.get(param_type)()
@ -141,5 +139,33 @@ def url_param_generator():
@pytest.fixture(scope="function") @pytest.fixture(scope="function")
def app(request): def app(request):
if not CACHE:
for target, method_name in TouchUp._registry:
CACHE[method_name] = getattr(target, method_name)
app = Sanic(slugify.sub("-", request.node.name)) app = Sanic(slugify.sub("-", request.node.name))
return app yield app
for target, method_name in TouchUp._registry:
setattr(target, method_name, CACHE[method_name])
@pytest.fixture(scope="function")
def run_startup(caplog):
def run(app):
nonlocal caplog
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
with caplog.at_level(logging.DEBUG):
server = app.create_server(
debug=True, return_asyncio_server=True, port=PORT
)
loop._stopping = False
_server = loop.run_until_complete(server)
_server.close()
loop.run_until_complete(_server.wait_closed())
app.stop()
return caplog.record_tuples
return run

View File

@ -89,7 +89,7 @@ def test_debug(cmd):
out, err, exitcode = capture(command) out, err, exitcode = capture(command)
lines = out.split(b"\n") lines = out.split(b"\n")
app_info = lines[9] app_info = lines[26]
info = json.loads(app_info) info = json.loads(app_info)
assert (b"\n".join(lines[:6])).decode("utf-8") == BASE_LOGO assert (b"\n".join(lines[:6])).decode("utf-8") == BASE_LOGO
@ -103,7 +103,7 @@ def test_auto_reload(cmd):
out, err, exitcode = capture(command) out, err, exitcode = capture(command)
lines = out.split(b"\n") lines = out.split(b"\n")
app_info = lines[9] app_info = lines[26]
info = json.loads(app_info) info = json.loads(app_info)
assert info["debug"] is False assert info["debug"] is False
@ -118,7 +118,7 @@ def test_access_logs(cmd, expected):
out, err, exitcode = capture(command) out, err, exitcode = capture(command)
lines = out.split(b"\n") lines = out.split(b"\n")
app_info = lines[9] app_info = lines[26]
info = json.loads(app_info) info = json.loads(app_info)
assert info["access_log"] is expected assert info["access_log"] is expected

View File

@ -1,6 +1,5 @@
import asyncio import asyncio
from queue import Queue
from threading import Event from threading import Event
from sanic.response import text from sanic.response import text
@ -13,8 +12,6 @@ def test_create_task(app):
await asyncio.sleep(0.05) await asyncio.sleep(0.05)
e.set() e.set()
app.add_task(coro)
@app.route("/early") @app.route("/early")
def not_set(request): def not_set(request):
return text(str(e.is_set())) return text(str(e.is_set()))
@ -24,24 +21,30 @@ def test_create_task(app):
await asyncio.sleep(0.1) await asyncio.sleep(0.1)
return text(str(e.is_set())) return text(str(e.is_set()))
app.add_task(coro)
request, response = app.test_client.get("/early") request, response = app.test_client.get("/early")
assert response.body == b"False" assert response.body == b"False"
app.signal_router.reset()
app.add_task(coro)
request, response = app.test_client.get("/late") request, response = app.test_client.get("/late")
assert response.body == b"True" assert response.body == b"True"
def test_create_task_with_app_arg(app): def test_create_task_with_app_arg(app):
q = Queue() @app.after_server_start
async def setup_q(app, _):
app.ctx.q = asyncio.Queue()
@app.route("/") @app.route("/")
def not_set(request): async def not_set(request):
return "hello" return text(await request.app.ctx.q.get())
async def coro(app): async def coro(app):
q.put(app.name) await app.ctx.q.put(app.name)
app.add_task(coro) app.add_task(coro)
request, response = app.test_client.get("/") _, response = app.test_client.get("/")
assert q.get() == "test_create_task_with_app_arg" assert response.text == "test_create_task_with_app_arg"

View File

@ -127,7 +127,6 @@ def test_html_traceback_output_in_debug_mode():
soup = BeautifulSoup(response.body, "html.parser") soup = BeautifulSoup(response.body, "html.parser")
html = str(soup) html = str(soup)
assert "response = handler(request, **kwargs)" in html
assert "handler_4" in html assert "handler_4" in html
assert "foo = bar" in html assert "foo = bar" in html
@ -151,7 +150,6 @@ def test_chained_exception_handler():
soup = BeautifulSoup(response.body, "html.parser") soup = BeautifulSoup(response.body, "html.parser")
html = str(soup) html = str(soup)
assert "response = handler(request, **kwargs)" in html
assert "handler_6" in html assert "handler_6" in html
assert "foo = 1 / arg" in html assert "foo = 1 / arg" in html
assert "ValueError" in html assert "ValueError" in html

View File

@ -2,16 +2,13 @@ import asyncio
import platform import platform
from asyncio import sleep as aio_sleep from asyncio import sleep as aio_sleep
from json import JSONDecodeError
from os import environ from os import environ
import httpcore
import httpx
import pytest import pytest
from sanic_testing.testing import HOST, SanicTestClient from sanic_testing.reusable import ReusableClient
from sanic import Sanic, server from sanic import Sanic
from sanic.compat import OS_IS_WINDOWS from sanic.compat import OS_IS_WINDOWS
from sanic.response import text from sanic.response import text
@ -21,164 +18,6 @@ CONFIG_FOR_TESTS = {"KEEP_ALIVE_TIMEOUT": 2, "KEEP_ALIVE": True}
PORT = 42101 # test_keep_alive_timeout_reuse doesn't work with random port PORT = 42101 # test_keep_alive_timeout_reuse doesn't work with random port
class ReusableSanicConnectionPool(httpcore.AsyncConnectionPool):
last_reused_connection = None
async def _get_connection_from_pool(self, *args, **kwargs):
conn = await super()._get_connection_from_pool(*args, **kwargs)
self.__class__.last_reused_connection = conn
return conn
class ResusableSanicSession(httpx.AsyncClient):
def __init__(self, *args, **kwargs) -> None:
transport = ReusableSanicConnectionPool()
super().__init__(transport=transport, *args, **kwargs)
class ReuseableSanicTestClient(SanicTestClient):
def __init__(self, app, loop=None):
super().__init__(app)
if loop is None:
loop = asyncio.get_event_loop()
self._loop = loop
self._server = None
self._tcp_connector = None
self._session = None
def get_new_session(self):
return ResusableSanicSession()
# Copied from SanicTestClient, but with some changes to reuse the
# same loop for the same app.
def _sanic_endpoint_test(
self,
method="get",
uri="/",
gather_request=True,
debug=False,
server_kwargs=None,
*request_args,
**request_kwargs,
):
loop = self._loop
results = [None, None]
exceptions = []
server_kwargs = server_kwargs or {"return_asyncio_server": True}
if gather_request:
def _collect_request(request):
if results[0] is None:
results[0] = request
self.app.request_middleware.appendleft(_collect_request)
if uri.startswith(
("http:", "https:", "ftp:", "ftps://", "//", "ws:", "wss:")
):
url = uri
else:
uri = uri if uri.startswith("/") else f"/{uri}"
scheme = "http"
url = f"{scheme}://{HOST}:{PORT}{uri}"
@self.app.listener("after_server_start")
async def _collect_response(loop):
try:
response = await self._local_request(
method, url, *request_args, **request_kwargs
)
results[-1] = response
except Exception as e2:
exceptions.append(e2)
if self._server is not None:
_server = self._server
else:
_server_co = self.app.create_server(
host=HOST, debug=debug, port=PORT, **server_kwargs
)
server.trigger_events(
self.app.listeners["before_server_start"], loop
)
try:
loop._stopping = False
_server = loop.run_until_complete(_server_co)
except Exception as e1:
raise e1
self._server = _server
server.trigger_events(self.app.listeners["after_server_start"], loop)
self.app.listeners["after_server_start"].pop()
if exceptions:
raise ValueError(f"Exception during request: {exceptions}")
if gather_request:
self.app.request_middleware.pop()
try:
request, response = results
return request, response
except Exception:
raise ValueError(
f"Request and response object expected, got ({results})"
)
else:
try:
return results[-1]
except Exception:
raise ValueError(f"Request object expected, got ({results})")
def kill_server(self):
try:
if self._server:
self._server.close()
self._loop.run_until_complete(self._server.wait_closed())
self._server = None
if self._session:
self._loop.run_until_complete(self._session.aclose())
self._session = None
except Exception as e3:
raise e3
# Copied from SanicTestClient, but with some changes to reuse the
# same TCPConnection and the sane ClientSession more than once.
# Note, you cannot use the same session if you are in a _different_
# loop, so the changes above are required too.
async def _local_request(self, method, url, *args, **kwargs):
raw_cookies = kwargs.pop("raw_cookies", None)
request_keepalive = kwargs.pop(
"request_keepalive", CONFIG_FOR_TESTS["KEEP_ALIVE_TIMEOUT"]
)
if not self._session:
self._session = self.get_new_session()
try:
response = await getattr(self._session, method.lower())(
url, timeout=request_keepalive, *args, **kwargs
)
except NameError:
raise Exception(response.status_code)
try:
response.json = response.json()
except (JSONDecodeError, UnicodeDecodeError):
response.json = None
response.body = await response.aread()
response.status = response.status_code
response.content_type = response.headers.get("content-type")
if raw_cookies:
response.raw_cookies = {}
for cookie in response.cookies:
response.raw_cookies[cookie.name] = cookie
return response
keep_alive_timeout_app_reuse = Sanic("test_ka_timeout_reuse") keep_alive_timeout_app_reuse = Sanic("test_ka_timeout_reuse")
keep_alive_app_client_timeout = Sanic("test_ka_client_timeout") keep_alive_app_client_timeout = Sanic("test_ka_client_timeout")
keep_alive_app_server_timeout = Sanic("test_ka_server_timeout") keep_alive_app_server_timeout = Sanic("test_ka_server_timeout")
@ -224,21 +63,22 @@ def test_keep_alive_timeout_reuse():
"""If the server keep-alive timeout and client keep-alive timeout are """If the server keep-alive timeout and client keep-alive timeout are
both longer than the delay, the client _and_ server will successfully both longer than the delay, the client _and_ server will successfully
reuse the existing connection.""" reuse the existing connection."""
try: loop = asyncio.new_event_loop()
loop = asyncio.new_event_loop() asyncio.set_event_loop(loop)
asyncio.set_event_loop(loop) client = ReusableClient(keep_alive_timeout_app_reuse, loop=loop, port=PORT)
client = ReuseableSanicTestClient(keep_alive_timeout_app_reuse, loop) with client:
headers = {"Connection": "keep-alive"} headers = {"Connection": "keep-alive"}
request, response = client.get("/1", headers=headers) request, response = client.get("/1", headers=headers)
assert response.status == 200 assert response.status == 200
assert response.text == "OK" assert response.text == "OK"
assert request.protocol.state["requests_count"] == 1
loop.run_until_complete(aio_sleep(1)) loop.run_until_complete(aio_sleep(1))
request, response = client.get("/1") request, response = client.get("/1")
assert response.status == 200 assert response.status == 200
assert response.text == "OK" assert response.text == "OK"
assert ReusableSanicConnectionPool.last_reused_connection assert request.protocol.state["requests_count"] == 2
finally:
client.kill_server()
@pytest.mark.skipif( @pytest.mark.skipif(
@ -250,22 +90,22 @@ def test_keep_alive_timeout_reuse():
def test_keep_alive_client_timeout(): def test_keep_alive_client_timeout():
"""If the server keep-alive timeout is longer than the client """If the server keep-alive timeout is longer than the client
keep-alive timeout, client will try to create a new connection here.""" keep-alive timeout, client will try to create a new connection here."""
try: loop = asyncio.new_event_loop()
loop = asyncio.new_event_loop() asyncio.set_event_loop(loop)
asyncio.set_event_loop(loop) client = ReusableClient(
client = ReuseableSanicTestClient(keep_alive_app_client_timeout, loop) keep_alive_app_client_timeout, loop=loop, port=PORT
)
with client:
headers = {"Connection": "keep-alive"} headers = {"Connection": "keep-alive"}
_, response = client.get("/1", headers=headers, request_keepalive=1) request, response = client.get("/1", headers=headers, timeout=1)
assert response.status == 200 assert response.status == 200
assert response.text == "OK" assert response.text == "OK"
assert request.protocol.state["requests_count"] == 1
loop.run_until_complete(aio_sleep(2)) loop.run_until_complete(aio_sleep(2))
_, response = client.get("/1", request_keepalive=1) request, response = client.get("/1", timeout=1)
assert request.protocol.state["requests_count"] == 1
assert ReusableSanicConnectionPool.last_reused_connection is None
finally:
client.kill_server()
@pytest.mark.skipif( @pytest.mark.skipif(
@ -277,22 +117,23 @@ def test_keep_alive_server_timeout():
keep-alive timeout, the client will either a 'Connection reset' error keep-alive timeout, the client will either a 'Connection reset' error
_or_ a new connection. Depending on how the event-loop handles the _or_ a new connection. Depending on how the event-loop handles the
broken server connection.""" broken server connection."""
try: loop = asyncio.new_event_loop()
loop = asyncio.new_event_loop() asyncio.set_event_loop(loop)
asyncio.set_event_loop(loop) client = ReusableClient(
client = ReuseableSanicTestClient(keep_alive_app_server_timeout, loop) keep_alive_app_server_timeout, loop=loop, port=PORT
)
with client:
headers = {"Connection": "keep-alive"} headers = {"Connection": "keep-alive"}
_, response = client.get("/1", headers=headers, request_keepalive=60) request, response = client.get("/1", headers=headers, timeout=60)
assert response.status == 200 assert response.status == 200
assert response.text == "OK" assert response.text == "OK"
assert request.protocol.state["requests_count"] == 1
loop.run_until_complete(aio_sleep(3)) loop.run_until_complete(aio_sleep(3))
_, response = client.get("/1", request_keepalive=60) request, response = client.get("/1", timeout=60)
assert ReusableSanicConnectionPool.last_reused_connection is None assert request.protocol.state["requests_count"] == 1
finally:
client.kill_server()
@pytest.mark.skipif( @pytest.mark.skipif(
@ -300,10 +141,10 @@ def test_keep_alive_server_timeout():
reason="Not testable with current client", reason="Not testable with current client",
) )
def test_keep_alive_connection_context(): def test_keep_alive_connection_context():
try: loop = asyncio.new_event_loop()
loop = asyncio.new_event_loop() asyncio.set_event_loop(loop)
asyncio.set_event_loop(loop) client = ReusableClient(keep_alive_app_context, loop=loop, port=PORT)
client = ReuseableSanicTestClient(keep_alive_app_context, loop) with client:
headers = {"Connection": "keep-alive"} headers = {"Connection": "keep-alive"}
request1, _ = client.post("/ctx", headers=headers) request1, _ = client.post("/ctx", headers=headers)
@ -315,5 +156,4 @@ def test_keep_alive_connection_context():
assert ( assert (
request1.conn_info.ctx.foo == request2.conn_info.ctx.foo == "hello" request1.conn_info.ctx.foo == request2.conn_info.ctx.foo == "hello"
) )
finally: assert request2.protocol.state["requests_count"] == 2
client.kill_server()

View File

@ -6,85 +6,37 @@ from sanic_testing.testing import PORT
from sanic.config import BASE_LOGO from sanic.config import BASE_LOGO
def test_logo_base(app, caplog): def test_logo_base(app, run_startup):
server = app.create_server( logs = run_startup(app)
debug=True, return_asyncio_server=True, port=PORT
)
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop._stopping = False
with caplog.at_level(logging.DEBUG): assert logs[0][1] == logging.DEBUG
_server = loop.run_until_complete(server) assert logs[0][2] == BASE_LOGO
_server.close()
loop.run_until_complete(_server.wait_closed())
app.stop()
assert caplog.record_tuples[0][1] == logging.DEBUG
assert caplog.record_tuples[0][2] == BASE_LOGO
def test_logo_false(app, caplog): def test_logo_false(app, caplog, run_startup):
app.config.LOGO = False app.config.LOGO = False
server = app.create_server( logs = run_startup(app)
debug=True, return_asyncio_server=True, port=PORT
)
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop._stopping = False
with caplog.at_level(logging.DEBUG): banner, port = logs[0][2].rsplit(":", 1)
_server = loop.run_until_complete(server) assert logs[0][1] == logging.INFO
_server.close()
loop.run_until_complete(_server.wait_closed())
app.stop()
banner, port = caplog.record_tuples[0][2].rsplit(":", 1)
assert caplog.record_tuples[0][1] == logging.INFO
assert banner == "Goin' Fast @ http://127.0.0.1" assert banner == "Goin' Fast @ http://127.0.0.1"
assert int(port) > 0 assert int(port) > 0
def test_logo_true(app, caplog): def test_logo_true(app, run_startup):
app.config.LOGO = True app.config.LOGO = True
server = app.create_server( logs = run_startup(app)
debug=True, return_asyncio_server=True, port=PORT
)
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop._stopping = False
with caplog.at_level(logging.DEBUG): assert logs[0][1] == logging.DEBUG
_server = loop.run_until_complete(server) assert logs[0][2] == BASE_LOGO
_server.close()
loop.run_until_complete(_server.wait_closed())
app.stop()
assert caplog.record_tuples[0][1] == logging.DEBUG
assert caplog.record_tuples[0][2] == BASE_LOGO
def test_logo_custom(app, caplog): def test_logo_custom(app, run_startup):
app.config.LOGO = "My Custom Logo" app.config.LOGO = "My Custom Logo"
server = app.create_server( logs = run_startup(app)
debug=True, return_asyncio_server=True, port=PORT
)
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop._stopping = False
with caplog.at_level(logging.DEBUG): assert logs[0][1] == logging.DEBUG
_server = loop.run_until_complete(server) assert logs[0][2] == "My Custom Logo"
_server.close()
loop.run_until_complete(_server.wait_closed())
app.stop()
assert caplog.record_tuples[0][1] == logging.DEBUG
assert caplog.record_tuples[0][2] == "My Custom Logo"

View File

@ -2,6 +2,7 @@ import asyncio
import httpcore import httpcore
import httpx import httpx
import pytest
from sanic_testing.testing import SanicTestClient from sanic_testing.testing import SanicTestClient
@ -48,42 +49,51 @@ class DelayableSanicTestClient(SanicTestClient):
return DelayableSanicSession(request_delay=self._request_delay) return DelayableSanicSession(request_delay=self._request_delay)
request_timeout_default_app = Sanic("test_request_timeout_default") @pytest.fixture
request_no_timeout_app = Sanic("test_request_no_timeout") def request_no_timeout_app():
request_timeout_default_app.config.REQUEST_TIMEOUT = 0.6 app = Sanic("test_request_no_timeout")
request_no_timeout_app.config.REQUEST_TIMEOUT = 0.6 app.config.REQUEST_TIMEOUT = 0.6
@app.route("/1")
async def handler2(request):
return text("OK")
return app
@request_timeout_default_app.route("/1") @pytest.fixture
async def handler1(request): def request_timeout_default_app():
return text("OK") app = Sanic("test_request_timeout_default")
app.config.REQUEST_TIMEOUT = 0.6
@app.route("/1")
async def handler1(request):
return text("OK")
@app.websocket("/ws1")
async def ws_handler1(request, ws):
await ws.send("OK")
return app
@request_no_timeout_app.route("/1") def test_default_server_error_request_timeout(request_timeout_default_app):
async def handler2(request):
return text("OK")
@request_timeout_default_app.websocket("/ws1")
async def ws_handler1(request, ws):
await ws.send("OK")
def test_default_server_error_request_timeout():
client = DelayableSanicTestClient(request_timeout_default_app, 2) client = DelayableSanicTestClient(request_timeout_default_app, 2)
request, response = client.get("/1") _, response = client.get("/1")
assert response.status == 408 assert response.status == 408
assert "Request Timeout" in response.text assert "Request Timeout" in response.text
def test_default_server_error_request_dont_timeout(): def test_default_server_error_request_dont_timeout(request_no_timeout_app):
client = DelayableSanicTestClient(request_no_timeout_app, 0.2) client = DelayableSanicTestClient(request_no_timeout_app, 0.2)
request, response = client.get("/1") _, response = client.get("/1")
assert response.status == 200 assert response.status == 200
assert response.text == "OK" assert response.text == "OK"
def test_default_server_error_websocket_request_timeout(): def test_default_server_error_websocket_request_timeout(
request_timeout_default_app,
):
headers = { headers = {
"Upgrade": "websocket", "Upgrade": "websocket",
@ -93,7 +103,7 @@ def test_default_server_error_websocket_request_timeout():
} }
client = DelayableSanicTestClient(request_timeout_default_app, 2) client = DelayableSanicTestClient(request_timeout_default_app, 2)
request, response = client.get("/ws1", headers=headers) _, response = client.get("/ws1", headers=headers)
assert response.status == 408 assert response.status == 408
assert "Request Timeout" in response.text assert "Request Timeout" in response.text

View File

@ -654,41 +654,46 @@ def test_websocket_route_invalid_handler(app):
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize("url", ["/ws", "ws"]) @pytest.mark.parametrize("url", ["/ws", "ws"])
async def test_websocket_route_asgi(app, url): async def test_websocket_route_asgi(app, url):
ev = asyncio.Event() @app.after_server_start
async def setup_ev(app, _):
app.ctx.ev = asyncio.Event()
@app.websocket(url) @app.websocket(url)
async def handler(request, ws): async def handler(request, ws):
ev.set() request.app.ctx.ev.set()
request, response = await app.asgi_client.websocket(url) @app.get("/ev")
assert ev.is_set() async def check(request):
return json({"set": request.app.ctx.ev.is_set()})
_, response = await app.asgi_client.websocket(url)
_, response = await app.asgi_client.get("/")
assert response.json["set"]
def test_websocket_route_with_subprotocols(app): @pytest.mark.parametrize(
results = [] "subprotocols,expected",
(
(["bar"], "bar"),
(["bar", "foo"], "bar"),
(["baz"], None),
(None, None),
),
)
def test_websocket_route_with_subprotocols(app, subprotocols, expected):
results = "unset"
@app.websocket("/ws", subprotocols=["foo", "bar"]) @app.websocket("/ws", subprotocols=["foo", "bar"])
async def handler(request, ws): async def handler(request, ws):
results.append(ws.subprotocol) nonlocal results
results = ws.subprotocol
assert ws.subprotocol is not None assert ws.subprotocol is not None
_, response = SanicTestClient(app).websocket("/ws", subprotocols=["bar"])
assert response.opened is True
assert results == ["bar"]
_, response = SanicTestClient(app).websocket( _, response = SanicTestClient(app).websocket(
"/ws", subprotocols=["bar", "foo"] "/ws", subprotocols=subprotocols
) )
assert response.opened is True assert response.opened is True
assert results == ["bar", "bar"] assert results == expected
_, response = SanicTestClient(app).websocket("/ws", subprotocols=["baz"])
assert response.opened is True
assert results == ["bar", "bar", None]
_, response = SanicTestClient(app).websocket("/ws")
assert response.opened is True
assert results == ["bar", "bar", None, None]
@pytest.mark.parametrize("strict_slashes", [True, False, None]) @pytest.mark.parametrize("strict_slashes", [True, False, None])

View File

@ -8,7 +8,7 @@ import pytest
from sanic_testing.testing import HOST, PORT from sanic_testing.testing import HOST, PORT
from sanic.exceptions import InvalidUsage from sanic.exceptions import InvalidUsage, SanicException
AVAILABLE_LISTENERS = [ AVAILABLE_LISTENERS = [
@ -103,7 +103,11 @@ async def test_trigger_before_events_create_server(app):
async def init_db(app, loop): async def init_db(app, loop):
app.db = MySanicDb() app.db = MySanicDb()
await app.create_server(debug=True, return_asyncio_server=True, port=PORT) srv = await app.create_server(
debug=True, return_asyncio_server=True, port=PORT
)
await srv.startup()
await srv.before_start()
assert hasattr(app, "db") assert hasattr(app, "db")
assert isinstance(app.db, MySanicDb) assert isinstance(app.db, MySanicDb)
@ -157,14 +161,15 @@ def test_create_server_trigger_events(app):
serv_coro = app.create_server(return_asyncio_server=True, sock=sock) serv_coro = app.create_server(return_asyncio_server=True, sock=sock)
serv_task = asyncio.ensure_future(serv_coro, loop=loop) serv_task = asyncio.ensure_future(serv_coro, loop=loop)
server = loop.run_until_complete(serv_task) server = loop.run_until_complete(serv_task)
server.after_start() loop.run_until_complete(server.startup())
loop.run_until_complete(server.after_start())
try: try:
loop.run_forever() loop.run_forever()
except KeyboardInterrupt as e: except KeyboardInterrupt:
loop.stop() loop.stop()
finally: finally:
# Run the on_stop function if provided # Run the on_stop function if provided
server.before_stop() loop.run_until_complete(server.before_stop())
# Wait for server to close # Wait for server to close
close_task = server.close() close_task = server.close()
@ -174,5 +179,19 @@ def test_create_server_trigger_events(app):
signal.stopped = True signal.stopped = True
for connection in server.connections: for connection in server.connections:
connection.close_if_idle() connection.close_if_idle()
server.after_stop() loop.run_until_complete(server.after_stop())
assert flag1 and flag2 and flag3 assert flag1 and flag2 and flag3
@pytest.mark.asyncio
async def test_missing_startup_raises_exception(app):
@app.listener("before_server_start")
async def init_db(app, loop):
...
srv = await app.create_server(
debug=True, return_asyncio_server=True, port=PORT
)
with pytest.raises(SanicException):
await srv.before_start()

View File

@ -7,6 +7,7 @@ from unittest.mock import MagicMock
import pytest import pytest
from sanic_testing.reusable import ReusableClient
from sanic_testing.testing import HOST, PORT from sanic_testing.testing import HOST, PORT
from sanic.compat import ctrlc_workaround_for_windows from sanic.compat import ctrlc_workaround_for_windows
@ -28,9 +29,13 @@ def set_loop(app, loop):
signal.signal = mock signal.signal = mock
else: else:
loop.add_signal_handler = mock loop.add_signal_handler = mock
print(">>>>>>>>>>>>>>>1", id(loop))
print(">>>>>>>>>>>>>>>1", loop.add_signal_handler)
def after(app, loop): def after(app, loop):
print(">>>>>>>>>>>>>>>2", id(loop))
print(">>>>>>>>>>>>>>>2", loop.add_signal_handler)
calledq.put(mock.called) calledq.put(mock.called)

View File

@ -68,6 +68,7 @@ async def test_dispatch_signal_triggers_multiple_handlers(app):
app.signal_router.finalize() app.signal_router.finalize()
assert len(app.signal_router.routes) == 3
await app.dispatch("foo.bar.baz") await app.dispatch("foo.bar.baz")
assert counter == 2 assert counter == 2
@ -331,7 +332,8 @@ def test_event_on_bp_not_registered():
"event,expected", "event,expected",
( (
("foo.bar.baz", True), ("foo.bar.baz", True),
("server.init.before", False), ("server.init.before", True),
("server.init.somethingelse", False),
("http.request.start", False), ("http.request.start", False),
("sanic.notice.anything", True), ("sanic.notice.anything", True),
), ),

21
tests/test_touchup.py Normal file
View File

@ -0,0 +1,21 @@
import logging
from sanic.signals import RESERVED_NAMESPACES
from sanic.touchup import TouchUp
def test_touchup_methods(app):
assert len(TouchUp._registry) == 9
async def test_ode_removes_dispatch_events(app, caplog):
with caplog.at_level(logging.DEBUG, logger="sanic.root"):
await app._startup()
logs = caplog.record_tuples
for signal in RESERVED_NAMESPACES["http"]:
assert (
"sanic.root",
logging.DEBUG,
f"Disabling event: {signal}",
) in logs

View File

@ -43,7 +43,15 @@ def test_routes_with_multiple_hosts(app):
) )
def test_websocket_bp_route_name(app): @pytest.mark.parametrize(
"name,expected",
(
("test_route", "/bp/route"),
("test_route2", "/bp/route2"),
("foobar_3", "/bp/route3"),
),
)
def test_websocket_bp_route_name(app, name, expected):
"""Tests that blueprint websocket route is named.""" """Tests that blueprint websocket route is named."""
event = asyncio.Event() event = asyncio.Event()
bp = Blueprint("test_bp", url_prefix="/bp") bp = Blueprint("test_bp", url_prefix="/bp")
@ -69,22 +77,12 @@ def test_websocket_bp_route_name(app):
uri = app.url_for("test_bp.main") uri = app.url_for("test_bp.main")
assert uri == "/bp/main" assert uri == "/bp/main"
uri = app.url_for("test_bp.test_route") uri = app.url_for(f"test_bp.{name}")
assert uri == "/bp/route" assert uri == expected
request, response = SanicTestClient(app).websocket(uri) request, response = SanicTestClient(app).websocket(uri)
assert response.opened is True assert response.opened is True
assert event.is_set() assert event.is_set()
event.clear()
uri = app.url_for("test_bp.test_route2")
assert uri == "/bp/route2"
request, response = SanicTestClient(app).websocket(uri)
assert response.opened is True
assert event.is_set()
uri = app.url_for("test_bp.foobar_3")
assert uri == "/bp/route3"
# TODO: add test with a route with multiple hosts # TODO: add test with a route with multiple hosts
# TODO: add test with a route with _host in url_for # TODO: add test with a route with _host in url_for

View File

@ -10,7 +10,7 @@ extras = test
commands = commands =
pytest {posargs:tests --cov sanic} pytest {posargs:tests --cov sanic}
- coverage combine --append - coverage combine --append
coverage report -m coverage report -m -i
coverage html -i coverage html -i
[testenv:lint] [testenv:lint]