Compare commits

...

41 Commits

Author SHA1 Message Date
Adam Hopkins
38b4ccf2bc Cleanup implementation 2022-09-19 21:34:50 +03:00
Adam Hopkins
8b970dd490 Merge branch 'main' of github.com:sanic-org/sanic into middleware-revamp 2022-09-19 16:04:38 +03:00
Adam Hopkins
389363ab71 Better request cancel handling (#2513) 2022-09-19 16:04:09 +03:00
Adam Hopkins
c9be17e8da Merge conflicts 2022-09-18 23:48:06 +03:00
Adam Hopkins
7f894c45b3 Add deprecation warning filter (#2546) 2022-09-18 18:54:35 +03:00
Adam Hopkins
4726cf1910 Sanic Server WorkerManager refactor (#2499)
Co-authored-by: Néstor Pérez <25409753+prryplatypus@users.noreply.github.com>
2022-09-18 17:17:23 +03:00
Adam Hopkins
19f642b364 Add to tests 2022-09-15 18:46:09 +03:00
Adam Hopkins
c4c39cb082 Merge branch 'main' of github.com:sanic-org/sanic into middleware-revamp 2022-09-15 18:33:22 +03:00
Adam Hopkins
d352a4155e Add signals before and after handler execution (#2540) 2022-09-15 15:49:21 +03:00
Adam Hopkins
e5010286b4 Raise warning and deprecation notice on violations (#2537) 2022-09-15 15:24:46 +03:00
Adam Hopkins
358498db96 Do not apply double slash to Blueprint and static dirs (#2515) 2022-09-15 14:43:20 +03:00
monosans
e4999401ab Improve and fix some type annotations (#2536) 2022-09-13 08:53:48 +03:00
Néstor Pérez
c8df0aa2cb Fix easter egg through CLI (#2542) 2022-09-12 01:44:21 +03:00
Adam Hopkins
5fb207176b Update bug_report.md 2022-08-29 12:47:01 +03:00
Adam Hopkins
a12b560478 Update feature_request.md 2022-08-29 12:46:39 +03:00
Adam Hopkins
c7bac72137 WIP 2022-08-20 22:24:43 +03:00
Zhiwei
753ee992a6 Validate File When Requested (#2526)
Co-authored-by: Adam Hopkins <adam@amhopkins.com>
2022-08-18 12:05:05 +03:00
Hunt Zhan
09089b1bd3 Resolve issue 2529 (#2530) 2022-08-18 08:58:07 +03:00
Adam Hopkins
beb5c62767 Add global middleware ordering 2022-08-17 21:57:07 +03:00
Adam Hopkins
09b59d34fe Fix typing error 2022-08-17 15:26:59 +03:00
Adam Hopkins
78bc475bb1 Add test case 2022-08-17 15:23:30 +03:00
Adam Hopkins
b59131504b Merge branch 'main' into middleware-revamp 2022-08-17 14:17:34 +03:00
Adam Hopkins
7ddbe5e844 Update config.yml 2022-08-17 10:44:02 +03:00
Adam Hopkins
ab5a7038af Update SECURITY.md 2022-08-17 10:42:22 +03:00
Adam Hopkins
4f3c780dc3 Update feature_request.md 2022-08-17 10:38:15 +03:00
Adam Hopkins
71f7765a4c Update bug_report.md 2022-08-17 10:37:35 +03:00
Adam Hopkins
0392d1dcfc Always show server location in ASGI (#2522)
Co-authored-by: Adam Hopkins <adam@amhopkins.com>
Co-authored-by: Zhiwei Liang <zhi.wei.liang@outlook.com>
Co-authored-by: Néstor Pérez <25409753+prryplatypus@users.noreply.github.com>
2022-08-11 10:00:35 +03:00
Adam Hopkins
7827b1b41d Add Request properties for HTTP method info (#2516) 2022-08-10 21:12:09 +03:00
Adam Hopkins
8e9342e188 Warn on duplicate route names (#2525) 2022-08-10 20:36:47 +03:00
Adam Hopkins
782e0881e5 Slots to Middleware 2022-08-07 22:38:25 +03:00
Adam Hopkins
c72cbe4326 Begin middleware revamp 2022-08-07 22:31:26 +03:00
Adam Hopkins
2f6f2bfa76 Rename code of conduct 2022-08-02 13:56:53 +03:00
Ryu Juheon
dee09d7fff style: add some type hints (#2517) 2022-08-02 08:47:59 +03:00
Adam Hopkins
9cf38a0a83 MERGEBACK (#2495) (#2512)
Co-authored-by: Adam Hopkins <adam@amhopkins.com>
Co-authored-by: Zhiwei Liang <zhi.wei.liang@outlook.com>
Co-authored-by: Néstor Pérez <25409753+prryplatypus@users.noreply.github.com>
2022-07-31 15:50:46 +03:00
Adam Hopkins
3def3d3569 Use path.parts instead of match (#2508) 2022-07-31 12:54:42 +03:00
Adam Hopkins
e100a14fd4 Use pathlib for path resolution (#2506) 2022-07-31 08:49:02 +03:00
Adam Hopkins
2fa28f1711 Fix dotted test 2022-07-28 10:07:30 +03:00
Néstor Pérez
9d415e4ec6 Prevent directory traversion with static files (#2495)
Co-authored-by: Adam Hopkins <adam@amhopkins.com>
Co-authored-by: Zhiwei Liang <zhi.wei.liang@outlook.com>
2022-07-28 09:45:45 +03:00
Tim Gates
312ab298fd docs: Fix a few typos (#2502)
Co-authored-by: Adam Hopkins <adam@amhopkins.com>
2022-07-24 22:47:39 +03:00
Zhiwei
2fc21ad576 Replace Unsupported Python Version Number from the Contributing Doc (#2505)
Co-authored-by: Adam Hopkins <adam@amhopkins.com>
2022-07-24 22:33:05 +03:00
Zhiwei
8f6c87c3d6 Fix Test Cases: test_http for Py3.9+, test_json_response_json for ujson 5.4.0+, and test_zero_downtime; Test Case Type Annotations (#2504) 2022-07-24 22:07:54 +03:00
108 changed files with 5389 additions and 1487 deletions

View File

@@ -4,8 +4,8 @@ source = sanic
omit =
site-packages
sanic/__main__.py
sanic/server/legacy.py
sanic/compat.py
sanic/reloader_helpers.py
sanic/simple.py
sanic/utils.py
sanic/cli
@@ -21,12 +21,4 @@ exclude_lines =
NOQA
pragma: no cover
TYPE_CHECKING
omit =
site-packages
sanic/__main__.py
sanic/compat.py
sanic/reloader_helpers.py
sanic/simple.py
sanic/utils.py
sanic/cli
skip_empty = True

View File

@@ -1,25 +1,27 @@
---
name: Bug report
about: Create a report to help us improve
labels: ["bug"]
---
**Describe the bug**
A clear and concise description of what the bug is, make sure to paste any exceptions and tracebacks.
<!-- A clear and concise description of what the bug is, make sure to paste any exceptions and tracebacks. -->
**Code snippet**
Relevant source code, make sure to remove what is not necessary.
<!-- Relevant source code, make sure to remove what is not necessary. -->
**Expected behavior**
A clear and concise description of what you expected to happen.
<!-- A clear and concise description of what you expected to happen. -->
**Environment (please complete the following information):**
- OS: [e.g. iOS]
- Version [e.g. 0.8.3]
<!-- Please provide the information below. Instead, you can copy and paste the message that Sanic shows on startup. If you do, please remember to format it with ``` -->
- OS:
- Sanic Version:
**Additional context**
Add any other context about the problem here.
<!-- Add any other context about the problem here. -->

View File

@@ -3,3 +3,6 @@ contact_links:
- name: Questions and Help
url: https://community.sanicframework.org/c/questions-and-help
about: Do you need help with Sanic? Ask your questions here.
- name: Discussion and Support
url: https://discord.gg/FARQzAEMAA
about: For live discussion and support, checkout the Sanic Discord server.

View File

@@ -1,16 +1,17 @@
---
name: Feature request
about: Suggest an idea for Sanic
labels: ["feature request"]
---
**Is your feature request related to a problem? Please describe.**
A clear and concise description of what the problem is. Ex. I'm always frustrated when [...]
<!-- A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] -->
**Describe the solution you'd like**
A clear and concise description of what you want to happen.
<!-- A clear and concise description of what you want to happen. -->
**Additional context**
Add any other context or sample code about the feature request here.
<!-- Add any other context or sample code about the feature request here. -->

View File

@@ -55,7 +55,7 @@ further defined and clarified by project maintainers.
## Enforcement
Instances of abusive, harassing, or otherwise unacceptable behavior may be
reported by contacting the project team at sanic-maintainers@googlegroups.com. All
reported by contacting the project team at adam@sanicframework.org. All
complaints will be reviewed and investigated and will result in a response that
is deemed necessary and appropriate to the circumstances. The project team is
obligated to maintain confidentiality with regard to the reporter of an incident.

View File

@@ -71,9 +71,9 @@ To execute only unittests, run ``tox`` with environment like so:
.. code-block:: bash
tox -e py36 -v -- tests/test_config.py
# or
tox -e py37 -v -- tests/test_config.py
# or
tox -e py310 -v -- tests/test_config.py
Run lint checks
---------------

View File

@@ -4,31 +4,40 @@
Sanic releases long term support release once a year in December. LTS releases receive bug and security updates for **24 months**. Interim releases throughout the year occur every three months, and are supported until the subsequent interim release.
| Version | LTS | Supported |
| ------- | ------------- | ------------------ |
| 20.12 | until 2022-12 | :heavy_check_mark: |
| 20.9 | | :x: |
| 20.6 | | :x: |
| 20.3 | | :x: |
| 19.12 | until 2021-12 | :white_check_mark: |
| 19.9 | | :x: |
| 19.6 | | :x: |
| 19.3 | | :x: |
| 18.12 | | :x: |
| 0.8.3 | | :x: |
| 0.7.0 | | :x: |
| 0.6.0 | | :x: |
| 0.5.4 | | :x: |
| 0.4.1 | | :x: |
| 0.3.1 | | :x: |
| 0.2.0 | | :x: |
| 0.1.9 | | :x: |
:white_check_mark: = security/bug fixes
:heavy_check_mark: = full support
| Version | LTS | Supported |
| ------- | ------------- | ----------------------- |
| 22.6 | | :white_check_mark: |
| 22.3 | | :x: |
| 21.12 | until 2023-12 | :white_check_mark: |
| 21.9 | | :x: |
| 21.6 | | :x: |
| 21.3 | | :x: |
| 20.12 | until 2022-12 | :ballot_box_with_check: |
| 20.9 | | :x: |
| 20.6 | | :x: |
| 20.3 | | :x: |
| 19.12 | | :x: |
| 19.9 | | :x: |
| 19.6 | | :x: |
| 19.3 | | :x: |
| 18.12 | | :x: |
| 0.8.3 | | :x: |
| 0.7.0 | | :x: |
| 0.6.0 | | :x: |
| 0.5.4 | | :x: |
| 0.4.1 | | :x: |
| 0.3.1 | | :x: |
| 0.2.0 | | :x: |
| 0.1.9 | | :x: |
:ballot_box_with_check: = security/bug fixes
:white_check_mark: = full support
## Reporting a Vulnerability
If you discover a security vulnerability, we ask that you **do not** create an issue on GitHub. Instead, please [send a message to the core-devs](https://community.sanicframework.org/g/core-devs) on the community forums. Once logged in, you can send a message to the core-devs by clicking the message button.
Alternatively, you can send a private message to Adam Hopkins on Discord. Find him on the [Sanic discord server](https://discord.gg/FARQzAEMAA).
This will help to not publicize the issue until the team can address it and resolve it.

View File

@@ -4,6 +4,7 @@ coverage:
default:
target: auto
threshold: 0.75
informational: true
project:
default:
target: auto
@@ -14,7 +15,6 @@ codecov:
ignore:
- "sanic/__main__.py"
- "sanic/compat.py"
- "sanic/reloader_helpers.py"
- "sanic/simple.py"
- "sanic/utils.py"
- "sanic/cli"

View File

@@ -1,6 +1,3 @@
import os
import socket
from sanic import Sanic, response
@@ -13,13 +10,4 @@ async def test(request):
if __name__ == "__main__":
server_address = "./uds_socket"
# Make sure the socket does not already exist
try:
os.unlink(server_address)
except OSError:
if os.path.exists(server_address):
raise
sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
sock.bind(server_address)
app.run(sock=sock)
app.run(unix="./uds_socket")

View File

@@ -3,7 +3,15 @@ from sanic.app import Sanic
from sanic.blueprints import Blueprint
from sanic.constants import HTTPMethod
from sanic.request import Request
from sanic.response import HTTPResponse, html, json, text
from sanic.response import (
HTTPResponse,
empty,
file,
html,
json,
redirect,
text,
)
from sanic.server.websockets.impl import WebsocketImplProtocol as Websocket
@@ -15,7 +23,10 @@ __all__ = (
"HTTPResponse",
"Request",
"Websocket",
"empty",
"file",
"html",
"json",
"redirect",
"text",
)

View File

@@ -6,10 +6,10 @@ if OS_IS_WINDOWS:
enable_windows_color_support()
def main():
def main(args=None):
cli = SanicCLI()
cli.attach()
cli.run()
cli.run(args)
if __name__ == "__main__":

View File

@@ -1 +1 @@
__version__ = "22.6.0"
__version__ = "22.9.1"

View File

@@ -19,8 +19,8 @@ from collections import defaultdict, deque
from contextlib import suppress
from functools import partial
from inspect import isawaitable
from os import environ
from socket import socket
from traceback import format_exc
from types import SimpleNamespace
from typing import (
TYPE_CHECKING,
@@ -41,7 +41,6 @@ from typing import (
Union,
)
from urllib.parse import urlencode, urlunparse
from warnings import filterwarnings
from sanic_routing.exceptions import FinalizationError, NotFound
from sanic_routing.route import Route
@@ -54,12 +53,7 @@ from sanic.blueprint_group import BlueprintGroup
from sanic.blueprints import Blueprint
from sanic.compat import OS_IS_WINDOWS, enable_windows_color_support
from sanic.config import SANIC_PREFIX, Config
from sanic.exceptions import (
BadRequest,
SanicException,
ServerError,
URLBuildError,
)
from sanic.exceptions import BadRequest, SanicException, URLBuildError
from sanic.handlers import ErrorHandler
from sanic.helpers import _default
from sanic.http import Stage
@@ -70,7 +64,7 @@ from sanic.log import (
logger,
)
from sanic.mixins.listeners import ListenerEvent
from sanic.mixins.runner import RunnerMixin
from sanic.mixins.startup import StartupMixin
from sanic.models.futures import (
FutureException,
FutureListener,
@@ -83,11 +77,14 @@ from sanic.models.futures import (
from sanic.models.handler_types import ListenerType, MiddlewareType
from sanic.models.handler_types import Sanic as SanicVar
from sanic.request import Request
from sanic.response import BaseHTTPResponse, HTTPResponse, ResponseStream
from sanic.response import BaseHTTPResponse
from sanic.router import Router
from sanic.server.websockets.impl import ConnectionClosed
from sanic.signals import Signal, SignalRouter
from sanic.touchup import TouchUp, TouchUpMeta
from sanic.types.shared_ctx import SharedContext
from sanic.worker.inspector import Inspector
from sanic.worker.manager import WorkerManager
if TYPE_CHECKING:
@@ -101,10 +98,8 @@ if TYPE_CHECKING:
if OS_IS_WINDOWS: # no cov
enable_windows_color_support()
filterwarnings("once", category=DeprecationWarning)
class Sanic(BaseSanic, RunnerMixin, metaclass=TouchUpMeta):
class Sanic(BaseSanic, StartupMixin, metaclass=TouchUpMeta):
"""
The main application instance
"""
@@ -128,6 +123,8 @@ class Sanic(BaseSanic, RunnerMixin, metaclass=TouchUpMeta):
"_future_routes",
"_future_signals",
"_future_statics",
"_inspector",
"_manager",
"_state",
"_task_registry",
"_test_client",
@@ -139,12 +136,14 @@ class Sanic(BaseSanic, RunnerMixin, metaclass=TouchUpMeta):
"error_handler",
"go_fast",
"listeners",
"multiplexer",
"named_request_middleware",
"named_response_middleware",
"request_class",
"request_middleware",
"response_middleware",
"router",
"shared_ctx",
"signal_router",
"sock",
"strict_slashes",
@@ -171,9 +170,9 @@ class Sanic(BaseSanic, RunnerMixin, metaclass=TouchUpMeta):
configure_logging: bool = True,
dumps: Optional[Callable[..., AnyStr]] = None,
loads: Optional[Callable[..., Any]] = None,
inspector: bool = False,
) -> None:
super().__init__(name=name)
# logging
if configure_logging:
dict_config = log_config or LOGGING_CONFIG_DEFAULTS
@@ -187,12 +186,16 @@ class Sanic(BaseSanic, RunnerMixin, metaclass=TouchUpMeta):
# First setup config
self.config: Config = config or Config(env_prefix=env_prefix)
if inspector:
self.config.INSPECTOR = inspector
# Then we can do the rest
self._asgi_client: Any = None
self._blueprint_order: List[Blueprint] = []
self._delayed_tasks: List[str] = []
self._future_registry: FutureRegistry = FutureRegistry()
self._inspector: Optional[Inspector] = None
self._manager: Optional[WorkerManager] = None
self._state: ApplicationState = ApplicationState(app=self)
self._task_registry: Dict[str, Task] = {}
self._test_client: Any = None
@@ -210,6 +213,7 @@ class Sanic(BaseSanic, RunnerMixin, metaclass=TouchUpMeta):
self.request_middleware: Deque[MiddlewareType] = deque()
self.response_middleware: Deque[MiddlewareType] = deque()
self.router: Router = router or Router()
self.shared_ctx: SharedContext = SharedContext()
self.signal_router: SignalRouter = signal_router or SignalRouter()
self.sock: Optional[socket] = None
self.strict_slashes: bool = strict_slashes
@@ -243,7 +247,7 @@ class Sanic(BaseSanic, RunnerMixin, metaclass=TouchUpMeta):
)
try:
return get_running_loop()
except RuntimeError:
except RuntimeError: # no cov
if sys.version_info > (3, 10):
return asyncio.get_event_loop_policy().get_event_loop()
else:
@@ -458,9 +462,7 @@ class Sanic(BaseSanic, RunnerMixin, metaclass=TouchUpMeta):
def blueprint(
self,
blueprint: Union[
Blueprint, List[Blueprint], Tuple[Blueprint], BlueprintGroup
],
blueprint: Union[Blueprint, Iterable[Blueprint], BlueprintGroup],
**options: Any,
):
"""Register a blueprint on the application.
@@ -469,7 +471,7 @@ class Sanic(BaseSanic, RunnerMixin, metaclass=TouchUpMeta):
:param options: option dictionary with blueprint defaults
:return: Nothing
"""
if isinstance(blueprint, (list, tuple, BlueprintGroup)):
if isinstance(blueprint, (Iterable, BlueprintGroup)):
for item in blueprint:
params = {**options}
if isinstance(blueprint, BlueprintGroup):
@@ -701,255 +703,15 @@ class Sanic(BaseSanic, RunnerMixin, metaclass=TouchUpMeta):
# -------------------------------------------------------------------- #
async def handle_exception(
self, request: Request, exception: BaseException
self,
request: Request,
exception: BaseException,
run_middleware: bool = True,
): # no cov
"""
A handler that catches specific exceptions and outputs a response.
:param request: The current request object
:param exception: The exception that was raised
:raises ServerError: response 500
"""
await self.dispatch(
"http.lifecycle.exception",
inline=True,
context={"request": request, "exception": exception},
)
if (
request.stream is not None
and request.stream.stage is not Stage.HANDLER
):
error_logger.exception(exception, exc_info=True)
logger.error(
"The error response will not be sent to the client for "
f'the following exception:"{exception}". A previous response '
"has at least partially been sent."
)
handler = self.error_handler._lookup(
exception, request.name if request else None
)
if handler:
logger.warning(
"An error occurred while handling the request after at "
"least some part of the response was sent to the client. "
"The response from your custom exception handler "
f"{handler.__name__} will not be sent to the client."
"Exception handlers should only be used to generate the "
"exception responses. If you would like to perform any "
"other action on a raised exception, consider using a "
"signal handler like "
'`@app.signal("http.lifecycle.exception")`\n'
"For further information, please see the docs: "
"https://sanicframework.org/en/guide/advanced/"
"signals.html",
)
return
# -------------------------------------------- #
# Request Middleware
# -------------------------------------------- #
response = await self._run_request_middleware(
request, request_name=None
)
# No middleware results
if not response:
try:
response = self.error_handler.response(request, exception)
if isawaitable(response):
response = await response
except Exception as e:
if isinstance(e, SanicException):
response = self.error_handler.default(request, e)
elif self.debug:
response = HTTPResponse(
(
f"Error while handling error: {e}\n"
f"Stack: {format_exc()}"
),
status=500,
)
else:
response = HTTPResponse(
"An error occurred while handling an error", status=500
)
if response is not None:
try:
request.reset_response()
response = await request.respond(response)
except BaseException:
# Skip response middleware
if request.stream:
request.stream.respond(response)
await response.send(end_stream=True)
raise
else:
if request.stream:
response = request.stream.response
# Marked for cleanup and DRY with handle_request/handle_exception
# when ResponseStream is no longer supporder
if isinstance(response, BaseHTTPResponse):
await self.dispatch(
"http.lifecycle.response",
inline=True,
context={
"request": request,
"response": response,
},
)
await response.send(end_stream=True)
elif isinstance(response, ResponseStream):
resp = await response(request)
await self.dispatch(
"http.lifecycle.response",
inline=True,
context={
"request": request,
"response": resp,
},
)
await response.eof()
else:
raise ServerError(
f"Invalid response type {response!r} (need HTTPResponse)"
)
raise NotImplementedError
async def handle_request(self, request: Request): # no cov
"""Take a request from the HTTP Server and return a response object
to be sent back The HTTP Server only expects a response object, so
exception handling must be done here
:param request: HTTP Request object
:return: Nothing
"""
await self.dispatch(
"http.lifecycle.handle",
inline=True,
context={"request": request},
)
# Define `response` var here to remove warnings about
# allocation before assignment below.
response = None
try:
await self.dispatch(
"http.routing.before",
inline=True,
context={"request": request},
)
# Fetch handler from router
route, handler, kwargs = self.router.get(
request.path,
request.method,
request.headers.getone("host", None),
)
request._match_info = {**kwargs}
request.route = route
await self.dispatch(
"http.routing.after",
inline=True,
context={
"request": request,
"route": route,
"kwargs": kwargs,
"handler": handler,
},
)
if (
request.stream
and request.stream.request_body
and not route.ctx.ignore_body
):
if hasattr(handler, "is_stream"):
# Streaming handler: lift the size limit
request.stream.request_max_size = float("inf")
else:
# Non-streaming handler: preload body
await request.receive_body()
# -------------------------------------------- #
# Request Middleware
# -------------------------------------------- #
response = await self._run_request_middleware(
request, request_name=route.name
)
# No middleware results
if not response:
# -------------------------------------------- #
# Execute Handler
# -------------------------------------------- #
if handler is None:
raise ServerError(
(
"'None' was returned while requesting a "
"handler from the router"
)
)
# Run response handler
response = handler(request, **request.match_info)
if isawaitable(response):
response = await response
if request.responded:
if response is not None:
error_logger.error(
"The response object returned by the route handler "
"will not be sent to client. The request has already "
"been responded to."
)
if request.stream is not None:
response = request.stream.response
elif response is not None:
response = await request.respond(response)
elif not hasattr(handler, "is_websocket"):
response = request.stream.response # type: ignore
# Marked for cleanup and DRY with handle_request/handle_exception
# when ResponseStream is no longer supporder
if isinstance(response, BaseHTTPResponse):
await self.dispatch(
"http.lifecycle.response",
inline=True,
context={
"request": request,
"response": response,
},
)
...
await response.send(end_stream=True)
elif isinstance(response, ResponseStream):
resp = await response(request)
await self.dispatch(
"http.lifecycle.response",
inline=True,
context={
"request": request,
"response": resp,
},
)
await response.eof()
else:
if not hasattr(handler, "is_websocket"):
raise ServerError(
f"Invalid response type {response!r} "
"(need HTTPResponse)"
)
except CancelledError:
raise
except Exception as e:
# Response Generation Failed
await self.handle_exception(request, e)
raise NotImplementedError
async def _websocket_handler(
self, handler, request, *args, subprotocols=None, **kwargs
@@ -999,7 +761,7 @@ class Sanic(BaseSanic, RunnerMixin, metaclass=TouchUpMeta):
def asgi_client(self): # noqa
"""
A testing client that uses ASGI to reach into the application to
execute hanlers.
execute handlers.
:return: testing client
:rtype: :class:`SanicASGITestClient`
@@ -1018,86 +780,72 @@ class Sanic(BaseSanic, RunnerMixin, metaclass=TouchUpMeta):
# -------------------------------------------------------------------- #
async def _run_request_middleware(
self, request, request_name=None
self, request, middleware_collection
): # no cov
# The if improves speed. I don't know why
named_middleware = self.named_request_middleware.get(
request_name, deque()
)
applicable_middleware = self.request_middleware + named_middleware
request._request_middleware_started = True
# request.request_middleware_started is meant as a stop-gap solution
# until RFC 1630 is adopted
if applicable_middleware and not request.request_middleware_started:
request.request_middleware_started = True
for middleware in middleware_collection:
await self.dispatch(
"http.middleware.before",
inline=True,
context={
"request": request,
"response": None,
},
condition={"attach_to": "request"},
)
for middleware in applicable_middleware:
await self.dispatch(
"http.middleware.before",
inline=True,
context={
"request": request,
"response": None,
},
condition={"attach_to": "request"},
)
response = middleware(request)
if isawaitable(response):
response = await response
response = middleware(request)
if isawaitable(response):
response = await response
await self.dispatch(
"http.middleware.after",
inline=True,
context={
"request": request,
"response": None,
},
condition={"attach_to": "request"},
)
await self.dispatch(
"http.middleware.after",
inline=True,
context={
"request": request,
"response": None,
},
condition={"attach_to": "request"},
)
if response:
return response
if response:
return response
return None
async def _run_response_middleware(
self, request, response, request_name=None
self, request, response, middleware_collection
): # no cov
named_middleware = self.named_response_middleware.get(
request_name, deque()
)
applicable_middleware = self.response_middleware + named_middleware
if applicable_middleware:
for middleware in applicable_middleware:
await self.dispatch(
"http.middleware.before",
inline=True,
context={
"request": request,
"response": response,
},
condition={"attach_to": "response"},
)
for middleware in middleware_collection:
await self.dispatch(
"http.middleware.before",
inline=True,
context={
"request": request,
"response": response,
},
condition={"attach_to": "response"},
)
_response = middleware(request, response)
if isawaitable(_response):
_response = await _response
_response = middleware(request, response)
if isawaitable(_response):
_response = await _response
await self.dispatch(
"http.middleware.after",
inline=True,
context={
"request": request,
"response": _response if _response else response,
},
condition={"attach_to": "response"},
)
await self.dispatch(
"http.middleware.after",
inline=True,
context={
"request": request,
"response": _response if _response else response,
},
condition={"attach_to": "response"},
)
if _response:
response = _response
if isinstance(response, BaseHTTPResponse):
response = request.stream.respond(response)
break
if _response:
response = _response
if isinstance(response, BaseHTTPResponse):
response = request.stream.respond(response)
break
return response
def _build_endpoint_name(self, *parts):
@@ -1184,7 +932,7 @@ class Sanic(BaseSanic, RunnerMixin, metaclass=TouchUpMeta):
*,
name: Optional[str] = None,
register: bool = True,
) -> Optional[Task]:
) -> Optional[Task[Any]]:
"""
Schedule a task to run later, after the loop has started.
Different from asyncio.ensure_future in that it does not
@@ -1194,7 +942,7 @@ class Sanic(BaseSanic, RunnerMixin, metaclass=TouchUpMeta):
`See user guide re: background tasks
<https://sanicframework.org/guide/basics/tasks.html#background-tasks>`__
:param task: future, couroutine or awaitable
:param task: future, coroutine or awaitable
"""
try:
loop = self.loop # Will raise SanicError if loop is not started
@@ -1315,7 +1063,7 @@ class Sanic(BaseSanic, RunnerMixin, metaclass=TouchUpMeta):
self.config.update_config(config)
@property
def asgi(self):
def asgi(self) -> bool:
return self.state.asgi
@asgi.setter
@@ -1345,6 +1093,7 @@ class Sanic(BaseSanic, RunnerMixin, metaclass=TouchUpMeta):
@auto_reload.setter
def auto_reload(self, value: bool):
self.config.AUTO_RELOAD = value
self.state.auto_reload = value
@property
def state(self) -> ApplicationState: # type: ignore
@@ -1462,6 +1211,18 @@ class Sanic(BaseSanic, RunnerMixin, metaclass=TouchUpMeta):
cls._app_registry[name] = app
@classmethod
def unregister_app(cls, app: "Sanic") -> None:
"""
Unregister a Sanic instance
"""
if not isinstance(app, cls):
raise SanicException("Registered app must be an instance of Sanic")
name = app.name
if name in cls._app_registry:
del cls._app_registry[name]
@classmethod
def get_app(
cls, name: Optional[str] = None, *, force_create: bool = False
@@ -1481,6 +1242,8 @@ class Sanic(BaseSanic, RunnerMixin, metaclass=TouchUpMeta):
try:
return cls._app_registry[name]
except KeyError:
if name == "__main__":
return cls.get_app("__mp_main__", force_create=force_create)
if force_create:
return cls(name)
raise SanicException(f'Sanic app name "{name}" not found.')
@@ -1495,6 +1258,7 @@ class Sanic(BaseSanic, RunnerMixin, metaclass=TouchUpMeta):
except FinalizationError as e:
if not Sanic.test_mode:
raise e
self.finalize_middleware()
def signalize(self, allow_fail_builtin=True):
self.signal_router.allow_fail_builtin = allow_fail_builtin
@@ -1521,6 +1285,18 @@ class Sanic(BaseSanic, RunnerMixin, metaclass=TouchUpMeta):
self.signalize(self.config.TOUCHUP)
self.finalize()
route_names = [route.name for route in self.router.routes]
duplicates = {
name for name in route_names if route_names.count(name) > 1
}
if duplicates:
names = ", ".join(duplicates)
deprecation(
f"Duplicate route names detected: {names}. In the future, "
"Sanic will enforce uniqueness in route naming.",
23.3,
)
# TODO: Replace in v22.6 to check against apps in app registry
if (
self.__class__._uvloop_setting is not None
@@ -1542,6 +1318,9 @@ class Sanic(BaseSanic, RunnerMixin, metaclass=TouchUpMeta):
self.state.is_started = True
if hasattr(self, "multiplexer"):
self.multiplexer.ack()
async def _server_event(
self,
concern: str,
@@ -1570,3 +1349,43 @@ class Sanic(BaseSanic, RunnerMixin, metaclass=TouchUpMeta):
"loop": loop,
},
)
# -------------------------------------------------------------------- #
# Process Management
# -------------------------------------------------------------------- #
def refresh(
self,
passthru: Optional[Dict[str, Any]] = None,
):
registered = self.__class__.get_app(self.name)
if self is not registered:
if not registered.state.server_info:
registered.state.server_info = self.state.server_info
self = registered
if passthru:
for attr, info in passthru.items():
if isinstance(info, dict):
for key, value in info.items():
setattr(getattr(self, attr), key, value)
else:
setattr(self, attr, info)
if hasattr(self, "multiplexer"):
self.shared_ctx.lock()
return self
@property
def inspector(self):
if environ.get("SANIC_WORKER_PROCESS") or not self._inspector:
raise SanicException(
"Can only access the inspector from the main process"
)
return self._inspector
@property
def manager(self):
if environ.get("SANIC_WORKER_PROCESS") or not self._manager:
raise SanicException(
"Can only access the manager from the main process"
)
return self._manager

View File

@@ -1,15 +1,24 @@
from enum import Enum, IntEnum, auto
class StrEnum(str, Enum):
class StrEnum(str, Enum): # no cov
def _generate_next_value_(name: str, *args) -> str: # type: ignore
return name.lower()
def __eq__(self, value: object) -> bool:
value = str(value).upper()
return super().__eq__(value)
def __hash__(self) -> int:
return hash(self.value)
def __str__(self) -> str:
return self.value
class Server(StrEnum):
SANIC = auto()
ASGI = auto()
GUNICORN = auto()
class Mode(StrEnum):

View File

@@ -22,7 +22,7 @@ def setup_ext(app: Sanic, *, fail: bool = False, **kwargs):
with suppress(ModuleNotFoundError):
sanic_ext = import_module("sanic_ext")
if not sanic_ext:
if not sanic_ext: # no cov
if fail:
raise RuntimeError(
"Sanic Extensions is not installed. You can add it to your "

View File

@@ -80,20 +80,23 @@ class MOTDTTY(MOTD):
)
self.display_length = self.key_width + self.value_width + 2
def display(self):
version = f"Sanic v{__version__}".center(self.centering_length)
def display(self, version=True, action="Goin' Fast", out=None):
if not out:
out = logger.info
header = "Sanic"
if version:
header += f" v{__version__}"
header = header.center(self.centering_length)
running = (
f"Goin' Fast @ {self.serve_location}"
if self.serve_location
else ""
f"{action} @ {self.serve_location}" if self.serve_location else ""
).center(self.centering_length)
length = len(version) + 2 - self.logo_line_length
length = len(header) + 2 - self.logo_line_length
first_filler = "" * (self.logo_line_length - 1)
second_filler = "" * length
display_filler = "" * (self.display_length + 2)
lines = [
f"\n{first_filler}{second_filler}",
f"{version}",
f"{header}",
f"{running}",
f"{first_filler}{second_filler}",
]
@@ -107,7 +110,7 @@ class MOTDTTY(MOTD):
self._render_fill(lines)
lines.append(f"{first_filler}{second_filler}\n")
logger.info(indent("\n".join(lines), " "))
out(indent("\n".join(lines), " "))
def _render_data(self, lines, data, start):
offset = 0

View File

@@ -7,6 +7,7 @@ from urllib.parse import quote
from sanic.compat import Header
from sanic.exceptions import ServerError
from sanic.handlers import RequestManager
from sanic.helpers import _default
from sanic.http import Stage
from sanic.log import logger
@@ -230,8 +231,9 @@ class ASGIApp:
"""
Handle the incoming request.
"""
manager = RequestManager.create(self.request)
try:
self.stage = Stage.HANDLER
await self.sanic_app.handle_request(self.request)
await manager.handle()
except Exception as e:
await self.sanic_app.handle_exception(self.request, e)
await manager.error(e)

View File

@@ -308,7 +308,7 @@ class Blueprint(BaseSanic):
# prefixed properly in the router
future.handler.__blueprintname__ = self.name
# Prepend the blueprint URI prefix if available
uri = url_prefix + future.uri if url_prefix else future.uri
uri = self._setup_uri(future.uri, url_prefix)
version_prefix = self.version_prefix
for prefix in (
@@ -333,7 +333,7 @@ class Blueprint(BaseSanic):
apply_route = FutureRoute(
future.handler,
uri[1:] if uri.startswith("//") else uri,
uri,
future.methods,
host,
strict_slashes,
@@ -363,7 +363,7 @@ class Blueprint(BaseSanic):
# Static Files
for future in self._future_statics:
# Prepend the blueprint URI prefix if available
uri = url_prefix + future.uri if url_prefix else future.uri
uri = self._setup_uri(future.uri, url_prefix)
apply_route = FutureStatic(uri, *future[1:])
if (self, apply_route) in app._future_registry:
@@ -456,6 +456,18 @@ class Blueprint(BaseSanic):
break
return value
@staticmethod
def _setup_uri(base: str, prefix: Optional[str]):
uri = base
if prefix:
uri = prefix
if base.startswith("/") and prefix.endswith("/"):
uri += base[1:]
else:
uri += base
return uri[1:] if uri.startswith("//") else uri
@staticmethod
def register_futures(
apps: Set[Sanic], bp: Blueprint, futures: Sequence[Tuple[Any, ...]]

View File

@@ -1,10 +1,10 @@
import logging
import os
import shutil
import sys
from argparse import ArgumentParser, RawTextHelpFormatter
from importlib import import_module
from pathlib import Path
from functools import partial
from textwrap import indent
from typing import Any, List, Union
@@ -12,7 +12,8 @@ from sanic.app import Sanic
from sanic.application.logo import get_logo
from sanic.cli.arguments import Group
from sanic.log import error_logger
from sanic.simple import create_simple_server
from sanic.worker.inspector import inspect
from sanic.worker.loader import AppLoader
class SanicArgumentParser(ArgumentParser):
@@ -66,13 +67,17 @@ Or, a path to a directory to run as a simple HTTP server:
instance.attach()
self.groups.append(instance)
def run(self):
# This is to provide backwards compat -v to display version
legacy_version = len(sys.argv) == 2 and sys.argv[-1] == "-v"
parse_args = ["--version"] if legacy_version else None
def run(self, parse_args=None):
legacy_version = False
if not parse_args:
parsed, unknown = self.parser.parse_known_args()
# This is to provide backwards compat -v to display version
legacy_version = len(sys.argv) == 2 and sys.argv[-1] == "-v"
parse_args = ["--version"] if legacy_version else None
elif parse_args == ["-v"]:
parse_args = ["--version"]
if not legacy_version:
parsed, unknown = self.parser.parse_known_args(args=parse_args)
if unknown and parsed.factory:
for arg in unknown:
if arg.startswith("--"):
@@ -80,20 +85,47 @@ Or, a path to a directory to run as a simple HTTP server:
self.args = self.parser.parse_args(args=parse_args)
self._precheck()
app_loader = AppLoader(
self.args.module,
self.args.factory,
self.args.simple,
self.args,
)
try:
app = self._get_app()
app = self._get_app(app_loader)
kwargs = self._build_run_kwargs()
except ValueError:
error_logger.exception("Failed to run app")
except ValueError as e:
error_logger.exception(f"Failed to run app: {e}")
else:
for http_version in self.args.http:
app.prepare(**kwargs, version=http_version)
if self.args.inspect or self.args.inspect_raw or self.args.trigger:
os.environ["SANIC_IGNORE_PRODUCTION_WARNING"] = "true"
else:
for http_version in self.args.http:
app.prepare(**kwargs, version=http_version)
Sanic.serve()
if self.args.inspect or self.args.inspect_raw or self.args.trigger:
action = self.args.trigger or (
"raw" if self.args.inspect_raw else "pretty"
)
inspect(
app.config.INSPECTOR_HOST,
app.config.INSPECTOR_PORT,
action,
)
del os.environ["SANIC_IGNORE_PRODUCTION_WARNING"]
return
if self.args.single:
serve = Sanic.serve_single
elif self.args.legacy:
serve = Sanic.serve_legacy
else:
serve = partial(Sanic.serve, app_loader=app_loader)
serve(app)
def _precheck(self):
# # Custom TLS mismatch handling for better diagnostics
# Custom TLS mismatch handling for better diagnostics
if self.main_process and (
# one of cert/key missing
bool(self.args.cert) != bool(self.args.key)
@@ -113,58 +145,14 @@ Or, a path to a directory to run as a simple HTTP server:
)
error_logger.error(message)
sys.exit(1)
if self.args.inspect or self.args.inspect_raw:
logging.disable(logging.CRITICAL)
def _get_app(self):
def _get_app(self, app_loader: AppLoader):
try:
module_path = os.path.abspath(os.getcwd())
if module_path not in sys.path:
sys.path.append(module_path)
if self.args.simple:
path = Path(self.args.module)
app = create_simple_server(path)
else:
delimiter = ":" if ":" in self.args.module else "."
module_name, app_name = self.args.module.rsplit(delimiter, 1)
if module_name == "" and os.path.isdir(self.args.module):
raise ValueError(
"App not found.\n"
" Please use --simple if you are passing a "
"directory to sanic.\n"
f" eg. sanic {self.args.module} --simple"
)
if app_name.endswith("()"):
self.args.factory = True
app_name = app_name[:-2]
module = import_module(module_name)
app = getattr(module, app_name, None)
if self.args.factory:
try:
app = app(self.args)
except TypeError:
app = app()
app_type_name = type(app).__name__
if not isinstance(app, Sanic):
if callable(app):
solution = f"sanic {self.args.module} --factory"
raise ValueError(
"Module is not a Sanic app, it is a "
f"{app_type_name}\n"
" If this callable returns a "
f"Sanic instance try: \n{solution}"
)
raise ValueError(
f"Module is not a Sanic app, it is a {app_type_name}\n"
f" Perhaps you meant {self.args.module}:app?"
)
app = app_loader.load()
except ImportError as e:
if module_name.startswith(e.name):
if app_loader.module_name.startswith(e.name): # type: ignore
error_logger.error(
f"No module named {e.name} found.\n"
" Example File: project/sanic_server.py -> app\n"
@@ -190,8 +178,10 @@ Or, a path to a directory to run as a simple HTTP server:
elif len(ssl) == 1 and ssl[0] is not None:
# Use only one cert, no TLSSelector.
ssl = ssl[0]
kwargs = {
"access_log": self.args.access_log,
"coffee": self.args.coffee,
"debug": self.args.debug,
"fast": self.args.fast,
"host": self.args.host,
@@ -203,6 +193,8 @@ Or, a path to a directory to run as a simple HTTP server:
"verbosity": self.args.verbosity or 0,
"workers": self.args.workers,
"auto_tls": self.args.auto_tls,
"single_process": self.args.single,
"legacy": self.args.legacy,
}
for maybe_arg in ("auto_reload", "dev"):

View File

@@ -30,7 +30,7 @@ class Group:
instance = cls(parser, cls.name)
return instance
def add_bool_arguments(self, *args, **kwargs):
def add_bool_arguments(self, *args, nullable=False, **kwargs):
group = self.container.add_mutually_exclusive_group()
kwargs["help"] = kwargs["help"].capitalize()
group.add_argument(*args, action="store_true", **kwargs)
@@ -38,6 +38,9 @@ class Group:
group.add_argument(
"--no-" + args[0][2:], *args[1:], action="store_false", **kwargs
)
if nullable:
params = {args[0][2:].replace("-", "_"): None}
group.set_defaults(**params)
def prepare(self, args) -> None:
...
@@ -67,7 +70,8 @@ class ApplicationGroup(Group):
name = "Application"
def attach(self):
self.container.add_argument(
group = self.container.add_mutually_exclusive_group()
group.add_argument(
"--factory",
action="store_true",
help=(
@@ -75,7 +79,7 @@ class ApplicationGroup(Group):
"i.e. a () -> <Sanic app> callable"
),
)
self.container.add_argument(
group.add_argument(
"-s",
"--simple",
dest="simple",
@@ -85,6 +89,32 @@ class ApplicationGroup(Group):
"a directory\n(module arg should be a path)"
),
)
group.add_argument(
"--inspect",
dest="inspect",
action="store_true",
help=("Inspect the state of a running instance, human readable"),
)
group.add_argument(
"--inspect-raw",
dest="inspect_raw",
action="store_true",
help=("Inspect the state of a running instance, JSON output"),
)
group.add_argument(
"--trigger-reload",
dest="trigger",
action="store_const",
const="reload",
help=("Trigger worker processes to reload"),
)
group.add_argument(
"--trigger-shutdown",
dest="trigger",
action="store_const",
const="shutdown",
help=("Trigger all processes to shutdown"),
)
class HTTPVersionGroup(Group):
@@ -207,8 +237,22 @@ class WorkerGroup(Group):
action="store_true",
help="Set the number of workers to max allowed",
)
group.add_argument(
"--single-process",
dest="single",
action="store_true",
help="Do not use multiprocessing, run server in a single process",
)
self.container.add_argument(
"--legacy",
action="store_true",
help="Use the legacy server manager",
)
self.add_bool_arguments(
"--access-logs", dest="access_log", help="display access logs"
"--access-logs",
dest="access_log",
help="display access logs",
default=None,
)
@@ -262,6 +306,12 @@ class OutputGroup(Group):
name = "Output"
def attach(self):
self.add_bool_arguments(
"--coffee",
dest="coffee",
default=False,
help="Uhm, coffee?",
)
self.add_bool_arguments(
"--motd",
dest="motd",

View File

@@ -1,9 +1,12 @@
from __future__ import annotations
import sys
from inspect import getmembers, isclass, isdatadescriptor
from os import environ
from pathlib import Path
from typing import Any, Callable, Dict, Optional, Sequence, Union
from warnings import filterwarnings
from sanic.constants import LocalCertCreator
from sanic.errorpages import DEFAULT_FORMAT, check_error_format
@@ -13,18 +16,36 @@ from sanic.log import deprecation, error_logger
from sanic.utils import load_module_from_file_location, str_to_bool
if sys.version_info >= (3, 8):
from typing import Literal
FilterWarningType = Union[
Literal["default"],
Literal["error"],
Literal["ignore"],
Literal["always"],
Literal["module"],
Literal["once"],
]
else:
FilterWarningType = str
SANIC_PREFIX = "SANIC_"
DEFAULT_CONFIG = {
"_FALLBACK_ERROR_FORMAT": _default,
"ACCESS_LOG": True,
"ACCESS_LOG": False,
"AUTO_EXTEND": True,
"AUTO_RELOAD": False,
"EVENT_AUTOREGISTER": False,
"DEPRECATION_FILTER": "once",
"FORWARDED_FOR_HEADER": "X-Forwarded-For",
"FORWARDED_SECRET": None,
"GRACEFUL_SHUTDOWN_TIMEOUT": 15.0, # 15 sec
"INSPECTOR": False,
"INSPECTOR_HOST": "localhost",
"INSPECTOR_PORT": 6457,
"KEEP_ALIVE_TIMEOUT": 5, # 5 seconds
"KEEP_ALIVE": True,
"LOCAL_CERT_CREATOR": LocalCertCreator.AUTO,
@@ -69,9 +90,13 @@ class Config(dict, metaclass=DescriptorMeta):
AUTO_EXTEND: bool
AUTO_RELOAD: bool
EVENT_AUTOREGISTER: bool
DEPRECATION_FILTER: FilterWarningType
FORWARDED_FOR_HEADER: str
FORWARDED_SECRET: Optional[str]
GRACEFUL_SHUTDOWN_TIMEOUT: float
INSPECTOR: bool
INSPECTOR_HOST: str
INSPECTOR_PORT: int
KEEP_ALIVE_TIMEOUT: int
KEEP_ALIVE: bool
LOCAL_CERT_CREATOR: Union[str, LocalCertCreator]
@@ -124,22 +149,23 @@ class Config(dict, metaclass=DescriptorMeta):
self.load_environment_vars(SANIC_PREFIX)
self._configure_header_size()
self._configure_warnings()
self._check_error_format()
self._init = True
def __getattr__(self, attr):
def __getattr__(self, attr: Any):
try:
return self[attr]
except KeyError as ke:
raise AttributeError(f"Config has no '{ke.args[0]}'")
def __setattr__(self, attr, value) -> None:
def __setattr__(self, attr: str, value: Any) -> None:
self.update({attr: value})
def __setitem__(self, attr, value) -> None:
def __setitem__(self, attr: str, value: Any) -> None:
self.update({attr: value})
def update(self, *other, **kwargs) -> None:
def update(self, *other: Any, **kwargs: Any) -> None:
kwargs.update({k: v for item in other for k, v in dict(item).items()})
setters: Dict[str, Any] = {
k: kwargs.pop(k)
@@ -172,6 +198,8 @@ class Config(dict, metaclass=DescriptorMeta):
self.LOCAL_CERT_CREATOR = LocalCertCreator[
self.LOCAL_CERT_CREATOR.upper()
]
elif attr == "DEPRECATION_FILTER":
self._configure_warnings()
@property
def FALLBACK_ERROR_FORMAT(self) -> str:
@@ -199,6 +227,13 @@ class Config(dict, metaclass=DescriptorMeta):
self.REQUEST_MAX_SIZE,
)
def _configure_warnings(self):
filterwarnings(
self.DEPRECATION_FILTER,
category=DeprecationWarning,
module=r"sanic.*",
)
def _check_error_format(self, format: Optional[str] = None):
check_error_format(format or self.FALLBACK_ERROR_FORMAT)

View File

@@ -34,6 +34,15 @@ class LocalCertCreator(str, Enum):
HTTP_METHODS = tuple(HTTPMethod.__members__.values())
SAFE_HTTP_METHODS = (HTTPMethod.GET, HTTPMethod.HEAD, HTTPMethod.OPTIONS)
IDEMPOTENT_HTTP_METHODS = (
HTTPMethod.GET,
HTTPMethod.HEAD,
HTTPMethod.PUT,
HTTPMethod.DELETE,
HTTPMethod.OPTIONS,
)
CACHEABLE_HTTP_METHODS = (HTTPMethod.GET, HTTPMethod.HEAD)
DEFAULT_HTTP_CONTENT_TYPE = "application/octet-stream"
DEFAULT_LOCAL_TLS_KEY = "key.pem"
DEFAULT_LOCAL_TLS_CERT = "cert.pem"

View File

@@ -1,8 +1,13 @@
from asyncio import CancelledError
from typing import Any, Dict, Optional, Union
from sanic.helpers import STATUS_CODES
class RequestCancelled(CancelledError):
quiet = True
class SanicException(Exception):
message: str = ""

View File

@@ -1,16 +1,317 @@
from __future__ import annotations
from functools import partial
from inspect import isawaitable
from traceback import format_exc
from typing import Dict, List, Optional, Tuple, Type
from sanic_routing import Route
from sanic.errorpages import BaseRenderer, TextRenderer, exception_response
from sanic.exceptions import (
HeaderNotFound,
InvalidRangeType,
RangeNotSatisfiable,
SanicException,
ServerError,
)
from sanic.log import deprecation, error_logger
from sanic.http.constants import Stage
from sanic.log import deprecation, error_logger, logger
from sanic.models.handler_types import RouteHandler
from sanic.response import text
from sanic.request import Request
from sanic.response import BaseHTTPResponse, HTTPResponse, ResponseStream, text
from sanic.touchup import TouchUpMeta
class RequestHandler:
def __init__(self, func, request_middleware, response_middleware):
self.func = func.func if isinstance(func, RequestHandler) else func
self.request_middleware = request_middleware
self.response_middleware = response_middleware
def __call__(self, *args, **kwargs):
return self.func(*args, **kwargs)
class RequestManager(metaclass=TouchUpMeta):
__touchup__ = (
"cleanup",
"run_request_middleware",
"run_response_middleware",
)
__slots__ = (
"handler",
"request_middleware_run",
"request_middleware",
"request",
"response_middleware_run",
"response_middleware",
)
request: Request
def __init__(self, request: Request):
self.request_middleware_run = False
self.response_middleware_run = False
self.handler = self._noop
self.set_request(request)
@classmethod
def create(cls, request: Request) -> RequestManager:
return cls(request)
def set_request(self, request: Request):
request._manager = self
self.request = request
self.request_middleware = request.app.request_middleware
self.response_middleware = request.app.response_middleware
async def handle(self):
route = self.resolve_route()
if self.handler is None:
await self.error(
ServerError(
(
"'None' was returned while requesting a "
"handler from the router"
)
)
)
return
if (
self.request.stream
and self.request.stream.request_body
and not route.ctx.ignore_body
):
await self.receive_body()
await self.lifecycle(
partial(self.handler, self.request, **self.request.match_info)
)
async def lifecycle(self, handler, raise_exception: bool = False):
response: Optional[BaseHTTPResponse] = None
if not self.request_middleware_run and self.request_middleware:
response = await self.run(
self.run_request_middleware, raise_exception
)
if not response:
# Run response handler
response = await self.run(handler, raise_exception)
if not self.response_middleware_run and self.response_middleware:
response = await self.run(
partial(self.run_response_middleware, response),
raise_exception,
)
await self.cleanup(response)
async def run(
self, operation, raise_exception: bool = False
) -> Optional[BaseHTTPResponse]:
try:
response = operation()
if isawaitable(response):
response = await response
except Exception as e:
if raise_exception:
raise
response = await self.error(e)
return response
async def error(self, exception: Exception):
error_handler = self.request.app.error_handler
if (
self.request.stream is not None
and self.request.stream.stage is not Stage.HANDLER
):
error_logger.exception(exception, exc_info=True)
logger.error(
"The error response will not be sent to the client for "
f'the following exception:"{exception}". A previous response '
"has at least partially been sent."
)
handler = error_handler._lookup(
exception, self.request.name if self.request else None
)
if handler:
logger.warning(
"An error occurred while handling the request after at "
"least some part of the response was sent to the client. "
"The response from your custom exception handler "
f"{handler.__name__} will not be sent to the client."
"Exception handlers should only be used to generate the "
"exception responses. If you would like to perform any "
"other action on a raised exception, consider using a "
"signal handler like "
'`@app.signal("http.lifecycle.exception")`\n'
"For further information, please see the docs: "
"https://sanicframework.org/en/guide/advanced/"
"signals.html",
)
return
try:
await self.lifecycle(
partial(error_handler.response, self.request, exception), True
)
except Exception as e:
if isinstance(e, SanicException):
response = error_handler.default(self.request, e)
elif self.request.app.debug:
response = HTTPResponse(
(
f"Error while handling error: {e}\n"
f"Stack: {format_exc()}"
),
status=500,
)
else:
error_logger.exception(e)
response = HTTPResponse(
"An error occurred while handling an error", status=500
)
return response
return None
async def cleanup(self, response: Optional[BaseHTTPResponse]):
if self.request.responded:
if response is not None:
error_logger.error(
"The response object returned by the route handler "
"will not be sent to client. The request has already "
"been responded to."
)
if self.request.stream is not None:
response = self.request.stream.response
elif response is not None:
self.request.reset_response()
response = await self.request.respond(response) # type: ignore
elif not hasattr(self.handler, "is_websocket"):
response = self.request.stream.response # type: ignore
if isinstance(response, BaseHTTPResponse):
await self.request.app.dispatch(
"http.lifecycle.response",
inline=True,
context={"request": self.request, "response": response},
)
await response.send(end_stream=True)
elif isinstance(response, ResponseStream):
await response(self.request) # type: ignore
await response.eof() # type: ignore
await self.request.app.dispatch(
"http.lifecycle.response",
inline=True,
context={"request": self.request, "response": response},
)
else:
if not hasattr(self.handler, "is_websocket"):
raise ServerError(
f"Invalid response type {response!r} "
"(need HTTPResponse)"
)
async def receive_body(self):
if hasattr(self.handler, "is_stream"):
# Streaming handler: lift the size limit
self.request.stream.request_max_size = float("inf")
else:
# Non-streaming handler: preload body
await self.request.receive_body()
async def run_request_middleware(self) -> Optional[BaseHTTPResponse]:
self.request._request_middleware_started = True
self.request_middleware_run = True
for middleware in self.request_middleware:
await self.request.app.dispatch(
"http.middleware.before",
inline=True,
context={"request": self.request, "response": None},
condition={"attach_to": "request"},
)
try:
response = await self.run(partial(middleware, self.request))
except Exception:
error_logger.exception(
"Exception occurred in one of request middleware handlers"
)
raise
await self.request.app.dispatch(
"http.middleware.after",
inline=True,
context={"request": self.request, "response": None},
condition={"attach_to": "request"},
)
if response:
return response
return None
async def run_response_middleware(
self, response: BaseHTTPResponse
) -> BaseHTTPResponse:
self.response_middleware_run = True
for middleware in self.response_middleware:
await self.request.app.dispatch(
"http.middleware.before",
inline=True,
context={"request": self.request, "response": None},
condition={"attach_to": "request"},
)
try:
resp = await self.run(
partial(middleware, self.request, response), True
)
except Exception as e:
error_logger.exception(
"Exception occurred in one of response middleware handlers"
)
await self.error(e)
resp = None
await self.request.app.dispatch(
"http.middleware.after",
inline=True,
context={"request": self.request, "response": None},
condition={"attach_to": "request"},
)
if resp:
return resp
return response
def resolve_route(self) -> Route:
# Fetch handler from router
route, handler, kwargs = self.request.app.router.get(
self.request.path,
self.request.method,
self.request.headers.getone("host", None),
)
self.request._match_info = {**kwargs}
self.request.route = route
self.handler = handler
if handler and handler.request_middleware:
self.request_middleware = handler.request_middleware
if handler and handler.response_middleware:
self.response_middleware = handler.response_middleware
return route
@staticmethod
def _noop(_):
...
class ErrorHandler:
@@ -47,6 +348,28 @@ class ErrorHandler:
def _full_lookup(self, exception, route_name: Optional[str] = None):
return self.lookup(exception, route_name)
def _add(
self,
key: Tuple[Type[BaseException], Optional[str]],
handler: RouteHandler,
) -> None:
if key in self.cached_handlers:
exc, name = key
if name is None:
name = "__ALL_ROUTES__"
error_logger.warning(
f"Duplicate exception handler definition on: route={name} "
f"and exception={exc}"
)
deprecation(
"A duplicate exception handler definition was discovered. "
"This may cause unintended consequences. A warning has been "
"issued now, but it will not be allowed starting in v23.3.",
23.3,
)
self.cached_handlers[key] = handler
def add(self, exception, handler, route_names: Optional[List[str]] = None):
"""
Add a new exception handler to an already existing handler object.
@@ -62,9 +385,9 @@ class ErrorHandler:
"""
if route_names:
for route in route_names:
self.cached_handlers[(exception, route)] = handler
self._add((exception, route), handler)
else:
self.cached_handlers[(exception, None)] = handler
self._add((exception, None), handler)
def lookup(self, exception, route_name: Optional[str] = None):
"""

View File

@@ -14,8 +14,8 @@ from sanic.exceptions import (
BadRequest,
ExpectationFailed,
PayloadTooLarge,
RequestCancelled,
ServerError,
ServiceUnavailable,
)
from sanic.headers import format_http1_response
from sanic.helpers import has_message_body
@@ -124,7 +124,8 @@ class Http(Stream, metaclass=TouchUpMeta):
self.stage = Stage.HANDLER
self.request.conn_info = self.protocol.conn_info
await self.protocol.request_handler(self.request)
await self.request.manager.handle()
# Handler finished, response should've been sent
if self.stage is Stage.HANDLER and not self.upgrade_websocket:
@@ -132,7 +133,7 @@ class Http(Stream, metaclass=TouchUpMeta):
if self.stage is Stage.RESPONSE:
await self.response.send(end_stream=True)
except CancelledError:
except CancelledError as exc:
# Write an appropriate response before exiting
if not self.protocol.transport:
logger.info(
@@ -140,7 +141,11 @@ class Http(Stream, metaclass=TouchUpMeta):
"stopped. Transport is closed."
)
return
e = self.exception or ServiceUnavailable("Cancelled")
e = (
RequestCancelled()
if self.protocol.conn_info.lost
else (self.exception or exc)
)
self.exception = None
self.keep_alive = False
await self.error_response(e)
@@ -246,6 +251,7 @@ class Http(Stream, metaclass=TouchUpMeta):
transport=self.protocol.transport,
app=self.protocol.app,
)
self.protocol.request_handler.create(request)
self.protocol.request_class._current.set(request)
await self.dispatch(
"http.lifecycle.request",
@@ -419,12 +425,11 @@ class Http(Stream, metaclass=TouchUpMeta):
# From request and handler states we can respond, otherwise be silent
if self.stage is Stage.HANDLER:
app = self.protocol.app
if self.request is None:
self.create_empty_request()
self.protocol.request_handler.create(self.request)
await app.handle_exception(self.request, exception)
await self.request.manager.error(exception)
def create_empty_request(self) -> None:
"""

View File

@@ -22,7 +22,7 @@ from sanic.exceptions import PayloadTooLarge, SanicException, ServerError
from sanic.helpers import has_message_body
from sanic.http.constants import Stage
from sanic.http.stream import Stream
from sanic.http.tls.context import CertSelector, CertSimple, SanicSSLContext
from sanic.http.tls.context import CertSelector, SanicSSLContext
from sanic.log import Colors, logger
from sanic.models.protocol_types import TransportProtocol
from sanic.models.server_types import ConnInfo
@@ -378,7 +378,7 @@ def get_config(
app: Sanic, ssl: Union[SanicSSLContext, CertSelector, SSLContext]
):
# TODO:
# - proper selection needed if servince with multiple certs insted of
# - proper selection needed if service with multiple certs insted of
# just taking the first
if isinstance(ssl, CertSelector):
ssl = cast(SanicSSLContext, ssl.sanic_select[0])
@@ -389,8 +389,8 @@ def get_config(
"should be able to use mkcert instead. For more information, see: "
"https://github.com/aiortc/aioquic/issues/295."
)
if not isinstance(ssl, CertSimple):
raise SanicException("SSLContext is not CertSimple")
if not isinstance(ssl, SanicSSLContext):
raise SanicException("SSLContext is not SanicSSLContext")
config = QuicConfiguration(
alpn_protocols=H3_ALPN + H0_ALPN + ["siduck"],

View File

@@ -240,7 +240,12 @@ class MkcertCreator(CertCreator):
self.cert_path.unlink()
self.tmpdir.rmdir()
return CertSimple(self.cert_path, self.key_path)
context = CertSimple(self.cert_path, self.key_path)
context.sanic["creator"] = "mkcert"
context.sanic["localhost"] = localhost
SanicSSLContext.create_from_ssl_context(context)
return context
class TrustmeCreator(CertCreator):
@@ -259,20 +264,23 @@ class TrustmeCreator(CertCreator):
)
def generate_cert(self, localhost: str) -> ssl.SSLContext:
context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
sanic_context = SanicSSLContext.create_from_ssl_context(context)
sanic_context.sanic = {
context = SanicSSLContext.create_from_ssl_context(
ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
)
context.sanic = {
"cert": self.cert_path.absolute(),
"key": self.key_path.absolute(),
}
ca = trustme.CA()
server_cert = ca.issue_cert(localhost)
server_cert.configure_cert(sanic_context)
server_cert.configure_cert(context)
ca.configure_trust(context)
ca.cert_pem.write_to_path(str(self.cert_path.absolute()))
server_cert.private_key_and_cert_chain_pem.write_to_path(
str(self.key_path.absolute())
)
context.sanic["creator"] = "trustme"
context.sanic["localhost"] = localhost
return context

View File

@@ -64,10 +64,11 @@ Defult logging configuration
class Colors(str, Enum): # no cov
END = "\033[0m"
BLUE = "\033[01;34m"
GREEN = "\033[01;32m"
PURPLE = "\033[01;35m"
RED = "\033[01;31m"
BOLD = "\033[1m"
BLUE = "\033[34m"
GREEN = "\033[32m"
PURPLE = "\033[35m"
RED = "\033[31m"
SANIC = "\033[38;2;255;13;104m"
YELLOW = "\033[01;33m"

66
sanic/middleware.py Normal file
View File

@@ -0,0 +1,66 @@
from __future__ import annotations
from collections import deque
from enum import IntEnum, auto
from itertools import count
from typing import Deque, Sequence, Union
from sanic.models.handler_types import MiddlewareType
class MiddlewareLocation(IntEnum):
REQUEST = auto()
RESPONSE = auto()
class Middleware:
_counter = count()
__slots__ = ("func", "priority", "location", "definition")
def __init__(
self,
func: MiddlewareType,
location: MiddlewareLocation = MiddlewareLocation.REQUEST,
priority: int = 0,
) -> None:
self.func = func
self.priority = priority
self.location = location
self.definition = next(Middleware._counter)
def __call__(self, *args, **kwargs):
return self.func(*args, **kwargs)
def __repr__(self) -> str:
name = getattr(self.func, "__name__", str(self.func))
return (
f"{self.__class__.__name__}("
f"func=<function {name}>, "
f"priority={self.priority}, "
f"location={self.location.name})"
)
@property
def order(self):
return (self.priority, -self.definition)
@classmethod
def convert(
cls,
*middleware_collections: Sequence[Union[Middleware, MiddlewareType]],
location: MiddlewareLocation,
) -> Deque[Middleware]:
return deque(
[
middleware
if isinstance(middleware, Middleware)
else Middleware(middleware, location)
for collection in middleware_collections
for middleware in collection
]
)
@classmethod
def reset_count(cls):
cls._counter = count()

View File

@@ -17,9 +17,12 @@ class ListenerEvent(str, Enum):
BEFORE_SERVER_STOP = "server.shutdown.before"
AFTER_SERVER_STOP = "server.shutdown.after"
MAIN_PROCESS_START = auto()
MAIN_PROCESS_READY = auto()
MAIN_PROCESS_STOP = auto()
RELOAD_PROCESS_START = auto()
RELOAD_PROCESS_STOP = auto()
BEFORE_RELOAD_TRIGGER = auto()
AFTER_RELOAD_TRIGGER = auto()
class ListenerMixin(metaclass=SanicMeta):
@@ -98,6 +101,11 @@ class ListenerMixin(metaclass=SanicMeta):
) -> ListenerType[Sanic]:
return self.listener(listener, "main_process_start")
def main_process_ready(
self, listener: ListenerType[Sanic]
) -> ListenerType[Sanic]:
return self.listener(listener, "main_process_ready")
def main_process_stop(
self, listener: ListenerType[Sanic]
) -> ListenerType[Sanic]:
@@ -113,6 +121,16 @@ class ListenerMixin(metaclass=SanicMeta):
) -> ListenerType[Sanic]:
return self.listener(listener, "reload_process_stop")
def before_reload_trigger(
self, listener: ListenerType[Sanic]
) -> ListenerType[Sanic]:
return self.listener(listener, "before_reload_trigger")
def after_reload_trigger(
self, listener: ListenerType[Sanic]
) -> ListenerType[Sanic]:
return self.listener(listener, "after_reload_trigger")
def before_server_start(
self, listener: ListenerType[Sanic]
) -> ListenerType[Sanic]:

View File

@@ -1,11 +1,18 @@
from collections import deque
from functools import partial
from operator import attrgetter
from typing import List
from sanic.base.meta import SanicMeta
from sanic.handlers import RequestHandler
from sanic.middleware import Middleware, MiddlewareLocation
from sanic.models.futures import FutureMiddleware
from sanic.router import Router
class MiddlewareMixin(metaclass=SanicMeta):
router: Router
def __init__(self, *args, **kwargs) -> None:
self._future_middleware: List[FutureMiddleware] = []
@@ -13,7 +20,12 @@ class MiddlewareMixin(metaclass=SanicMeta):
raise NotImplementedError # noqa
def middleware(
self, middleware_or_request, attach_to="request", apply=True
self,
middleware_or_request,
attach_to="request",
apply=True,
*,
priority=0
):
"""
Decorate and register middleware to be called before a request
@@ -30,6 +42,12 @@ class MiddlewareMixin(metaclass=SanicMeta):
def register_middleware(middleware, attach_to="request"):
nonlocal apply
location = (
MiddlewareLocation.REQUEST
if attach_to == "request"
else MiddlewareLocation.RESPONSE
)
middleware = Middleware(middleware, location, priority=priority)
future_middleware = FutureMiddleware(middleware, attach_to)
self._future_middleware.append(future_middleware)
if apply:
@@ -46,7 +64,7 @@ class MiddlewareMixin(metaclass=SanicMeta):
register_middleware, attach_to=middleware_or_request
)
def on_request(self, middleware=None):
def on_request(self, middleware=None, *, priority=0):
"""Register a middleware to be called before a request is handled.
This is the same as *@app.middleware('request')*.
@@ -54,11 +72,13 @@ class MiddlewareMixin(metaclass=SanicMeta):
:param: middleware: A callable that takes in request.
"""
if callable(middleware):
return self.middleware(middleware, "request")
return self.middleware(middleware, "request", priority=priority)
else:
return partial(self.middleware, attach_to="request")
return partial(
self.middleware, attach_to="request", priority=priority
)
def on_response(self, middleware=None):
def on_response(self, middleware=None, *, priority=0):
"""Register a middleware to be called after a response is created.
This is the same as *@app.middleware('response')*.
@@ -67,6 +87,61 @@ class MiddlewareMixin(metaclass=SanicMeta):
A callable that takes in a request and its response.
"""
if callable(middleware):
return self.middleware(middleware, "response")
return self.middleware(middleware, "response", priority=priority)
else:
return partial(self.middleware, attach_to="response")
return partial(
self.middleware, attach_to="response", priority=priority
)
def finalize_middleware(self):
for route in self.router.routes:
request_middleware = Middleware.convert(
self.request_middleware,
self.named_request_middleware.get(route.name, deque()),
location=MiddlewareLocation.REQUEST,
)
response_middleware = Middleware.convert(
self.response_middleware,
self.named_response_middleware.get(route.name, deque()),
location=MiddlewareLocation.RESPONSE,
)
route.handler = RequestHandler(
route.handler,
deque(
sorted(
request_middleware,
key=attrgetter("order"),
reverse=True,
)
),
deque(
sorted(
response_middleware,
key=attrgetter("order"),
reverse=True,
)[::-1]
),
)
request_middleware = Middleware.convert(
self.request_middleware,
location=MiddlewareLocation.REQUEST,
)
response_middleware = Middleware.convert(
self.response_middleware,
location=MiddlewareLocation.RESPONSE,
)
self.request_middleware = deque(
sorted(
request_middleware,
key=attrgetter("order"),
reverse=True,
)
)
self.response_middleware = deque(
sorted(
response_middleware,
key=attrgetter("order"),
reverse=True,
)[::-1]
)

View File

@@ -4,8 +4,7 @@ from functools import partial, wraps
from inspect import getsource, signature
from mimetypes import guess_type
from os import path
from pathlib import PurePath
from re import sub
from pathlib import Path, PurePath
from textwrap import dedent
from time import gmtime, strftime
from typing import (
@@ -27,12 +26,7 @@ from sanic.base.meta import SanicMeta
from sanic.compat import stat_async
from sanic.constants import DEFAULT_HTTP_CONTENT_TYPE, HTTP_METHODS
from sanic.errorpages import RESPONSE_MAPPING
from sanic.exceptions import (
BadRequest,
FileNotFound,
HeaderNotFound,
RangeNotSatisfiable,
)
from sanic.exceptions import FileNotFound, HeaderNotFound, RangeNotSatisfiable
from sanic.handlers import ContentRangeHandler
from sanic.log import error_logger
from sanic.models.futures import FutureRoute, FutureStatic
@@ -231,7 +225,7 @@ class RouteMixin(metaclass=SanicMeta):
stream: bool = False,
version_prefix: str = "/v",
error_format: Optional[str] = None,
**ctx_kwargs,
**ctx_kwargs: Any,
) -> RouteHandler:
"""A helper method to register class instance or
functions as a handler to the application url
@@ -292,7 +286,7 @@ class RouteMixin(metaclass=SanicMeta):
ignore_body: bool = True,
version_prefix: str = "/v",
error_format: Optional[str] = None,
**ctx_kwargs,
**ctx_kwargs: Any,
) -> RouteHandler:
"""
Add an API URL under the **GET** *HTTP* method
@@ -335,7 +329,7 @@ class RouteMixin(metaclass=SanicMeta):
name: Optional[str] = None,
version_prefix: str = "/v",
error_format: Optional[str] = None,
**ctx_kwargs,
**ctx_kwargs: Any,
) -> RouteHandler:
"""
Add an API URL under the **POST** *HTTP* method
@@ -378,7 +372,7 @@ class RouteMixin(metaclass=SanicMeta):
name: Optional[str] = None,
version_prefix: str = "/v",
error_format: Optional[str] = None,
**ctx_kwargs,
**ctx_kwargs: Any,
) -> RouteHandler:
"""
Add an API URL under the **PUT** *HTTP* method
@@ -421,7 +415,7 @@ class RouteMixin(metaclass=SanicMeta):
ignore_body: bool = True,
version_prefix: str = "/v",
error_format: Optional[str] = None,
**ctx_kwargs,
**ctx_kwargs: Any,
) -> RouteHandler:
"""
Add an API URL under the **HEAD** *HTTP* method
@@ -472,7 +466,7 @@ class RouteMixin(metaclass=SanicMeta):
ignore_body: bool = True,
version_prefix: str = "/v",
error_format: Optional[str] = None,
**ctx_kwargs,
**ctx_kwargs: Any,
) -> RouteHandler:
"""
Add an API URL under the **OPTIONS** *HTTP* method
@@ -523,7 +517,7 @@ class RouteMixin(metaclass=SanicMeta):
name: Optional[str] = None,
version_prefix: str = "/v",
error_format: Optional[str] = None,
**ctx_kwargs,
**ctx_kwargs: Any,
) -> RouteHandler:
"""
Add an API URL under the **PATCH** *HTTP* method
@@ -576,7 +570,7 @@ class RouteMixin(metaclass=SanicMeta):
ignore_body: bool = True,
version_prefix: str = "/v",
error_format: Optional[str] = None,
**ctx_kwargs,
**ctx_kwargs: Any,
) -> RouteHandler:
"""
Add an API URL under the **DELETE** *HTTP* method
@@ -620,7 +614,7 @@ class RouteMixin(metaclass=SanicMeta):
apply: bool = True,
version_prefix: str = "/v",
error_format: Optional[str] = None,
**ctx_kwargs,
**ctx_kwargs: Any,
):
"""
Decorate a function to be registered as a websocket route
@@ -664,7 +658,7 @@ class RouteMixin(metaclass=SanicMeta):
name: Optional[str] = None,
version_prefix: str = "/v",
error_format: Optional[str] = None,
**ctx_kwargs,
**ctx_kwargs: Any,
):
"""
A helper method to register a function as a websocket route.
@@ -701,18 +695,18 @@ class RouteMixin(metaclass=SanicMeta):
def static(
self,
uri,
uri: str,
file_or_directory: Union[str, bytes, PurePath],
pattern=r"/?.+",
use_modified_since=True,
use_content_range=False,
stream_large_files=False,
name="static",
host=None,
strict_slashes=None,
content_type=None,
apply=True,
resource_type=None,
pattern: str = r"/?.+",
use_modified_since: bool = True,
use_content_range: bool = False,
stream_large_files: bool = False,
name: str = "static",
host: Optional[str] = None,
strict_slashes: Optional[bool] = None,
content_type: Optional[bool] = None,
apply: bool = True,
resource_type: Optional[str] = None,
):
"""
Register a root to serve files from. The input can either be a
@@ -806,32 +800,40 @@ class RouteMixin(metaclass=SanicMeta):
content_type=None,
__file_uri__=None,
):
# Using this to determine if the URL is trying to break out of the path
# served. os.path.realpath seems to be very slow
if __file_uri__ and "../" in __file_uri__:
raise BadRequest("Invalid URL")
# Merge served directory and requested file if provided
# Strip all / that in the beginning of the URL to help prevent python
# from herping a derp and treating the uri as an absolute path
root_path = file_path = file_or_directory
if __file_uri__:
file_path = path.join(
file_or_directory, sub("^[/]*", "", __file_uri__)
)
file_path_raw = Path(unquote(file_or_directory))
root_path = file_path = file_path_raw.resolve()
not_found = FileNotFound(
"File not found",
path=file_or_directory,
relative_url=__file_uri__,
)
# URL decode the path sent by the browser otherwise we won't be able to
# match filenames which got encoded (filenames with spaces etc)
file_path = path.abspath(unquote(file_path))
if not file_path.startswith(path.abspath(unquote(root_path))):
error_logger.exception(
f"File not found: path={file_or_directory}, "
f"relative_url={__file_uri__}"
)
raise FileNotFound(
"File not found",
path=file_or_directory,
relative_url=__file_uri__,
)
if __file_uri__:
# Strip all / that in the beginning of the URL to help prevent
# python from herping a derp and treating the uri as an
# absolute path
unquoted_file_uri = unquote(__file_uri__).lstrip("/")
file_path_raw = Path(file_or_directory, unquoted_file_uri)
file_path = file_path_raw.resolve()
if (
file_path < root_path and not file_path_raw.is_symlink()
) or ".." in file_path_raw.parts:
error_logger.exception(
f"File not found: path={file_or_directory}, "
f"relative_url={__file_uri__}"
)
raise not_found
try:
file_path.relative_to(root_path)
except ValueError:
if not file_path_raw.is_symlink():
error_logger.exception(
f"File not found: path={file_or_directory}, "
f"relative_url={__file_uri__}"
)
raise not_found
try:
headers = {}
# Check if the client has been sent this file before
@@ -899,11 +901,7 @@ class RouteMixin(metaclass=SanicMeta):
except RangeNotSatisfiable:
raise
except FileNotFoundError:
raise FileNotFound(
"File not found",
path=file_or_directory,
relative_url=__file_uri__,
)
raise not_found
except Exception:
error_logger.exception(
f"Exception in static request handler: "
@@ -960,6 +958,7 @@ class RouteMixin(metaclass=SanicMeta):
# serve from the folder
if not static.resource_type:
if not path.isfile(file_or_directory):
uri = uri.rstrip("/")
uri += "/<__file_uri__:path>"
elif static.resource_type == "dir":
if path.isfile(file_or_directory):
@@ -967,6 +966,7 @@ class RouteMixin(metaclass=SanicMeta):
"Resource type improperly identified as directory. "
f"'{file_or_directory}'"
)
uri = uri.rstrip("/")
uri += "/<__file_uri__:path>"
elif static.resource_type == "file" and not path.isfile(
file_or_directory

View File

@@ -16,12 +16,15 @@ from asyncio import (
from contextlib import suppress
from functools import partial
from importlib import import_module
from multiprocessing import Manager, Pipe, get_context
from multiprocessing.context import BaseContext
from pathlib import Path
from socket import socket
from ssl import SSLContext
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
List,
Optional,
@@ -32,7 +35,6 @@ from typing import (
cast,
)
from sanic import reloader_helpers
from sanic.application.logo import get_logo
from sanic.application.motd import MOTD
from sanic.application.state import ApplicationServerInfo, Mode, ServerStage
@@ -41,15 +43,25 @@ from sanic.compat import OS_IS_WINDOWS, is_atty
from sanic.helpers import _default
from sanic.http.constants import HTTP
from sanic.http.tls import get_ssl_context, process_to_context
from sanic.http.tls.context import SanicSSLContext
from sanic.log import Colors, deprecation, error_logger, logger
from sanic.models.handler_types import ListenerType
from sanic.server import Signal as ServerSignal
from sanic.server import try_use_uvloop
from sanic.server.async_server import AsyncioServer
from sanic.server.events import trigger_events
from sanic.server.legacy import watchdog
from sanic.server.loop import try_windows_loop
from sanic.server.protocols.http_protocol import HttpProtocol
from sanic.server.protocols.websocket_protocol import WebSocketProtocol
from sanic.server.runners import serve, serve_multiple, serve_single
from sanic.server.socket import configure_socket, remove_unix_socket
from sanic.worker.inspector import Inspector
from sanic.worker.loader import AppLoader
from sanic.worker.manager import WorkerManager
from sanic.worker.multiplexer import WorkerMultiplexer
from sanic.worker.reloader import Reloader
from sanic.worker.serve import worker_serve
if TYPE_CHECKING:
@@ -59,20 +71,35 @@ if TYPE_CHECKING:
SANIC_PACKAGES = ("sanic-routing", "sanic-testing", "sanic-ext")
if sys.version_info < (3, 8):
if sys.version_info < (3, 8): # no cov
HTTPVersion = Union[HTTP, int]
else:
else: # no cov
from typing import Literal
HTTPVersion = Union[HTTP, Literal[1], Literal[3]]
class RunnerMixin(metaclass=SanicMeta):
class StartupMixin(metaclass=SanicMeta):
_app_registry: Dict[str, Sanic]
config: Config
listeners: Dict[str, List[ListenerType[Any]]]
state: ApplicationState
websocket_enabled: bool
multiplexer: WorkerMultiplexer
def setup_loop(self):
if not self.asgi:
if self.config.USE_UVLOOP is True or (
self.config.USE_UVLOOP is _default and not OS_IS_WINDOWS
):
try_use_uvloop()
elif OS_IS_WINDOWS:
try_windows_loop()
@property
def m(self) -> WorkerMultiplexer:
"""Interface for interacting with the worker processes"""
return self.multiplexer
def make_coffee(self, *args, **kwargs):
self.state.coffee = True
@@ -103,6 +130,8 @@ class RunnerMixin(metaclass=SanicMeta):
verbosity: int = 0,
motd_display: Optional[Dict[str, str]] = None,
auto_tls: bool = False,
single_process: bool = False,
legacy: bool = False,
) -> None:
"""
Run the HTTP Server and listen until keyboard interrupt or term
@@ -163,9 +192,17 @@ class RunnerMixin(metaclass=SanicMeta):
verbosity=verbosity,
motd_display=motd_display,
auto_tls=auto_tls,
single_process=single_process,
legacy=legacy,
)
self.__class__.serve(primary=self) # type: ignore
if single_process:
serve = self.__class__.serve_single
elif legacy:
serve = self.__class__.serve_legacy
else:
serve = self.__class__.serve
serve(primary=self) # type: ignore
def prepare(
self,
@@ -191,7 +228,10 @@ class RunnerMixin(metaclass=SanicMeta):
fast: bool = False,
verbosity: int = 0,
motd_display: Optional[Dict[str, str]] = None,
coffee: bool = False,
auto_tls: bool = False,
single_process: bool = False,
legacy: bool = False,
) -> None:
if version == 3 and self.state.server_info:
raise RuntimeError(
@@ -204,6 +244,9 @@ class RunnerMixin(metaclass=SanicMeta):
debug = True
auto_reload = True
if debug and access_log is None:
access_log = True
self.state.verbosity = verbosity
if not self.state.auto_reload:
self.state.auto_reload = bool(auto_reload)
@@ -211,6 +254,21 @@ class RunnerMixin(metaclass=SanicMeta):
if fast and workers != 1:
raise RuntimeError("You cannot use both fast=True and workers=X")
if single_process and (fast or (workers > 1) or auto_reload):
raise RuntimeError(
"Single process cannot be run with multiple workers "
"or auto-reload"
)
if single_process and legacy:
raise RuntimeError("Cannot run single process and legacy mode")
if register_sys_signals is False and not (single_process or legacy):
raise RuntimeError(
"Cannot run Sanic.serve with register_sys_signals=False. "
"Use either Sanic.serve_single or Sanic.serve_legacy."
)
if motd_display:
self.config.MOTD_DISPLAY.update(motd_display)
@@ -234,12 +292,6 @@ class RunnerMixin(metaclass=SanicMeta):
"#asynchronous-support"
)
if (
self.__class__.should_auto_reload()
and os.environ.get("SANIC_SERVER_RUNNING") != "true"
): # no cov
return
if sock is None:
host, port = self.get_address(host, port, version, auto_tls)
@@ -265,6 +317,9 @@ class RunnerMixin(metaclass=SanicMeta):
except AttributeError: # no cov
workers = os.cpu_count() or 1
if coffee:
self.state.coffee = True
server_settings = self._helper(
host=host,
port=port,
@@ -283,10 +338,10 @@ class RunnerMixin(metaclass=SanicMeta):
ApplicationServerInfo(settings=server_settings)
)
if self.config.USE_UVLOOP is True or (
self.config.USE_UVLOOP is _default and not OS_IS_WINDOWS
):
try_use_uvloop()
# if self.config.USE_UVLOOP is True or (
# self.config.USE_UVLOOP is _default and not OS_IS_WINDOWS
# ):
# try_use_uvloop()
async def create_server(
self,
@@ -395,18 +450,23 @@ class RunnerMixin(metaclass=SanicMeta):
asyncio_server_kwargs=asyncio_server_kwargs, **server_settings
)
def stop(self):
def stop(self, terminate: bool = True, unregister: bool = False):
"""
This kills the Sanic
"""
if terminate and hasattr(self, "multiplexer"):
self.multiplexer.terminate()
if self.state.stage is not ServerStage.STOPPED:
self.shutdown_tasks(timeout=0)
self.shutdown_tasks(timeout=0) # type: ignore
for task in all_tasks():
with suppress(AttributeError):
if task.get_name() == "RunServer":
task.cancel()
get_event_loop().stop()
if unregister:
self.__class__.unregister_app(self) # type: ignore
def _helper(
self,
host: Optional[str] = None,
@@ -468,7 +528,11 @@ class RunnerMixin(metaclass=SanicMeta):
self.motd(server_settings=server_settings)
if is_atty() and not self.state.is_debug:
if (
is_atty()
and not self.state.is_debug
and not os.environ.get("SANIC_IGNORE_PRODUCTION_WARNING")
):
error_logger.warning(
f"{Colors.YELLOW}Sanic is running in PRODUCTION mode. "
"Consider using '--debug' or '--dev' while actively "
@@ -497,6 +561,13 @@ class RunnerMixin(metaclass=SanicMeta):
serve_location: str = "",
server_settings: Optional[Dict[str, Any]] = None,
):
if (
os.environ.get("SANIC_WORKER_NAME")
or os.environ.get("SANIC_MOTD_OUTPUT")
or os.environ.get("SANIC_WORKER_PROCESS")
or os.environ.get("SANIC_SERVER_RUNNING")
):
return
if serve_location:
deprecation(
"Specifying a serve_location in the MOTD is deprecated and "
@@ -506,73 +577,83 @@ class RunnerMixin(metaclass=SanicMeta):
else:
serve_location = self.get_server_location(server_settings)
if self.config.MOTD:
mode = [f"{self.state.mode},"]
if self.state.fast:
mode.append("goin' fast")
if self.state.asgi:
mode.append("ASGI")
else:
if self.state.workers == 1:
mode.append("single worker")
else:
mode.append(f"w/ {self.state.workers} workers")
if server_settings:
server = ", ".join(
(
self.state.server,
server_settings["version"].display(), # type: ignore
)
)
else:
server = ""
display = {
"mode": " ".join(mode),
"server": server,
"python": platform.python_version(),
"platform": platform.platform(),
}
extra = {}
if self.config.AUTO_RELOAD:
reload_display = "enabled"
if self.state.reload_dirs:
reload_display += ", ".join(
[
"",
*(
str(path.absolute())
for path in self.state.reload_dirs
),
]
)
display["auto-reload"] = reload_display
packages = []
for package_name in SANIC_PACKAGES:
module_name = package_name.replace("-", "_")
try:
module = import_module(module_name)
packages.append(
f"{package_name}=={module.__version__}" # type: ignore
)
except ImportError:
...
if packages:
display["packages"] = ", ".join(packages)
if self.config.MOTD_DISPLAY:
extra.update(self.config.MOTD_DISPLAY)
logo = get_logo(coffee=self.state.coffee)
display, extra = self.get_motd_data(server_settings)
MOTD.output(logo, serve_location, display, extra)
def get_motd_data(
self, server_settings: Optional[Dict[str, Any]] = None
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
mode = [f"{self.state.mode},"]
if self.state.fast:
mode.append("goin' fast")
if self.state.asgi:
mode.append("ASGI")
else:
if self.state.workers == 1:
mode.append("single worker")
else:
mode.append(f"w/ {self.state.workers} workers")
if server_settings:
server = ", ".join(
(
self.state.server,
server_settings["version"].display(), # type: ignore
)
)
else:
server = "ASGI" if self.asgi else "unknown" # type: ignore
display = {
"mode": " ".join(mode),
"server": server,
"python": platform.python_version(),
"platform": platform.platform(),
}
extra = {}
if self.config.AUTO_RELOAD:
reload_display = "enabled"
if self.state.reload_dirs:
reload_display += ", ".join(
[
"",
*(
str(path.absolute())
for path in self.state.reload_dirs
),
]
)
display["auto-reload"] = reload_display
packages = []
for package_name in SANIC_PACKAGES:
module_name = package_name.replace("-", "_")
try:
module = import_module(module_name)
packages.append(
f"{package_name}=={module.__version__}" # type: ignore
)
except ImportError: # no cov
...
if packages:
display["packages"] = ", ".join(packages)
if self.config.MOTD_DISPLAY:
extra.update(self.config.MOTD_DISPLAY)
return display, extra
@property
def serve_location(self) -> str:
server_settings = self.state.server_info[0].settings
return self.get_server_location(server_settings)
try:
server_settings = self.state.server_info[0].settings
return self.get_server_location(server_settings)
except IndexError:
location = "ASGI" if self.asgi else "unknown" # type: ignore
return f"http://<{location}>"
@staticmethod
def get_server_location(
@@ -583,24 +664,20 @@ class RunnerMixin(metaclass=SanicMeta):
if not server_settings:
return serve_location
if server_settings["ssl"] is not None:
host = server_settings["host"]
port = server_settings["port"]
if server_settings.get("ssl") is not None:
proto = "https"
if server_settings["unix"]:
if server_settings.get("unix"):
serve_location = f'{server_settings["unix"]} {proto}://...'
elif server_settings["sock"]:
serve_location = (
f'{server_settings["sock"].getsockname()} {proto}://...'
)
elif server_settings["host"] and server_settings["port"]:
elif server_settings.get("sock"):
host, port, *_ = server_settings["sock"].getsockname()
if not serve_location and host and port:
# colon(:) is legal for a host only in an ipv6 address
display_host = (
f'[{server_settings["host"]}]'
if ":" in server_settings["host"]
else server_settings["host"]
)
serve_location = (
f'{proto}://{display_host}:{server_settings["port"]}'
)
display_host = f"[{host}]" if ":" in host else host
serve_location = f"{proto}://{display_host}:{port}"
return serve_location
@@ -620,7 +697,252 @@ class RunnerMixin(metaclass=SanicMeta):
return any(app.state.auto_reload for app in cls._app_registry.values())
@classmethod
def serve(cls, primary: Optional[Sanic] = None) -> None:
def _get_context(cls) -> BaseContext:
method = (
"spawn"
if "linux" not in sys.platform or cls.should_auto_reload()
else "fork"
)
return get_context(method)
@classmethod
def serve(
cls,
primary: Optional[Sanic] = None,
*,
app_loader: Optional[AppLoader] = None,
factory: Optional[Callable[[], Sanic]] = None,
) -> None:
os.environ["SANIC_MOTD_OUTPUT"] = "true"
apps = list(cls._app_registry.values())
if factory:
primary = factory()
else:
if not primary:
if app_loader:
primary = app_loader.load()
if not primary:
try:
primary = apps[0]
except IndexError:
raise RuntimeError(
"Did not find any applications."
) from None
# This exists primarily for unit testing
if not primary.state.server_info: # no cov
for app in apps:
app.state.server_info.clear()
return
try:
primary_server_info = primary.state.server_info[0]
except IndexError:
raise RuntimeError(
f"No server information found for {primary.name}. Perhaps you "
"need to run app.prepare(...)?\n"
"See ____ for more information."
) from None
socks = []
sync_manager = Manager()
try:
main_start = primary_server_info.settings.pop("main_start", None)
main_stop = primary_server_info.settings.pop("main_stop", None)
app = primary_server_info.settings.pop("app")
app.setup_loop()
loop = new_event_loop()
trigger_events(main_start, loop, primary)
socks = [
sock
for sock in [
configure_socket(server_info.settings)
for app in apps
for server_info in app.state.server_info
]
if sock
]
primary_server_info.settings["run_multiple"] = True
monitor_sub, monitor_pub = Pipe(True)
worker_state: Dict[str, Any] = sync_manager.dict()
kwargs: Dict[str, Any] = {
**primary_server_info.settings,
"monitor_publisher": monitor_pub,
"worker_state": worker_state,
}
if not app_loader:
if factory:
app_loader = AppLoader(factory=factory)
else:
app_loader = AppLoader(
factory=partial(cls.get_app, app.name) # type: ignore
)
kwargs["app_name"] = app.name
kwargs["app_loader"] = app_loader
kwargs["server_info"] = {}
kwargs["passthru"] = {
"auto_reload": app.auto_reload,
"state": {
"verbosity": app.state.verbosity,
"mode": app.state.mode,
},
"config": {
"ACCESS_LOG": app.config.ACCESS_LOG,
"NOISY_EXCEPTIONS": app.config.NOISY_EXCEPTIONS,
},
"shared_ctx": app.shared_ctx.__dict__,
}
for app in apps:
kwargs["server_info"][app.name] = []
for server_info in app.state.server_info:
server_info.settings = {
k: v
for k, v in server_info.settings.items()
if k not in ("main_start", "main_stop", "app", "ssl")
}
kwargs["server_info"][app.name].append(server_info)
ssl = kwargs.get("ssl")
if isinstance(ssl, SanicSSLContext):
kwargs["ssl"] = kwargs["ssl"].sanic
manager = WorkerManager(
primary.state.workers,
worker_serve,
kwargs,
cls._get_context(),
(monitor_pub, monitor_sub),
worker_state,
)
if cls.should_auto_reload():
reload_dirs: Set[Path] = primary.state.reload_dirs.union(
*(app.state.reload_dirs for app in apps)
)
reloader = Reloader(monitor_pub, 1.0, reload_dirs, app_loader)
manager.manage("Reloader", reloader, {}, transient=False)
inspector = None
if primary.config.INSPECTOR:
display, extra = primary.get_motd_data()
packages = [
pkg.strip() for pkg in display["packages"].split(",")
]
module = import_module("sanic")
sanic_version = f"sanic=={module.__version__}" # type: ignore
app_info = {
**display,
"packages": [sanic_version, *packages],
"extra": extra,
}
inspector = Inspector(
monitor_pub,
app_info,
worker_state,
primary.config.INSPECTOR_HOST,
primary.config.INSPECTOR_PORT,
)
manager.manage("Inspector", inspector, {}, transient=False)
primary._inspector = inspector
primary._manager = manager
ready = primary.listeners["main_process_ready"]
trigger_events(ready, loop, primary)
manager.run()
except BaseException:
kwargs = primary_server_info.settings
error_logger.exception(
"Experienced exception while trying to serve"
)
raise
finally:
logger.info("Server Stopped")
for app in apps:
app.state.server_info.clear()
app.router.reset()
app.signal_router.reset()
sync_manager.shutdown()
for sock in socks:
sock.close()
socks = []
trigger_events(main_stop, loop, primary)
loop.close()
cls._cleanup_env_vars()
cls._cleanup_apps()
unix = kwargs.get("unix")
if unix:
remove_unix_socket(unix)
@classmethod
def serve_single(cls, primary: Optional[Sanic] = None) -> None:
os.environ["SANIC_MOTD_OUTPUT"] = "true"
apps = list(cls._app_registry.values())
if not primary:
try:
primary = apps[0]
except IndexError:
raise RuntimeError("Did not find any applications.")
# This exists primarily for unit testing
if not primary.state.server_info: # no cov
for app in apps:
app.state.server_info.clear()
return
primary_server_info = primary.state.server_info[0]
primary.before_server_start(partial(primary._start_servers, apps=apps))
kwargs = {
k: v
for k, v in primary_server_info.settings.items()
if k
not in (
"main_start",
"main_stop",
"app",
)
}
kwargs["app_name"] = primary.name
kwargs["app_loader"] = None
sock = configure_socket(kwargs)
kwargs["server_info"] = {}
kwargs["server_info"][primary.name] = []
for server_info in primary.state.server_info:
server_info.settings = {
k: v
for k, v in server_info.settings.items()
if k not in ("main_start", "main_stop", "app")
}
kwargs["server_info"][primary.name].append(server_info)
try:
worker_serve(monitor_publisher=None, **kwargs)
except BaseException:
error_logger.exception(
"Experienced exception while trying to serve"
)
raise
finally:
logger.info("Server Stopped")
for app in apps:
app.state.server_info.clear()
app.router.reset()
app.signal_router.reset()
if sock:
sock.close()
cls._cleanup_env_vars()
cls._cleanup_apps()
@classmethod
def serve_legacy(cls, primary: Optional[Sanic] = None) -> None:
apps = list(cls._app_registry.values())
if not primary:
@@ -641,7 +963,7 @@ class RunnerMixin(metaclass=SanicMeta):
reload_dirs: Set[Path] = primary.state.reload_dirs.union(
*(app.state.reload_dirs for app in apps)
)
reloader_helpers.watchdog(1.0, reload_dirs)
watchdog(1.0, reload_dirs)
trigger_events(reloader_stop, loop, primary)
return
@@ -654,11 +976,17 @@ class RunnerMixin(metaclass=SanicMeta):
primary_server_info = primary.state.server_info[0]
primary.before_server_start(partial(primary._start_servers, apps=apps))
deprecation(
f"{Colors.YELLOW}Running {Colors.SANIC}Sanic {Colors.YELLOW}w/ "
f"LEGACY manager.{Colors.END} Support for will be dropped in "
"version 23.3.",
23.3,
)
try:
primary_server_info.stage = ServerStage.SERVING
if primary.state.workers > 1 and os.name != "posix": # no cov
logger.warn(
logger.warning(
f"Multiprocessing is currently not supported on {os.name},"
" using workers=1 instead"
)
@@ -679,10 +1007,9 @@ class RunnerMixin(metaclass=SanicMeta):
finally:
primary_server_info.stage = ServerStage.STOPPED
logger.info("Server Stopped")
for app in apps:
app.state.server_info.clear()
app.router.reset()
app.signal_router.reset()
cls._cleanup_env_vars()
cls._cleanup_apps()
async def _start_servers(
self,
@@ -720,7 +1047,7 @@ class RunnerMixin(metaclass=SanicMeta):
*server_info.settings.pop("main_start", []),
*server_info.settings.pop("main_stop", []),
]
if handlers:
if handlers: # no cov
error_logger.warning(
f"Sanic found {len(handlers)} listener(s) on "
"secondary applications attached to the main "
@@ -733,12 +1060,15 @@ class RunnerMixin(metaclass=SanicMeta):
if not server_info.settings["loop"]:
server_info.settings["loop"] = get_running_loop()
serve_args: Dict[str, Any] = {
**server_info.settings,
"run_async": True,
"reuse_port": bool(primary.state.workers - 1),
}
if "app" not in serve_args:
serve_args["app"] = app
try:
server_info.server = await serve(
**server_info.settings,
run_async=True,
reuse_port=bool(primary.state.workers - 1),
)
server_info.server = await serve(**serve_args)
except OSError as e: # no cov
first_message = (
"An OSError was detected on startup. "
@@ -764,9 +1094,9 @@ class RunnerMixin(metaclass=SanicMeta):
async def _run_server(
self,
app: RunnerMixin,
app: StartupMixin,
server_info: ApplicationServerInfo,
) -> None:
) -> None: # no cov
try:
# We should never get to this point without a server
@@ -790,3 +1120,26 @@ class RunnerMixin(metaclass=SanicMeta):
finally:
server_info.stage = ServerStage.STOPPED
server_info.server = None
@staticmethod
def _cleanup_env_vars():
variables = (
"SANIC_RELOADER_PROCESS",
"SANIC_IGNORE_PRODUCTION_WARNING",
"SANIC_WORKER_NAME",
"SANIC_MOTD_OUTPUT",
"SANIC_WORKER_PROCESS",
"SANIC_SERVER_RUNNING",
)
for var in variables:
try:
del os.environ[var]
except KeyError:
...
@classmethod
def _cleanup_apps(cls):
for app in cls._app_registry.values():
app.state.server_info.clear()
app.router.reset()
app.signal_router.reset()

View File

@@ -21,6 +21,7 @@ class ConnInfo:
"client",
"client_ip",
"ctx",
"lost",
"peername",
"server_port",
"server",
@@ -33,6 +34,7 @@ class ConnInfo:
def __init__(self, transport: TransportProtocol, unix=None):
self.ctx = SimpleNamespace()
self.lost = False
self.peername = None
self.server = self.client = ""
self.server_port = self.client_port = 0

View File

@@ -1,6 +1,7 @@
from __future__ import annotations
from contextvars import ContextVar
from functools import partial
from inspect import isawaitable
from typing import (
TYPE_CHECKING,
@@ -23,6 +24,7 @@ from sanic.models.http_types import Credentials
if TYPE_CHECKING:
from sanic.handlers import RequestManager
from sanic.server import ConnInfo
from sanic.app import Sanic
@@ -37,8 +39,13 @@ from urllib.parse import parse_qs, parse_qsl, unquote, urlunparse
from httptools import parse_url
from httptools.parser.errors import HttpParserInvalidURLError
from sanic.compat import CancelledErrors, Header
from sanic.constants import DEFAULT_HTTP_CONTENT_TYPE
from sanic.compat import Header
from sanic.constants import (
CACHEABLE_HTTP_METHODS,
DEFAULT_HTTP_CONTENT_TYPE,
IDEMPOTENT_HTTP_METHODS,
SAFE_HTTP_METHODS,
)
from sanic.exceptions import BadRequest, BadURL, ServerError
from sanic.headers import (
AcceptContainer,
@@ -51,7 +58,7 @@ from sanic.headers import (
parse_xforwarded,
)
from sanic.http import Stage
from sanic.log import error_logger, logger
from sanic.log import deprecation, error_logger, logger
from sanic.models.protocol_types import TransportProtocol
from sanic.response import BaseHTTPResponse, HTTPResponse
@@ -94,10 +101,12 @@ class Request:
"_cookies",
"_id",
"_ip",
"_manager",
"_parsed_url",
"_port",
"_protocol",
"_remote_addr",
"_request_middleware_started",
"_scheme",
"_socket",
"_stream_id",
@@ -121,7 +130,6 @@ class Request:
"parsed_token",
"raw_url",
"responded",
"request_middleware_started",
"route",
"stream",
"transport",
@@ -173,10 +181,11 @@ class Request:
self.parsed_not_grouped_args: DefaultDict[
Tuple[bool, bool, str, str], List[Tuple[str, str]]
] = defaultdict(list)
self.request_middleware_started = False
self._request_middleware_started = False
self.responded: bool = False
self.route: Optional[Route] = None
self.stream: Optional[Stream] = None
self._manager: Optional[RequestManager] = None
self._cookies: Optional[Dict[str, str]] = None
self._match_info: Dict[str, Any] = {}
self._protocol = None
@@ -188,7 +197,7 @@ class Request:
@classmethod
def get_current(cls) -> Request:
"""
Retrieve the currrent request object
Retrieve the current request object
This implements `Context Variables
<https://docs.python.org/3/library/contextvars.html>`_
@@ -214,6 +223,16 @@ class Request:
def generate_id(*_):
return uuid.uuid4()
@property
def request_middleware_started(self):
deprecation(
"Request.request_middleware_started has been deprecated and will"
"be removed. You should set a flag on the request context using"
"either middleware or signals if you need this feature.",
22.3,
)
return self._request_middleware_started
@property
def stream_id(self):
"""
@@ -228,6 +247,10 @@ class Request:
)
return self._stream_id
@property
def manager(self):
return self._manager
def reset_response(self):
try:
if (
@@ -318,15 +341,13 @@ class Request:
if isawaitable(response):
response = await response # type: ignore
# Run response middleware
try:
response = await self.app._run_response_middleware(
self, response, request_name=self.name
)
except CancelledErrors:
raise
except Exception:
error_logger.exception(
"Exception occurred in one of response middleware handlers"
if (
self._manager
and not self._manager.response_middleware_run
and self._manager.response_middleware
):
response = await self._manager.run(
partial(self._manager.run_response_middleware, response)
)
self.responded = True
return response
@@ -975,6 +996,33 @@ class Request:
return self.transport.scope
@property
def is_safe(self) -> bool:
"""
:return: Whether the HTTP method is safe.
See https://datatracker.ietf.org/doc/html/rfc7231#section-4.2.1
:rtype: bool
"""
return self.method in SAFE_HTTP_METHODS
@property
def is_idempotent(self) -> bool:
"""
:return: Whether the HTTP method is iempotent.
See https://datatracker.ietf.org/doc/html/rfc7231#section-4.2.2
:rtype: bool
"""
return self.method in IDEMPOTENT_HTTP_METHODS
@property
def is_cacheable(self) -> bool:
"""
:return: Whether the HTTP method is cacheable.
See https://datatracker.ietf.org/doc/html/rfc7231#section-4.2.3
:rtype: bool
"""
return self.method in CACHEABLE_HTTP_METHODS
class File(NamedTuple):
"""

View File

@@ -1,7 +1,7 @@
from __future__ import annotations
from datetime import datetime
from email.utils import formatdate
from datetime import datetime, timezone
from email.utils import formatdate, parsedate_to_datetime
from functools import partial
from mimetypes import guess_type
from os import path
@@ -33,6 +33,7 @@ from sanic.helpers import (
remove_entity_headers,
)
from sanic.http import Http
from sanic.log import logger
from sanic.models.protocol_types import HTMLProtocol, Range
@@ -210,7 +211,7 @@ class HTTPResponse(BaseHTTPResponse):
def empty(
status=204, headers: Optional[Dict[str, str]] = None
status: int = 204, headers: Optional[Dict[str, str]] = None
) -> HTTPResponse:
"""
Returns an empty response to the client.
@@ -227,7 +228,7 @@ def json(
headers: Optional[Dict[str, str]] = None,
content_type: str = "application/json",
dumps: Optional[Callable[..., str]] = None,
**kwargs,
**kwargs: Any,
) -> HTTPResponse:
"""
Returns response object with body in json format.
@@ -319,9 +320,34 @@ def html(
)
async def validate_file(
request_headers: Header, last_modified: Union[datetime, float, int]
):
try:
if_modified_since = request_headers.getone("If-Modified-Since")
except KeyError:
return
try:
if_modified_since = parsedate_to_datetime(if_modified_since)
except (TypeError, ValueError):
logger.warning(
"Ignorning invalid If-Modified-Since header received: " "'%s'",
if_modified_since,
)
return
if not isinstance(last_modified, datetime):
last_modified = datetime.fromtimestamp(
float(last_modified), tz=timezone.utc
).replace(microsecond=0)
if last_modified <= if_modified_since:
return HTTPResponse(status=304)
async def file(
location: Union[str, PurePath],
status: int = 200,
request_headers: Optional[Header] = None,
validate_when_requested: bool = True,
mime_type: Optional[str] = None,
headers: Optional[Dict[str, str]] = None,
filename: Optional[str] = None,
@@ -331,7 +357,12 @@ async def file(
_range: Optional[Range] = None,
) -> HTTPResponse:
"""Return a response object with file data.
:param status: HTTP response code. Won't enforce the passed in
status if only a part of the content will be sent (206)
or file is being validated (304).
:param request_headers: The request headers.
:param validate_when_requested: If True, will validate the
file when requested.
:param location: Location of file on system.
:param mime_type: Specific mime_type.
:param headers: Custom Headers.
@@ -341,11 +372,6 @@ async def file(
:param no_store: Any cache should not store this response.
:param _range:
"""
headers = headers or {}
if filename:
headers.setdefault(
"Content-Disposition", f'attachment; filename="{filename}"'
)
if isinstance(last_modified, datetime):
last_modified = last_modified.replace(microsecond=0).timestamp()
@@ -353,9 +379,24 @@ async def file(
stat = await stat_async(location)
last_modified = stat.st_mtime
if (
validate_when_requested
and request_headers is not None
and last_modified
):
response = await validate_file(request_headers, last_modified)
if response:
return response
headers = headers or {}
if last_modified:
headers.setdefault(
"last-modified", formatdate(last_modified, usegmt=True)
"Last-Modified", formatdate(last_modified, usegmt=True)
)
if filename:
headers.setdefault(
"Content-Disposition", f'attachment; filename="{filename}"'
)
if no_store:

View File

@@ -13,6 +13,7 @@ from sanic_routing.route import Route
from sanic.constants import HTTP_METHODS
from sanic.errorpages import check_error_format
from sanic.exceptions import MethodNotAllowed, NotFound, SanicException
from sanic.handlers import RequestHandler
from sanic.models.handler_types import RouteHandler
@@ -31,9 +32,11 @@ class Router(BaseRouter):
def _get(
self, path: str, method: str, host: Optional[str]
) -> Tuple[Route, RouteHandler, Dict[str, Any]]:
) -> Tuple[Route, RequestHandler, Dict[str, Any]]:
try:
return self.resolve(
# We know this will always be RequestHandler, so we can ignore
# typing issue here
return self.resolve( # type: ignore
path=path,
method=method,
extra={"host": host} if host else None,
@@ -50,7 +53,7 @@ class Router(BaseRouter):
@lru_cache(maxsize=ROUTER_CACHE_SIZE)
def get( # type: ignore
self, path: str, method: str, host: Optional[str]
) -> Tuple[Route, RouteHandler, Dict[str, Any]]:
) -> Tuple[Route, RequestHandler, Dict[str, Any]]:
"""
Retrieve a `Route` object containing the details about how to handle
a response for a given request
@@ -59,7 +62,7 @@ class Router(BaseRouter):
:type request: Request
:return: details needed for handling the request and returning the
correct response
:rtype: Tuple[ Route, RouteHandler, Dict[str, Any]]
:rtype: Tuple[ Route, RequestHandler, Dict[str, Any]]
"""
return self._get(path, method, host)
@@ -114,7 +117,7 @@ class Router(BaseRouter):
params = dict(
path=uri,
handler=handler,
handler=RequestHandler(handler, [], []),
methods=frozenset(map(str, methods)) if methods else None,
name=name,
strict=strict_slashes,

View File

@@ -9,7 +9,6 @@ from time import sleep
def _iter_module_files():
"""This iterates over all relevant Python files.
It goes through all
loaded files from modules, all files in folders of already loaded modules
as well as all files reachable through a package.
@@ -52,7 +51,7 @@ def restart_with_reloader(changed=None):
this one.
"""
reloaded = ",".join(changed) if changed else ""
return subprocess.Popen(
return subprocess.Popen( # nosec B603
_get_args_for_reloading(),
env={
**os.environ,
@@ -79,7 +78,6 @@ def _check_file(filename, mtimes):
def watchdog(sleep_interval, reload_dirs):
"""Watch project files, restart worker process if a change happened.
:param sleep_interval: interval in second.
:return: Nothing
"""

View File

@@ -1,4 +1,5 @@
import asyncio
import sys
from distutils.util import strtobool
from os import getenv
@@ -47,3 +48,19 @@ def try_use_uvloop() -> None:
if not isinstance(asyncio.get_event_loop_policy(), uvloop.EventLoopPolicy):
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
def try_windows_loop():
if not OS_IS_WINDOWS:
error_logger.warning(
"You are trying to use an event loop policy that is not "
"compatible with your system. You can simply let Sanic handle "
"selecting the best loop for you. Sanic will now continue to run "
"using the default event loop."
)
return
if sys.version_info >= (3, 8) and not isinstance(
asyncio.get_event_loop_policy(), asyncio.WindowsSelectorEventLoopPolicy
):
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())

View File

@@ -2,13 +2,14 @@ from __future__ import annotations
from typing import TYPE_CHECKING, Optional
from sanic.exceptions import RequestCancelled
if TYPE_CHECKING:
from sanic.app import Sanic
import asyncio
from asyncio import CancelledError
from asyncio.transports import Transport
from time import monotonic as current_time
@@ -69,7 +70,7 @@ class SanicProtocol(asyncio.Protocol):
"""
await self._can_write.wait()
if self.transport.is_closing():
raise CancelledError
raise RequestCancelled
self.transport.write(data)
self._time = current_time()
@@ -120,6 +121,7 @@ class SanicProtocol(asyncio.Protocol):
try:
self.connections.discard(self)
self.resume_writing()
self.conn_info.lost = True
if self._task:
self._task.cancel()
except BaseException:

View File

@@ -2,6 +2,7 @@ from __future__ import annotations
from typing import TYPE_CHECKING, Optional
from sanic.handlers import RequestManager
from sanic.http.constants import HTTP
from sanic.http.http3 import Http3
from sanic.touchup.meta import TouchUpMeta
@@ -15,7 +16,11 @@ import sys
from asyncio import CancelledError
from time import monotonic as current_time
from sanic.exceptions import RequestTimeout, ServiceUnavailable
from sanic.exceptions import (
RequestCancelled,
RequestTimeout,
ServiceUnavailable,
)
from sanic.http import Http, Stage
from sanic.log import Colors, error_logger, logger
from sanic.models.server_types import ConnInfo
@@ -53,7 +58,7 @@ class HttpProtocolMixin:
def _setup(self):
self.request: Optional[Request] = None
self.access_log = self.app.config.ACCESS_LOG
self.request_handler = self.app.handle_request
self.request_handler = RequestManager
self.error_handler = self.app.error_handler
self.request_timeout = self.app.config.REQUEST_TIMEOUT
self.response_timeout = self.app.config.RESPONSE_TIMEOUT
@@ -225,7 +230,7 @@ class HttpProtocol(HttpProtocolMixin, SanicProtocol, metaclass=TouchUpMeta):
"""
await self._can_write.wait()
if self.transport.is_closing():
raise CancelledError
raise RequestCancelled
await self.app.dispatch(
"http.lifecycle.send",
inline=True,
@@ -265,7 +270,6 @@ class HttpProtocol(HttpProtocolMixin, SanicProtocol, metaclass=TouchUpMeta):
error_logger.exception("protocol.connect_made")
def data_received(self, data: bytes):
try:
self._time = current_time()
if not data:

View File

@@ -129,7 +129,7 @@ def _setup_system_signals(
run_multiple: bool,
register_sys_signals: bool,
loop: asyncio.AbstractEventLoop,
) -> None:
) -> None: # no cov
# Ignore SIGINT when run_multiple
if run_multiple:
signal_func(SIGINT, SIG_IGN)
@@ -141,7 +141,9 @@ def _setup_system_signals(
ctrlc_workaround_for_windows(app)
else:
for _signal in [SIGTERM] if run_multiple else [SIGINT, SIGTERM]:
loop.add_signal_handler(_signal, app.stop)
loop.add_signal_handler(
_signal, partial(app.stop, terminate=False)
)
def _run_server_forever(loop, before_stop, after_stop, cleanup, unix):
@@ -161,6 +163,7 @@ def _run_server_forever(loop, before_stop, after_stop, cleanup, unix):
loop.run_until_complete(after_stop())
remove_unix_socket(unix)
loop.close()
def _serve_http_1(
@@ -197,8 +200,12 @@ def _serve_http_1(
asyncio_server_kwargs = (
asyncio_server_kwargs if asyncio_server_kwargs else {}
)
if OS_IS_WINDOWS:
pid = os.getpid()
sock = sock.share(pid)
sock = socket.fromshare(sock)
# UNIX sockets are always bound by us (to preserve semantics between modes)
if unix:
elif unix:
sock = bind_unix_socket(unix, backlog=backlog)
server_coroutine = loop.create_server(
server,

View File

@@ -6,7 +6,10 @@ import socket
import stat
from ipaddress import ip_address
from typing import Optional
from typing import Any, Dict, Optional
from sanic.exceptions import ServerError
from sanic.http.constants import HTTP
def bind_socket(host: str, port: int, *, backlog=100) -> socket.socket:
@@ -16,6 +19,8 @@ def bind_socket(host: str, port: int, *, backlog=100) -> socket.socket:
:param backlog: Maximum number of connections to queue
:return: socket.socket object
"""
location = (host, port)
# socket.share, socket.fromshare
try: # IP address: family must be specified for IPv6 at least
ip = ip_address(host)
host = str(ip)
@@ -25,8 +30,9 @@ def bind_socket(host: str, port: int, *, backlog=100) -> socket.socket:
except ValueError: # Hostname, may become AF_INET or AF_INET6
sock = socket.socket()
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
sock.bind((host, port))
sock.bind(location)
sock.listen(backlog)
sock.set_inheritable(True)
return sock
@@ -36,7 +42,7 @@ def bind_unix_socket(path: str, *, mode=0o666, backlog=100) -> socket.socket:
:param backlog: Maximum number of connections to queue
:return: socket.socket object
"""
"""Open or atomically replace existing socket with zero downtime."""
# Sanitise and pre-verify socket path
path = os.path.abspath(path)
folder = os.path.dirname(path)
@@ -85,3 +91,37 @@ def remove_unix_socket(path: Optional[str]) -> None:
os.unlink(path)
except FileNotFoundError:
pass
def configure_socket(
server_settings: Dict[str, Any]
) -> Optional[socket.SocketType]:
# Create a listening socket or use the one in settings
if server_settings.get("version") is HTTP.VERSION_3:
return None
sock = server_settings.get("sock")
unix = server_settings["unix"]
backlog = server_settings["backlog"]
if unix:
sock = bind_unix_socket(unix, backlog=backlog)
server_settings["unix"] = unix
if sock is None:
try:
sock = bind_socket(
server_settings["host"],
server_settings["port"],
backlog=backlog,
)
except OSError as e: # no cov
raise ServerError(
f"Sanic server could not start: {e}.\n"
"This may have happened if you are running Sanic in the "
"global scope and not inside of a "
'`if __name__ == "__main__"` block. See more information: '
"____."
) from e
sock.set_inheritable(True)
server_settings["sock"] = sock
server_settings["host"] = None
server_settings["port"] = None
return sock

View File

@@ -252,7 +252,7 @@ class WebsocketImplProtocol:
def _force_disconnect(self) -> bool:
"""
Internal methdod used by end_connection and fail_connection
Internal method used by end_connection and fail_connection
only when the graceful auto-closer cannot be used
"""
if self.auto_closer_task and not self.auto_closer_task.done():

View File

@@ -30,6 +30,8 @@ class Event(Enum):
HTTP_LIFECYCLE_RESPONSE = "http.lifecycle.response"
HTTP_ROUTING_AFTER = "http.routing.after"
HTTP_ROUTING_BEFORE = "http.routing.before"
HTTP_HANDLER_AFTER = "http.handler.after"
HTTP_HANDLER_BEFORE = "http.handler.before"
HTTP_LIFECYCLE_SEND = "http.lifecycle.send"
HTTP_MIDDLEWARE_AFTER = "http.middleware.after"
HTTP_MIDDLEWARE_BEFORE = "http.middleware.before"
@@ -53,6 +55,8 @@ RESERVED_NAMESPACES = {
Event.HTTP_LIFECYCLE_RESPONSE.value,
Event.HTTP_ROUTING_AFTER.value,
Event.HTTP_ROUTING_BEFORE.value,
Event.HTTP_HANDLER_AFTER.value,
Event.HTTP_HANDLER_BEFORE.value,
Event.HTTP_LIFECYCLE_SEND.value,
Event.HTTP_MIDDLEWARE_AFTER.value,
Event.HTTP_MIDDLEWARE_BEFORE.value,
@@ -150,13 +154,13 @@ class SignalRouter(BaseRouter):
try:
for signal in signals:
params.pop("__trigger__", None)
requirements = getattr(
signal.handler, "__requirements__", None
)
if (
(condition is None and signal.ctx.exclusive is False)
or (
condition is None
and not signal.handler.__requirements__
)
or (condition == signal.handler.__requirements__)
or (condition is None and not requirements)
or (condition == requirements)
) and (signal.ctx.trigger or event == signal.ctx.definition):
maybe_coroutine = signal.handler(**params)
if isawaitable(maybe_coroutine):
@@ -187,7 +191,7 @@ class SignalRouter(BaseRouter):
fail_not_found=fail_not_found and inline,
reverse=reverse,
)
logger.debug(f"Dispatching signal: {event}")
logger.debug(f"Dispatching signal: {event}", extra={"verbosity": 1})
if inline:
return await dispatch

55
sanic/types/shared_ctx.py Normal file
View File

@@ -0,0 +1,55 @@
import os
from types import SimpleNamespace
from typing import Any, Iterable
from sanic.log import Colors, error_logger
class SharedContext(SimpleNamespace):
SAFE = ("_lock",)
def __init__(self, **kwargs: Any) -> None:
super().__init__(**kwargs)
self._lock = False
def __setattr__(self, name: str, value: Any) -> None:
if self.is_locked:
raise RuntimeError(
f"Cannot set {name} on locked SharedContext object"
)
if not os.environ.get("SANIC_WORKER_NAME"):
to_check: Iterable[Any]
if not isinstance(value, (tuple, frozenset)):
to_check = [value]
else:
to_check = value
for item in to_check:
self._check(name, item)
super().__setattr__(name, value)
def _check(self, name: str, value: Any) -> None:
if name in self.SAFE:
return
try:
module = value.__module__
except AttributeError:
module = ""
if not any(
module.startswith(prefix)
for prefix in ("multiprocessing", "ctypes")
):
error_logger.warning(
f"{Colors.YELLOW}Unsafe object {Colors.PURPLE}{name} "
f"{Colors.YELLOW}with type {Colors.PURPLE}{type(value)} "
f"{Colors.YELLOW}was added to shared_ctx. It may not "
"not function as intended. Consider using the regular "
f"ctx. For more information, please see ____.{Colors.END}"
)
@property
def is_locked(self) -> bool:
return getattr(self, "_lock", False)
def lock(self) -> None:
self._lock = True

0
sanic/worker/__init__.py Normal file
View File

141
sanic/worker/inspector.py Normal file
View File

@@ -0,0 +1,141 @@
import sys
from datetime import datetime
from multiprocessing.connection import Connection
from signal import SIGINT, SIGTERM
from signal import signal as signal_func
from socket import AF_INET, SOCK_STREAM, socket, timeout
from textwrap import indent
from typing import Any, Dict
from sanic.application.logo import get_logo
from sanic.application.motd import MOTDTTY
from sanic.log import Colors, error_logger, logger
from sanic.server.socket import configure_socket
try: # no cov
from ujson import dumps, loads
except ModuleNotFoundError: # no cov
from json import dumps, loads # type: ignore
class Inspector:
def __init__(
self,
publisher: Connection,
app_info: Dict[str, Any],
worker_state: Dict[str, Any],
host: str,
port: int,
):
self._publisher = publisher
self.run = True
self.app_info = app_info
self.worker_state = worker_state
self.host = host
self.port = port
def __call__(self) -> None:
sock = configure_socket(
{"host": self.host, "port": self.port, "unix": None, "backlog": 1}
)
assert sock
signal_func(SIGINT, self.stop)
signal_func(SIGTERM, self.stop)
logger.info(f"Inspector started on: {sock.getsockname()}")
sock.settimeout(0.5)
try:
while self.run:
try:
conn, _ = sock.accept()
except timeout:
continue
else:
action = conn.recv(64)
if action == b"reload":
conn.send(b"\n")
self.reload()
elif action == b"shutdown":
conn.send(b"\n")
self.shutdown()
else:
data = dumps(self.state_to_json())
conn.send(data.encode())
conn.close()
finally:
logger.debug("Inspector closing")
sock.close()
def stop(self, *_):
self.run = False
def state_to_json(self):
output = {"info": self.app_info}
output["workers"] = self._make_safe(dict(self.worker_state))
return output
def reload(self):
message = "__ALL_PROCESSES__:"
self._publisher.send(message)
def shutdown(self):
message = "__TERMINATE__"
self._publisher.send(message)
def _make_safe(self, obj: Dict[str, Any]) -> Dict[str, Any]:
for key, value in obj.items():
if isinstance(value, dict):
obj[key] = self._make_safe(value)
elif isinstance(value, datetime):
obj[key] = value.isoformat()
return obj
def inspect(host: str, port: int, action: str):
out = sys.stdout.write
with socket(AF_INET, SOCK_STREAM) as sock:
try:
sock.connect((host, port))
except ConnectionRefusedError:
error_logger.error(
f"{Colors.RED}Could not connect to inspector at: "
f"{Colors.YELLOW}{(host, port)}{Colors.END}\n"
"Either the application is not running, or it did not start "
"an inspector instance."
)
sock.close()
sys.exit(1)
sock.sendall(action.encode())
data = sock.recv(4096)
if action == "raw":
out(data.decode())
elif action == "pretty":
loaded = loads(data)
display = loaded.pop("info")
extra = display.pop("extra", {})
display["packages"] = ", ".join(display["packages"])
MOTDTTY(get_logo(), f"{host}:{port}", display, extra).display(
version=False,
action="Inspecting",
out=out,
)
for name, info in loaded["workers"].items():
info = "\n".join(
f"\t{key}: {Colors.BLUE}{value}{Colors.END}"
for key, value in info.items()
)
out(
"\n"
+ indent(
"\n".join(
[
f"{Colors.BOLD}{Colors.SANIC}{name}{Colors.END}",
info,
]
),
" ",
)
+ "\n"
)

126
sanic/worker/loader.py Normal file
View File

@@ -0,0 +1,126 @@
from __future__ import annotations
import os
import sys
from importlib import import_module
from pathlib import Path
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Optional,
Type,
Union,
cast,
)
from sanic.http.tls.creators import CertCreator, MkcertCreator, TrustmeCreator
if TYPE_CHECKING:
from sanic import Sanic as SanicApp
class AppLoader:
def __init__(
self,
module_input: str = "",
as_factory: bool = False,
as_simple: bool = False,
args: Any = None,
factory: Optional[Callable[[], SanicApp]] = None,
) -> None:
self.module_input = module_input
self.module_name = ""
self.app_name = ""
self.as_factory = as_factory
self.as_simple = as_simple
self.args = args
self.factory = factory
self.cwd = os.getcwd()
if module_input:
delimiter = ":" if ":" in module_input else "."
if module_input.count(delimiter):
module_name, app_name = module_input.rsplit(delimiter, 1)
self.module_name = module_name
self.app_name = app_name
if self.app_name.endswith("()"):
self.as_factory = True
self.app_name = self.app_name[:-2]
def load(self) -> SanicApp:
module_path = os.path.abspath(self.cwd)
if module_path not in sys.path:
sys.path.append(module_path)
if self.factory:
return self.factory()
else:
from sanic.app import Sanic
from sanic.simple import create_simple_server
if self.as_simple:
path = Path(self.module_input)
app = create_simple_server(path)
else:
if self.module_name == "" and os.path.isdir(self.module_input):
raise ValueError(
"App not found.\n"
" Please use --simple if you are passing a "
"directory to sanic.\n"
f" eg. sanic {self.module_input} --simple"
)
module = import_module(self.module_name)
app = getattr(module, self.app_name, None)
if self.as_factory:
try:
app = app(self.args)
except TypeError:
app = app()
app_type_name = type(app).__name__
if (
not isinstance(app, Sanic)
and self.args
and hasattr(self.args, "module")
):
if callable(app):
solution = f"sanic {self.args.module} --factory"
raise ValueError(
"Module is not a Sanic app, it is a "
f"{app_type_name}\n"
" If this callable returns a "
f"Sanic instance try: \n{solution}"
)
raise ValueError(
f"Module is not a Sanic app, it is a {app_type_name}\n"
f" Perhaps you meant {self.args.module}:app?"
)
return app
class CertLoader:
_creator_class: Type[CertCreator]
def __init__(self, ssl_data: Dict[str, Union[str, os.PathLike]]):
creator_name = ssl_data.get("creator")
if creator_name not in ("mkcert", "trustme"):
raise RuntimeError(f"Unknown certificate creator: {creator_name}")
elif creator_name == "mkcert":
self._creator_class = MkcertCreator
elif creator_name == "trustme":
self._creator_class = TrustmeCreator
self._key = ssl_data["key"]
self._cert = ssl_data["cert"]
self._localhost = cast(str, ssl_data["localhost"])
def load(self, app: SanicApp):
creator = self._creator_class(app, self._key, self._cert)
return creator.generate_cert(self._localhost)

181
sanic/worker/manager.py Normal file
View File

@@ -0,0 +1,181 @@
import os
import sys
from signal import SIGINT, SIGTERM, Signals
from signal import signal as signal_func
from time import sleep
from typing import List, Optional
from sanic.compat import OS_IS_WINDOWS
from sanic.log import error_logger, logger
from sanic.worker.process import ProcessState, Worker, WorkerProcess
if not OS_IS_WINDOWS:
from signal import SIGKILL
else:
SIGKILL = SIGINT
class WorkerManager:
THRESHOLD = 50
def __init__(
self,
number: int,
serve,
server_settings,
context,
monitor_pubsub,
worker_state,
):
self.num_server = number
self.context = context
self.transient: List[Worker] = []
self.durable: List[Worker] = []
self.monitor_publisher, self.monitor_subscriber = monitor_pubsub
self.worker_state = worker_state
self.worker_state["Sanic-Main"] = {"pid": self.pid}
self.terminated = False
if number == 0:
raise RuntimeError("Cannot serve with no workers")
for i in range(number):
self.manage(
f"{WorkerProcess.SERVER_LABEL}-{i}",
serve,
server_settings,
transient=True,
)
signal_func(SIGINT, self.shutdown_signal)
signal_func(SIGTERM, self.shutdown_signal)
def manage(self, ident, func, kwargs, transient=False):
container = self.transient if transient else self.durable
container.append(
Worker(ident, func, kwargs, self.context, self.worker_state)
)
def run(self):
self.start()
self.monitor()
self.join()
self.terminate()
# self.kill()
def start(self):
for process in self.processes:
process.start()
def join(self):
logger.debug("Joining processes", extra={"verbosity": 1})
joined = set()
for process in self.processes:
logger.debug(
f"Found {process.pid} - {process.state.name}",
extra={"verbosity": 1},
)
if process.state < ProcessState.JOINED:
logger.debug(f"Joining {process.pid}", extra={"verbosity": 1})
joined.add(process.pid)
process.join()
if joined:
self.join()
def terminate(self):
if not self.terminated:
for process in self.processes:
process.terminate()
self.terminated = True
def restart(self, process_names: Optional[List[str]] = None, **kwargs):
for process in self.transient_processes:
if not process_names or process.name in process_names:
process.restart(**kwargs)
def monitor(self):
self.wait_for_ack()
while True:
try:
if self.monitor_subscriber.poll(0.1):
message = self.monitor_subscriber.recv()
logger.debug(
f"Monitor message: {message}", extra={"verbosity": 2}
)
if not message:
break
elif message == "__TERMINATE__":
self.shutdown()
break
split_message = message.split(":", 1)
processes = split_message[0]
reloaded_files = (
split_message[1] if len(split_message) > 1 else None
)
process_names = [
name.strip() for name in processes.split(",")
]
if "__ALL_PROCESSES__" in process_names:
process_names = None
self.restart(
process_names=process_names,
reloaded_files=reloaded_files,
)
except InterruptedError:
if not OS_IS_WINDOWS:
raise
break
def wait_for_ack(self): # no cov
misses = 0
while not self._all_workers_ack():
sleep(0.1)
misses += 1
if misses > self.THRESHOLD:
error_logger.error("Not all workers are ack. Shutting down.")
self.kill()
sys.exit(1)
@property
def workers(self):
return self.transient + self.durable
@property
def processes(self):
for worker in self.workers:
for process in worker.processes:
yield process
@property
def transient_processes(self):
for worker in self.transient:
for process in worker.processes:
yield process
def kill(self):
for process in self.processes:
os.kill(process.pid, SIGKILL)
def shutdown_signal(self, signal, frame):
logger.info("Received signal %s. Shutting down.", Signals(signal).name)
self.monitor_publisher.send(None)
self.shutdown()
def shutdown(self):
for process in self.processes:
if process.is_alive():
process.terminate()
@property
def pid(self):
return os.getpid()
def _all_workers_ack(self):
acked = [
worker_state.get("state") == ProcessState.ACKED.name
for worker_state in self.worker_state.values()
if worker_state.get("server")
]
return all(acked) and len(acked) == self.num_server

View File

@@ -0,0 +1,48 @@
from multiprocessing.connection import Connection
from os import environ, getpid
from typing import Any, Dict
from sanic.worker.process import ProcessState
from sanic.worker.state import WorkerState
class WorkerMultiplexer:
def __init__(
self,
monitor_publisher: Connection,
worker_state: Dict[str, Any],
):
self._monitor_publisher = monitor_publisher
self._state = WorkerState(worker_state, self.name)
def ack(self):
self._state._state[self.name] = {
**self._state._state[self.name],
"state": ProcessState.ACKED.name,
}
def restart(self, name: str = ""):
if not name:
name = self.name
self._monitor_publisher.send(name)
reload = restart # no cov
def terminate(self):
self._monitor_publisher.send("__TERMINATE__")
@property
def pid(self) -> int:
return getpid()
@property
def name(self) -> str:
return environ.get("SANIC_WORKER_NAME", "")
@property
def state(self):
return self._state
@property
def workers(self) -> Dict[str, Any]:
return self.state.full()

161
sanic/worker/process.py Normal file
View File

@@ -0,0 +1,161 @@
import os
from datetime import datetime, timezone
from enum import IntEnum, auto
from multiprocessing.context import BaseContext
from signal import SIGINT
from typing import Any, Dict, Set
from sanic.log import Colors, logger
def get_now():
now = datetime.now(tz=timezone.utc)
return now
class ProcessState(IntEnum):
IDLE = auto()
STARTED = auto()
ACKED = auto()
JOINED = auto()
TERMINATED = auto()
class WorkerProcess:
SERVER_LABEL = "Server"
def __init__(self, factory, name, target, kwargs, worker_state):
self.state = ProcessState.IDLE
self.factory = factory
self.name = name
self.target = target
self.kwargs = kwargs
self.worker_state = worker_state
if self.name not in self.worker_state:
self.worker_state[self.name] = {
"server": self.SERVER_LABEL in self.name
}
self.spawn()
def set_state(self, state: ProcessState, force=False):
if not force and state < self.state:
raise Exception("...")
self.state = state
self.worker_state[self.name] = {
**self.worker_state[self.name],
"state": self.state.name,
}
def start(self):
os.environ["SANIC_WORKER_NAME"] = self.name
logger.debug(
f"{Colors.BLUE}Starting a process: {Colors.BOLD}"
f"{Colors.SANIC}%s{Colors.END}",
self.name,
)
self.set_state(ProcessState.STARTED)
self._process.start()
if not self.worker_state[self.name].get("starts"):
self.worker_state[self.name] = {
**self.worker_state[self.name],
"pid": self.pid,
"start_at": get_now(),
"starts": 1,
}
del os.environ["SANIC_WORKER_NAME"]
def join(self):
self.set_state(ProcessState.JOINED)
self._process.join()
def terminate(self):
if self.state is not ProcessState.TERMINATED:
logger.debug(
f"{Colors.BLUE}Terminating a process: "
f"{Colors.BOLD}{Colors.SANIC}"
f"%s {Colors.BLUE}[%s]{Colors.END}",
self.name,
self.pid,
)
self.set_state(ProcessState.TERMINATED, force=True)
try:
# self._process.terminate()
os.kill(self.pid, SIGINT)
del self.worker_state[self.name]
except (KeyError, AttributeError, ProcessLookupError):
...
def restart(self, **kwargs):
logger.debug(
f"{Colors.BLUE}Restarting a process: {Colors.BOLD}{Colors.SANIC}"
f"%s {Colors.BLUE}[%s]{Colors.END}",
self.name,
self.pid,
)
self._process.terminate()
self.set_state(ProcessState.IDLE, force=True)
self.kwargs.update(
{"config": {k.upper(): v for k, v in kwargs.items()}}
)
try:
self.spawn()
self.start()
except AttributeError:
raise RuntimeError("Restart failed")
self.worker_state[self.name] = {
**self.worker_state[self.name],
"pid": self.pid,
"starts": self.worker_state[self.name]["starts"] + 1,
"restart_at": get_now(),
}
def is_alive(self):
try:
return self._process.is_alive()
except AssertionError:
return False
def spawn(self):
if self.state is not ProcessState.IDLE:
raise Exception("Cannot spawn a worker process until it is idle.")
self._process = self.factory(
name=self.name,
target=self.target,
kwargs=self.kwargs,
daemon=True,
)
@property
def pid(self):
return self._process.pid
class Worker:
def __init__(
self,
ident: str,
serve,
server_settings,
context: BaseContext,
worker_state: Dict[str, Any],
):
self.ident = ident
self.context = context
self.serve = serve
self.server_settings = server_settings
self.worker_state = worker_state
self.processes: Set[WorkerProcess] = set()
self.create_process()
def create_process(self) -> WorkerProcess:
process = WorkerProcess(
factory=self.context.Process,
name=f"Sanic-{self.ident}-{len(self.processes)}",
target=self.serve,
kwargs={**self.server_settings},
worker_state=self.worker_state,
)
self.processes.add(process)
return process

119
sanic/worker/reloader.py Normal file
View File

@@ -0,0 +1,119 @@
from __future__ import annotations
import os
import sys
from asyncio import new_event_loop
from itertools import chain
from multiprocessing.connection import Connection
from pathlib import Path
from signal import SIGINT, SIGTERM
from signal import signal as signal_func
from typing import Dict, Set
from sanic.server.events import trigger_events
from sanic.worker.loader import AppLoader
class Reloader:
def __init__(
self,
publisher: Connection,
interval: float,
reload_dirs: Set[Path],
app_loader: AppLoader,
):
self._publisher = publisher
self.interval = interval
self.reload_dirs = reload_dirs
self.run = True
self.app_loader = app_loader
def __call__(self) -> None:
app = self.app_loader.load()
signal_func(SIGINT, self.stop)
signal_func(SIGTERM, self.stop)
mtimes: Dict[str, float] = {}
reloader_start = app.listeners.get("reload_process_start")
reloader_stop = app.listeners.get("reload_process_stop")
before_trigger = app.listeners.get("before_reload_trigger")
after_trigger = app.listeners.get("after_reload_trigger")
loop = new_event_loop()
if reloader_start:
trigger_events(reloader_start, loop, app)
while self.run:
changed = set()
for filename in self.files():
try:
if self.check_file(filename, mtimes):
path = (
filename
if isinstance(filename, str)
else filename.resolve()
)
changed.add(str(path))
except OSError:
continue
if changed:
if before_trigger:
trigger_events(before_trigger, loop, app)
self.reload(",".join(changed) if changed else "unknown")
if after_trigger:
trigger_events(after_trigger, loop, app)
else:
if reloader_stop:
trigger_events(reloader_stop, loop, app)
def stop(self, *_):
self.run = False
def reload(self, reloaded_files):
message = f"__ALL_PROCESSES__:{reloaded_files}"
self._publisher.send(message)
def files(self):
return chain(
self.python_files(),
*(d.glob("**/*") for d in self.reload_dirs),
)
def python_files(self): # no cov
"""This iterates over all relevant Python files.
It goes through all
loaded files from modules, all files in folders of already loaded
modules as well as all files reachable through a package.
"""
# The list call is necessary on Python 3 in case the module
# dictionary modifies during iteration.
for module in list(sys.modules.values()):
if module is None:
continue
filename = getattr(module, "__file__", None)
if filename:
old = None
while not os.path.isfile(filename):
old = filename
filename = os.path.dirname(filename)
if filename == old:
break
else:
if filename[-4:] in (".pyc", ".pyo"):
filename = filename[:-1]
yield filename
@staticmethod
def check_file(filename, mtimes) -> bool:
need_reload = False
mtime = os.stat(filename).st_mtime
old_time = mtimes.get(filename)
if old_time is None:
mtimes[filename] = mtime
elif mtime > old_time:
mtimes[filename] = mtime
need_reload = True
return need_reload

124
sanic/worker/serve.py Normal file
View File

@@ -0,0 +1,124 @@
import asyncio
import os
import socket
from functools import partial
from multiprocessing.connection import Connection
from ssl import SSLContext
from typing import Any, Dict, List, Optional, Type, Union
from sanic.application.constants import ServerStage
from sanic.application.state import ApplicationServerInfo
from sanic.http.constants import HTTP
from sanic.models.server_types import Signal
from sanic.server.protocols.http_protocol import HttpProtocol
from sanic.server.runners import _serve_http_1, _serve_http_3
from sanic.worker.loader import AppLoader, CertLoader
from sanic.worker.multiplexer import WorkerMultiplexer
def worker_serve(
host,
port,
app_name: str,
monitor_publisher: Optional[Connection],
app_loader: AppLoader,
worker_state: Optional[Dict[str, Any]] = None,
server_info: Optional[Dict[str, List[ApplicationServerInfo]]] = None,
ssl: Optional[
Union[SSLContext, Dict[str, Union[str, os.PathLike]]]
] = None,
sock: Optional[socket.socket] = None,
unix: Optional[str] = None,
reuse_port: bool = False,
loop=None,
protocol: Type[asyncio.Protocol] = HttpProtocol,
backlog: int = 100,
register_sys_signals: bool = True,
run_multiple: bool = False,
run_async: bool = False,
connections=None,
signal=Signal(),
state=None,
asyncio_server_kwargs=None,
version=HTTP.VERSION_1,
config=None,
passthru: Optional[Dict[str, Any]] = None,
):
from sanic import Sanic
if app_loader:
app = app_loader.load()
else:
app = Sanic.get_app(app_name)
app.refresh(passthru)
app.setup_loop()
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
# Hydrate server info if needed
if server_info:
for app_name, server_info_objects in server_info.items():
a = Sanic.get_app(app_name)
if not a.state.server_info:
a.state.server_info = []
for info in server_info_objects:
if not info.settings.get("app"):
info.settings["app"] = a
a.state.server_info.append(info)
if isinstance(ssl, dict):
cert_loader = CertLoader(ssl)
ssl = cert_loader.load(app)
for info in app.state.server_info:
info.settings["ssl"] = ssl
# When in a worker process, do some init
if os.environ.get("SANIC_WORKER_NAME"):
# Hydrate apps with any passed server info
if monitor_publisher is None:
raise RuntimeError("No restart publisher found in worker process")
if worker_state is None:
raise RuntimeError("No worker state found in worker process")
# Run secondary servers
apps = list(Sanic._app_registry.values())
app.before_server_start(partial(app._start_servers, apps=apps))
for a in apps:
a.multiplexer = WorkerMultiplexer(monitor_publisher, worker_state)
if app.debug:
loop.set_debug(app.debug)
app.asgi = False
if app.state.server_info:
primary_server_info = app.state.server_info[0]
primary_server_info.stage = ServerStage.SERVING
if config:
app.update_config(config)
if version is HTTP.VERSION_3:
return _serve_http_3(host, port, app, loop, ssl)
return _serve_http_1(
host,
port,
app,
ssl,
sock,
unix,
reuse_port,
loop,
protocol,
backlog,
register_sys_signals,
run_multiple,
run_async,
connections,
signal,
state,
asyncio_server_kwargs,
)

85
sanic/worker/state.py Normal file
View File

@@ -0,0 +1,85 @@
from collections.abc import Mapping
from typing import Any, Dict, ItemsView, Iterator, KeysView, List
from typing import Mapping as MappingType
from typing import ValuesView
dict
class WorkerState(Mapping):
RESTRICTED = (
"health",
"pid",
"requests",
"restart_at",
"server",
"start_at",
"starts",
"state",
)
def __init__(self, state: Dict[str, Any], current: str) -> None:
self._name = current
self._state = state
def __getitem__(self, key: str) -> Any:
return self._state[self._name][key]
def __setitem__(self, key: str, value: Any) -> None:
if key in self.RESTRICTED:
self._write_error([key])
self._state[self._name] = {
**self._state[self._name],
key: value,
}
def __delitem__(self, key: str) -> None:
if key in self.RESTRICTED:
self._write_error([key])
self._state[self._name] = {
k: v for k, v in self._state[self._name].items() if k != key
}
def __iter__(self) -> Iterator[Any]:
return iter(self._state[self._name])
def __len__(self) -> int:
return len(self._state[self._name])
def __repr__(self) -> str:
return repr(self._state[self._name])
def __eq__(self, other: object) -> bool:
return self._state[self._name] == other
def keys(self) -> KeysView[str]:
return self._state[self._name].keys()
def values(self) -> ValuesView[Any]:
return self._state[self._name].values()
def items(self) -> ItemsView[str, Any]:
return self._state[self._name].items()
def update(self, mapping: MappingType[str, Any]) -> None:
if any(k in self.RESTRICTED for k in mapping.keys()):
self._write_error(
[k for k in mapping.keys() if k in self.RESTRICTED]
)
self._state[self._name] = {
**self._state[self._name],
**mapping,
}
def pop(self) -> None:
raise NotImplementedError
def full(self) -> Dict[str, Any]:
return dict(self._state)
def _write_error(self, keys: List[str]) -> None:
raise LookupError(
f"Cannot set restricted key{'s' if len(keys) > 1 else ''} on "
f"WorkerState: {', '.join(keys)}"
)

View File

@@ -61,7 +61,7 @@ setup_kwargs = {
"Build fast. Run fast."
),
"long_description": long_description,
"packages": find_packages(),
"packages": find_packages(exclude=("tests", "tests.*")),
"package_data": {"sanic": ["py.typed"]},
"platforms": "any",
"python_requires": ">=3.7",
@@ -84,7 +84,7 @@ ujson = "ujson>=1.35" + env_dependency
uvloop = "uvloop>=0.5.3" + env_dependency
types_ujson = "types-ujson" + env_dependency
requirements = [
"sanic-routing>=22.3.0,<22.6.0",
"sanic-routing>=22.8.0",
"httptools>=0.0.10",
uvloop,
ujson,
@@ -94,14 +94,11 @@ requirements = [
]
tests_require = [
"sanic-testing>=22.3.0",
"pytest==6.2.5",
"coverage==5.3",
"gunicorn==20.0.4",
"pytest-cov",
"sanic-testing>=22.9.0b1",
"pytest",
"coverage",
"beautifulsoup4",
"pytest-sanic",
"pytest-sugar",
"pytest-benchmark",
"chardet==3.*",
"flake8",

View File

@@ -126,7 +126,7 @@ def sanic_router(app):
except RouteExists:
pass
router.finalize()
return router, added_router
return router, tuple(added_router)
return _setup

View File

@@ -1,6 +1,8 @@
import json
from sanic import Sanic, text
from sanic.application.constants import Mode
from sanic.config import Config
from sanic.log import LOGGING_CONFIG_DEFAULTS, logger
@@ -16,7 +18,7 @@ async def handler(request):
return text(request.ip)
@app.before_server_start
@app.main_process_start
async def app_info_dump(app: Sanic, _):
app_data = {
"access_log": app.config.ACCESS_LOG,
@@ -27,6 +29,13 @@ async def app_info_dump(app: Sanic, _):
logger.info(json.dumps(app_data))
@app.main_process_stop
async def app_cleanup(app: Sanic, _):
app.state.auto_reload = False
app.state.mode = Mode.PRODUCTION
app.config = Config()
@app.after_server_start
async def shutdown(app: Sanic, _):
app.stop()
@@ -38,8 +47,8 @@ def create_app():
def create_app_with_args(args):
try:
print(f"foo={args.foo}")
logger.info(f"foo={args.foo}")
except AttributeError:
print(f"module={args.module}")
logger.info(f"module={args.module}")
return app

View File

@@ -35,6 +35,7 @@ def test_server_starts_http3(app: Sanic, version, caplog):
"cert": localhost_dir / "fullchain.pem",
"key": localhost_dir / "privkey.pem",
},
single_process=True,
)
assert ev.is_set()
@@ -69,7 +70,7 @@ def test_server_starts_http1_and_http3(app: Sanic, caplog):
},
)
with caplog.at_level(logging.INFO):
Sanic.serve()
Sanic.serve_single()
assert (
"sanic.root",

View File

@@ -4,6 +4,7 @@ import re
from collections import Counter
from inspect import isawaitable
from os import environ
from unittest.mock import Mock, patch
import pytest
@@ -15,6 +16,7 @@ from sanic.compat import OS_IS_WINDOWS
from sanic.config import Config
from sanic.exceptions import SanicException
from sanic.helpers import _default
from sanic.log import LOGGING_CONFIG_DEFAULTS
from sanic.response import text
@@ -23,7 +25,7 @@ def clear_app_registry():
Sanic._app_registry = {}
def test_app_loop_running(app):
def test_app_loop_running(app: Sanic):
@app.get("/test")
async def handler(request):
assert isinstance(app.loop, asyncio.AbstractEventLoop)
@@ -33,7 +35,7 @@ def test_app_loop_running(app):
assert response.text == "pass"
def test_create_asyncio_server(app):
def test_create_asyncio_server(app: Sanic):
loop = asyncio.get_event_loop()
asyncio_srv_coro = app.create_server(return_asyncio_server=True)
assert isawaitable(asyncio_srv_coro)
@@ -41,7 +43,7 @@ def test_create_asyncio_server(app):
assert srv.is_serving() is True
def test_asyncio_server_no_start_serving(app):
def test_asyncio_server_no_start_serving(app: Sanic):
loop = asyncio.get_event_loop()
asyncio_srv_coro = app.create_server(
port=43123,
@@ -52,7 +54,7 @@ def test_asyncio_server_no_start_serving(app):
assert srv.is_serving() is False
def test_asyncio_server_start_serving(app):
def test_asyncio_server_start_serving(app: Sanic):
loop = asyncio.get_event_loop()
asyncio_srv_coro = app.create_server(
port=43124,
@@ -69,7 +71,7 @@ def test_asyncio_server_start_serving(app):
# Looks like we can't easily test `serve_forever()`
def test_create_server_main(app, caplog):
def test_create_server_main(app: Sanic, caplog):
app.listener("main_process_start")(lambda *_: ...)
loop = asyncio.get_event_loop()
with caplog.at_level(logging.INFO):
@@ -83,7 +85,7 @@ def test_create_server_main(app, caplog):
) in caplog.record_tuples
def test_create_server_no_startup(app):
def test_create_server_no_startup(app: Sanic):
loop = asyncio.get_event_loop()
asyncio_srv_coro = app.create_server(
port=43124,
@@ -98,7 +100,7 @@ def test_create_server_no_startup(app):
loop.run_until_complete(srv.start_serving())
def test_create_server_main_convenience(app, caplog):
def test_create_server_main_convenience(app: Sanic, caplog):
app.main_process_start(lambda *_: ...)
loop = asyncio.get_event_loop()
with caplog.at_level(logging.INFO):
@@ -112,7 +114,7 @@ def test_create_server_main_convenience(app, caplog):
) in caplog.record_tuples
def test_app_loop_not_running(app):
def test_app_loop_not_running(app: Sanic):
with pytest.raises(SanicException) as excinfo:
app.loop
@@ -122,7 +124,7 @@ def test_app_loop_not_running(app):
)
def test_app_run_raise_type_error(app):
def test_app_run_raise_type_error(app: Sanic):
with pytest.raises(TypeError) as excinfo:
app.run(loop="loop")
@@ -135,7 +137,7 @@ def test_app_run_raise_type_error(app):
)
def test_app_route_raise_value_error(app):
def test_app_route_raise_value_error(app: Sanic):
with pytest.raises(ValueError) as excinfo:
@@ -149,11 +151,13 @@ def test_app_route_raise_value_error(app):
)
def test_app_handle_request_handler_is_none(app, monkeypatch):
def mockreturn(*args, **kwargs):
return Mock(), None, {}
def test_app_handle_request_handler_is_none(app: Sanic, monkeypatch):
mock = Mock()
mock.handler = None
def mockreturn(*args, **kwargs):
return mock, None, {}
# Not sure how to make app.router.get() return None, so use mock here.
monkeypatch.setattr(app.router, "get", mockreturn)
@app.get("/test")
@@ -170,7 +174,7 @@ def test_app_handle_request_handler_is_none(app, monkeypatch):
@pytest.mark.parametrize("websocket_enabled", [True, False])
@pytest.mark.parametrize("enable", [True, False])
def test_app_enable_websocket(app, websocket_enabled, enable):
def test_app_enable_websocket(app: Sanic, websocket_enabled, enable):
app.websocket_enabled = websocket_enabled
app.enable_websocket(enable=enable)
@@ -180,11 +184,11 @@ def test_app_enable_websocket(app, websocket_enabled, enable):
async def handler(request, ws):
await ws.send("test")
assert app.websocket_enabled == True
assert app.websocket_enabled is True
@patch("sanic.mixins.runner.WebSocketProtocol")
def test_app_websocket_parameters(websocket_protocol_mock, app):
@patch("sanic.mixins.startup.WebSocketProtocol")
def test_app_websocket_parameters(websocket_protocol_mock, app: Sanic):
app.config.WEBSOCKET_MAX_SIZE = 44
app.config.WEBSOCKET_PING_TIMEOUT = 48
app.config.WEBSOCKET_PING_INTERVAL = 50
@@ -194,9 +198,10 @@ def test_app_websocket_parameters(websocket_protocol_mock, app):
await ws.send("test")
try:
# This will fail because WebSocketProtocol is mocked and only the call kwargs matter
# This will fail because WebSocketProtocol is mocked and only the
# call kwargs matter
app.test_client.get("/ws")
except:
except Exception:
pass
websocket_protocol_call_args = websocket_protocol_mock.call_args
@@ -212,11 +217,10 @@ def test_app_websocket_parameters(websocket_protocol_mock, app):
)
def test_handle_request_with_nested_exception(app, monkeypatch):
def test_handle_request_with_nested_exception(app: Sanic, monkeypatch):
err_msg = "Mock Exception"
# Not sure how to raise an exception in app.error_handler.response(), use mock here
def mock_error_handler_response(*args, **kwargs):
raise Exception(err_msg)
@@ -233,11 +237,10 @@ def test_handle_request_with_nested_exception(app, monkeypatch):
assert response.text == "An error occurred while handling an error"
def test_handle_request_with_nested_exception_debug(app, monkeypatch):
def test_handle_request_with_nested_exception_debug(app: Sanic, monkeypatch):
err_msg = "Mock Exception"
# Not sure how to raise an exception in app.error_handler.response(), use mock here
def mock_error_handler_response(*args, **kwargs):
raise Exception(err_msg)
@@ -252,13 +255,14 @@ def test_handle_request_with_nested_exception_debug(app, monkeypatch):
request, response = app.test_client.get("/", debug=True)
assert response.status == 500
assert response.text.startswith(
f"Error while handling error: {err_msg}\nStack: Traceback (most recent call last):\n"
f"Error while handling error: {err_msg}\n"
"Stack: Traceback (most recent call last):\n"
)
def test_handle_request_with_nested_sanic_exception(app, monkeypatch, caplog):
# Not sure how to raise an exception in app.error_handler.response(), use mock here
def test_handle_request_with_nested_sanic_exception(
app: Sanic, monkeypatch, caplog
):
def mock_error_handler_response(*args, **kwargs):
raise SanicException("Mock SanicException")
@@ -301,8 +305,12 @@ def test_app_has_test_mode_sync():
def test_app_registry():
assert len(Sanic._app_registry) == 0
instance = Sanic("test")
assert len(Sanic._app_registry) == 1
assert Sanic._app_registry["test"] is instance
Sanic.unregister_app(instance)
assert len(Sanic._app_registry) == 0
def test_app_registry_wrong_type():
@@ -371,7 +379,7 @@ def test_get_app_default_ambiguous():
Sanic.get_app()
def test_app_set_attribute_warning(app):
def test_app_set_attribute_warning(app: Sanic):
message = (
"Setting variables on Sanic instances is not allowed. You should "
"change your Sanic instance to use instance.ctx.foo instead."
@@ -380,7 +388,7 @@ def test_app_set_attribute_warning(app):
app.foo = 1
def test_app_set_context(app):
def test_app_set_context(app: Sanic):
app.ctx.foo = 1
retrieved = Sanic.get_app(app.name)
@@ -426,13 +434,13 @@ def test_custom_context():
@pytest.mark.parametrize("use", (False, True))
def test_uvloop_config(app, monkeypatch, use):
def test_uvloop_config(app: Sanic, monkeypatch, use):
@app.get("/test")
def handler(request):
return text("ok")
try_use_uvloop = Mock()
monkeypatch.setattr(sanic.mixins.runner, "try_use_uvloop", try_use_uvloop)
monkeypatch.setattr(sanic.mixins.startup, "try_use_uvloop", try_use_uvloop)
# Default config
app.test_client.get("/test")
@@ -458,7 +466,7 @@ def test_uvloop_cannot_never_called_with_create_server(caplog, monkeypatch):
apps[2].config.USE_UVLOOP = True
try_use_uvloop = Mock()
monkeypatch.setattr(sanic.mixins.runner, "try_use_uvloop", try_use_uvloop)
monkeypatch.setattr(sanic.mixins.startup, "try_use_uvloop", try_use_uvloop)
loop = asyncio.get_event_loop()
@@ -517,12 +525,133 @@ def test_multiple_uvloop_configs_display_warning(caplog):
assert counter[(logging.WARNING, message)] == 2
def test_cannot_run_fast_and_workers(app):
def test_cannot_run_fast_and_workers(app: Sanic):
message = "You cannot use both fast=True and workers=X"
with pytest.raises(RuntimeError, match=message):
app.run(fast=True, workers=4)
def test_no_workers(app):
def test_no_workers(app: Sanic):
with pytest.raises(RuntimeError, match="Cannot serve with no workers"):
app.run(workers=0)
@pytest.mark.parametrize(
"extra",
(
{"fast": True},
{"workers": 2},
{"auto_reload": True},
),
)
def test_cannot_run_single_process_and_workers_or_auto_reload(
app: Sanic, extra
):
message = (
"Single process cannot be run with multiple workers or auto-reload"
)
with pytest.raises(RuntimeError, match=message):
app.run(single_process=True, **extra)
def test_cannot_run_single_process_and_legacy(app: Sanic):
message = "Cannot run single process and legacy mode"
with pytest.raises(RuntimeError, match=message):
app.run(single_process=True, legacy=True)
def test_cannot_run_without_sys_signals_with_workers(app: Sanic):
message = (
"Cannot run Sanic.serve with register_sys_signals=False. "
"Use either Sanic.serve_single or Sanic.serve_legacy."
)
with pytest.raises(RuntimeError, match=message):
app.run(register_sys_signals=False, single_process=False, legacy=False)
def test_default_configure_logging():
with patch("sanic.app.logging") as mock:
Sanic("Test")
mock.config.dictConfig.assert_called_with(LOGGING_CONFIG_DEFAULTS)
def test_custom_configure_logging():
with patch("sanic.app.logging") as mock:
Sanic("Test", log_config={"foo": "bar"})
mock.config.dictConfig.assert_called_with({"foo": "bar"})
def test_disable_configure_logging():
with patch("sanic.app.logging") as mock:
Sanic("Test", configure_logging=False)
mock.config.dictConfig.assert_not_called()
@pytest.mark.parametrize("inspector", (True, False))
def test_inspector(inspector):
app = Sanic("Test", inspector=inspector)
assert app.config.INSPECTOR is inspector
def test_build_endpoint_name():
app = Sanic("Test")
name = app._build_endpoint_name("foo", "bar")
assert name == "Test.foo.bar"
def test_manager_in_main_process_only(app: Sanic):
message = "Can only access the manager from the main process"
with pytest.raises(SanicException, match=message):
app.manager
app._manager = 1
environ["SANIC_WORKER_PROCESS"] = "ok"
with pytest.raises(SanicException, match=message):
app.manager
del environ["SANIC_WORKER_PROCESS"]
assert app.manager == 1
def test_inspector_in_main_process_only(app: Sanic):
message = "Can only access the inspector from the main process"
with pytest.raises(SanicException, match=message):
app.inspector
app._inspector = 1
environ["SANIC_WORKER_PROCESS"] = "ok"
with pytest.raises(SanicException, match=message):
app.inspector
del environ["SANIC_WORKER_PROCESS"]
assert app.inspector == 1
def test_stop_trigger_terminate(app: Sanic):
app.multiplexer = Mock()
app.stop()
app.multiplexer.terminate.assert_called_once()
app.multiplexer.reset_mock()
assert len(Sanic._app_registry) == 1
Sanic._app_registry.clear()
app.stop(terminate=True)
app.multiplexer.terminate.assert_called_once()
app.multiplexer.reset_mock()
assert len(Sanic._app_registry) == 0
Sanic._app_registry.clear()
app.stop(unregister=False)
app.multiplexer.terminate.assert_called_once()

View File

@@ -531,6 +531,8 @@ async def test_signals_triggered(app):
"http.lifecycle.handle",
"http.routing.before",
"http.routing.after",
"http.handler.before",
"http.handler.after",
"http.lifecycle.response",
# "http.lifecycle.send",
# "http.lifecycle.complete",
@@ -546,3 +548,13 @@ async def test_signals_triggered(app):
assert response.status_code == 200
assert response.text == "test_signals_triggered"
assert signals_triggered == signals_expected
@pytest.mark.asyncio
async def test_asgi_serve_location(app):
@app.get("/")
def _request(request: Request):
return text(request.app.serve_location)
_, response = await app.asgi_client.get("/")
assert response.text == "http://<ASGI>"

View File

@@ -1,13 +1,16 @@
import asyncio
from sanic import Sanic
def test_bad_request_response(app):
def test_bad_request_response(app: Sanic):
lines = []
app.get("/")(lambda x: ...)
@app.listener("after_server_start")
async def _request(sanic, loop):
nonlocal lines
connect = asyncio.open_connection("127.0.0.1", 42101)
reader, writer = await connect
writer.write(b"not http\r\n\r\n")
@@ -18,6 +21,6 @@ def test_bad_request_response(app):
lines.append(line)
app.stop()
app.run(host="127.0.0.1", port=42101, debug=False)
app.run(host="127.0.0.1", port=42101, debug=False, single_process=True)
assert lines[0] == b"HTTP/1.1 400 Bad Request\r\n"
assert b"Bad Request" in lines[-2]

View File

@@ -17,7 +17,7 @@ from sanic.response import json, text
# ------------------------------------------------------------ #
def test_bp(app):
def test_bp(app: Sanic):
bp = Blueprint("test_text")
@bp.route("/")
@@ -30,7 +30,7 @@ def test_bp(app):
assert response.text == "Hello"
def test_bp_app_access(app):
def test_bp_app_access(app: Sanic):
bp = Blueprint("test")
with pytest.raises(
@@ -87,7 +87,7 @@ def test_versioned_routes_get(app, method):
assert response.status == 200
def test_bp_strict_slash(app):
def test_bp_strict_slash(app: Sanic):
bp = Blueprint("test_text")
@bp.get("/get", strict_slashes=True)
@@ -114,7 +114,7 @@ def test_bp_strict_slash(app):
assert response.status == 404
def test_bp_strict_slash_default_value(app):
def test_bp_strict_slash_default_value(app: Sanic):
bp = Blueprint("test_text", strict_slashes=True)
@bp.get("/get")
@@ -134,7 +134,7 @@ def test_bp_strict_slash_default_value(app):
assert response.status == 404
def test_bp_strict_slash_without_passing_default_value(app):
def test_bp_strict_slash_without_passing_default_value(app: Sanic):
bp = Blueprint("test_text")
@bp.get("/get")
@@ -154,7 +154,7 @@ def test_bp_strict_slash_without_passing_default_value(app):
assert response.text == "OK"
def test_bp_strict_slash_default_value_can_be_overwritten(app):
def test_bp_strict_slash_default_value_can_be_overwritten(app: Sanic):
bp = Blueprint("test_text", strict_slashes=True)
@bp.get("/get", strict_slashes=False)
@@ -174,7 +174,7 @@ def test_bp_strict_slash_default_value_can_be_overwritten(app):
assert response.text == "OK"
def test_bp_with_url_prefix(app):
def test_bp_with_url_prefix(app: Sanic):
bp = Blueprint("test_text", url_prefix="/test1")
@bp.route("/")
@@ -187,7 +187,7 @@ def test_bp_with_url_prefix(app):
assert response.text == "Hello"
def test_several_bp_with_url_prefix(app):
def test_several_bp_with_url_prefix(app: Sanic):
bp = Blueprint("test_text", url_prefix="/test1")
bp2 = Blueprint("test_text2", url_prefix="/test2")
@@ -208,7 +208,7 @@ def test_several_bp_with_url_prefix(app):
assert response.text == "Hello2"
def test_bp_with_host(app):
def test_bp_with_host(app: Sanic):
bp = Blueprint("test_bp_host", url_prefix="/test1", host="example.com")
@bp.route("/")
@@ -230,7 +230,7 @@ def test_bp_with_host(app):
assert response.body == b"Hello subdomain!"
def test_several_bp_with_host(app):
def test_several_bp_with_host(app: Sanic):
bp = Blueprint(
"test_text",
url_prefix="/test",
@@ -274,7 +274,7 @@ def test_several_bp_with_host(app):
assert response.text == "Hello3"
def test_bp_with_host_list(app):
def test_bp_with_host_list(app: Sanic):
bp = Blueprint(
"test_bp_host",
url_prefix="/test1",
@@ -304,7 +304,7 @@ def test_bp_with_host_list(app):
assert response.text == "Hello subdomain!"
def test_several_bp_with_host_list(app):
def test_several_bp_with_host_list(app: Sanic):
bp = Blueprint(
"test_text",
url_prefix="/test",
@@ -356,7 +356,7 @@ def test_several_bp_with_host_list(app):
assert response.text == "Hello3"
def test_bp_middleware(app):
def test_bp_middleware(app: Sanic):
blueprint = Blueprint("test_bp_middleware")
@blueprint.middleware("response")
@@ -375,7 +375,7 @@ def test_bp_middleware(app):
assert response.text == "FAIL"
def test_bp_middleware_with_route(app):
def test_bp_middleware_with_route(app: Sanic):
blueprint = Blueprint("test_bp_middleware")
@blueprint.middleware("response")
@@ -398,7 +398,7 @@ def test_bp_middleware_with_route(app):
assert response.text == "OK"
def test_bp_middleware_order(app):
def test_bp_middleware_order(app: Sanic):
blueprint = Blueprint("test_bp_middleware_order")
order = []
@@ -438,7 +438,7 @@ def test_bp_middleware_order(app):
assert order == [1, 2, 3, 4, 5, 6]
def test_bp_exception_handler(app):
def test_bp_exception_handler(app: Sanic):
blueprint = Blueprint("test_middleware")
@blueprint.route("/1")
@@ -470,7 +470,7 @@ def test_bp_exception_handler(app):
assert response.status == 200
def test_bp_exception_handler_applied(app):
def test_bp_exception_handler_applied(app: Sanic):
class Error(Exception):
pass
@@ -500,7 +500,7 @@ def test_bp_exception_handler_applied(app):
assert response.status == 500
def test_bp_exception_handler_not_applied(app):
def test_bp_exception_handler_not_applied(app: Sanic):
class Error(Exception):
pass
@@ -522,7 +522,7 @@ def test_bp_exception_handler_not_applied(app):
assert response.status == 500
def test_bp_listeners(app):
def test_bp_listeners(app: Sanic):
app.route("/")(lambda x: x)
blueprint = Blueprint("test_middleware")
@@ -559,7 +559,7 @@ def test_bp_listeners(app):
assert order == [1, 2, 3, 4, 5, 6]
def test_bp_static(app):
def test_bp_static(app: Sanic):
current_file = inspect.getfile(inspect.currentframe())
with open(current_file, "rb") as file:
current_file_contents = file.read()
@@ -597,7 +597,7 @@ def test_bp_static_content_type(app, file_name):
assert response.headers["Content-Type"] == "text/html; charset=utf-8"
def test_bp_shorthand(app):
def test_bp_shorthand(app: Sanic):
blueprint = Blueprint("test_shorhand_routes")
ev = asyncio.Event()
@@ -682,7 +682,7 @@ def test_bp_shorthand(app):
assert ev.is_set()
def test_bp_group(app):
def test_bp_group(app: Sanic):
deep_0 = Blueprint("deep_0", url_prefix="/deep")
deep_1 = Blueprint("deep_1", url_prefix="/deep1")
@@ -722,7 +722,7 @@ def test_bp_group(app):
assert response.text == "D1B_OK"
def test_bp_group_with_default_url_prefix(app):
def test_bp_group_with_default_url_prefix(app: Sanic):
from sanic.response import json
bp_resources = Blueprint("bp_resources")
@@ -873,7 +873,7 @@ def test_websocket_route(app: Sanic):
assert event.is_set()
def test_duplicate_blueprint(app):
def test_duplicate_blueprint(app: Sanic):
bp_name = "bp"
bp = Blueprint(bp_name)
bp1 = Blueprint(bp_name)
@@ -1056,7 +1056,7 @@ def test_bp_set_attribute_warning():
bp.foo = 1
def test_early_registration(app):
def test_early_registration(app: Sanic):
assert len(app.router.routes) == 0
bp = Blueprint("bp")
@@ -1082,3 +1082,29 @@ def test_early_registration(app):
for path in ("one", "two", "three"):
_, response = app.test_client.get(f"/{path}")
assert response.text == path
def test_remove_double_slashes_defined_on_bp(app: Sanic):
bp = Blueprint("bp", url_prefix="/foo/", strict_slashes=True)
@bp.get("/")
async def handler(_):
...
app.blueprint(bp)
app.router.finalize()
assert app.router.routes[0].path == "foo/"
def test_remove_double_slashes_defined_on_register(app: Sanic):
bp = Blueprint("bp")
@bp.get("/")
async def index(_):
...
app.blueprint(bp, url_prefix="/foo/", strict_slashes=True)
app.router.finalize()
assert app.router.routes[0].path == "foo/"

View File

@@ -0,0 +1,21 @@
import asyncio
from asyncio import CancelledError
import pytest
from sanic import Request, Sanic, json
def test_can_raise_in_handler(app: Sanic):
@app.get("/")
async def handler(request: Request):
raise CancelledError("STOP!!")
@app.exception(CancelledError)
async def handle_cancel(request: Request, exc: CancelledError):
return json({"message": exc.args[0]}, status=418)
_, response = app.test_client.get("/")
assert response.status == 418
assert response.json["message"] == "STOP!!"

View File

@@ -1,5 +1,6 @@
import json
import subprocess
import os
import sys
from pathlib import Path
from typing import List, Optional, Tuple
@@ -9,33 +10,30 @@ import pytest
from sanic_routing import __version__ as __routing_version__
from sanic import __version__
from sanic.__main__ import main
def capture(command: List[str]):
proc = subprocess.Popen(
command,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
cwd=Path(__file__).parent,
)
@pytest.fixture(scope="module", autouse=True)
def tty():
orig = sys.stdout.isatty
sys.stdout.isatty = lambda: False
yield
sys.stdout.isatty = orig
def capture(command: List[str], caplog):
caplog.clear()
os.chdir(Path(__file__).parent)
try:
out, err = proc.communicate(timeout=10)
except subprocess.TimeoutExpired:
proc.kill()
out, err = proc.communicate()
return out, err, proc.returncode
def starting_line(lines: List[str]):
for idx, line in enumerate(lines):
if line.strip().startswith(b"Sanic v"):
return idx
return 0
main(command)
except SystemExit:
...
return [record.message for record in caplog.records]
def read_app_info(lines: List[str]):
for line in lines:
if line.startswith(b"{") and line.endswith(b"}"):
if line.startswith("{") and line.endswith("}"): # type: ignore
return json.loads(line)
@@ -47,59 +45,57 @@ def read_app_info(lines: List[str]):
("fake.server.create_app()", None),
),
)
def test_server_run(appname: str, extra: Optional[str]):
command = ["sanic", appname]
def test_server_run(
appname: str,
extra: Optional[str],
caplog: pytest.LogCaptureFixture,
):
command = [appname]
if extra:
command.append(extra)
out, err, exitcode = capture(command)
lines = out.split(b"\n")
firstline = lines[starting_line(lines) + 1]
lines = capture(command, caplog)
assert exitcode != 1
assert firstline == b"Goin' Fast @ http://127.0.0.1:8000"
assert "Goin' Fast @ http://127.0.0.1:8000" in lines
def test_server_run_factory_with_args():
def test_server_run_factory_with_args(caplog):
command = [
"sanic",
"fake.server.create_app_with_args",
"--factory",
]
out, err, exitcode = capture(command)
lines = out.split(b"\n")
lines = capture(command, caplog)
assert exitcode != 1, lines
assert b"module=fake.server.create_app_with_args" in lines
assert "module=fake.server.create_app_with_args" in lines
def test_server_run_factory_with_args_arbitrary():
def test_server_run_factory_with_args_arbitrary(caplog):
command = [
"sanic",
"fake.server.create_app_with_args",
"--factory",
"--foo=bar",
]
out, err, exitcode = capture(command)
lines = out.split(b"\n")
lines = capture(command, caplog)
assert exitcode != 1, lines
assert b"foo=bar" in lines
assert "foo=bar" in lines
def test_error_with_function_as_instance_without_factory_arg():
command = ["sanic", "fake.server.create_app"]
out, err, exitcode = capture(command)
assert b"try: \nsanic fake.server.create_app --factory" in err
assert exitcode != 1
def test_error_with_path_as_instance_without_simple_arg():
command = ["sanic", "./fake/"]
out, err, exitcode = capture(command)
def test_error_with_function_as_instance_without_factory_arg(caplog):
command = ["fake.server.create_app"]
lines = capture(command, caplog)
assert (
b"Please use --simple if you are passing a directory to sanic." in err
)
assert exitcode != 1
"Failed to run app: Module is not a Sanic app, it is a function\n "
"If this callable returns a Sanic instance try: \n"
"sanic fake.server.create_app --factory"
) in lines
def test_error_with_path_as_instance_without_simple_arg(caplog):
command = ["./fake/"]
lines = capture(command, caplog)
assert (
"Failed to run app: App not found.\n Please use --simple if you "
"are passing a directory to sanic.\n eg. sanic ./fake/ --simple"
) in lines
@pytest.mark.parametrize(
@@ -120,13 +116,10 @@ def test_error_with_path_as_instance_without_simple_arg():
),
),
)
def test_tls_options(cmd: Tuple[str]):
command = ["sanic", "fake.server.app", *cmd, "-p=9999", "--debug"]
out, err, exitcode = capture(command)
assert exitcode != 1
lines = out.split(b"\n")
firstline = lines[starting_line(lines) + 1]
assert firstline == b"Goin' Fast @ https://127.0.0.1:9999"
def test_tls_options(cmd: Tuple[str, ...], caplog):
command = ["fake.server.app", *cmd, "--port=9999", "--debug"]
lines = capture(command, caplog)
assert "Goin' Fast @ https://127.0.0.1:9999" in lines
@pytest.mark.parametrize(
@@ -141,14 +134,15 @@ def test_tls_options(cmd: Tuple[str]):
("--tls-strict-host",),
),
)
def test_tls_wrong_options(cmd: Tuple[str]):
command = ["sanic", "fake.server.app", *cmd, "-p=9999", "--debug"]
out, err, exitcode = capture(command)
assert exitcode == 1
assert not out
lines = err.decode().split("\n")
def test_tls_wrong_options(cmd: Tuple[str, ...], caplog):
command = ["fake.server.app", *cmd, "-p=9999", "--debug"]
lines = capture(command, caplog)
assert "TLS certificates must be specified by either of:" in lines
assert (
"TLS certificates must be specified by either of:\n "
"--cert certdir/fullchain.pem --key certdir/privkey.pem\n "
"--tls certdir (equivalent to the above)"
) in lines
@pytest.mark.parametrize(
@@ -158,65 +152,44 @@ def test_tls_wrong_options(cmd: Tuple[str]):
("-H", "localhost", "-p", "9999"),
),
)
def test_host_port_localhost(cmd: Tuple[str]):
command = ["sanic", "fake.server.app", *cmd]
out, err, exitcode = capture(command)
lines = out.split(b"\n")
expected = b"Goin' Fast @ http://localhost:9999"
def test_host_port_localhost(cmd: Tuple[str, ...], caplog):
command = ["fake.server.app", *cmd]
lines = capture(command, caplog)
expected = "Goin' Fast @ http://localhost:9999"
assert exitcode != 1
assert expected in lines, f"Lines found: {lines}\nErr output: {err}"
assert expected in lines
@pytest.mark.parametrize(
"cmd",
"cmd,expected",
(
("--host=127.0.0.127", "--port=9999"),
("-H", "127.0.0.127", "-p", "9999"),
(
("--host=localhost", "--port=9999"),
"Goin' Fast @ http://localhost:9999",
),
(
("-H", "localhost", "-p", "9999"),
"Goin' Fast @ http://localhost:9999",
),
(
("--host=127.0.0.127", "--port=9999"),
"Goin' Fast @ http://127.0.0.127:9999",
),
(
("-H", "127.0.0.127", "-p", "9999"),
"Goin' Fast @ http://127.0.0.127:9999",
),
(("--host=::", "--port=9999"), "Goin' Fast @ http://[::]:9999"),
(("-H", "::", "-p", "9999"), "Goin' Fast @ http://[::]:9999"),
(("--host=::1", "--port=9999"), "Goin' Fast @ http://[::1]:9999"),
(("-H", "::1", "-p", "9999"), "Goin' Fast @ http://[::1]:9999"),
),
)
def test_host_port_ipv4(cmd: Tuple[str]):
command = ["sanic", "fake.server.app", *cmd]
out, err, exitcode = capture(command)
lines = out.split(b"\n")
expected = b"Goin' Fast @ http://127.0.0.127:9999"
def test_host_port(cmd: Tuple[str, ...], expected: str, caplog):
command = ["fake.server.app", *cmd]
lines = capture(command, caplog)
assert exitcode != 1
assert expected in lines, f"Lines found: {lines}\nErr output: {err}"
@pytest.mark.parametrize(
"cmd",
(
("--host=::", "--port=9999"),
("-H", "::", "-p", "9999"),
),
)
def test_host_port_ipv6_any(cmd: Tuple[str]):
command = ["sanic", "fake.server.app", *cmd]
out, err, exitcode = capture(command)
lines = out.split(b"\n")
expected = b"Goin' Fast @ http://[::]:9999"
assert exitcode != 1
assert expected in lines, f"Lines found: {lines}\nErr output: {err}"
@pytest.mark.parametrize(
"cmd",
(
("--host=::1", "--port=9999"),
("-H", "::1", "-p", "9999"),
),
)
def test_host_port_ipv6_loopback(cmd: Tuple[str]):
command = ["sanic", "fake.server.app", *cmd]
out, err, exitcode = capture(command)
lines = out.split(b"\n")
expected = b"Goin' Fast @ http://[::1]:9999"
assert exitcode != 1
assert expected in lines, f"Lines found: {lines}\nErr output: {err}"
assert expected in lines
@pytest.mark.parametrize(
@@ -230,82 +203,74 @@ def test_host_port_ipv6_loopback(cmd: Tuple[str]):
(4, ("-w", "4")),
),
)
def test_num_workers(num: int, cmd: Tuple[str]):
command = ["sanic", "fake.server.app", *cmd]
out, err, exitcode = capture(command)
lines = out.split(b"\n")
def test_num_workers(num: int, cmd: Tuple[str, ...], caplog):
command = ["fake.server.app", *cmd]
lines = capture(command, caplog)
if num == 1:
expected = b"mode: production, single worker"
expected = "mode: production, single worker"
else:
expected = (f"mode: production, w/ {num} workers").encode()
expected = f"mode: production, w/ {num} workers"
assert exitcode != 1
assert expected in lines, f"Expected {expected}\nLines found: {lines}"
assert expected in lines
@pytest.mark.parametrize("cmd", ("--debug",))
def test_debug(cmd: str):
command = ["sanic", "fake.server.app", cmd]
out, err, exitcode = capture(command)
lines = out.split(b"\n")
def test_debug(cmd: str, caplog):
command = ["fake.server.app", cmd]
lines = capture(command, caplog)
info = read_app_info(lines)
assert info["debug"] is True, f"Lines found: {lines}\nErr output: {err}"
assert (
info["auto_reload"] is False
), f"Lines found: {lines}\nErr output: {err}"
assert "dev" not in info, f"Lines found: {lines}\nErr output: {err}"
assert info["debug"] is True
assert info["auto_reload"] is False
@pytest.mark.parametrize("cmd", ("--dev", "-d"))
def test_dev(cmd: str):
command = ["sanic", "fake.server.app", cmd]
out, err, exitcode = capture(command)
lines = out.split(b"\n")
def test_dev(cmd: str, caplog):
command = ["fake.server.app", cmd]
lines = capture(command, caplog)
info = read_app_info(lines)
assert info["debug"] is True, f"Lines found: {lines}\nErr output: {err}"
assert (
info["auto_reload"] is True
), f"Lines found: {lines}\nErr output: {err}"
assert info["debug"] is True
assert info["auto_reload"] is True
@pytest.mark.parametrize("cmd", ("--auto-reload", "-r"))
def test_auto_reload(cmd: str):
command = ["sanic", "fake.server.app", cmd]
out, err, exitcode = capture(command)
lines = out.split(b"\n")
def test_auto_reload(cmd: str, caplog):
command = ["fake.server.app", cmd]
lines = capture(command, caplog)
info = read_app_info(lines)
assert info["debug"] is False, f"Lines found: {lines}\nErr output: {err}"
assert (
info["auto_reload"] is True
), f"Lines found: {lines}\nErr output: {err}"
assert "dev" not in info, f"Lines found: {lines}\nErr output: {err}"
assert info["debug"] is False
assert info["auto_reload"] is True
@pytest.mark.parametrize(
"cmd,expected", (("--access-log", True), ("--no-access-log", False))
"cmd,expected",
(
("", False),
("--debug", True),
("--access-log", True),
("--no-access-log", False),
),
)
def test_access_logs(cmd: str, expected: bool):
command = ["sanic", "fake.server.app", cmd]
out, err, exitcode = capture(command)
lines = out.split(b"\n")
def test_access_logs(cmd: str, expected: bool, caplog):
command = ["fake.server.app"]
if cmd:
command.append(cmd)
lines = capture(command, caplog)
info = read_app_info(lines)
assert (
info["access_log"] is expected
), f"Lines found: {lines}\nErr output: {err}"
assert info["access_log"] is expected
@pytest.mark.parametrize("cmd", ("--version", "-v"))
def test_version(cmd: str):
command = ["sanic", cmd]
out, err, exitcode = capture(command)
def test_version(cmd: str, caplog, capsys):
command = [cmd]
capture(command, caplog)
version_string = f"Sanic {__version__}; Routing {__routing_version__}\n"
assert out == version_string.encode("utf-8")
out, _ = capsys.readouterr()
assert version_string == out
@pytest.mark.parametrize(
@@ -315,12 +280,9 @@ def test_version(cmd: str):
("--no-noisy-exceptions", False),
),
)
def test_noisy_exceptions(cmd: str, expected: bool):
command = ["sanic", "fake.server.app", cmd]
out, err, exitcode = capture(command)
lines = out.split(b"\n")
def test_noisy_exceptions(cmd: str, expected: bool, caplog):
command = ["fake.server.app", cmd]
lines = capture(command, caplog)
info = read_app_info(lines)
assert (
info["noisy_exceptions"] is expected
), f"Lines found: {lines}\nErr output: {err}"
assert info["noisy_exceptions"] is expected

View File

@@ -293,26 +293,21 @@ def test_config_custom_defaults_with_env():
del environ[key]
def test_config_access_log_passing_in_run(app: Sanic):
assert app.config.ACCESS_LOG is True
@pytest.mark.parametrize("access_log", (True, False))
def test_config_access_log_passing_in_run(app: Sanic, access_log):
assert app.config.ACCESS_LOG is False
@app.listener("after_server_start")
async def _request(sanic, loop):
app.stop()
app.run(port=1340, access_log=False)
assert app.config.ACCESS_LOG is False
app.router.reset()
app.signal_router.reset()
app.run(port=1340, access_log=True)
assert app.config.ACCESS_LOG is True
app.run(port=1340, access_log=access_log, single_process=True)
assert app.config.ACCESS_LOG is access_log
@pytest.mark.asyncio
async def test_config_access_log_passing_in_create_server(app: Sanic):
assert app.config.ACCESS_LOG is True
assert app.config.ACCESS_LOG is False
@app.listener("after_server_start")
async def _request(sanic, loop):

View File

@@ -1,15 +1,24 @@
import pytest
from sanic import Sanic, text
from sanic.application.constants import Mode, Server, ServerStage
from sanic.constants import HTTP_METHODS, HTTPMethod
def test_string_compat():
assert "GET" == HTTPMethod.GET
assert "GET" in HTTP_METHODS
assert "get" == HTTPMethod.GET
assert "get" in HTTP_METHODS
@pytest.mark.parametrize("enum", (HTTPMethod, Server, Mode))
def test_string_compat(enum):
for key in enum.__members__.keys():
assert key.upper() == getattr(enum, key).upper()
assert key.lower() == getattr(enum, key).lower()
assert HTTPMethod.GET.lower() == "get"
assert HTTPMethod.GET.upper() == "GET"
def test_http_methods():
for value in HTTPMethod.__members__.values():
assert value in HTTP_METHODS
def test_server_stage():
assert ServerStage.SERVING > ServerStage.PARTIAL > ServerStage.STOPPED
def test_use_in_routes(app: Sanic):

View File

@@ -1,5 +1,6 @@
import pytest
from sanic import Sanic
from sanic.log import deprecation
@@ -7,3 +8,13 @@ def test_deprecation():
message = r"\[DEPRECATION v9\.9\] hello"
with pytest.warns(DeprecationWarning, match=message):
deprecation("hello", 9.9)
@pytest.mark.parametrize(
"filter,expected",
(("default", 1), ("once", 1), ("ignore", 0)),
)
def test_deprecation_filter(app: Sanic, filter, expected, recwarn):
app.config.DEPRECATION_FILTER = filter
deprecation("hello", 9.9)
assert len(recwarn) == expected

View File

@@ -1,44 +1,44 @@
# import pytest
import pytest
# from sanic.response import text
# from sanic.router import RouteExists
from sanic_routing.exceptions import RouteExists
from sanic.response import text
# @pytest.mark.parametrize(
# "method,attr, expected",
# [
# ("get", "text", "OK1 test"),
# ("post", "text", "OK2 test"),
# ("put", "text", "OK2 test"),
# ("delete", "status", 405),
# ],
# )
# def test_overload_dynamic_routes(app, method, attr, expected):
# @app.route("/overload/<param>", methods=["GET"])
# async def handler1(request, param):
# return text("OK1 " + param)
@pytest.mark.parametrize(
"method,attr, expected",
[
("get", "text", "OK1 test"),
("post", "text", "OK2 test"),
("put", "text", "OK2 test"),
],
)
def test_overload_dynamic_routes(app, method, attr, expected):
@app.route("/overload/<param>", methods=["GET"])
async def handler1(request, param):
return text("OK1 " + param)
# @app.route("/overload/<param>", methods=["POST", "PUT"])
# async def handler2(request, param):
# return text("OK2 " + param)
@app.route("/overload/<param>", methods=["POST", "PUT"])
async def handler2(request, param):
return text("OK2 " + param)
# request, response = getattr(app.test_client, method)("/overload/test")
# assert getattr(response, attr) == expected
request, response = getattr(app.test_client, method)("/overload/test")
assert getattr(response, attr) == expected
# def test_overload_dynamic_routes_exist(app):
# @app.route("/overload/<param>", methods=["GET"])
# async def handler1(request, param):
# return text("OK1 " + param)
def test_overload_dynamic_routes_exist(app):
@app.route("/overload/<param>", methods=["GET"])
async def handler1(request, param):
return text("OK1 " + param)
# @app.route("/overload/<param>", methods=["POST", "PUT"])
# async def handler2(request, param):
# return text("OK2 " + param)
@app.route("/overload/<param>", methods=["POST", "PUT"])
async def handler2(request, param):
return text("OK2 " + param)
# # if this doesn't raise an error, than at least the below should happen:
# # assert response.text == 'Duplicated'
# with pytest.raises(RouteExists):
# if this doesn't raise an error, than at least the below should happen:
# assert response.text == 'Duplicated'
with pytest.raises(RouteExists):
# @app.route("/overload/<param>", methods=["PUT", "DELETE"])
# async def handler3(request, param):
# return text("Duplicated")
@app.route("/overload/<param>", methods=["PUT", "DELETE"])
async def handler3(request, param):
return text("Duplicated")

View File

@@ -353,7 +353,7 @@ def test_config_fallback_before_and_after_startup(app):
_, response = app.test_client.get("/error")
assert response.status == 500
assert response.content_type == "text/plain; charset=utf-8"
assert response.content_type == "application/json"
def test_config_fallback_using_update_dict(app):

View File

@@ -7,7 +7,7 @@ from unittest.mock import Mock
import pytest
from bs4 import BeautifulSoup
from pytest import LogCaptureFixture, MonkeyPatch
from pytest import LogCaptureFixture, MonkeyPatch, WarningsRecorder
from sanic import Sanic, handlers
from sanic.exceptions import BadRequest, Forbidden, NotFound, ServerError
@@ -266,3 +266,22 @@ def test_exception_handler_response_was_sent(
_, response = app.test_client.get("/2")
assert "Error" in response.text
def test_warn_on_duplicate(
app: Sanic, caplog: LogCaptureFixture, recwarn: WarningsRecorder
):
@app.exception(ServerError)
async def exception_handler_1(request, exception):
...
@app.exception(ServerError)
async def exception_handler_2(request, exception):
...
assert len(caplog.records) == 1
assert len(recwarn) == 1
assert caplog.records[0].message == (
"Duplicate exception handler definition on: route=__ALL_ROUTES__ and "
"exception=<class 'sanic.exceptions.ServerError'>"
)

View File

@@ -25,19 +25,19 @@ def stoppable_app(app):
def test_ext_is_loaded(stoppable_app: Sanic, sanic_ext):
stoppable_app.run()
stoppable_app.run(single_process=True)
sanic_ext.Extend.assert_called_once_with(stoppable_app)
def test_ext_is_not_loaded(stoppable_app: Sanic, sanic_ext):
stoppable_app.config.AUTO_EXTEND = False
stoppable_app.run()
stoppable_app.run(single_process=True)
sanic_ext.Extend.assert_not_called()
def test_extend_with_args(stoppable_app: Sanic, sanic_ext):
stoppable_app.extend(built_in_extensions=False)
stoppable_app.run()
stoppable_app.run(single_process=True)
sanic_ext.Extend.assert_called_once_with(
stoppable_app, built_in_extensions=False, config=None, extensions=None
)
@@ -80,5 +80,5 @@ def test_can_access_app_ext_while_running(app: Sanic, sanic_ext, ext_instance):
app.ext.injection(IceCream)
app.stop()
app.run()
app.run(single_process=True)
ext_instance.injection.assert_called_with(IceCream)

View File

@@ -1,48 +1,61 @@
import asyncio
import logging
import time
from multiprocessing import Process
from pytest import LogCaptureFixture
import httpx
from sanic.response import empty
PORT = 42101
def test_no_exceptions_when_cancel_pending_request(app, caplog):
def test_no_exceptions_when_cancel_pending_request(
app, caplog: LogCaptureFixture
):
app.config.GRACEFUL_SHUTDOWN_TIMEOUT = 1
@app.get("/")
async def handler(request):
await asyncio.sleep(5)
@app.after_server_start
def shutdown(app, _):
time.sleep(0.2)
@app.listener("after_server_start")
async def _request(sanic, loop):
connect = asyncio.open_connection("127.0.0.1", 8000)
_, writer = await connect
writer.write(b"GET / HTTP/1.1\r\n\r\n")
app.stop()
def ping():
time.sleep(0.1)
response = httpx.get("http://127.0.0.1:8000")
print(response.status_code)
with caplog.at_level(logging.INFO):
app.run(single_process=True, access_log=True)
p = Process(target=ping)
p.start()
assert "Request: GET http:/// stopped. Transport is closed." in caplog.text
def test_completes_request(app, caplog: LogCaptureFixture):
app.config.GRACEFUL_SHUTDOWN_TIMEOUT = 1
@app.get("/")
async def handler(request):
await asyncio.sleep(0.5)
return empty()
@app.listener("after_server_start")
async def _request(sanic, loop):
connect = asyncio.open_connection("127.0.0.1", 8000)
_, writer = await connect
writer.write(b"GET / HTTP/1.1\r\n\r\n")
app.stop()
with caplog.at_level(logging.INFO):
app.run()
app.run(single_process=True, access_log=True)
p.kill()
assert ("sanic.access", 20, "") in caplog.record_tuples
info = 0
for record in caplog.record_tuples:
assert record[1] != logging.ERROR
if record[1] == logging.INFO:
info += 1
if record[2].startswith("Request:"):
assert record[2] == (
"Request: GET http://127.0.0.1:8000/ stopped. "
"Transport is closed."
)
assert info == 11
# Make sure that the server starts shutdown process before access log
index_stopping = 0
for idx, record in enumerate(caplog.records):
if record.message.startswith("Stopping worker"):
index_stopping = idx
break
index_request = caplog.record_tuples.index(("sanic.access", 20, ""))
assert index_request > index_stopping > 0

36
tests/test_handler.py Normal file
View File

@@ -0,0 +1,36 @@
from sanic.app import Sanic
from sanic.response import empty
from sanic.signals import Event
def test_handler_operation_order(app: Sanic):
operations = []
@app.on_request
async def on_request(_):
nonlocal operations
operations.append(1)
@app.on_response
async def on_response(*_):
nonlocal operations
operations.append(5)
@app.get("/")
async def handler(_):
nonlocal operations
operations.append(3)
return empty()
@app.signal(Event.HTTP_HANDLER_BEFORE)
async def handler_before(**_):
nonlocal operations
operations.append(2)
@app.signal(Event.HTTP_HANDLER_AFTER)
async def handler_after(**_):
nonlocal operations
operations.append(4)
app.test_client.get("/")
assert operations == [1, 2, 3, 4, 5]

View File

@@ -2,6 +2,7 @@ import json as stdjson
from collections import namedtuple
from pathlib import Path
from sys import version_info
import pytest
@@ -35,7 +36,7 @@ def test_app(app: Sanic):
@pytest.fixture
def runner(test_app):
def runner(test_app: Sanic):
client = ReusableClient(test_app, port=PORT)
client.run()
yield client
@@ -43,7 +44,7 @@ def runner(test_app):
@pytest.fixture
def client(runner):
def client(runner: ReusableClient):
client = namedtuple("Client", ("raw", "send", "recv"))
raw = RawClient(runner.host, runner.port)
@@ -74,7 +75,10 @@ def test_full_message(client):
"""
)
response = client.recv()
assert len(response) == 151
# AltSvcCheck touchup removes the Alt-Svc header from the
# response in the Python 3.9+ in this case
assert len(response) == (151 if version_info < (3, 9) else 140)
assert b"200 OK" in response

View File

@@ -61,6 +61,6 @@ def test_http1_response_has_alt_svc():
version=1,
port=PORT,
)
Sanic.serve()
Sanic.serve_single(app)
assert f'alt-svc: h3=":{PORT}"\r\n'.encode() in response

25
tests/test_init.py Normal file
View File

@@ -0,0 +1,25 @@
from importlib import import_module
import pytest
@pytest.mark.parametrize(
"item",
(
"__version__",
"Sanic",
"Blueprint",
"HTTPMethod",
"HTTPResponse",
"Request",
"Websocket",
"empty",
"file",
"html",
"json",
"redirect",
"text",
),
)
def test_imports(item):
import_module("sanic", item)

View File

@@ -3,18 +3,27 @@ import sys
from dataclasses import asdict, dataclass
from functools import partial
from json import dumps as sdumps
from string import ascii_lowercase
from typing import Dict
import pytest
try:
import ujson
from ujson import dumps as udumps
ujson_version = tuple(
map(int, ujson.__version__.strip(ascii_lowercase).split("."))
)
NO_UJSON = False
DEFAULT_DUMPS = udumps
except ModuleNotFoundError:
NO_UJSON = True
DEFAULT_DUMPS = partial(sdumps, separators=(",", ":"))
ujson_version = None
from sanic import Sanic
from sanic.response import BaseHTTPResponse, json
@@ -34,7 +43,7 @@ def foo():
@pytest.fixture
def payload(foo):
def payload(foo: Foo):
return {"foo": foo}
@@ -58,7 +67,7 @@ def test_change_encoder_to_some_custom():
@pytest.mark.skipif(NO_UJSON is True, reason="ujson not installed")
def test_json_response_ujson(payload):
def test_json_response_ujson(payload: Dict[str, Foo]):
"""ujson will look at __json__"""
response = json(payload)
assert response.body == b'{"foo":{"bar":"bar"}}'
@@ -75,7 +84,13 @@ def test_json_response_ujson(payload):
json(payload)
@pytest.mark.skipif(NO_UJSON is True, reason="ujson not installed")
@pytest.mark.skipif(
NO_UJSON is True or ujson_version >= (5, 4, 0),
reason=(
"ujson not installed or version is 5.4.0 or newer, "
"which can handle arbitrary size integers"
),
)
def test_json_response_json():
"""One of the easiest ways to tell the difference is that ujson cannot
serialize over 64 bits"""

View File

@@ -166,6 +166,7 @@ def test_access_log_client_ip_remote_addr(monkeypatch):
monkeypatch.setattr(sanic.http.http1, "access_logger", access)
app = Sanic("test_logging")
app.config.ACCESS_LOG = True
app.config.PROXIES_COUNT = 2
@app.route("/")
@@ -193,6 +194,7 @@ def test_access_log_client_ip_reqip(monkeypatch):
monkeypatch.setattr(sanic.http.http1, "access_logger", access)
app = Sanic("test_logging")
app.config.ACCESS_LOG = True
@app.route("/")
async def handler(request):

View File

@@ -30,9 +30,12 @@ def test_get_logo_returns_expected_logo(tty, full, expected):
def test_get_logo_returns_no_colors_on_apple_terminal():
platform = sys.platform
sys.platform = "darwin"
os.environ["TERM_PROGRAM"] = "Apple_Terminal"
with patch("sys.stdout.isatty") as isatty:
isatty.return_value = False
sys.platform = "darwin"
os.environ["TERM_PROGRAM"] = "Apple_Terminal"
logo = get_logo()
assert "\033" not in logo
sys.platform = platform
del os.environ["TERM_PROGRAM"]

View File

@@ -2,12 +2,22 @@ import logging
from asyncio import CancelledError
from itertools import count
from unittest.mock import Mock
import pytest
from sanic.exceptions import NotFound
from sanic.middleware import Middleware
from sanic.request import Request
from sanic.response import HTTPResponse, json, text
@pytest.fixture(autouse=True)
def reset_middleware():
yield
Middleware.reset_count()
# ------------------------------------------------------------ #
# GET
# ------------------------------------------------------------ #
@@ -166,7 +176,7 @@ def test_middleware_response_raise_cancelled_error(app, caplog):
with caplog.at_level(logging.ERROR):
reqrequest, response = app.test_client.get("/")
assert response.status == 503
assert response.status == 500
assert (
"sanic.error",
logging.ERROR,
@@ -183,7 +193,7 @@ def test_middleware_response_raise_exception(app, caplog):
with caplog.at_level(logging.ERROR):
reqrequest, response = app.test_client.get("/fail")
assert response.status == 404
assert response.status == 500
# 404 errors are not logged
assert (
"sanic.error",
@@ -318,6 +328,15 @@ def test_middleware_return_response(app):
resp1 = await request.respond()
return resp1
_, response = app.test_client.get("/")
app.test_client.get("/")
assert response_middleware_run_count == 1
assert request_middleware_run_count == 1
def test_middleware_object():
mock = Mock()
middleware = Middleware(mock)
middleware(1, 2, 3, answer=42)
mock.assert_called_once_with(1, 2, 3, answer=42)
assert middleware.order == (0, 0)

View File

@@ -0,0 +1,83 @@
from functools import partial
import pytest
from sanic import Sanic
from sanic.response import json
PRIORITY_TEST_CASES = (
([0, 1, 2], [1, 1, 1]),
([0, 1, 2], [1, 1, None]),
([0, 1, 2], [1, None, None]),
([0, 1, 2], [2, 1, None]),
([0, 1, 2], [2, 2, None]),
([0, 1, 2], [3, 2, 1]),
([0, 1, 2], [None, None, None]),
([0, 2, 1], [1, None, 1]),
([0, 2, 1], [2, None, 1]),
([0, 2, 1], [2, None, 2]),
([0, 2, 1], [3, 1, 2]),
([1, 0, 2], [1, 2, None]),
([1, 0, 2], [2, 3, 1]),
([1, 0, 2], [None, 1, None]),
([1, 2, 0], [1, 3, 2]),
([1, 2, 0], [None, 1, 1]),
([1, 2, 0], [None, 2, 1]),
([1, 2, 0], [None, 2, 2]),
([2, 0, 1], [1, None, 2]),
([2, 0, 1], [2, 1, 3]),
([2, 0, 1], [None, None, 1]),
([2, 1, 0], [1, 2, 3]),
([2, 1, 0], [None, 1, 2]),
)
@pytest.mark.parametrize(
"expected,priorities",
PRIORITY_TEST_CASES,
)
def test_request_middleware_order_priority(app: Sanic, expected, priorities):
order = []
def add_ident(request, ident):
order.append(ident)
@app.get("/")
def handler(request):
return json(None)
for ident, priority in enumerate(priorities):
kwargs = {}
if priority is not None:
kwargs["priority"] = priority
app.on_request(partial(add_ident, ident=ident), **kwargs)
app.test_client.get("/")
assert order == expected
@pytest.mark.parametrize(
"expected,priorities",
PRIORITY_TEST_CASES,
)
def test_response_middleware_order_priority(app: Sanic, expected, priorities):
order = []
def add_ident(request, response, ident):
order.append(ident)
@app.get("/")
def handler(request):
return json(None)
for ident, priority in enumerate(priorities):
kwargs = {}
if priority is not None:
kwargs["priority"] = priority
app.on_response(partial(add_ident, ident=ident), **kwargs)
app.test_client.get("/")
assert order[::-1] == expected

View File

@@ -3,15 +3,23 @@ import os
import platform
import sys
from unittest.mock import Mock
from unittest.mock import Mock, patch
import pytest
from sanic import Sanic, __version__
from sanic import __version__
from sanic.application.logo import BASE_LOGO
from sanic.application.motd import MOTD, MOTDTTY
@pytest.fixture(autouse=True)
def reset():
try:
del os.environ["SANIC_MOTD_OUTPUT"]
except KeyError:
...
def test_logo_base(app, run_startup):
logs = run_startup(app)
@@ -63,20 +71,13 @@ def test_motd_display(caplog):
@pytest.mark.skipif(sys.version_info < (3, 8), reason="Not on 3.7")
def test_reload_dirs(app):
app.config.LOGO = None
app.config.MOTD = True
app.config.AUTO_RELOAD = True
app.prepare(reload_dir="./", auto_reload=True, motd_display={"foo": "bar"})
existing = MOTD.output
MOTD.output = Mock()
app.motd("foo")
MOTD.output.assert_called_once()
assert (
MOTD.output.call_args.args[2]["auto-reload"]
== f"enabled, {os.getcwd()}"
)
assert MOTD.output.call_args.args[3] == {"foo": "bar"}
MOTD.output = existing
Sanic._app_registry = {}
with patch.object(MOTD, "output") as mock:
app.prepare(
reload_dir="./", auto_reload=True, motd_display={"foo": "bar"}
)
mock.assert_called()
assert mock.call_args.args[2]["auto-reload"] == f"enabled, {os.getcwd()}"
assert mock.call_args.args[3] == {"foo": "bar"}

View File

@@ -1,207 +1,207 @@
import logging
# import logging
from unittest.mock import Mock
# from unittest.mock import Mock
import pytest
# import pytest
from sanic import Sanic
from sanic.response import text
from sanic.server.async_server import AsyncioServer
from sanic.signals import Event
from sanic.touchup.schemes.ode import OptionalDispatchEvent
# from sanic import Sanic
# from sanic.response import text
# from sanic.server.async_server import AsyncioServer
# from sanic.signals import Event
# from sanic.touchup.schemes.ode import OptionalDispatchEvent
try:
from unittest.mock import AsyncMock
except ImportError:
from tests.asyncmock import AsyncMock # type: ignore
# try:
# from unittest.mock import AsyncMock
# except ImportError:
# from tests.asyncmock import AsyncMock # type: ignore
@pytest.fixture
def app_one():
app = Sanic("One")
# @pytest.fixture
# def app_one():
# app = Sanic("One")
@app.get("/one")
async def one(request):
return text("one")
# @app.get("/one")
# async def one(request):
# return text("one")
return app
# return app
@pytest.fixture
def app_two():
app = Sanic("Two")
# @pytest.fixture
# def app_two():
# app = Sanic("Two")
@app.get("/two")
async def two(request):
return text("two")
# @app.get("/two")
# async def two(request):
# return text("two")
return app
# return app
@pytest.fixture(autouse=True)
def clean():
Sanic._app_registry = {}
yield
# @pytest.fixture(autouse=True)
# def clean():
# Sanic._app_registry = {}
# yield
def test_serve_same_app_multiple_tuples(app_one, run_multi):
app_one.prepare(port=23456)
app_one.prepare(port=23457)
# def test_serve_same_app_multiple_tuples(app_one, run_multi):
# app_one.prepare(port=23456)
# app_one.prepare(port=23457)
logs = run_multi(app_one)
assert (
"sanic.root",
logging.INFO,
"Goin' Fast @ http://127.0.0.1:23456",
) in logs
assert (
"sanic.root",
logging.INFO,
"Goin' Fast @ http://127.0.0.1:23457",
) in logs
# logs = run_multi(app_one)
# assert (
# "sanic.root",
# logging.INFO,
# "Goin' Fast @ http://127.0.0.1:23456",
# ) in logs
# assert (
# "sanic.root",
# logging.INFO,
# "Goin' Fast @ http://127.0.0.1:23457",
# ) in logs
def test_serve_multiple_apps(app_one, app_two, run_multi):
app_one.prepare(port=23456)
app_two.prepare(port=23457)
# def test_serve_multiple_apps(app_one, app_two, run_multi):
# app_one.prepare(port=23456)
# app_two.prepare(port=23457)
logs = run_multi(app_one)
assert (
"sanic.root",
logging.INFO,
"Goin' Fast @ http://127.0.0.1:23456",
) in logs
assert (
"sanic.root",
logging.INFO,
"Goin' Fast @ http://127.0.0.1:23457",
) in logs
# logs = run_multi(app_one)
# assert (
# "sanic.root",
# logging.INFO,
# "Goin' Fast @ http://127.0.0.1:23456",
# ) in logs
# assert (
# "sanic.root",
# logging.INFO,
# "Goin' Fast @ http://127.0.0.1:23457",
# ) in logs
def test_listeners_on_secondary_app(app_one, app_two, run_multi):
app_one.prepare(port=23456)
app_two.prepare(port=23457)
# def test_listeners_on_secondary_app(app_one, app_two, run_multi):
# app_one.prepare(port=23456)
# app_two.prepare(port=23457)
before_start = AsyncMock()
after_start = AsyncMock()
before_stop = AsyncMock()
after_stop = AsyncMock()
# before_start = AsyncMock()
# after_start = AsyncMock()
# before_stop = AsyncMock()
# after_stop = AsyncMock()
app_two.before_server_start(before_start)
app_two.after_server_start(after_start)
app_two.before_server_stop(before_stop)
app_two.after_server_stop(after_stop)
# app_two.before_server_start(before_start)
# app_two.after_server_start(after_start)
# app_two.before_server_stop(before_stop)
# app_two.after_server_stop(after_stop)
run_multi(app_one)
# run_multi(app_one)
before_start.assert_awaited_once()
after_start.assert_awaited_once()
before_stop.assert_awaited_once()
after_stop.assert_awaited_once()
# before_start.assert_awaited_once()
# after_start.assert_awaited_once()
# before_stop.assert_awaited_once()
# after_stop.assert_awaited_once()
@pytest.mark.parametrize(
"events",
(
(Event.HTTP_LIFECYCLE_BEGIN,),
(Event.HTTP_LIFECYCLE_BEGIN, Event.HTTP_LIFECYCLE_COMPLETE),
(
Event.HTTP_LIFECYCLE_BEGIN,
Event.HTTP_LIFECYCLE_COMPLETE,
Event.HTTP_LIFECYCLE_REQUEST,
),
),
)
def test_signal_synchronization(app_one, app_two, run_multi, events):
app_one.prepare(port=23456)
app_two.prepare(port=23457)
# @pytest.mark.parametrize(
# "events",
# (
# (Event.HTTP_LIFECYCLE_BEGIN,),
# (Event.HTTP_LIFECYCLE_BEGIN, Event.HTTP_LIFECYCLE_COMPLETE),
# (
# Event.HTTP_LIFECYCLE_BEGIN,
# Event.HTTP_LIFECYCLE_COMPLETE,
# Event.HTTP_LIFECYCLE_REQUEST,
# ),
# ),
# )
# def test_signal_synchronization(app_one, app_two, run_multi, events):
# app_one.prepare(port=23456)
# app_two.prepare(port=23457)
for event in events:
app_one.signal(event)(AsyncMock())
# for event in events:
# app_one.signal(event)(AsyncMock())
run_multi(app_one)
# run_multi(app_one)
assert len(app_two.signal_router.routes) == len(events) + 1
# assert len(app_two.signal_router.routes) == len(events) + 1
signal_handlers = {
signal.handler
for signal in app_two.signal_router.routes
if signal.name.startswith("http")
}
# signal_handlers = {
# signal.handler
# for signal in app_two.signal_router.routes
# if signal.name.startswith("http")
# }
assert len(signal_handlers) == 1
assert list(signal_handlers)[0] is OptionalDispatchEvent.noop
# assert len(signal_handlers) == 1
# assert list(signal_handlers)[0] is OptionalDispatchEvent.noop
def test_warning_main_process_listeners_on_secondary(
app_one, app_two, run_multi
):
app_two.main_process_start(AsyncMock())
app_two.main_process_stop(AsyncMock())
app_one.prepare(port=23456)
app_two.prepare(port=23457)
# def test_warning_main_process_listeners_on_secondary(
# app_one, app_two, run_multi
# ):
# app_two.main_process_start(AsyncMock())
# app_two.main_process_stop(AsyncMock())
# app_one.prepare(port=23456)
# app_two.prepare(port=23457)
log = run_multi(app_one)
# log = run_multi(app_one)
message = (
f"Sanic found 2 listener(s) on "
"secondary applications attached to the main "
"process. These will be ignored since main "
"process listeners can only be attached to your "
"primary application: "
f"{repr(app_one)}"
)
# message = (
# f"Sanic found 2 listener(s) on "
# "secondary applications attached to the main "
# "process. These will be ignored since main "
# "process listeners can only be attached to your "
# "primary application: "
# f"{repr(app_one)}"
# )
assert ("sanic.error", logging.WARNING, message) in log
# assert ("sanic.error", logging.WARNING, message) in log
def test_no_applications():
Sanic._app_registry = {}
message = "Did not find any applications."
with pytest.raises(RuntimeError, match=message):
Sanic.serve()
# def test_no_applications():
# Sanic._app_registry = {}
# message = "Did not find any applications."
# with pytest.raises(RuntimeError, match=message):
# Sanic.serve()
def test_oserror_warning(app_one, app_two, run_multi, capfd):
orig = AsyncioServer.__await__
AsyncioServer.__await__ = Mock(side_effect=OSError("foo"))
app_one.prepare(port=23456, workers=2)
app_two.prepare(port=23457, workers=2)
# def test_oserror_warning(app_one, app_two, run_multi, capfd):
# orig = AsyncioServer.__await__
# AsyncioServer.__await__ = Mock(side_effect=OSError("foo"))
# app_one.prepare(port=23456, workers=2)
# app_two.prepare(port=23457, workers=2)
run_multi(app_one)
# run_multi(app_one)
captured = capfd.readouterr()
assert (
"An OSError was detected on startup. The encountered error was: foo"
) in captured.err
# captured = capfd.readouterr()
# assert (
# "An OSError was detected on startup. The encountered error was: foo"
# ) in captured.err
AsyncioServer.__await__ = orig
# AsyncioServer.__await__ = orig
def test_running_multiple_offset_warning(app_one, app_two, run_multi, capfd):
app_one.prepare(port=23456, workers=2)
app_two.prepare(port=23457)
# def test_running_multiple_offset_warning(app_one, app_two, run_multi, capfd):
# app_one.prepare(port=23456, workers=2)
# app_two.prepare(port=23457)
run_multi(app_one)
# run_multi(app_one)
captured = capfd.readouterr()
assert (
f"The primary application {repr(app_one)} is running "
"with 2 worker(s). All "
"application instances will run with the same number. "
f"You requested {repr(app_two)} to run with "
"1 worker(s), which will be ignored "
"in favor of the primary application."
) in captured.err
# captured = capfd.readouterr()
# assert (
# f"The primary application {repr(app_one)} is running "
# "with 2 worker(s). All "
# "application instances will run with the same number. "
# f"You requested {repr(app_two)} to run with "
# "1 worker(s), which will be ignored "
# "in favor of the primary application."
# ) in captured.err
def test_running_multiple_secondary(app_one, app_two, run_multi, capfd):
app_one.prepare(port=23456, workers=2)
app_two.prepare(port=23457)
# def test_running_multiple_secondary(app_one, app_two, run_multi, capfd):
# app_one.prepare(port=23456, workers=2)
# app_two.prepare(port=23457)
before_start = AsyncMock()
app_two.before_server_start(before_start)
run_multi(app_one)
# before_start = AsyncMock()
# app_two.before_server_start(before_start)
# run_multi(app_one)
before_start.await_count == 2
# before_start.await_count == 2

View File

@@ -4,13 +4,15 @@ import pickle
import random
import signal
from asyncio import sleep
import pytest
from sanic_testing.testing import HOST, PORT
from sanic import Blueprint
from sanic import Blueprint, text
from sanic.log import logger
from sanic.response import text
from sanic.server.socket import configure_socket
@pytest.mark.skipif(
@@ -24,14 +26,108 @@ def test_multiprocessing(app):
num_workers = random.choice(range(2, multiprocessing.cpu_count() * 2 + 1))
process_list = set()
@app.after_server_start
async def shutdown(app):
await sleep(2.1)
app.stop()
def stop_on_alarm(*args):
for process in multiprocessing.active_children():
process_list.add(process.pid)
process.terminate()
signal.signal(signal.SIGALRM, stop_on_alarm)
signal.alarm(3)
app.run(HOST, PORT, workers=num_workers)
signal.alarm(2)
app.run(HOST, 4120, workers=num_workers, debug=True)
assert len(process_list) == num_workers + 1
@pytest.mark.skipif(
not hasattr(signal, "SIGALRM"),
reason="SIGALRM is not implemented for this platform, we have to come "
"up with another timeout strategy to test these",
)
def test_multiprocessing_legacy(app):
"""Tests that the number of children we produce is correct"""
# Selects a number at random so we can spot check
num_workers = random.choice(range(2, multiprocessing.cpu_count() * 2 + 1))
process_list = set()
@app.after_server_start
async def shutdown(app):
await sleep(2.1)
app.stop()
def stop_on_alarm(*args):
for process in multiprocessing.active_children():
process_list.add(process.pid)
signal.signal(signal.SIGALRM, stop_on_alarm)
signal.alarm(2)
app.run(HOST, 4121, workers=num_workers, debug=True, legacy=True)
assert len(process_list) == num_workers
@pytest.mark.skipif(
not hasattr(signal, "SIGALRM"),
reason="SIGALRM is not implemented for this platform, we have to come "
"up with another timeout strategy to test these",
)
def test_multiprocessing_legacy_sock(app):
"""Tests that the number of children we produce is correct"""
# Selects a number at random so we can spot check
num_workers = random.choice(range(2, multiprocessing.cpu_count() * 2 + 1))
process_list = set()
@app.after_server_start
async def shutdown(app):
await sleep(2.1)
app.stop()
def stop_on_alarm(*args):
for process in multiprocessing.active_children():
process_list.add(process.pid)
signal.signal(signal.SIGALRM, stop_on_alarm)
signal.alarm(2)
sock = configure_socket(
{
"host": HOST,
"port": 4121,
"unix": None,
"backlog": 100,
}
)
app.run(workers=num_workers, debug=True, legacy=True, sock=sock)
sock.close()
assert len(process_list) == num_workers
@pytest.mark.skipif(
not hasattr(signal, "SIGALRM"),
reason="SIGALRM is not implemented for this platform, we have to come "
"up with another timeout strategy to test these",
)
def test_multiprocessing_legacy_unix(app):
"""Tests that the number of children we produce is correct"""
# Selects a number at random so we can spot check
num_workers = random.choice(range(2, multiprocessing.cpu_count() * 2 + 1))
process_list = set()
@app.after_server_start
async def shutdown(app):
await sleep(2.1)
app.stop()
def stop_on_alarm(*args):
for process in multiprocessing.active_children():
process_list.add(process.pid)
signal.signal(signal.SIGALRM, stop_on_alarm)
signal.alarm(2)
app.run(workers=num_workers, debug=True, legacy=True, unix="./test.sock")
assert len(process_list) == num_workers
@@ -45,19 +141,23 @@ def test_multiprocessing_with_blueprint(app):
num_workers = random.choice(range(2, multiprocessing.cpu_count() * 2 + 1))
process_list = set()
@app.after_server_start
async def shutdown(app):
await sleep(2.1)
app.stop()
def stop_on_alarm(*args):
for process in multiprocessing.active_children():
process_list.add(process.pid)
process.terminate()
signal.signal(signal.SIGALRM, stop_on_alarm)
signal.alarm(3)
signal.alarm(2)
bp = Blueprint("test_text")
app.blueprint(bp)
app.run(HOST, PORT, workers=num_workers)
app.run(HOST, 4121, workers=num_workers, debug=True)
assert len(process_list) == num_workers
assert len(process_list) == num_workers + 1
# this function must be outside a test function so that it can be
@@ -66,62 +166,58 @@ def handler(request):
return text("Hello")
def stop(app):
app.stop()
# Multiprocessing on Windows requires app to be able to be pickled
@pytest.mark.parametrize("protocol", [3, 4])
def test_pickle_app(app, protocol):
app.route("/")(handler)
app.router.finalize()
app.after_server_start(stop)
app.router.reset()
app.signal_router.reset()
p_app = pickle.dumps(app, protocol=protocol)
del app
up_p_app = pickle.loads(p_app)
up_p_app.router.finalize()
assert up_p_app
request, response = up_p_app.test_client.get("/")
assert response.text == "Hello"
up_p_app.run(single_process=True)
@pytest.mark.parametrize("protocol", [3, 4])
def test_pickle_app_with_bp(app, protocol):
bp = Blueprint("test_text")
bp.route("/")(handler)
bp.after_server_start(stop)
app.blueprint(bp)
app.router.finalize()
app.router.reset()
app.signal_router.reset()
p_app = pickle.dumps(app, protocol=protocol)
del app
up_p_app = pickle.loads(p_app)
up_p_app.router.finalize()
assert up_p_app
request, response = up_p_app.test_client.get("/")
assert response.text == "Hello"
up_p_app.run(single_process=True)
@pytest.mark.parametrize("protocol", [3, 4])
def test_pickle_app_with_static(app, protocol):
app.route("/")(handler)
app.after_server_start(stop)
app.static("/static", "/tmp/static")
app.router.finalize()
app.router.reset()
app.signal_router.reset()
p_app = pickle.dumps(app, protocol=protocol)
del app
up_p_app = pickle.loads(p_app)
up_p_app.router.finalize()
assert up_p_app
request, response = up_p_app.test_client.get("/static/missing.txt")
assert response.status == 404
up_p_app.run(single_process=True)
def test_main_process_event(app, caplog):
# Selects a number at random so we can spot check
num_workers = random.choice(range(2, multiprocessing.cpu_count() * 2 + 1))
def stop_on_alarm(*args):
for process in multiprocessing.active_children():
process.terminate()
signal.signal(signal.SIGALRM, stop_on_alarm)
signal.alarm(1)
app.after_server_start(stop)
@app.listener("main_process_start")
def main_process_start(app, loop):

View File

@@ -1,4 +1,3 @@
from httpx import AsyncByteStream
from sanic_testing.reusable import ReusableClient
from sanic.response import json, text

View File

@@ -17,6 +17,10 @@ def no_skip():
yield
Sanic._app_registry = {}
Sanic.should_auto_reload = should_auto_reload
try:
del os.environ["SANIC_MOTD_OUTPUT"]
except KeyError:
...
def get_primary(app: Sanic) -> ApplicationServerInfo:
@@ -55,17 +59,21 @@ def test_reload_dir(app: Sanic, dirs, caplog):
assert ("sanic.root", logging.WARNING, message) in caplog.record_tuples
def test_fast(app: Sanic, run_multi):
app.prepare(fast=True)
def test_fast(app: Sanic, caplog):
@app.after_server_start
async def stop(app, _):
app.stop()
try:
workers = len(os.sched_getaffinity(0))
except AttributeError:
workers = os.cpu_count() or 1
with caplog.at_level(logging.INFO):
app.prepare(fast=True)
assert app.state.fast
assert app.state.workers == workers
logs = run_multi(app, logging.INFO)
messages = [m[2] for m in logs]
messages = [m[2] for m in caplog.record_tuples]
assert f"mode: production, goin' fast w/ {workers} workers" in messages

View File

@@ -243,3 +243,54 @@ def test_request_stream_id(app):
_, resp = app.test_client.get("/")
assert resp.text == "Stream ID is only a property of a HTTP/3 request"
@pytest.mark.parametrize(
"method,safe",
(
("DELETE", False),
("GET", True),
("HEAD", True),
("OPTIONS", True),
("PATCH", False),
("POST", False),
("PUT", False),
),
)
def test_request_safe(method, safe):
request = Request(b"/", {}, None, method, None, None)
assert request.is_safe is safe
@pytest.mark.parametrize(
"method,idempotent",
(
("DELETE", True),
("GET", True),
("HEAD", True),
("OPTIONS", True),
("PATCH", False),
("POST", False),
("PUT", True),
),
)
def test_request_idempotent(method, idempotent):
request = Request(b"/", {}, None, method, None, None)
assert request.is_idempotent is idempotent
@pytest.mark.parametrize(
"method,cacheable",
(
("DELETE", False),
("GET", True),
("HEAD", True),
("OPTIONS", False),
("PATCH", False),
("POST", False),
("PUT", False),
),
)
def test_request_cacheable(method, cacheable):
request = Request(b"/", {}, None, method, None, None)
assert request.is_cacheable is cacheable

View File

@@ -568,7 +568,7 @@ def test_streaming_echo():
@app.listener("after_server_start")
async def client_task(app, loop):
try:
reader, writer = await asyncio.open_connection(*addr)
reader, writer = await asyncio.open_connection("localhost", 8000)
await client(app, reader, writer)
finally:
writer.close()
@@ -576,7 +576,7 @@ def test_streaming_echo():
async def client(app, reader, writer):
# Unfortunately httpx does not support 2-way streaming, so do it by hand.
host = f"host: {addr[0]}:{addr[1]}\r\n".encode()
host = f"host: localhost:8000\r\n".encode()
writer.write(
b"POST /echo HTTP/1.1\r\n" + host + b"content-length: 2\r\n"
b"content-type: text/plain; charset=utf-8\r\n"
@@ -625,6 +625,4 @@ def test_streaming_echo():
# Use random port for tests
with closing(socket()) as sock:
sock.bind(("127.0.0.1", 0))
addr = sock.getsockname()
app.run(sock=sock, access_log=False)
app.run(access_log=False)

View File

@@ -1,6 +1,7 @@
import asyncio
import inspect
import os
import time
from collections import namedtuple
from datetime import datetime
@@ -730,8 +731,10 @@ def test_file_response_headers(
test_expires = test_last_modified.timestamp() + test_max_age
@app.route("/files/cached/<filename>", methods=["GET"])
def file_route_cache(request, filename):
file_path = (Path(static_file_directory) / file_name).absolute()
def file_route_cache(request: Request, filename: str):
file_path = (
Path(static_file_directory) / unquote(filename)
).absolute()
return file(
file_path, max_age=test_max_age, last_modified=test_last_modified
)
@@ -739,18 +742,26 @@ def test_file_response_headers(
@app.route(
"/files/cached_default_last_modified/<filename>", methods=["GET"]
)
def file_route_cache_default_last_modified(request, filename):
file_path = (Path(static_file_directory) / file_name).absolute()
def file_route_cache_default_last_modified(
request: Request, filename: str
):
file_path = (
Path(static_file_directory) / unquote(filename)
).absolute()
return file(file_path, max_age=test_max_age)
@app.route("/files/no_cache/<filename>", methods=["GET"])
def file_route_no_cache(request, filename):
file_path = (Path(static_file_directory) / file_name).absolute()
def file_route_no_cache(request: Request, filename: str):
file_path = (
Path(static_file_directory) / unquote(filename)
).absolute()
return file(file_path)
@app.route("/files/no_store/<filename>", methods=["GET"])
def file_route_no_store(request, filename):
file_path = (Path(static_file_directory) / file_name).absolute()
def file_route_no_store(request: Request, filename: str):
file_path = (
Path(static_file_directory) / unquote(filename)
).absolute()
return file(file_path, no_store=True)
_, response = app.test_client.get(f"/files/cached/{file_name}")
@@ -767,11 +778,11 @@ def test_file_response_headers(
== formatdate(test_expires, usegmt=True)[:-6]
# [:-6] to allow at most 1 min difference
# It's minimal for cases like:
# Thu, 26 May 2022 05:36:49 GMT
# Thu, 26 May 2022 05:36:59 GMT
# AND
# Thu, 26 May 2022 05:36:50 GMT
# Thu, 26 May 2022 05:37:00 GMT
)
assert response.status == 200
assert "last-modified" in headers and headers.get(
"last-modified"
) == formatdate(test_last_modified.timestamp(), usegmt=True)
@@ -786,15 +797,127 @@ def test_file_response_headers(
assert "last-modified" in headers and headers.get(
"last-modified"
) == formatdate(file_last_modified, usegmt=True)
assert response.status == 200
_, response = app.test_client.get(f"/files/no_cache/{file_name}")
headers = response.headers
assert "cache-control" in headers and f"no-cache" == headers.get(
"cache-control"
)
assert response.status == 200
_, response = app.test_client.get(f"/files/no_store/{file_name}")
headers = response.headers
assert "cache-control" in headers and f"no-store" == headers.get(
"cache-control"
)
assert response.status == 200
def test_file_validate(app: Sanic, static_file_directory: str):
file_name = "test_validate.txt"
static_file_directory = Path(static_file_directory)
file_path = static_file_directory / file_name
file_path = file_path.absolute()
test_max_age = 10
with open(file_path, "w+") as f:
f.write("foo\n")
@app.route("/validate", methods=["GET"])
def file_route_cache(request: Request):
return file(
file_path,
request_headers=request.headers,
max_age=test_max_age,
validate_when_requested=True,
)
_, response = app.test_client.get("/validate")
assert response.status == 200
assert response.body == b"foo\n"
last_modified = response.headers["Last-Modified"]
time.sleep(1)
with open(file_path, "a") as f:
f.write("bar\n")
_, response = app.test_client.get(
"/validate", headers={"If-Modified-Since": last_modified}
)
assert response.status == 200
assert response.body == b"foo\nbar\n"
last_modified = response.headers["Last-Modified"]
_, response = app.test_client.get(
"/validate", headers={"if-modified-since": last_modified}
)
assert response.status == 304
assert response.body == b""
file_path.unlink()
@pytest.mark.parametrize(
"file_name", ["test.file", "decode me.txt", "python.png"]
)
def test_file_validating_invalid_header(
app: Sanic, file_name: str, static_file_directory: str
):
@app.route("/files/<filename>", methods=["GET"])
def file_route(request: Request, filename: str):
handler_file_path = (
Path(static_file_directory) / unquote(filename)
).absolute()
return file(
handler_file_path,
request_headers=request.headers,
validate_when_requested=True,
)
_, response = app.test_client.get(f"/files/{file_name}")
assert response.status == 200
assert response.body == get_file_content(static_file_directory, file_name)
_, response = app.test_client.get(
f"/files/{file_name}", headers={"if-modified-since": "invalid-value"}
)
assert response.status == 200
assert response.body == get_file_content(static_file_directory, file_name)
_, response = app.test_client.get(
f"/files/{file_name}", headers={"if-modified-since": ""}
)
assert response.status == 200
assert response.body == get_file_content(static_file_directory, file_name)
@pytest.mark.parametrize(
"file_name", ["test.file", "decode me.txt", "python.png"]
)
def test_file_validating_304_response(
app: Sanic, file_name: str, static_file_directory: str
):
@app.route("/files/<filename>", methods=["GET"])
def file_route(request: Request, filename: str):
handler_file_path = (
Path(static_file_directory) / unquote(filename)
).absolute()
return file(
handler_file_path,
request_headers=request.headers,
validate_when_requested=True,
)
_, response = app.test_client.get(f"/files/{file_name}")
assert response.status == 200
assert response.body == get_file_content(static_file_directory, file_name)
_, response = app.test_client.get(
f"/files/{file_name}",
headers={"if-modified-since": response.headers["Last-Modified"]},
)
assert response.status == 304
assert response.body == b""

View File

@@ -3,69 +3,80 @@ import logging
from time import sleep
import pytest
from sanic import Sanic
from sanic.exceptions import ServiceUnavailable
from sanic.log import LOGGING_CONFIG_DEFAULTS
from sanic.response import text
response_timeout_app = Sanic("test_response_timeout")
response_timeout_default_app = Sanic("test_response_timeout_default")
response_handler_cancelled_app = Sanic("test_response_handler_cancelled")
@pytest.fixture
def response_timeout_app():
app = Sanic("test_response_timeout")
app.config.RESPONSE_TIMEOUT = 1
response_timeout_app.config.RESPONSE_TIMEOUT = 1
response_timeout_default_app.config.RESPONSE_TIMEOUT = 1
response_handler_cancelled_app.config.RESPONSE_TIMEOUT = 1
@app.route("/1")
async def handler_1(request):
await asyncio.sleep(2)
return text("OK")
response_handler_cancelled_app.ctx.flag = False
@app.exception(ServiceUnavailable)
def handler_exception(request, exception):
return text("Response Timeout from error_handler.", 503)
return app
@response_timeout_app.route("/1")
async def handler_1(request):
await asyncio.sleep(2)
return text("OK")
@pytest.fixture
def response_timeout_default_app():
app = Sanic("test_response_timeout_default")
app.config.RESPONSE_TIMEOUT = 1
@app.route("/1")
async def handler_2(request):
await asyncio.sleep(2)
return text("OK")
return app
@response_timeout_app.exception(ServiceUnavailable)
def handler_exception(request, exception):
return text("Response Timeout from error_handler.", 503)
@pytest.fixture
def response_handler_cancelled_app():
app = Sanic("test_response_handler_cancelled")
app.config.RESPONSE_TIMEOUT = 1
app.ctx.flag = False
@app.exception(asyncio.CancelledError)
def handler_cancelled(request, exception):
# If we get a CancelledError, it means sanic has already sent a response,
# we should not ever have to handle a CancelledError.
response_handler_cancelled_app.ctx.flag = True
return text("App received CancelledError!", 500)
# The client will never receive this response, because the socket
# is already closed when we get a CancelledError.
@app.route("/1")
async def handler_3(request):
await asyncio.sleep(2)
return text("OK")
return app
@response_timeout_default_app.route("/1")
async def handler_2(request):
await asyncio.sleep(2)
return text("OK")
@response_handler_cancelled_app.exception(asyncio.CancelledError)
def handler_cancelled(request, exception):
# If we get a CancelledError, it means sanic has already sent a response,
# we should not ever have to handle a CancelledError.
response_handler_cancelled_app.ctx.flag = True
return text("App received CancelledError!", 500)
# The client will never receive this response, because the socket
# is already closed when we get a CancelledError.
@response_handler_cancelled_app.route("/1")
async def handler_3(request):
await asyncio.sleep(2)
return text("OK")
def test_server_error_response_timeout():
def test_server_error_response_timeout(response_timeout_app):
request, response = response_timeout_app.test_client.get("/1")
assert response.status == 503
assert response.text == "Response Timeout from error_handler."
def test_default_server_error_response_timeout():
def test_default_server_error_response_timeout(response_timeout_default_app):
request, response = response_timeout_default_app.test_client.get("/1")
assert response.status == 503
assert "Response Timeout" in response.text
def test_response_handler_cancelled():
def test_response_handler_cancelled(response_handler_cancelled_app):
request, response = response_handler_cancelled_app.test_client.get("/1")
assert response.status == 503
assert "Response Timeout" in response.text

View File

@@ -1266,3 +1266,22 @@ async def test_added_callable_route_ctx_kwargs(app):
assert request.route.ctx.foo() == "foo"
assert await request.route.ctx.bar() == 99
@pytest.mark.asyncio
async def test_duplicate_route_deprecation(app):
@app.route("/foo", name="duped")
async def handler_foo(request):
return text("...")
@app.route("/bar", name="duped")
async def handler_bar(request):
return text("...")
message = (
r"\[DEPRECATION v23\.3\] Duplicate route names detected: "
r"test_duplicate_route_deprecation\.duped\. In the future, "
r"Sanic will enforce uniqueness in route naming\."
)
with pytest.warns(DeprecationWarning, match=message):
await app._startup()

View File

@@ -18,12 +18,6 @@ AVAILABLE_LISTENERS = [
"after_server_stop",
]
skipif_no_alarm = pytest.mark.skipif(
not hasattr(signal, "SIGALRM"),
reason="SIGALRM is not implemented for this platform, we have to come "
"up with another timeout strategy to test these",
)
def create_listener(listener_name, in_list):
async def _listener(app, loop):
@@ -42,18 +36,17 @@ def create_listener_no_loop(listener_name, in_list):
def start_stop_app(random_name_app, **run_kwargs):
def stop_on_alarm(signum, frame):
random_name_app.stop()
@random_name_app.after_server_start
async def shutdown(app):
await asyncio.sleep(1.1)
app.stop()
signal.signal(signal.SIGALRM, stop_on_alarm)
signal.alarm(1)
try:
random_name_app.run(HOST, PORT, **run_kwargs)
random_name_app.run(HOST, PORT, single_process=True, **run_kwargs)
except KeyboardInterrupt:
pass
@skipif_no_alarm
@pytest.mark.parametrize("listener_name", AVAILABLE_LISTENERS)
def test_single_listener(app, listener_name):
"""Test that listeners on their own work"""
@@ -64,7 +57,6 @@ def test_single_listener(app, listener_name):
assert app.name + listener_name == output.pop()
@skipif_no_alarm
@pytest.mark.parametrize("listener_name", AVAILABLE_LISTENERS)
def test_single_listener_no_loop(app, listener_name):
"""Test that listeners on their own work"""
@@ -75,7 +67,6 @@ def test_single_listener_no_loop(app, listener_name):
assert app.name + listener_name == output.pop()
@skipif_no_alarm
@pytest.mark.parametrize("listener_name", AVAILABLE_LISTENERS)
def test_register_listener(app, listener_name):
"""
@@ -90,7 +81,6 @@ def test_register_listener(app, listener_name):
assert app.name + listener_name == output.pop()
@skipif_no_alarm
def test_all_listeners(app):
output = []
for listener_name in AVAILABLE_LISTENERS:
@@ -101,7 +91,6 @@ def test_all_listeners(app):
assert app.name + listener_name == output.pop()
@skipif_no_alarm
def test_all_listeners_as_convenience(app):
output = []
for listener_name in AVAILABLE_LISTENERS:
@@ -159,7 +148,6 @@ def test_create_server_trigger_events(app):
async def stop(app, loop):
nonlocal flag1
flag1 = True
signal.alarm(1)
async def before_stop(app, loop):
nonlocal flag2
@@ -178,10 +166,13 @@ def test_create_server_trigger_events(app):
# Use random port for tests
signal.signal(signal.SIGALRM, stop_on_alarm)
signal.alarm(1)
with closing(socket()) as sock:
sock.bind(("127.0.0.1", 0))
serv_coro = app.create_server(return_asyncio_server=True, sock=sock)
serv_coro = app.create_server(
return_asyncio_server=True, sock=sock, debug=True
)
serv_task = asyncio.ensure_future(serv_coro, loop=loop)
server = loop.run_until_complete(serv_task)
loop.run_until_complete(server.startup())
@@ -199,7 +190,6 @@ def test_create_server_trigger_events(app):
loop.run_until_complete(close_task)
# Complete all tasks on the loop
signal.stopped = True
for connection in server.connections:
connection.close_if_idle()
loop.run_until_complete(server.after_stop())

View File

@@ -33,6 +33,7 @@ def set_loop(app, loop):
def after(app, loop):
print("...")
calledq.put(mock.called)
@@ -48,10 +49,31 @@ def test_register_system_signals(app):
app.listener("before_server_start")(set_loop)
app.listener("after_server_stop")(after)
app.run(HOST, PORT)
app.run(HOST, PORT, single_process=True)
assert calledq.get() is True
@pytest.mark.skipif(os.name == "nt", reason="May hang CI on py38/windows")
def test_no_register_system_signals_fails(app):
"""Test if sanic don't register system signals"""
@app.route("/hello")
async def hello_route(request):
return HTTPResponse()
app.listener("after_server_start")(stop)
app.listener("before_server_start")(set_loop)
app.listener("after_server_stop")(after)
message = (
"Cannot run Sanic.serve with register_sys_signals=False. Use "
"either Sanic.serve_single or Sanic.serve_legacy."
)
with pytest.raises(RuntimeError, match=message):
app.prepare(HOST, PORT, register_sys_signals=False)
assert calledq.empty()
@pytest.mark.skipif(os.name == "nt", reason="May hang CI on py38/windows")
def test_dont_register_system_signals(app):
"""Test if sanic don't register system signals"""
@@ -64,7 +86,7 @@ def test_dont_register_system_signals(app):
app.listener("before_server_start")(set_loop)
app.listener("after_server_stop")(after)
app.run(HOST, PORT, register_sys_signals=False)
app.run(HOST, PORT, register_sys_signals=False, single_process=True)
assert calledq.get() is False

View File

@@ -1,6 +1,7 @@
import inspect
import logging
import os
import sys
from collections import Counter
from pathlib import Path
@@ -8,7 +9,7 @@ from time import gmtime, strftime
import pytest
from sanic import text
from sanic import Sanic, text
from sanic.exceptions import FileNotFound
@@ -21,6 +22,22 @@ def static_file_directory():
return static_directory
@pytest.fixture(scope="module")
def double_dotted_directory_file(static_file_directory: str):
"""Generate double dotted directory and its files"""
if sys.platform == "win32":
raise Exception("Windows doesn't support double dotted directories")
file_path = Path(static_file_directory) / "dotted.." / "dot.txt"
double_dotted_dir = file_path.parent
Path.mkdir(double_dotted_dir, exist_ok=True)
with open(file_path, "w") as f:
f.write("DOT\n")
yield file_path
Path.unlink(file_path)
Path.rmdir(double_dotted_dir)
def get_file_path(static_file_directory, file_name):
return os.path.join(static_file_directory, file_name)
@@ -578,3 +595,43 @@ def test_resource_type_dir(app, static_file_directory):
def test_resource_type_unknown(app, static_file_directory, caplog):
with pytest.raises(ValueError):
app.static("/static", static_file_directory, resource_type="unknown")
@pytest.mark.skipif(
sys.platform == "win32",
reason="Windows does not support double dotted directories",
)
def test_dotted_dir_ok(
app: Sanic, static_file_directory: str, double_dotted_directory_file: Path
):
app.static("/foo", static_file_directory)
dot_relative_path = str(
double_dotted_directory_file.relative_to(static_file_directory)
)
_, response = app.test_client.get("/foo/" + dot_relative_path)
assert response.status == 200
assert response.body == b"DOT\n"
def test_breakout(app: Sanic, static_file_directory: str):
app.static("/foo", static_file_directory)
_, response = app.test_client.get("/foo/..%2Ffake/server.py")
assert response.status == 404
_, response = app.test_client.get("/foo/..%2Fstatic/test.file")
assert response.status == 404
@pytest.mark.skipif(
sys.platform != "win32", reason="Block backslash on Windows only"
)
def test_double_backslash_prohibited_on_win32(
app: Sanic, static_file_directory: str
):
app.static("/foo", static_file_directory)
_, response = app.test_client.get("/foo/static/..\\static/test.file")
assert response.status == 404
_, response = app.test_client.get("/foo/static\\../static/test.file")
assert response.status == 404

View File

@@ -610,24 +610,24 @@ def test_get_ssl_context_only_mkcert(
MockTrustmeCreator.generate_cert.assert_not_called()
def test_no_http3_with_trustme(
app,
monkeypatch,
MockTrustmeCreator,
):
monkeypatch.setattr(
sanic.http.tls.creators, "TrustmeCreator", MockTrustmeCreator
)
MockTrustmeCreator.SUPPORTED = True
app.config.LOCAL_CERT_CREATOR = "TRUSTME"
with pytest.raises(
SanicException,
match=(
"Sorry, you cannot currently use trustme as a local certificate "
"generator for an HTTP/3 server"
),
):
app.run(version=3, debug=True)
# def test_no_http3_with_trustme(
# app,
# monkeypatch,
# MockTrustmeCreator,
# ):
# monkeypatch.setattr(
# sanic.http.tls.creators, "TrustmeCreator", MockTrustmeCreator
# )
# MockTrustmeCreator.SUPPORTED = True
# app.config.LOCAL_CERT_CREATOR = "TRUSTME"
# with pytest.raises(
# SanicException,
# match=(
# "Sorry, you cannot currently use trustme as a local certificate "
# "generator for an HTTP/3 server"
# ),
# ):
# app.run(version=3, debug=True)
def test_sanic_ssl_context_create():

View File

@@ -1,20 +1,26 @@
import asyncio
# import asyncio
import logging
import os
import platform
import subprocess
import sys
from asyncio import AbstractEventLoop
from string import ascii_lowercase
import httpcore
import httpx
import pytest
from pytest import LogCaptureFixture
from sanic import Sanic
from sanic.request import Request
from sanic.response import text
# import platform
# import subprocess
# import sys
pytestmark = pytest.mark.skipif(os.name != "posix", reason="UNIX only")
SOCKPATH = "/tmp/sanictest.sock"
SOCKPATH2 = "/tmp/sanictest2.sock"
@@ -45,7 +51,10 @@ def socket_cleanup():
pass
def test_unix_socket_creation(caplog):
@pytest.mark.xfail(
reason="Flaky Test on Non Linux Infra",
)
def test_unix_socket_creation(caplog: LogCaptureFixture):
from socket import AF_UNIX, socket
with socket(AF_UNIX) as sock:
@@ -55,14 +64,14 @@ def test_unix_socket_creation(caplog):
app = Sanic(name="test")
@app.listener("after_server_start")
def running(app, loop):
@app.after_server_start
def running(app: Sanic):
assert os.path.exists(SOCKPATH)
assert ino != os.stat(SOCKPATH).st_ino
app.stop()
with caplog.at_level(logging.INFO):
app.run(unix=SOCKPATH)
app.run(unix=SOCKPATH, single_process=True)
assert (
"sanic.root",
@@ -73,11 +82,11 @@ def test_unix_socket_creation(caplog):
@pytest.mark.parametrize("path", (".", "no-such-directory/sanictest.sock"))
def test_invalid_paths(path):
def test_invalid_paths(path: str):
app = Sanic(name="test")
#
with pytest.raises((FileExistsError, FileNotFoundError)):
app.run(unix=path)
app.run(unix=path, single_process=True)
def test_dont_replace_file():
@@ -86,12 +95,12 @@ def test_dont_replace_file():
app = Sanic(name="test")
@app.listener("after_server_start")
def stop(app, loop):
@app.after_server_start
def stop(app: Sanic):
app.stop()
with pytest.raises(FileExistsError):
app.run(unix=SOCKPATH)
app.run(unix=SOCKPATH, single_process=True)
def test_dont_follow_symlink():
@@ -103,47 +112,47 @@ def test_dont_follow_symlink():
app = Sanic(name="test")
@app.listener("after_server_start")
def stop(app, loop):
@app.after_server_start
def stop(app: Sanic):
app.stop()
with pytest.raises(FileExistsError):
app.run(unix=SOCKPATH)
app.run(unix=SOCKPATH, single_process=True)
def test_socket_deleted_while_running():
app = Sanic(name="test")
@app.listener("after_server_start")
async def hack(app, loop):
@app.after_server_start
async def hack(app: Sanic):
os.unlink(SOCKPATH)
app.stop()
app.run(host="myhost.invalid", unix=SOCKPATH)
app.run(host="myhost.invalid", unix=SOCKPATH, single_process=True)
def test_socket_replaced_with_file():
app = Sanic(name="test")
@app.listener("after_server_start")
async def hack(app, loop):
@app.after_server_start
async def hack(app: Sanic):
os.unlink(SOCKPATH)
with open(SOCKPATH, "w") as f:
f.write("Not a socket")
app.stop()
app.run(host="myhost.invalid", unix=SOCKPATH)
app.run(host="myhost.invalid", unix=SOCKPATH, single_process=True)
def test_unix_connection():
app = Sanic(name="test")
@app.get("/")
def handler(request):
def handler(request: Request):
return text(f"{request.conn_info.server}")
@app.listener("after_server_start")
async def client(app, loop):
@app.after_server_start
async def client(app: Sanic):
if httpx_version >= (0, 20):
transport = httpx.AsyncHTTPTransport(uds=SOCKPATH)
else:
@@ -156,17 +165,14 @@ def test_unix_connection():
finally:
app.stop()
app.run(host="myhost.invalid", unix=SOCKPATH)
app.run(host="myhost.invalid", unix=SOCKPATH, single_process=True)
app_multi = Sanic(name="test")
def handler(request):
def handler(request: Request):
return text(f"{request.conn_info.server}")
async def client(app, loop):
async def client(app: Sanic, loop: AbstractEventLoop):
try:
async with httpx.AsyncClient(uds=SOCKPATH) as client:
r = await client.get("http://myhost.invalid/")
@@ -177,86 +183,87 @@ async def client(app, loop):
def test_unix_connection_multiple_workers():
app_multi = Sanic(name="test")
app_multi.get("/")(handler)
app_multi.listener("after_server_start")(client)
app_multi.run(host="myhost.invalid", unix=SOCKPATH, workers=2)
@pytest.mark.xfail(
condition=platform.system() != "Linux",
reason="Flaky Test on Non Linux Infra",
)
async def test_zero_downtime():
"""Graceful server termination and socket replacement on restarts"""
from signal import SIGINT
from time import monotonic as current_time
# @pytest.mark.xfail(
# condition=platform.system() != "Linux",
# reason="Flaky Test on Non Linux Infra",
# )
# async def test_zero_downtime():
# """Graceful server termination and socket replacement on restarts"""
# from signal import SIGINT
# from time import monotonic as current_time
async def client():
if httpx_version >= (0, 20):
transport = httpx.AsyncHTTPTransport(uds=SOCKPATH)
else:
transport = httpcore.AsyncConnectionPool(uds=SOCKPATH)
for _ in range(40):
async with httpx.AsyncClient(transport=transport) as client:
r = await client.get("http://localhost/sleep/0.1")
assert r.status_code == 200, r.text
assert r.text == "Slept 0.1 seconds.\n"
# async def client():
# if httpx_version >= (0, 20):
# transport = httpx.AsyncHTTPTransport(uds=SOCKPATH)
# else:
# transport = httpcore.AsyncConnectionPool(uds=SOCKPATH)
# for _ in range(40):
# async with httpx.AsyncClient(transport=transport) as client:
# r = await client.get("http://localhost/sleep/0.1")
# assert r.status_code == 200, r.text
# assert r.text == "Slept 0.1 seconds.\n"
def spawn():
command = [
sys.executable,
"-m",
"sanic",
"--debug",
"--unix",
SOCKPATH,
"examples.delayed_response.app",
]
DN = subprocess.DEVNULL
return subprocess.Popen(
command, stdin=DN, stdout=DN, stderr=subprocess.PIPE
)
# def spawn():
# command = [
# sys.executable,
# "-m",
# "sanic",
# "--debug",
# "--unix",
# SOCKPATH,
# "examples.delayed_response.app",
# ]
# DN = subprocess.DEVNULL
# return subprocess.Popen(
# command, stdin=DN, stdout=DN, stderr=subprocess.PIPE
# )
try:
processes = [spawn()]
while not os.path.exists(SOCKPATH):
if processes[0].poll() is not None:
raise Exception(
"Worker did not start properly. "
f"stderr: {processes[0].stderr.read()}"
)
await asyncio.sleep(0.0001)
ino = os.stat(SOCKPATH).st_ino
task = asyncio.get_event_loop().create_task(client())
start_time = current_time()
while current_time() < start_time + 4:
# Start a new one and wait until the socket is replaced
processes.append(spawn())
while ino == os.stat(SOCKPATH).st_ino:
await asyncio.sleep(0.001)
ino = os.stat(SOCKPATH).st_ino
# Graceful termination of the previous one
processes[-2].send_signal(SIGINT)
# Wait until client has completed all requests
await task
processes[-1].send_signal(SIGINT)
for worker in processes:
try:
worker.wait(1.0)
except subprocess.TimeoutExpired:
raise Exception(
f"Worker would not terminate:\n{worker.stderr}"
)
finally:
for worker in processes:
worker.kill()
# Test for clean run and termination
return_codes = [worker.poll() for worker in processes]
# try:
# processes = [spawn()]
# while not os.path.exists(SOCKPATH):
# if processes[0].poll() is not None:
# raise Exception(
# "Worker did not start properly. "
# f"stderr: {processes[0].stderr.read()}"
# )
# await asyncio.sleep(0.0001)
# ino = os.stat(SOCKPATH).st_ino
# task = asyncio.get_event_loop().create_task(client())
# start_time = current_time()
# while current_time() < start_time + 6:
# # Start a new one and wait until the socket is replaced
# processes.append(spawn())
# while ino == os.stat(SOCKPATH).st_ino:
# await asyncio.sleep(0.001)
# ino = os.stat(SOCKPATH).st_ino
# # Graceful termination of the previous one
# processes[-2].send_signal(SIGINT)
# # Wait until client has completed all requests
# await task
# processes[-1].send_signal(SIGINT)
# for worker in processes:
# try:
# worker.wait(1.0)
# except subprocess.TimeoutExpired:
# raise Exception(
# f"Worker would not terminate:\n{worker.stderr}"
# )
# finally:
# for worker in processes:
# worker.kill()
# # Test for clean run and termination
# return_codes = [worker.poll() for worker in processes]
# Removing last process which seems to be flappy
return_codes.pop()
assert len(processes) > 5
assert all(code == 0 for code in return_codes)
# # Removing last process which seems to be flappy
# return_codes.pop()
# assert len(processes) > 5
# assert all(code == 0 for code in return_codes)
# Removing this check that seems to be flappy
# assert not os.path.exists(SOCKPATH)
# # Removing this check that seems to be flappy
# # assert not os.path.exists(SOCKPATH)

View File

@@ -0,0 +1,167 @@
import json
from datetime import datetime
from logging import ERROR, INFO
from socket import AF_INET, SOCK_STREAM, timeout
from unittest.mock import Mock, patch
import pytest
from sanic.log import Colors
from sanic.worker.inspector import Inspector, inspect
DATA = {
"info": {
"packages": ["foo"],
},
"extra": {
"more": "data",
},
"workers": {"Worker-Name": {"some": "state"}},
}
SERIALIZED = json.dumps(DATA)
def test_inspector_stop():
inspector = Inspector(Mock(), {}, {}, "", 1)
assert inspector.run is True
inspector.stop()
assert inspector.run is False
@patch("sanic.worker.inspector.sys.stdout.write")
@patch("sanic.worker.inspector.socket")
@pytest.mark.parametrize("command", ("foo", "raw", "pretty"))
def test_send_inspect(socket: Mock, write: Mock, command: str):
socket.return_value = socket
socket.__enter__.return_value = socket
socket.recv.return_value = SERIALIZED.encode()
inspect("localhost", 9999, command)
socket.sendall.assert_called_once_with(command.encode())
socket.recv.assert_called_once_with(4096)
socket.connect.assert_called_once_with(("localhost", 9999))
socket.assert_called_once_with(AF_INET, SOCK_STREAM)
if command == "raw":
write.assert_called_once_with(SERIALIZED)
elif command == "pretty":
write.assert_called()
else:
write.assert_not_called()
@patch("sanic.worker.inspector.sys")
@patch("sanic.worker.inspector.socket")
def test_send_inspect_conn_refused(socket: Mock, sys: Mock, caplog):
with caplog.at_level(INFO):
socket.return_value = socket
socket.__enter__.return_value = socket
socket.connect.side_effect = ConnectionRefusedError()
inspect("localhost", 9999, "foo")
socket.close.assert_called_once()
sys.exit.assert_called_once_with(1)
message = (
f"{Colors.RED}Could not connect to inspector at: "
f"{Colors.YELLOW}('localhost', 9999){Colors.END}\n"
"Either the application is not running, or it did not start "
"an inspector instance."
)
assert ("sanic.error", ERROR, message) in caplog.record_tuples
@patch("sanic.worker.inspector.configure_socket")
@pytest.mark.parametrize("action", (b"reload", b"shutdown", b"foo"))
def test_run_inspector(configure_socket: Mock, action: bytes):
sock = Mock()
conn = Mock()
conn.recv.return_value = action
configure_socket.return_value = sock
inspector = Inspector(Mock(), {}, {}, "localhost", 9999)
inspector.reload = Mock() # type: ignore
inspector.shutdown = Mock() # type: ignore
inspector.state_to_json = Mock(return_value="foo") # type: ignore
def accept():
inspector.run = False
return conn, ...
sock.accept = accept
inspector()
configure_socket.assert_called_once_with(
{"host": "localhost", "port": 9999, "unix": None, "backlog": 1}
)
conn.recv.assert_called_with(64)
if action == b"reload":
conn.send.assert_called_with(b"\n")
inspector.reload.assert_called()
inspector.shutdown.assert_not_called()
inspector.state_to_json.assert_not_called()
elif action == b"shutdown":
conn.send.assert_called_with(b"\n")
inspector.reload.assert_not_called()
inspector.shutdown.assert_called()
inspector.state_to_json.assert_not_called()
else:
conn.send.assert_called_with(b'"foo"')
inspector.reload.assert_not_called()
inspector.shutdown.assert_not_called()
inspector.state_to_json.assert_called()
@patch("sanic.worker.inspector.configure_socket")
def test_accept_timeout(configure_socket: Mock):
sock = Mock()
configure_socket.return_value = sock
inspector = Inspector(Mock(), {}, {}, "localhost", 9999)
inspector.reload = Mock() # type: ignore
inspector.shutdown = Mock() # type: ignore
inspector.state_to_json = Mock(return_value="foo") # type: ignore
def accept():
inspector.run = False
raise timeout
sock.accept = accept
inspector()
inspector.reload.assert_not_called()
inspector.shutdown.assert_not_called()
inspector.state_to_json.assert_not_called()
def test_state_to_json():
now = datetime.now()
now_iso = now.isoformat()
app_info = {"app": "hello"}
worker_state = {"Test": {"now": now, "nested": {"foo": now}}}
inspector = Inspector(Mock(), app_info, worker_state, "", 0)
state = inspector.state_to_json()
assert state == {
"info": app_info,
"workers": {"Test": {"now": now_iso, "nested": {"foo": now_iso}}},
}
def test_reload():
publisher = Mock()
inspector = Inspector(publisher, {}, {}, "", 0)
inspector.reload()
publisher.send.assert_called_once_with("__ALL_PROCESSES__:")
def test_shutdown():
publisher = Mock()
inspector = Inspector(publisher, {}, {}, "", 0)
inspector.shutdown()
publisher.send.assert_called_once_with("__TERMINATE__")

102
tests/worker/test_loader.py Normal file
View File

@@ -0,0 +1,102 @@
import sys
from os import getcwd
from pathlib import Path
from types import SimpleNamespace
from unittest.mock import Mock, patch
import pytest
from sanic.app import Sanic
from sanic.worker.loader import AppLoader, CertLoader
STATIC = Path.cwd() / "tests" / "static"
@pytest.mark.parametrize(
"module_input", ("tests.fake.server:app", "tests.fake.server.app")
)
def test_load_app_instance(module_input):
loader = AppLoader(module_input)
app = loader.load()
assert isinstance(app, Sanic)
@pytest.mark.parametrize(
"module_input",
("tests.fake.server:create_app", "tests.fake.server:create_app()"),
)
def test_load_app_factory(module_input):
loader = AppLoader(module_input, as_factory=True)
app = loader.load()
assert isinstance(app, Sanic)
def test_load_app_simple():
loader = AppLoader(str(STATIC), as_simple=True)
app = loader.load()
assert isinstance(app, Sanic)
def test_create_with_factory():
loader = AppLoader(factory=lambda: Sanic("Test"))
app = loader.load()
assert isinstance(app, Sanic)
def test_cwd_in_path():
AppLoader("tests.fake.server:app").load()
assert getcwd() in sys.path
def test_input_is_dir():
loader = AppLoader(str(STATIC))
message = (
"App not found.\n Please use --simple if you are passing a "
f"directory to sanic.\n eg. sanic {str(STATIC)} --simple"
)
with pytest.raises(ValueError, match=message):
loader.load()
def test_input_is_factory():
ns = SimpleNamespace(module="foo")
loader = AppLoader("tests.fake.server:create_app", args=ns)
message = (
"Module is not a Sanic app, it is a function\n If this callable "
"returns a Sanic instance try: \nsanic foo --factory"
)
with pytest.raises(ValueError, match=message):
loader.load()
def test_input_is_module():
ns = SimpleNamespace(module="foo")
loader = AppLoader("tests.fake.server", args=ns)
message = (
"Module is not a Sanic app, it is a module\n "
"Perhaps you meant foo:app?"
)
with pytest.raises(ValueError, match=message):
loader.load()
@pytest.mark.parametrize("creator", ("mkcert", "trustme"))
@patch("sanic.worker.loader.TrustmeCreator")
@patch("sanic.worker.loader.MkcertCreator")
def test_cert_loader(MkcertCreator: Mock, TrustmeCreator: Mock, creator: str):
MkcertCreator.return_value = MkcertCreator
TrustmeCreator.return_value = TrustmeCreator
data = {
"creator": creator,
"key": Path.cwd() / "tests" / "certs" / "localhost" / "privkey.pem",
"cert": Path.cwd() / "tests" / "certs" / "localhost" / "fullchain.pem",
"localhost": "localhost",
}
app = Sanic("Test")
loader = CertLoader(data) # type: ignore
loader.load(app)
creator_class = MkcertCreator if creator == "mkcert" else TrustmeCreator
creator_class.assert_called_once_with(app, data["key"], data["cert"])
creator_class.generate_cert.assert_called_once_with("localhost")

View File

@@ -0,0 +1,217 @@
from signal import SIGINT, SIGKILL
from unittest.mock import Mock, call, patch
import pytest
from sanic.worker.manager import WorkerManager
def fake_serve():
...
def test_manager_no_workers():
message = "Cannot serve with no workers"
with pytest.raises(RuntimeError, match=message):
WorkerManager(
0,
fake_serve,
{},
Mock(),
(Mock(), Mock()),
{},
)
@patch("sanic.worker.process.os")
def test_terminate(os_mock: Mock):
process = Mock()
process.pid = 1234
context = Mock()
context.Process.return_value = process
manager = WorkerManager(
1,
fake_serve,
{},
context,
(Mock(), Mock()),
{},
)
assert manager.terminated is False
manager.terminate()
assert manager.terminated is True
os_mock.kill.assert_called_once_with(1234, SIGINT)
@patch("sanic.worker.process.os")
def test_shutown(os_mock: Mock):
process = Mock()
process.pid = 1234
process.is_alive.return_value = True
context = Mock()
context.Process.return_value = process
manager = WorkerManager(
1,
fake_serve,
{},
context,
(Mock(), Mock()),
{},
)
manager.shutdown()
os_mock.kill.assert_called_once_with(1234, SIGINT)
@patch("sanic.worker.manager.os")
def test_kill(os_mock: Mock):
process = Mock()
process.pid = 1234
context = Mock()
context.Process.return_value = process
manager = WorkerManager(
1,
fake_serve,
{},
context,
(Mock(), Mock()),
{},
)
manager.kill()
os_mock.kill.assert_called_once_with(1234, SIGKILL)
def test_restart_all():
p1 = Mock()
p2 = Mock()
context = Mock()
context.Process.side_effect = [p1, p2, p1, p2]
manager = WorkerManager(
2,
fake_serve,
{},
context,
(Mock(), Mock()),
{},
)
assert len(list(manager.transient_processes))
manager.restart()
p1.terminate.assert_called_once()
p2.terminate.assert_called_once()
context.Process.assert_has_calls(
[
call(
name="Sanic-Server-0-0",
target=fake_serve,
kwargs={"config": {}},
daemon=True,
),
call(
name="Sanic-Server-1-0",
target=fake_serve,
kwargs={"config": {}},
daemon=True,
),
call(
name="Sanic-Server-0-0",
target=fake_serve,
kwargs={"config": {}},
daemon=True,
),
call(
name="Sanic-Server-1-0",
target=fake_serve,
kwargs={"config": {}},
daemon=True,
),
]
)
def test_monitor_all():
p1 = Mock()
p2 = Mock()
sub = Mock()
sub.recv.side_effect = ["__ALL_PROCESSES__:", ""]
context = Mock()
context.Process.side_effect = [p1, p2]
manager = WorkerManager(
2,
fake_serve,
{},
context,
(Mock(), sub),
{},
)
manager.restart = Mock() # type: ignore
manager.wait_for_ack = Mock() # type: ignore
manager.monitor()
manager.restart.assert_called_once_with(
process_names=None, reloaded_files=""
)
def test_monitor_all_with_files():
p1 = Mock()
p2 = Mock()
sub = Mock()
sub.recv.side_effect = ["__ALL_PROCESSES__:foo,bar", ""]
context = Mock()
context.Process.side_effect = [p1, p2]
manager = WorkerManager(
2,
fake_serve,
{},
context,
(Mock(), sub),
{},
)
manager.restart = Mock() # type: ignore
manager.wait_for_ack = Mock() # type: ignore
manager.monitor()
manager.restart.assert_called_once_with(
process_names=None, reloaded_files="foo,bar"
)
def test_monitor_one_process():
p1 = Mock()
p1.name = "Testing"
p2 = Mock()
sub = Mock()
sub.recv.side_effect = [f"{p1.name}:foo,bar", ""]
context = Mock()
context.Process.side_effect = [p1, p2]
manager = WorkerManager(
2,
fake_serve,
{},
context,
(Mock(), sub),
{},
)
manager.restart = Mock() # type: ignore
manager.wait_for_ack = Mock() # type: ignore
manager.monitor()
manager.restart.assert_called_once_with(
process_names=[p1.name], reloaded_files="foo,bar"
)
def test_shutdown_signal():
pub = Mock()
manager = WorkerManager(
1,
fake_serve,
{},
Mock(),
(pub, Mock()),
{},
)
manager.shutdown = Mock() # type: ignore
manager.shutdown_signal(SIGINT, None)
pub.send.assert_called_with(None)
manager.shutdown.assert_called_once_with()

Some files were not shown because too many files have changed in this diff Show More