Compare commits

...

33 Commits

Author SHA1 Message Date
Adam Hopkins
b2c0eed24d Bump version 2022-08-11 10:04:06 +03:00
Adam Hopkins
3abe4f885e Always show server location in ASGI (#2522)
Co-authored-by: Adam Hopkins <adam@amhopkins.com>
Co-authored-by: Zhiwei Liang <zhi.wei.liang@outlook.com>
Co-authored-by: Néstor Pérez <25409753+prryplatypus@users.noreply.github.com>
2022-08-11 10:03:50 +03:00
Adam Hopkins
f4c8252185 Always show server location in ASGI 2022-08-07 23:15:05 +03:00
Adam Hopkins
daa1f8f2d5 Bump version 2022-07-31 14:16:48 +03:00
Adam Hopkins
0901d3188a Use path.parts instead of match (#2508) 2022-07-31 12:55:50 +03:00
Adam Hopkins
0985d130e2 Use pathlib for path resolution (#2506) 2022-07-31 12:55:50 +03:00
Adam Hopkins
7e0b0deb11 Fix dotted test 2022-07-31 12:55:49 +03:00
Néstor Pérez
e54ac3c6fd Prevent directory traversion with static files (#2495)
Co-authored-by: Adam Hopkins <adam@amhopkins.com>
Co-authored-by: Zhiwei Liang <zhi.wei.liang@outlook.com>
2022-07-31 12:55:49 +03:00
Adam Hopkins
4429e76532 Add to changelog 2022-06-30 12:52:27 +03:00
Michael Azimov
e4be70bae8 Add custom loads function (#2445)
Co-authored-by: Zhiwei <chihwei.public@outlook.com>
Co-authored-by: Adam Hopkins <adam@amhopkins.com>
2022-06-29 23:39:21 +03:00
Adam Hopkins
13d5a44278 Release 22.6 (#2487) 2022-06-28 15:25:46 +03:00
Adam Hopkins
aba333bfb6 Improve API docs (#2488) 2022-06-28 10:53:03 +03:00
Adam Hopkins
b59da498cc HTTP/3 Support (#2378) 2022-06-27 11:19:26 +03:00
Zhiwei
70382f21ba Fix and improve file cache control header calculation (#2486) 2022-06-26 23:11:48 +03:00
Néstor Pérez
0e1bf89fad Add missing spaces in CLI error message (#2485) 2022-06-26 10:38:35 +03:00
Aidan Timson
6c48c8b3ba Fix for running in pythonw (#2448)
Co-authored-by: Adam Hopkins <adam@amhopkins.com>
2022-06-19 14:48:06 +03:00
Zhiwei
d1c5e8003b Fix test_cli and test_cookies (#2479) 2022-06-19 04:43:12 +03:00
Adam Hopkins
ce926a34f2 Add Request contextvars (#2475)
* Add Request contextvars

* Add missing contextvar setter

* Move location of context setter
2022-06-16 22:57:02 +03:00
Zhiwei
a744041e38 File Cache Control Headers Support (#2447)
Co-authored-by: Adam Hopkins <adam@amhopkins.com>
2022-06-16 16:24:39 +03:00
Mary
2f90a85df1 feat(type): extend (#2466)
Co-authored-by: Adam Hopkins <adam@amhopkins.com>
2022-06-16 15:38:13 +03:00
Adam Hopkins
a411bc06e3 Resolve typing of stacked route definitions (#2455) 2022-06-16 15:15:20 +03:00
Adam Hopkins
1668e1532f Move verbosity filtering to logger (#2453) 2022-06-16 12:35:49 +03:00
Vetési Zoltán
b87982769f Trigger http.lifecycle.request signal in ASGI mode (#2451)
Co-authored-by: Adam Hopkins <adam@amhopkins.com>
2022-06-16 11:55:50 +03:00
Ryu Juheon
65b53a5f3f style: add msg in `task.cancel` (#2416)
* style: add msg in ``task.cancel``

* style: apply isort

* fix: use else statement

* fix: use tuple

* fix: rollback for test

* fix: rollback like previous change

* fix: add ``=``

Co-authored-by: Adam Hopkins <adam@amhopkins.com>
2022-06-16 10:55:20 +03:00
Zhiwei
49789b7841 Clean Up Black and Isort Config (#2449)
Co-authored-by: Adam Hopkins <adam@amhopkins.com>
2022-05-26 12:48:32 +03:00
Amitay
c249004c30 fixed manual to match current Sanic app name policy (#2461)
Co-authored-by: Adam Hopkins <adam@amhopkins.com>
2022-05-26 10:16:24 +03:00
Ashley Sommer
4ee2e57ec8 Properly catch websocket CancelledError in websocket handler in Python 3.7 (#2463) 2022-05-23 22:47:05 +03:00
Néstor Pérez
86ae5f981c refactor: consistent exception naming (#2420)
Co-authored-by: Adam Hopkins <adam@amhopkins.com>
2022-05-12 20:39:35 +03:00
Adam Hopkins
2bfa65e0de Current release mergeback (#2454) 2022-05-11 09:37:33 +03:00
Adam Hopkins
293278bb08 Resolve warning issue with error handler mismatch warning (#2452) 2022-05-11 09:36:05 +03:00
Michael Azimov
5d683c6ea4 Expose scope parameter in request object (#2432)
Co-authored-by: Adam Hopkins <adam@amhopkins.com>
2022-04-26 17:25:29 +03:00
Stephen Sadowski
78b6723149 Preserve blank form values for urlencoded forms (option) (#2439)
* task(request.form): Add tests for blank values

* fix(request): abstract form property to implement get_form(), allow for preserving of blanks

* fix(request): hinting for parsed_form

* fix(request): typing for parsed_files

* fix(request): ignore type assumption

* fix(request): mypy typechecking caused E501 when type set to ignore

* fix(request): mypy is too stupid to parse continuations

* fix(request): formatting

* fix(request): fix annotation and return for get_form()

* fix(request): linting, hinting
2022-04-24 23:01:35 +03:00
Ryu juheon
3a6cc7389c feat: easier websocket interface annotation (#2438) 2022-04-24 13:32:13 +03:00
108 changed files with 3716 additions and 1097 deletions

View File

@@ -1,2 +0,0 @@
[tool.black]
line-length = 79

View File

@@ -20,6 +20,7 @@ exclude_lines =
noqa noqa
NOQA NOQA
pragma: no cover pragma: no cover
TYPE_CHECKING
omit = omit =
site-packages site-packages
sanic/__main__.py sanic/__main__.py

View File

@@ -313,8 +313,10 @@ Version 21.3.0
`#2074 <https://github.com/sanic-org/sanic/pull/2074>`_ `#2074 <https://github.com/sanic-org/sanic/pull/2074>`_
Performance adjustments in ``handle_request_`` Performance adjustments in ``handle_request_``
Version 20.12.3 Version 20.12.3 🔷
--------------- ------------------
`Current LTS version`
**Bugfixes** **Bugfixes**
@@ -348,8 +350,8 @@ Version 19.12.5
`#2027 <https://github.com/sanic-org/sanic/pull/2027>`_ `#2027 <https://github.com/sanic-org/sanic/pull/2027>`_
Remove old chardet requirement, add in hard multidict requirement Remove old chardet requirement, add in hard multidict requirement
Version 20.12.0 Version 20.12.0 🔹
--------------- -----------------
**Features** **Features**
@@ -357,11 +359,6 @@ Version 20.12.0
`#1993 <https://github.com/sanic-org/sanic/pull/1993>`_ `#1993 <https://github.com/sanic-org/sanic/pull/1993>`_
Add disable app registry Add disable app registry
Version 20.12.0
---------------
**Features**
* *
`#1945 <https://github.com/sanic-org/sanic/pull/1945>`_ `#1945 <https://github.com/sanic-org/sanic/pull/1945>`_
Static route more verbose if file not found Static route more verbose if file not found

View File

@@ -66,15 +66,15 @@ ifdef include_tests
isort -rc sanic tests isort -rc sanic tests
else else
$(info Sorting Imports) $(info Sorting Imports)
isort -rc sanic tests --profile=black isort -rc sanic tests
endif endif
endif endif
black: black:
black --config ./.black.toml sanic tests black sanic tests
isort: isort:
isort sanic tests --profile=black isort sanic tests
pretty: black isort pretty: black isort

View File

@@ -114,7 +114,7 @@ Hello World Example
from sanic import Sanic from sanic import Sanic
from sanic.response import json from sanic.response import json
app = Sanic("My Hello, world app") app = Sanic("my-hello-world-app")
@app.route('/') @app.route('/')
async def test(request): async def test(request):

View File

@@ -4,6 +4,7 @@ coverage:
default: default:
target: auto target: auto
threshold: 0.75 threshold: 0.75
informational: true
project: project:
default: default:
target: auto target: auto

View File

@@ -24,7 +24,11 @@ import sanic
# -- General configuration ------------------------------------------------ # -- General configuration ------------------------------------------------
extensions = ["sphinx.ext.autodoc", "m2r2"] extensions = [
"sphinx.ext.autodoc",
"m2r2",
"enum_tools.autoenum",
]
templates_path = ["_templates"] templates_path = ["_templates"]

View File

@@ -9,7 +9,7 @@ API
=== ===
.. toctree:: .. toctree::
:maxdepth: 2 :maxdepth: 3
👥 User Guide <https://sanicframework.org/guide/> 👥 User Guide <https://sanicframework.org/guide/>
sanic/api_reference sanic/api_reference

View File

@@ -15,3 +15,19 @@ sanic.config
.. automodule:: sanic.config .. automodule:: sanic.config
:members: :members:
:show-inheritance: :show-inheritance:
sanic.application.constants
---------------------------
.. automodule:: sanic.application.constants
:exclude-members: StrEnum
:members:
:show-inheritance:
:inherited-members:
sanic.application.state
-----------------------
.. automodule:: sanic.application.state
:members:
:show-inheritance:

View File

@@ -17,6 +17,14 @@ sanic.handlers
:show-inheritance: :show-inheritance:
sanic.headers
--------------
.. automodule:: sanic.headers
:members:
:show-inheritance:
sanic.request sanic.request
------------- -------------

View File

@@ -16,10 +16,3 @@ sanic.server
:members: :members:
:show-inheritance: :show-inheritance:
sanic.worker
------------
.. automodule:: sanic.worker
:members:
:show-inheritance:

View File

@@ -1,6 +1,7 @@
📜 Changelog 📜 Changelog
============ ============
.. mdinclude:: ./releases/22/22.6.md
.. mdinclude:: ./releases/22/22.3.md .. mdinclude:: ./releases/22/22.3.md
.. mdinclude:: ./releases/21/21.12.md .. mdinclude:: ./releases/21/21.12.md
.. mdinclude:: ./releases/21/21.9.md .. mdinclude:: ./releases/21/21.9.md

View File

@@ -1,10 +1,12 @@
## Version 21.12.1 ## Version 21.12.1 🔷
_Current LTS version_
- [#2349](https://github.com/sanic-org/sanic/pull/2349) Only display MOTD on startup - [#2349](https://github.com/sanic-org/sanic/pull/2349) Only display MOTD on startup
- [#2354](https://github.com/sanic-org/sanic/pull/2354) Ignore name argument in Python 3.7 - [#2354](https://github.com/sanic-org/sanic/pull/2354) Ignore name argument in Python 3.7
- [#2355](https://github.com/sanic-org/sanic/pull/2355) Add config.update support for all config values - [#2355](https://github.com/sanic-org/sanic/pull/2355) Add config.update support for all config values
## Version 21.12.0 ## Version 21.12.0 🔹
### Features ### Features
- [#2260](https://github.com/sanic-org/sanic/pull/2260) Allow early Blueprint registrations to still apply later added objects - [#2260](https://github.com/sanic-org/sanic/pull/2260) Allow early Blueprint registrations to still apply later added objects

View File

@@ -0,0 +1,43 @@
## Version 22.6.0 🔶
_Current version_
### Features
- [#2378](https://github.com/sanic-org/sanic/pull/2378) Introduce HTTP/3 and autogeneration of TLS certificates in `DEBUG` mode
- 👶 *EARLY RELEASE FEATURE*: Serving Sanic over HTTP/3 is an early release feature. It does not yet fully cover the HTTP/3 spec, but instead aims for feature parity with Sanic's existing HTTP/1.1 server. Websockets, WebTransport, push responses are examples of some features not yet implemented.
- 📦 *EXTRA REQUIREMENT*: Not all HTTP clients are capable of interfacing with HTTP/3 servers. You may need to install a [HTTP/3 capable client](https://curl.se/docs/http3.html).
- 📦 *EXTRA REQUIREMENT*: In order to use TLS autogeneration, you must install either [mkcert](https://github.com/FiloSottile/mkcert) or [trustme](https://github.com/python-trio/trustme).
- [#2416](https://github.com/sanic-org/sanic/pull/2416) Add message to `task.cancel`
- [#2420](https://github.com/sanic-org/sanic/pull/2420) Add exception aliases for more consistent naming with standard HTTP response types (`BadRequest`, `MethodNotAllowed`, `RangeNotSatisfiable`)
- [#2432](https://github.com/sanic-org/sanic/pull/2432) Expose ASGI `scope` as a property on the `Request` object
- [#2438](https://github.com/sanic-org/sanic/pull/2438) Easier access to websocket class for annotation: `from sanic import Websocket`
- [#2439](https://github.com/sanic-org/sanic/pull/2439) New API for reading form values with options: `Request.get_form`
- [#2445](https://github.com/sanic-org/sanic/pull/2445) Add custom `loads` function
- [#2447](https://github.com/sanic-org/sanic/pull/2447), [#2486](https://github.com/sanic-org/sanic/pull/2486) Improved API to support setting cache control headers
- [#2453](https://github.com/sanic-org/sanic/pull/2453) Move verbosity filtering to logger
- [#2475](https://github.com/sanic-org/sanic/pull/2475) Expose getter for current request using `Request.get_current()`
### Bugfixes
- [#2448](https://github.com/sanic-org/sanic/pull/2448) Fix to allow running with `pythonw.exe` or places where there is no `sys.stdout`
- [#2451](https://github.com/sanic-org/sanic/pull/2451) Trigger `http.lifecycle.request` signal in ASGI mode
- [#2455](https://github.com/sanic-org/sanic/pull/2455) Resolve typing of stacked route definitions
- [#2463](https://github.com/sanic-org/sanic/pull/2463) Properly catch websocket CancelledError in websocket handler in Python 3.7
### Deprecations and Removals
- [#2487](https://github.com/sanic-org/sanic/pull/2487) v22.6 deprecations and changes
1. Optional application registry
1. Execution of custom handlers after some part of response was sent
1. Configuring fallback handlers on the `ErrorHandler`
1. Custom `LOGO` setting
1. `sanic.response.stream`
1. `AsyncioServer.init`
### Developer infrastructure
- [#2449](https://github.com/sanic-org/sanic/pull/2449) Clean up `black` and `isort` config
- [#2479](https://github.com/sanic-org/sanic/pull/2479) Fix some flappy tests
### Improved Documentation
- [#2461](https://github.com/sanic-org/sanic/pull/2461) Update example to match current application naming standards
- [#2466](https://github.com/sanic-org/sanic/pull/2466) Better type annotation for `Extend`
- [#2485](https://github.com/sanic-org/sanic/pull/2485) Improved help messages in CLI

View File

@@ -1,3 +1,26 @@
[build-system] [build-system]
requires = ["setuptools<60.0", "wheel"] requires = ["setuptools<60.0", "wheel"]
build-backend = "setuptools.build_meta" build-backend = "setuptools.build_meta"
[tool.black]
line-length = 79
[tool.isort]
atomic = true
default_section = "THIRDPARTY"
include_trailing_comma = true
known_first_party = "sanic"
known_third_party = "pytest"
line_length = 79
lines_after_imports = 2
lines_between_types = 1
multi_line_output = 3
profile = "black"
[[tool.mypy.overrides]]
module = [
"httptools.*",
"trustme.*",
"sanic_routing.*",
]
ignore_missing_imports = true

View File

@@ -4,6 +4,7 @@ from sanic.blueprints import Blueprint
from sanic.constants import HTTPMethod from sanic.constants import HTTPMethod
from sanic.request import Request from sanic.request import Request
from sanic.response import HTTPResponse, html, json, text from sanic.response import HTTPResponse, html, json, text
from sanic.server.websockets.impl import WebsocketImplProtocol as Websocket
__all__ = ( __all__ = (
@@ -13,6 +14,7 @@ __all__ = (
"HTTPMethod", "HTTPMethod",
"HTTPResponse", "HTTPResponse",
"Request", "Request",
"Websocket",
"html", "html",
"json", "json",
"text", "text",

View File

@@ -1 +1 @@
__version__ = "22.3.1" __version__ = "22.6.2"

View File

@@ -43,11 +43,8 @@ from typing import (
from urllib.parse import urlencode, urlunparse from urllib.parse import urlencode, urlunparse
from warnings import filterwarnings from warnings import filterwarnings
from sanic_routing.exceptions import ( # type: ignore from sanic_routing.exceptions import FinalizationError, NotFound
FinalizationError, from sanic_routing.route import Route
NotFound,
)
from sanic_routing.route import Route # type: ignore
from sanic.application.ext import setup_ext from sanic.application.ext import setup_ext
from sanic.application.state import ApplicationState, Mode, ServerStage from sanic.application.state import ApplicationState, Mode, ServerStage
@@ -58,12 +55,13 @@ from sanic.blueprints import Blueprint
from sanic.compat import OS_IS_WINDOWS, enable_windows_color_support from sanic.compat import OS_IS_WINDOWS, enable_windows_color_support
from sanic.config import SANIC_PREFIX, Config from sanic.config import SANIC_PREFIX, Config
from sanic.exceptions import ( from sanic.exceptions import (
InvalidUsage, BadRequest,
SanicException, SanicException,
ServerError, ServerError,
URLBuildError, URLBuildError,
) )
from sanic.handlers import ErrorHandler from sanic.handlers import ErrorHandler
from sanic.helpers import _default
from sanic.http import Stage from sanic.http import Stage
from sanic.log import ( from sanic.log import (
LOGGING_CONFIG_DEFAULTS, LOGGING_CONFIG_DEFAULTS,
@@ -92,12 +90,12 @@ from sanic.signals import Signal, SignalRouter
from sanic.touchup import TouchUp, TouchUpMeta from sanic.touchup import TouchUp, TouchUpMeta
if TYPE_CHECKING: # no cov if TYPE_CHECKING:
try: try:
from sanic_ext import Extend # type: ignore from sanic_ext import Extend # type: ignore
from sanic_ext.extensions.base import Extension # type: ignore from sanic_ext.extensions.base import Extension # type: ignore
except ImportError: except ImportError:
Extend = TypeVar("Extend") # type: ignore Extend = TypeVar("Extend", Type) # type: ignore
if OS_IS_WINDOWS: # no cov if OS_IS_WINDOWS: # no cov
@@ -171,8 +169,8 @@ class Sanic(BaseSanic, RunnerMixin, metaclass=TouchUpMeta):
strict_slashes: bool = False, strict_slashes: bool = False,
log_config: Optional[Dict[str, Any]] = None, log_config: Optional[Dict[str, Any]] = None,
configure_logging: bool = True, configure_logging: bool = True,
register: Optional[bool] = None,
dumps: Optional[Callable[..., AnyStr]] = None, dumps: Optional[Callable[..., AnyStr]] = None,
loads: Optional[Callable[..., Any]] = None,
) -> None: ) -> None:
super().__init__(name=name) super().__init__(name=name)
@@ -220,23 +218,14 @@ class Sanic(BaseSanic, RunnerMixin, metaclass=TouchUpMeta):
# Register alternative method names # Register alternative method names
self.go_fast = self.run self.go_fast = self.run
if register is not None:
deprecation(
"The register argument is deprecated and will stop working "
"in v22.6. After v22.6 all apps will be added to the Sanic "
"app registry.",
22.6,
)
self.config.REGISTER = register
if self.config.REGISTER:
self.__class__.register_app(self)
self.router.ctx.app = self self.router.ctx.app = self
self.signal_router.ctx.app = self self.signal_router.ctx.app = self
self.__class__.register_app(self)
if dumps: if dumps:
BaseHTTPResponse._dumps = dumps # type: ignore BaseHTTPResponse._dumps = dumps # type: ignore
if loads:
Request._loads = loads # type: ignore
@property @property
def loop(self): def loop(self):
@@ -281,7 +270,7 @@ class Sanic(BaseSanic, RunnerMixin, metaclass=TouchUpMeta):
valid = ", ".join( valid = ", ".join(
map(lambda x: x.lower(), ListenerEvent.__members__.keys()) map(lambda x: x.lower(), ListenerEvent.__members__.keys())
) )
raise InvalidUsage(f"Invalid event: {event}. Use one of: {valid}") raise BadRequest(f"Invalid event: {event}. Use one of: {valid}")
if "." in _event: if "." in _event:
self.signal(_event.value)( self.signal(_event.value)(
@@ -738,37 +727,24 @@ class Sanic(BaseSanic, RunnerMixin, metaclass=TouchUpMeta):
"has at least partially been sent." "has at least partially been sent."
) )
# ----------------- deprecated -----------------
handler = self.error_handler._lookup( handler = self.error_handler._lookup(
exception, request.name if request else None exception, request.name if request else None
) )
if handler: if handler:
deprecation( logger.warning(
"An error occurred while handling the request after at " "An error occurred while handling the request after at "
"least some part of the response was sent to the client. " "least some part of the response was sent to the client. "
"Therefore, the response from your custom exception " "The response from your custom exception handler "
f"handler {handler.__name__} will not be sent to the " f"{handler.__name__} will not be sent to the client."
"client. Beginning in v22.6, Sanic will stop executing " "Exception handlers should only be used to generate the "
"custom exception handlers in this scenario. Exception " "exception responses. If you would like to perform any "
"handlers should only be used to generate the exception " "other action on a raised exception, consider using a "
"responses. If you would like to perform any other "
"action on a raised exception, please consider using a "
"signal handler like " "signal handler like "
'`@app.signal("http.lifecycle.exception")`\n' '`@app.signal("http.lifecycle.exception")`\n'
"For further information, please see the docs: " "For further information, please see the docs: "
"https://sanicframework.org/en/guide/advanced/" "https://sanicframework.org/en/guide/advanced/"
"signals.html", "signals.html",
22.6,
) )
try:
response = self.error_handler.response(request, exception)
if isawaitable(response):
response = await response
except BaseException as e:
logger.error("An error occurred in the exception handler.")
error_logger.exception(e)
# ----------------------------------------------
return return
# -------------------------------------------- # # -------------------------------------------- #
@@ -949,6 +925,7 @@ class Sanic(BaseSanic, RunnerMixin, metaclass=TouchUpMeta):
"response": response, "response": response,
}, },
) )
...
await response.send(end_stream=True) await response.send(end_stream=True)
elif isinstance(response, ResponseStream): elif isinstance(response, ResponseStream):
resp = await response(request) resp = await response(request)
@@ -992,10 +969,10 @@ class Sanic(BaseSanic, RunnerMixin, metaclass=TouchUpMeta):
cancelled = False cancelled = False
try: try:
await fut await fut
except Exception as e:
self.error_handler.log(request, e)
except (CancelledError, ConnectionClosed): except (CancelledError, ConnectionClosed):
cancelled = True cancelled = True
except Exception as e:
self.error_handler.log(request, e)
finally: finally:
self.websocket_tasks.remove(fut) self.websocket_tasks.remove(fut)
if cancelled: if cancelled:
@@ -1338,7 +1315,7 @@ class Sanic(BaseSanic, RunnerMixin, metaclass=TouchUpMeta):
self.config.update_config(config) self.config.update_config(config)
@property @property
def asgi(self): def asgi(self) -> bool:
return self.state.asgi return self.state.asgi
@asgi.setter @asgi.setter
@@ -1370,7 +1347,10 @@ class Sanic(BaseSanic, RunnerMixin, metaclass=TouchUpMeta):
self.config.AUTO_RELOAD = value self.config.AUTO_RELOAD = value
@property @property
def state(self): def state(self) -> ApplicationState: # type: ignore
"""
:return: The application state
"""
return self._state return self._state
@property @property
@@ -1532,8 +1512,10 @@ class Sanic(BaseSanic, RunnerMixin, metaclass=TouchUpMeta):
if hasattr(self, "_ext"): if hasattr(self, "_ext"):
self.ext._display() self.ext._display()
if self.state.is_debug: if self.state.is_debug and self.config.TOUCHUP is not True:
self.config.TOUCHUP = False self.config.TOUCHUP = False
elif self.config.TOUCHUP is _default:
self.config.TOUCHUP = True
# Setup routers # Setup routers
self.signalize(self.config.TOUCHUP) self.signalize(self.config.TOUCHUP)
@@ -1555,7 +1537,6 @@ class Sanic(BaseSanic, RunnerMixin, metaclass=TouchUpMeta):
if self.state.primary: if self.state.primary:
# TODO: # TODO:
# - Raise warning if secondary apps have error handler config # - Raise warning if secondary apps have error handler config
ErrorHandler.finalize(self.error_handler, config=self.config)
if self.config.TOUCHUP: if self.config.TOUCHUP:
TouchUp.run(self) TouchUp.run(self)
@@ -1573,8 +1554,9 @@ class Sanic(BaseSanic, RunnerMixin, metaclass=TouchUpMeta):
"shutdown", "shutdown",
): ):
raise SanicException(f"Invalid server event: {event}") raise SanicException(f"Invalid server event: {event}")
if self.state.verbosity >= 1: logger.debug(
logger.debug(f"Triggering server events: {event}") f"Triggering server events: {event}", extra={"verbosity": 1}
)
reverse = concern == "shutdown" reverse = concern == "shutdown"
if loop is None: if loop is None:
loop = self.loop loop = self.loop

View File

@@ -0,0 +1,23 @@
from enum import Enum, IntEnum, auto
class StrEnum(str, Enum):
def _generate_next_value_(name: str, *args) -> str: # type: ignore
return name.lower()
class Server(StrEnum):
SANIC = auto()
ASGI = auto()
GUNICORN = auto()
class Mode(StrEnum):
PRODUCTION = auto()
DEBUG = auto()
class ServerStage(IntEnum):
STOPPED = auto()
PARTIAL = auto()
SERVING = auto()

View File

@@ -5,7 +5,7 @@ from importlib import import_module
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
if TYPE_CHECKING: # no cov if TYPE_CHECKING:
from sanic import Sanic from sanic import Sanic
try: try:

View File

@@ -3,6 +3,8 @@ import sys
from os import environ from os import environ
from sanic.compat import is_atty
BASE_LOGO = """ BASE_LOGO = """
@@ -44,7 +46,7 @@ ansi_pattern = re.compile(r"\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])")
def get_logo(full=False, coffee=False): def get_logo(full=False, coffee=False):
logo = ( logo = (
(FULL_COLOR_LOGO if full else (COFFEE_LOGO if coffee else COLOR_LOGO)) (FULL_COLOR_LOGO if full else (COFFEE_LOGO if coffee else COLOR_LOGO))
if sys.stdout.isatty() if is_atty()
else BASE_LOGO else BASE_LOGO
) )

View File

@@ -1,11 +1,10 @@
import sys
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from shutil import get_terminal_size from shutil import get_terminal_size
from textwrap import indent, wrap from textwrap import indent, wrap
from typing import Dict, Optional from typing import Dict, Optional
from sanic import __version__ from sanic import __version__
from sanic.compat import is_atty
from sanic.log import logger from sanic.log import logger
@@ -36,7 +35,7 @@ class MOTD(ABC):
data: Dict[str, str], data: Dict[str, str],
extra: Dict[str, str], extra: Dict[str, str],
) -> None: ) -> None:
motd_class = MOTDTTY if sys.stdout.isatty() else MOTDBasic motd_class = MOTDTTY if is_atty() else MOTDBasic
motd_class(logo, serve_location, data, extra).display() motd_class(logo, serve_location, data, extra).display()

View File

@@ -0,0 +1,86 @@
import os
import sys
import time
from contextlib import contextmanager
from queue import Queue
from threading import Thread
if os.name == "nt": # noqa
import ctypes # noqa
class _CursorInfo(ctypes.Structure):
_fields_ = [("size", ctypes.c_int), ("visible", ctypes.c_byte)]
class Spinner: # noqa
def __init__(self, message: str) -> None:
self.message = message
self.queue: Queue[int] = Queue()
self.spinner = self.cursor()
self.thread = Thread(target=self.run)
def start(self):
self.queue.put(1)
self.thread.start()
self.hide()
def run(self):
while self.queue.get():
output = f"\r{self.message} [{next(self.spinner)}]"
sys.stdout.write(output)
sys.stdout.flush()
time.sleep(0.1)
self.queue.put(1)
def stop(self):
self.queue.put(0)
self.thread.join()
self.show()
@staticmethod
def cursor():
while True:
for cursor in "|/-\\":
yield cursor
@staticmethod
def hide():
if os.name == "nt":
ci = _CursorInfo()
handle = ctypes.windll.kernel32.GetStdHandle(-11)
ctypes.windll.kernel32.GetConsoleCursorInfo(
handle, ctypes.byref(ci)
)
ci.visible = False
ctypes.windll.kernel32.SetConsoleCursorInfo(
handle, ctypes.byref(ci)
)
elif os.name == "posix":
sys.stdout.write("\033[?25l")
sys.stdout.flush()
@staticmethod
def show():
if os.name == "nt":
ci = _CursorInfo()
handle = ctypes.windll.kernel32.GetStdHandle(-11)
ctypes.windll.kernel32.GetConsoleCursorInfo(
handle, ctypes.byref(ci)
)
ci.visible = True
ctypes.windll.kernel32.SetConsoleCursorInfo(
handle, ctypes.byref(ci)
)
elif os.name == "posix":
sys.stdout.write("\033[?25h")
sys.stdout.flush()
@contextmanager
def loading(message: str = "Loading"): # noqa
spinner = Spinner(message)
spinner.start()
yield
spinner.stop()

View File

@@ -3,42 +3,20 @@ from __future__ import annotations
import logging import logging
from dataclasses import dataclass, field from dataclasses import dataclass, field
from enum import Enum, IntEnum, auto
from pathlib import Path from pathlib import Path
from socket import socket from socket import socket
from ssl import SSLContext from ssl import SSLContext
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Union from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Union
from sanic.log import logger from sanic.application.constants import Mode, Server, ServerStage
from sanic.log import VerbosityFilter, logger
from sanic.server.async_server import AsyncioServer from sanic.server.async_server import AsyncioServer
if TYPE_CHECKING: # no cov if TYPE_CHECKING:
from sanic import Sanic from sanic import Sanic
class StrEnum(str, Enum):
def _generate_next_value_(name: str, *args) -> str: # type: ignore
return name.lower()
class Server(StrEnum):
SANIC = auto()
ASGI = auto()
GUNICORN = auto()
class Mode(StrEnum):
PRODUCTION = auto()
DEBUG = auto()
class ServerStage(IntEnum):
STOPPED = auto()
PARTIAL = auto()
SERVING = auto()
@dataclass @dataclass
class ApplicationServerInfo: class ApplicationServerInfo:
settings: Dict[str, Any] settings: Dict[str, Any]
@@ -91,6 +69,9 @@ class ApplicationState:
if getattr(self.app, "configure_logging", False) and self.app.debug: if getattr(self.app, "configure_logging", False) and self.app.debug:
logger.setLevel(logging.DEBUG) logger.setLevel(logging.DEBUG)
def set_verbosity(self, value: int):
VerbosityFilter.verbosity = value
@property @property
def is_debug(self): def is_debug(self):
return self.mode is Mode.DEBUG return self.mode is Mode.DEBUG

View File

@@ -17,7 +17,7 @@ from sanic.server import ConnInfo
from sanic.server.websockets.connection import WebSocketConnection from sanic.server.websockets.connection import WebSocketConnection
if TYPE_CHECKING: # no cov if TYPE_CHECKING:
from sanic import Sanic from sanic import Sanic
@@ -25,27 +25,28 @@ class Lifespan:
def __init__(self, asgi_app: ASGIApp) -> None: def __init__(self, asgi_app: ASGIApp) -> None:
self.asgi_app = asgi_app self.asgi_app = asgi_app
if self.asgi_app.sanic_app.state.verbosity > 0: if (
if ( "server.init.before"
"server.init.before" in self.asgi_app.sanic_app.signal_router.name_index
in self.asgi_app.sanic_app.signal_router.name_index ):
): logger.debug(
logger.debug( 'You have set a listener for "before_server_start" '
'You have set a listener for "before_server_start" ' "in ASGI mode. "
"in ASGI mode. " "It will be executed as early as possible, but not before "
"It will be executed as early as possible, but not before " "the ASGI server is started.",
"the ASGI server is started." extra={"verbosity": 1},
) )
if ( if (
"server.shutdown.after" "server.shutdown.after"
in self.asgi_app.sanic_app.signal_router.name_index in self.asgi_app.sanic_app.signal_router.name_index
): ):
logger.debug( logger.debug(
'You have set a listener for "after_server_stop" ' 'You have set a listener for "after_server_stop" '
"in ASGI mode. " "in ASGI mode. "
"It will be executed as late as possible, but not after " "It will be executed as late as possible, but not after "
"the ASGI server is stopped." "the ASGI server is stopped.",
) extra={"verbosity": 1},
)
async def startup(self) -> None: async def startup(self) -> None:
""" """
@@ -163,6 +164,13 @@ class ASGIApp:
instance.request_body = True instance.request_body = True
instance.request.conn_info = ConnInfo(instance.transport) instance.request.conn_info = ConnInfo(instance.transport)
await sanic_app.dispatch(
"http.lifecycle.request",
inline=True,
context={"request": instance.request},
fail_not_found=False,
)
return instance return instance
async def read(self) -> Optional[bytes]: async def read(self) -> Optional[bytes]:

View File

@@ -5,7 +5,7 @@ from functools import partial
from typing import TYPE_CHECKING, List, Optional, Union from typing import TYPE_CHECKING, List, Optional, Union
if TYPE_CHECKING: # no cov if TYPE_CHECKING:
from sanic.blueprints import Blueprint from sanic.blueprints import Blueprint

View File

@@ -21,8 +21,8 @@ from typing import (
Union, Union,
) )
from sanic_routing.exceptions import NotFound # type: ignore from sanic_routing.exceptions import NotFound
from sanic_routing.route import Route # type: ignore from sanic_routing.route import Route
from sanic.base.root import BaseSanic from sanic.base.root import BaseSanic
from sanic.blueprint_group import BlueprintGroup from sanic.blueprint_group import BlueprintGroup
@@ -36,7 +36,7 @@ from sanic.models.handler_types import (
) )
if TYPE_CHECKING: # no cov if TYPE_CHECKING:
from sanic import Sanic from sanic import Sanic

View File

@@ -58,10 +58,13 @@ Or, a path to a directory to run as a simple HTTP server:
os.environ.get("SANIC_RELOADER_PROCESS", "") != "true" os.environ.get("SANIC_RELOADER_PROCESS", "") != "true"
) )
self.args: List[Any] = [] self.args: List[Any] = []
self.groups: List[Group] = []
def attach(self): def attach(self):
for group in Group._registry: for group in Group._registry:
group.create(self.parser).attach() instance = group.create(self.parser)
instance.attach()
self.groups.append(instance)
def run(self): def run(self):
# This is to provide backwards compat -v to display version # This is to provide backwards compat -v to display version
@@ -81,9 +84,13 @@ Or, a path to a directory to run as a simple HTTP server:
try: try:
app = self._get_app() app = self._get_app()
kwargs = self._build_run_kwargs() kwargs = self._build_run_kwargs()
app.run(**kwargs)
except ValueError: except ValueError:
error_logger.exception("Failed to run app") error_logger.exception("Failed to run app")
else:
for http_version in self.args.http:
app.prepare(**kwargs, version=http_version)
Sanic.serve()
def _precheck(self): def _precheck(self):
# # Custom TLS mismatch handling for better diagnostics # # Custom TLS mismatch handling for better diagnostics
@@ -146,9 +153,9 @@ Or, a path to a directory to run as a simple HTTP server:
if callable(app): if callable(app):
solution = f"sanic {self.args.module} --factory" solution = f"sanic {self.args.module} --factory"
raise ValueError( raise ValueError(
"Module is not a Sanic app, it is a" "Module is not a Sanic app, it is a "
f"{app_type_name}\n" f"{app_type_name}\n"
" If this callable returns a" " If this callable returns a "
f"Sanic instance try: \n{solution}" f"Sanic instance try: \n{solution}"
) )
@@ -163,11 +170,14 @@ Or, a path to a directory to run as a simple HTTP server:
" Example File: project/sanic_server.py -> app\n" " Example File: project/sanic_server.py -> app\n"
" Example Module: project.sanic_server.app" " Example Module: project.sanic_server.app"
) )
sys.exit(1)
else: else:
raise e raise e
return app return app
def _build_run_kwargs(self): def _build_run_kwargs(self):
for group in self.groups:
group.prepare(self.args)
ssl: Union[None, dict, str, list] = [] ssl: Union[None, dict, str, list] = []
if self.args.tlshost: if self.args.tlshost:
ssl.append(None) ssl.append(None)
@@ -192,6 +202,7 @@ Or, a path to a directory to run as a simple HTTP server:
"unix": self.args.unix, "unix": self.args.unix,
"verbosity": self.args.verbosity or 0, "verbosity": self.args.verbosity or 0,
"workers": self.args.workers, "workers": self.args.workers,
"auto_tls": self.args.auto_tls,
} }
for maybe_arg in ("auto_reload", "dev"): for maybe_arg in ("auto_reload", "dev"):
@@ -201,4 +212,5 @@ Or, a path to a directory to run as a simple HTTP server:
if self.args.path: if self.args.path:
kwargs["auto_reload"] = True kwargs["auto_reload"] = True
kwargs["reload_dir"] = self.args.path kwargs["reload_dir"] = self.args.path
return kwargs return kwargs

View File

@@ -3,9 +3,10 @@ from __future__ import annotations
from argparse import ArgumentParser, _ArgumentGroup from argparse import ArgumentParser, _ArgumentGroup
from typing import List, Optional, Type, Union from typing import List, Optional, Type, Union
from sanic_routing import __version__ as __routing_version__ # type: ignore from sanic_routing import __version__ as __routing_version__
from sanic import __version__ from sanic import __version__
from sanic.http.constants import HTTP
class Group: class Group:
@@ -38,6 +39,9 @@ class Group:
"--no-" + args[0][2:], *args[1:], action="store_false", **kwargs "--no-" + args[0][2:], *args[1:], action="store_false", **kwargs
) )
def prepare(self, args) -> None:
...
class GeneralGroup(Group): class GeneralGroup(Group):
name = None name = None
@@ -83,6 +87,44 @@ class ApplicationGroup(Group):
) )
class HTTPVersionGroup(Group):
name = "HTTP version"
def attach(self):
http_values = [http.value for http in HTTP.__members__.values()]
self.container.add_argument(
"--http",
dest="http",
action="append",
choices=http_values,
type=int,
help=(
"Which HTTP version to use: HTTP/1.1 or HTTP/3. Value should\n"
"be either 1, or 3. [default 1]"
),
)
self.container.add_argument(
"-1",
dest="http",
action="append_const",
const=1,
help=("Run Sanic server using HTTP/1.1"),
)
self.container.add_argument(
"-3",
dest="http",
action="append_const",
const=3,
help=("Run Sanic server using HTTP/3"),
)
def prepare(self, args):
if not args.http:
args.http = [1]
args.http = tuple(sorted(set(map(HTTP, args.http)), reverse=True))
class SocketGroup(Group): class SocketGroup(Group):
name = "Socket binding" name = "Socket binding"
@@ -92,7 +134,6 @@ class SocketGroup(Group):
"--host", "--host",
dest="host", dest="host",
type=str, type=str,
default="127.0.0.1",
help="Host address [default 127.0.0.1]", help="Host address [default 127.0.0.1]",
) )
self.container.add_argument( self.container.add_argument(
@@ -100,7 +141,6 @@ class SocketGroup(Group):
"--port", "--port",
dest="port", dest="port",
type=int, type=int,
default=8000,
help="Port to serve on [default 8000]", help="Port to serve on [default 8000]",
) )
self.container.add_argument( self.container.add_argument(
@@ -180,11 +220,7 @@ class DevelopmentGroup(Group):
"--debug", "--debug",
dest="debug", dest="debug",
action="store_true", action="store_true",
help=( help="Run the server in debug mode",
"Run the server in DEBUG mode. It includes DEBUG logging,\n"
"additional context on exceptions, and other settings\n"
"not-safe for PRODUCTION, but helpful for debugging problems."
),
) )
self.container.add_argument( self.container.add_argument(
"-r", "-r",
@@ -209,7 +245,16 @@ class DevelopmentGroup(Group):
"--dev", "--dev",
dest="dev", dest="dev",
action="store_true", action="store_true",
help=("debug + auto reload."), help=("debug + auto reload"),
)
self.container.add_argument(
"--auto-tls",
dest="auto_tls",
action="store_true",
help=(
"Create a temporary TLS certificate for local development "
"(requires mkcert or trustme)"
),
) )

View File

@@ -1,8 +1,9 @@
import asyncio import asyncio
import os import os
import signal import signal
import sys
from sys import argv from typing import Awaitable
from multidict import CIMultiDict # type: ignore from multidict import CIMultiDict # type: ignore
@@ -47,12 +48,12 @@ class Header(CIMultiDict):
return self.getall(key, default=[]) return self.getall(key, default=[])
use_trio = argv[0].endswith("hypercorn") and "trio" in argv use_trio = sys.argv[0].endswith("hypercorn") and "trio" in sys.argv
if use_trio: # pragma: no cover if use_trio: # pragma: no cover
import trio # type: ignore import trio # type: ignore
def stat_async(path): def stat_async(path) -> Awaitable[os.stat_result]:
return trio.Path(path).stat() return trio.Path(path).stat()
open_async = trio.open_file open_async = trio.open_file
@@ -89,3 +90,7 @@ def ctrlc_workaround_for_windows(app):
die = False die = False
signal.signal(signal.SIGINT, ctrlc_handler) signal.signal(signal.SIGINT, ctrlc_handler)
app.add_task(stay_active) app.add_task(stay_active)
def is_atty() -> bool:
return bool(sys.stdout and sys.stdout.isatty())

View File

@@ -5,6 +5,7 @@ from os import environ
from pathlib import Path from pathlib import Path
from typing import Any, Callable, Dict, Optional, Sequence, Union from typing import Any, Callable, Dict, Optional, Sequence, Union
from sanic.constants import LocalCertCreator
from sanic.errorpages import DEFAULT_FORMAT, check_error_format from sanic.errorpages import DEFAULT_FORMAT, check_error_format
from sanic.helpers import Default, _default from sanic.helpers import Default, _default
from sanic.http import Http from sanic.http import Http
@@ -26,19 +27,23 @@ DEFAULT_CONFIG = {
"GRACEFUL_SHUTDOWN_TIMEOUT": 15.0, # 15 sec "GRACEFUL_SHUTDOWN_TIMEOUT": 15.0, # 15 sec
"KEEP_ALIVE_TIMEOUT": 5, # 5 seconds "KEEP_ALIVE_TIMEOUT": 5, # 5 seconds
"KEEP_ALIVE": True, "KEEP_ALIVE": True,
"LOCAL_CERT_CREATOR": LocalCertCreator.AUTO,
"LOCAL_TLS_KEY": _default,
"LOCAL_TLS_CERT": _default,
"LOCALHOST": "localhost",
"MOTD": True, "MOTD": True,
"MOTD_DISPLAY": {}, "MOTD_DISPLAY": {},
"NOISY_EXCEPTIONS": False, "NOISY_EXCEPTIONS": False,
"PROXIES_COUNT": None, "PROXIES_COUNT": None,
"REAL_IP_HEADER": None, "REAL_IP_HEADER": None,
"REGISTER": True,
"REQUEST_BUFFER_SIZE": 65536, # 64 KiB "REQUEST_BUFFER_SIZE": 65536, # 64 KiB
"REQUEST_MAX_HEADER_SIZE": 8192, # 8 KiB, but cannot exceed 16384 "REQUEST_MAX_HEADER_SIZE": 8192, # 8 KiB, but cannot exceed 16384
"REQUEST_ID_HEADER": "X-Request-ID", "REQUEST_ID_HEADER": "X-Request-ID",
"REQUEST_MAX_SIZE": 100000000, # 100 megabytes "REQUEST_MAX_SIZE": 100000000, # 100 megabytes
"REQUEST_TIMEOUT": 60, # 60 seconds "REQUEST_TIMEOUT": 60, # 60 seconds
"RESPONSE_TIMEOUT": 60, # 60 seconds "RESPONSE_TIMEOUT": 60, # 60 seconds
"TOUCHUP": True, "TLS_CERT_PASSWORD": "",
"TOUCHUP": _default,
"USE_UVLOOP": _default, "USE_UVLOOP": _default,
"WEBSOCKET_MAX_SIZE": 2**20, # 1 megabyte "WEBSOCKET_MAX_SIZE": 2**20, # 1 megabyte
"WEBSOCKET_PING_INTERVAL": 20, "WEBSOCKET_PING_INTERVAL": 20,
@@ -69,12 +74,15 @@ class Config(dict, metaclass=DescriptorMeta):
GRACEFUL_SHUTDOWN_TIMEOUT: float GRACEFUL_SHUTDOWN_TIMEOUT: float
KEEP_ALIVE_TIMEOUT: int KEEP_ALIVE_TIMEOUT: int
KEEP_ALIVE: bool KEEP_ALIVE: bool
NOISY_EXCEPTIONS: bool LOCAL_CERT_CREATOR: Union[str, LocalCertCreator]
LOCAL_TLS_KEY: Union[Path, str, Default]
LOCAL_TLS_CERT: Union[Path, str, Default]
LOCALHOST: str
MOTD: bool MOTD: bool
MOTD_DISPLAY: Dict[str, str] MOTD_DISPLAY: Dict[str, str]
NOISY_EXCEPTIONS: bool
PROXIES_COUNT: Optional[int] PROXIES_COUNT: Optional[int]
REAL_IP_HEADER: Optional[str] REAL_IP_HEADER: Optional[str]
REGISTER: bool
REQUEST_BUFFER_SIZE: int REQUEST_BUFFER_SIZE: int
REQUEST_MAX_HEADER_SIZE: int REQUEST_MAX_HEADER_SIZE: int
REQUEST_ID_HEADER: str REQUEST_ID_HEADER: str
@@ -82,7 +90,8 @@ class Config(dict, metaclass=DescriptorMeta):
REQUEST_TIMEOUT: int REQUEST_TIMEOUT: int
RESPONSE_TIMEOUT: int RESPONSE_TIMEOUT: int
SERVER_NAME: str SERVER_NAME: str
TOUCHUP: bool TLS_CERT_PASSWORD: str
TOUCHUP: Union[Default, bool]
USE_UVLOOP: Union[Default, bool] USE_UVLOOP: Union[Default, bool]
WEBSOCKET_MAX_SIZE: int WEBSOCKET_MAX_SIZE: int
WEBSOCKET_PING_INTERVAL: int WEBSOCKET_PING_INTERVAL: int
@@ -100,7 +109,6 @@ class Config(dict, metaclass=DescriptorMeta):
super().__init__({**DEFAULT_CONFIG, **defaults}) super().__init__({**DEFAULT_CONFIG, **defaults})
self._converters = [str, str_to_bool, float, int] self._converters = [str, str_to_bool, float, int]
self._LOGO = ""
if converters: if converters:
for converter in converters: for converter in converters:
@@ -157,17 +165,13 @@ class Config(dict, metaclass=DescriptorMeta):
"REQUEST_MAX_SIZE", "REQUEST_MAX_SIZE",
): ):
self._configure_header_size() self._configure_header_size()
elif attr == "LOGO":
self._LOGO = value
deprecation(
"Setting the config.LOGO is deprecated and will no longer "
"be supported starting in v22.6.",
22.6,
)
@property if attr == "LOCAL_CERT_CREATOR" and not isinstance(
def LOGO(self): self.LOCAL_CERT_CREATOR, LocalCertCreator
return self._LOGO ):
self.LOCAL_CERT_CREATOR = LocalCertCreator[
self.LOCAL_CERT_CREATOR.upper()
]
@property @property
def FALLBACK_ERROR_FORMAT(self) -> str: def FALLBACK_ERROR_FORMAT(self) -> str:

View File

@@ -24,5 +24,16 @@ class HTTPMethod(str, Enum):
DELETE = auto() DELETE = auto()
class LocalCertCreator(str, Enum):
def _generate_next_value_(name, start, count, last_values):
return name.upper()
AUTO = auto()
TRUSTME = auto()
MKCERT = auto()
HTTP_METHODS = tuple(HTTPMethod.__members__.values()) HTTP_METHODS = tuple(HTTPMethod.__members__.values())
DEFAULT_HTTP_CONTENT_TYPE = "application/octet-stream" DEFAULT_HTTP_CONTENT_TYPE = "application/octet-stream"
DEFAULT_LOCAL_TLS_KEY = "key.pem"
DEFAULT_LOCAL_TLS_CERT = "cert.pem"

View File

@@ -19,7 +19,7 @@ import typing as t
from functools import partial from functools import partial
from traceback import extract_tb from traceback import extract_tb
from sanic.exceptions import InvalidUsage, SanicException from sanic.exceptions import BadRequest, SanicException
from sanic.helpers import STATUS_CODES from sanic.helpers import STATUS_CODES
from sanic.request import Request from sanic.request import Request
from sanic.response import HTTPResponse, html, json, text from sanic.response import HTTPResponse, html, json, text
@@ -506,7 +506,7 @@ def exception_response(
# $ curl localhost:8000 -d '{"foo": "bar"}' # $ curl localhost:8000 -d '{"foo": "bar"}'
# And provide them with JSONRenderer # And provide them with JSONRenderer
renderer = JSONRenderer if request.json else base renderer = JSONRenderer if request.json else base
except InvalidUsage: except BadRequest:
renderer = base renderer = base
else: else:
renderer = RENDERERS_BY_CONFIG.get(render_format, renderer) renderer = RENDERERS_BY_CONFIG.get(render_format, renderer)

View File

@@ -42,7 +42,7 @@ class NotFound(SanicException):
quiet = True quiet = True
class InvalidUsage(SanicException): class BadRequest(SanicException):
""" """
**Status**: 400 Bad Request **Status**: 400 Bad Request
""" """
@@ -51,11 +51,14 @@ class InvalidUsage(SanicException):
quiet = True quiet = True
class BadURL(InvalidUsage): InvalidUsage = BadRequest
class BadURL(BadRequest):
... ...
class MethodNotSupported(SanicException): class MethodNotAllowed(SanicException):
""" """
**Status**: 405 Method Not Allowed **Status**: 405 Method Not Allowed
""" """
@@ -68,6 +71,9 @@ class MethodNotSupported(SanicException):
self.headers = {"Allow": ", ".join(allowed_methods)} self.headers = {"Allow": ", ".join(allowed_methods)}
MethodNotSupported = MethodNotAllowed
class ServerError(SanicException): class ServerError(SanicException):
""" """
**Status**: 500 Internal Server Error **Status**: 500 Internal Server Error
@@ -129,19 +135,19 @@ class PayloadTooLarge(SanicException):
quiet = True quiet = True
class HeaderNotFound(InvalidUsage): class HeaderNotFound(BadRequest):
""" """
**Status**: 400 Bad Request **Status**: 400 Bad Request
""" """
class InvalidHeader(InvalidUsage): class InvalidHeader(BadRequest):
""" """
**Status**: 400 Bad Request **Status**: 400 Bad Request
""" """
class ContentRangeError(SanicException): class RangeNotSatisfiable(SanicException):
""" """
**Status**: 416 Range Not Satisfiable **Status**: 416 Range Not Satisfiable
""" """
@@ -154,7 +160,10 @@ class ContentRangeError(SanicException):
self.headers = {"Content-Range": f"bytes */{content_range.total}"} self.headers = {"Content-Range": f"bytes */{content_range.total}"}
class HeaderExpectationFailed(SanicException): ContentRangeError = RangeNotSatisfiable
class ExpectationFailed(SanicException):
""" """
**Status**: 417 Expectation Failed **Status**: 417 Expectation Failed
""" """
@@ -163,6 +172,9 @@ class HeaderExpectationFailed(SanicException):
quiet = True quiet = True
HeaderExpectationFailed = ExpectationFailed
class Forbidden(SanicException): class Forbidden(SanicException):
""" """
**Status**: 403 Forbidden **Status**: 403 Forbidden
@@ -172,7 +184,7 @@ class Forbidden(SanicException):
quiet = True quiet = True
class InvalidRangeType(ContentRangeError): class InvalidRangeType(RangeNotSatisfiable):
""" """
**Status**: 416 Range Not Satisfiable **Status**: 416 Range Not Satisfiable
""" """

View File

@@ -1,21 +1,13 @@
from __future__ import annotations from __future__ import annotations
from typing import Dict, List, Optional, Tuple, Type, Union from typing import Dict, List, Optional, Tuple, Type
from sanic.config import Config from sanic.errorpages import BaseRenderer, TextRenderer, exception_response
from sanic.errorpages import (
DEFAULT_FORMAT,
BaseRenderer,
TextRenderer,
exception_response,
)
from sanic.exceptions import ( from sanic.exceptions import (
ContentRangeError,
HeaderNotFound, HeaderNotFound,
InvalidRangeType, InvalidRangeType,
SanicException, RangeNotSatisfiable,
) )
from sanic.helpers import Default, _default
from sanic.log import deprecation, error_logger from sanic.log import deprecation, error_logger
from sanic.models.handler_types import RouteHandler from sanic.models.handler_types import RouteHandler
from sanic.response import text from sanic.response import text
@@ -36,91 +28,22 @@ class ErrorHandler:
def __init__( def __init__(
self, self,
fallback: Union[str, Default] = _default,
base: Type[BaseRenderer] = TextRenderer, base: Type[BaseRenderer] = TextRenderer,
): ):
self.cached_handlers: Dict[ self.cached_handlers: Dict[
Tuple[Type[BaseException], Optional[str]], Optional[RouteHandler] Tuple[Type[BaseException], Optional[str]], Optional[RouteHandler]
] = {} ] = {}
self.debug = False self.debug = False
self._fallback = fallback
self.base = base self.base = base
if fallback is not _default: @classmethod
self._warn_fallback_deprecation() def finalize(cls, *args, **kwargs):
@property
def fallback(self): # no cov
# This is for backwards compat and can be removed in v22.6
if self._fallback is _default:
return DEFAULT_FORMAT
return self._fallback
@fallback.setter
def fallback(self, value: str): # no cov
self._warn_fallback_deprecation()
if not isinstance(value, str):
raise SanicException(
f"Cannot set error handler fallback to: value={value}"
)
self._fallback = value
@staticmethod
def _warn_fallback_deprecation():
deprecation( deprecation(
"Setting the ErrorHandler fallback value directly is " "ErrorHandler.finalize is deprecated and no longer needed. "
"deprecated and no longer supported. This feature will " "Please remove update your code to remove it. ",
"be removed in v22.6. Instead, use " 22.12,
"app.config.FALLBACK_ERROR_FORMAT.",
22.6,
) )
@classmethod
def _get_fallback_value(cls, error_handler: ErrorHandler, config: Config):
if error_handler._fallback is not _default:
if config._FALLBACK_ERROR_FORMAT is _default:
return error_handler.fallback
error_logger.warning(
"Conflicting error fallback values were found in the "
"error handler and in the app.config while handling an "
"exception. Using the value from app.config."
)
return config.FALLBACK_ERROR_FORMAT
@classmethod
def finalize(
cls,
error_handler: ErrorHandler,
config: Config,
fallback: Optional[str] = None,
):
if fallback:
deprecation(
"Setting the ErrorHandler fallback value via finalize() "
"is deprecated and no longer supported. This feature will "
"be removed in v22.6. Instead, use "
"app.config.FALLBACK_ERROR_FORMAT.",
22.6,
)
if not fallback:
fallback = config.FALLBACK_ERROR_FORMAT
if fallback != DEFAULT_FORMAT:
if error_handler._fallback is not _default:
error_logger.warning(
f"Setting the fallback value to {fallback}. This changes "
"the current non-default value "
f"'{error_handler._fallback}'."
)
error_handler._fallback = fallback
if not isinstance(error_handler, cls):
error_logger.warning(
f"Error handler is non-conforming: {type(error_handler)}"
)
def _full_lookup(self, exception, route_name: Optional[str] = None): def _full_lookup(self, exception, route_name: Optional[str] = None):
return self.lookup(exception, route_name) return self.lookup(exception, route_name)
@@ -237,7 +160,7 @@ class ErrorHandler:
:return: :return:
""" """
self.log(request, exception) self.log(request, exception)
fallback = ErrorHandler._get_fallback_value(self, request.app.config) fallback = request.app.config.FALLBACK_ERROR_FORMAT
return exception_response( return exception_response(
request, request,
exception, exception,
@@ -296,18 +219,18 @@ class ContentRangeHandler:
try: try:
self.start = int(start_b) if start_b else None self.start = int(start_b) if start_b else None
except ValueError: except ValueError:
raise ContentRangeError( raise RangeNotSatisfiable(
"'%s' is invalid for Content Range" % (start_b,), self "'%s' is invalid for Content Range" % (start_b,), self
) )
try: try:
self.end = int(end_b) if end_b else None self.end = int(end_b) if end_b else None
except ValueError: except ValueError:
raise ContentRangeError( raise RangeNotSatisfiable(
"'%s' is invalid for Content Range" % (end_b,), self "'%s' is invalid for Content Range" % (end_b,), self
) )
if self.end is None: if self.end is None:
if self.start is None: if self.start is None:
raise ContentRangeError( raise RangeNotSatisfiable(
"Invalid for Content Range parameters", self "Invalid for Content Range parameters", self
) )
else: else:
@@ -319,7 +242,7 @@ class ContentRangeHandler:
self.start = self.total - self.end self.start = self.total - self.end
self.end = self.total - 1 self.end = self.total - 1
if self.start >= self.end: if self.start >= self.end:
raise ContentRangeError( raise RangeNotSatisfiable(
"Invalid for Content Range parameters", self "Invalid for Content Range parameters", self
) )
self.size = self.end - self.start + 1 self.size = self.end - self.start + 1

6
sanic/http/__init__.py Normal file
View File

@@ -0,0 +1,6 @@
from .constants import Stage
from .http1 import Http
from .http3 import Http3
__all__ = ("Http", "Stage", "Http3")

29
sanic/http/constants.py Normal file
View File

@@ -0,0 +1,29 @@
from enum import Enum, IntEnum
class Stage(Enum):
"""
Enum for representing the stage of the request/response cycle
| ``IDLE`` Waiting for request
| ``REQUEST`` Request headers being received
| ``HANDLER`` Headers done, handler running
| ``RESPONSE`` Response headers sent, body in progress
| ``FAILED`` Unrecoverable state (error while sending response)
|
"""
IDLE = 0 # Waiting for request
REQUEST = 1 # Request headers being received
HANDLER = 3 # Headers done, handler running
RESPONSE = 4 # Response headers sent, body in progress
FAILED = 100 # Unrecoverable state (error while sending response)
class HTTP(IntEnum):
VERSION_1 = 1
VERSION_3 = 3
def display(self) -> str:
value = 1.1 if self.value == 1 else self.value
return f"HTTP/{value}"

View File

@@ -3,71 +3,52 @@ from __future__ import annotations
from typing import TYPE_CHECKING, Optional from typing import TYPE_CHECKING, Optional
if TYPE_CHECKING: # no cov if TYPE_CHECKING:
from sanic.request import Request from sanic.request import Request
from sanic.response import BaseHTTPResponse from sanic.response import BaseHTTPResponse
from asyncio import CancelledError, sleep from asyncio import CancelledError, sleep
from enum import Enum
from sanic.compat import Header from sanic.compat import Header
from sanic.exceptions import ( from sanic.exceptions import (
HeaderExpectationFailed, BadRequest,
InvalidUsage, ExpectationFailed,
PayloadTooLarge, PayloadTooLarge,
ServerError, ServerError,
ServiceUnavailable, ServiceUnavailable,
) )
from sanic.headers import format_http1_response from sanic.headers import format_http1_response
from sanic.helpers import has_message_body from sanic.helpers import has_message_body
from sanic.http.constants import Stage
from sanic.http.stream import Stream
from sanic.log import access_logger, error_logger, logger from sanic.log import access_logger, error_logger, logger
from sanic.touchup import TouchUpMeta from sanic.touchup import TouchUpMeta
class Stage(Enum):
"""
Enum for representing the stage of the request/response cycle
| ``IDLE`` Waiting for request
| ``REQUEST`` Request headers being received
| ``HANDLER`` Headers done, handler running
| ``RESPONSE`` Response headers sent, body in progress
| ``FAILED`` Unrecoverable state (error while sending response)
|
"""
IDLE = 0 # Waiting for request
REQUEST = 1 # Request headers being received
HANDLER = 3 # Headers done, handler running
RESPONSE = 4 # Response headers sent, body in progress
FAILED = 100 # Unrecoverable state (error while sending response)
HTTP_CONTINUE = b"HTTP/1.1 100 Continue\r\n\r\n" HTTP_CONTINUE = b"HTTP/1.1 100 Continue\r\n\r\n"
class Http(metaclass=TouchUpMeta): class Http(Stream, metaclass=TouchUpMeta):
""" """
Internal helper for managing the HTTP request/response cycle Internal helper for managing the HTTP/1.1 request/response cycle
:raises ServerError: :raises ServerError:
:raises PayloadTooLarge: :raises PayloadTooLarge:
:raises Exception: :raises Exception:
:raises InvalidUsage: :raises BadRequest:
:raises HeaderExpectationFailed: :raises ExpectationFailed:
:raises RuntimeError: :raises RuntimeError:
:raises ServerError: :raises ServerError:
:raises ServerError: :raises ServerError:
:raises InvalidUsage: :raises BadRequest:
:raises InvalidUsage: :raises BadRequest:
:raises InvalidUsage: :raises BadRequest:
:raises PayloadTooLarge: :raises PayloadTooLarge:
:raises RuntimeError: :raises RuntimeError:
""" """
HEADER_CEILING = 16_384 HEADER_CEILING = 16_384
HEADER_MAX_SIZE = 0 HEADER_MAX_SIZE = 0
__touchup__ = ( __touchup__ = (
"http1_request_header", "http1_request_header",
"http1_response_header", "http1_response_header",
@@ -248,7 +229,7 @@ class Http(metaclass=TouchUpMeta):
headers.append(h) headers.append(h)
except Exception: except Exception:
raise InvalidUsage("Bad Request") raise BadRequest("Bad Request")
headers_instance = Header(headers) headers_instance = Header(headers)
self.upgrade_websocket = ( self.upgrade_websocket = (
@@ -265,6 +246,7 @@ class Http(metaclass=TouchUpMeta):
transport=self.protocol.transport, transport=self.protocol.transport,
app=self.protocol.app, app=self.protocol.app,
) )
self.protocol.request_class._current.set(request)
await self.dispatch( await self.dispatch(
"http.lifecycle.request", "http.lifecycle.request",
inline=True, inline=True,
@@ -281,7 +263,7 @@ class Http(metaclass=TouchUpMeta):
if expect.lower() == "100-continue": if expect.lower() == "100-continue":
self.expecting_continue = True self.expecting_continue = True
else: else:
raise HeaderExpectationFailed(f"Unknown Expect: {expect}") raise ExpectationFailed(f"Unknown Expect: {expect}")
if headers.getone("transfer-encoding", None) == "chunked": if headers.getone("transfer-encoding", None) == "chunked":
self.request_body = "chunked" self.request_body = "chunked"
@@ -352,6 +334,12 @@ class Http(metaclass=TouchUpMeta):
self.response_func = self.head_response_ignored self.response_func = self.head_response_ignored
headers["connection"] = "keep-alive" if self.keep_alive else "close" headers["connection"] = "keep-alive" if self.keep_alive else "close"
# This header may be removed or modified by the AltSvcCheck Touchup
# service. At server start, we either remove this header from ever
# being assigned, or we change the value as required.
headers["alt-svc"] = ""
ret = format_http1_response(status, res.processed_headers) ret = format_http1_response(status, res.processed_headers)
if data: if data:
ret += data ret += data
@@ -510,7 +498,7 @@ class Http(metaclass=TouchUpMeta):
if len(buf) > 64: if len(buf) > 64:
self.keep_alive = False self.keep_alive = False
raise InvalidUsage("Bad chunked encoding") raise BadRequest("Bad chunked encoding")
await self._receive_more() await self._receive_more()
@@ -518,14 +506,14 @@ class Http(metaclass=TouchUpMeta):
size = int(buf[2:pos].split(b";", 1)[0].decode(), 16) size = int(buf[2:pos].split(b";", 1)[0].decode(), 16)
except Exception: except Exception:
self.keep_alive = False self.keep_alive = False
raise InvalidUsage("Bad chunked encoding") raise BadRequest("Bad chunked encoding")
if size <= 0: if size <= 0:
self.request_body = None self.request_body = None
if size < 0: if size < 0:
self.keep_alive = False self.keep_alive = False
raise InvalidUsage("Bad chunked encoding") raise BadRequest("Bad chunked encoding")
# Consume CRLF, chunk size 0 and the two CRLF that follow # Consume CRLF, chunk size 0 and the two CRLF that follow
pos += 4 pos += 4

406
sanic/http/http3.py Normal file
View File

@@ -0,0 +1,406 @@
from __future__ import annotations
import asyncio
from abc import ABC, abstractmethod
from ssl import SSLContext
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
List,
Optional,
Tuple,
Union,
cast,
)
from sanic.compat import Header
from sanic.constants import LocalCertCreator
from sanic.exceptions import PayloadTooLarge, SanicException, ServerError
from sanic.helpers import has_message_body
from sanic.http.constants import Stage
from sanic.http.stream import Stream
from sanic.http.tls.context import CertSelector, CertSimple, SanicSSLContext
from sanic.log import Colors, logger
from sanic.models.protocol_types import TransportProtocol
from sanic.models.server_types import ConnInfo
try:
from aioquic.h0.connection import H0_ALPN, H0Connection
from aioquic.h3.connection import H3_ALPN, H3Connection
from aioquic.h3.events import (
DatagramReceived,
DataReceived,
H3Event,
HeadersReceived,
WebTransportStreamDataReceived,
)
from aioquic.quic.configuration import QuicConfiguration
from aioquic.tls import SessionTicket
HTTP3_AVAILABLE = True
except ModuleNotFoundError: # no cov
HTTP3_AVAILABLE = False
if TYPE_CHECKING:
from sanic import Sanic
from sanic.request import Request
from sanic.response import BaseHTTPResponse
from sanic.server.protocols.http_protocol import Http3Protocol
HttpConnection = Union[H0Connection, H3Connection]
class HTTP3Transport(TransportProtocol):
__slots__ = ("_protocol",)
def __init__(self, protocol: Http3Protocol):
self._protocol = protocol
def get_protocol(self) -> Http3Protocol:
return self._protocol
def get_extra_info(self, info: str, default: Any = None) -> Any:
if (
info in ("socket", "sockname", "peername")
and self._protocol._transport
):
return self._protocol._transport.get_extra_info(info, default)
elif info == "network_paths":
return self._protocol._quic._network_paths
elif info == "ssl_context":
return self._protocol.app.state.ssl
return default
class Receiver(ABC):
future: asyncio.Future
def __init__(self, transmit, protocol, request: Request) -> None:
self.transmit = transmit
self.protocol = protocol
self.request = request
@abstractmethod
async def run(self): # no cov
...
class HTTPReceiver(Receiver, Stream):
stage: Stage
request: Request
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.request_body = None
self.stage = Stage.IDLE
self.headers_sent = False
self.response: Optional[BaseHTTPResponse] = None
self.request_max_size = self.protocol.request_max_size
self.request_bytes = 0
async def run(self, exception: Optional[Exception] = None):
self.stage = Stage.HANDLER
self.head_only = self.request.method.upper() == "HEAD"
if exception:
logger.info( # no cov
f"{Colors.BLUE}[exception]: "
f"{Colors.RED}{exception}{Colors.END}",
exc_info=True,
extra={"verbosity": 1},
)
await self.error_response(exception)
else:
try:
logger.info( # no cov
f"{Colors.BLUE}[request]:{Colors.END} {self.request}",
extra={"verbosity": 1},
)
await self.protocol.request_handler(self.request)
except Exception as e: # no cov
# This should largely be handled within the request handler.
# But, just in case...
await self.run(e)
self.stage = Stage.IDLE
async def error_response(self, exception: Exception) -> None:
"""
Handle response when exception encountered
"""
# From request and handler states we can respond, otherwise be silent
app = self.protocol.app
await app.handle_exception(self.request, exception)
def _prepare_headers(
self, response: BaseHTTPResponse
) -> List[Tuple[bytes, bytes]]:
size = len(response.body) if response.body else 0
headers = response.headers
status = response.status
if not has_message_body(status) and (
size
or "content-length" in headers
or "transfer-encoding" in headers
):
headers.pop("content-length", None)
headers.pop("transfer-encoding", None)
logger.warning( # no cov
f"Message body set in response on {self.request.path}. "
f"A {status} response may only have headers, no body."
)
elif "content-length" not in headers:
if size:
headers["content-length"] = size
else:
headers["transfer-encoding"] = "chunked"
headers = [
(b":status", str(response.status).encode()),
*response.processed_headers,
]
return headers
def send_headers(self) -> None:
logger.debug( # no cov
f"{Colors.BLUE}[send]: {Colors.GREEN}HEADERS{Colors.END}",
extra={"verbosity": 2},
)
if not self.response:
raise RuntimeError("no response")
response = self.response
headers = self._prepare_headers(response)
self.protocol.connection.send_headers(
stream_id=self.request.stream_id,
headers=headers,
)
self.headers_sent = True
self.stage = Stage.RESPONSE
if self.response.body and not self.head_only:
self._send(self.response.body, False)
elif self.head_only:
self.future.cancel()
def respond(self, response: BaseHTTPResponse) -> BaseHTTPResponse:
logger.debug( # no cov
f"{Colors.BLUE}[respond]:{Colors.END} {response}",
extra={"verbosity": 2},
)
if self.stage is not Stage.HANDLER:
self.stage = Stage.FAILED
raise RuntimeError("Response already started")
# Disconnect any earlier but unused response object
if self.response is not None:
self.response.stream = None
self.response, response.stream = response, self
return response
def receive_body(self, data: bytes) -> None:
self.request_bytes += len(data)
if self.request_bytes > self.request_max_size:
raise PayloadTooLarge("Request body exceeds the size limit")
self.request.body += data
async def send(self, data: bytes, end_stream: bool) -> None:
logger.debug( # no cov
f"{Colors.BLUE}[send]: {Colors.GREEN}data={data.decode()} "
f"end_stream={end_stream}{Colors.END}",
extra={"verbosity": 2},
)
self._send(data, end_stream)
def _send(self, data: bytes, end_stream: bool) -> None:
if not self.headers_sent:
self.send_headers()
if self.stage is not Stage.RESPONSE:
raise ServerError(f"not ready to send: {self.stage}")
# Chunked
if (
self.response
and self.response.headers.get("transfer-encoding") == "chunked"
):
size = len(data)
if end_stream:
data = (
b"%x\r\n%b\r\n0\r\n\r\n" % (size, data)
if size
else b"0\r\n\r\n"
)
elif size:
data = b"%x\r\n%b\r\n" % (size, data)
logger.debug( # no cov
f"{Colors.BLUE}[transmitting]{Colors.END}",
extra={"verbosity": 2},
)
self.protocol.connection.send_data(
stream_id=self.request.stream_id,
data=data,
end_stream=end_stream,
)
self.transmit()
if end_stream:
self.stage = Stage.IDLE
class WebsocketReceiver(Receiver): # noqa
async def run(self):
...
class WebTransportReceiver(Receiver): # noqa
async def run(self):
...
class Http3:
"""
Internal helper for managing the HTTP/3 request/response cycle
"""
if HTTP3_AVAILABLE:
HANDLER_PROPERTY_MAPPING = {
DataReceived: "stream_id",
HeadersReceived: "stream_id",
DatagramReceived: "flow_id",
WebTransportStreamDataReceived: "session_id",
}
def __init__(
self,
protocol: Http3Protocol,
transmit: Callable[[], None],
) -> None:
self.protocol = protocol
self.transmit = transmit
self.receivers: Dict[int, Receiver] = {}
def http_event_received(self, event: H3Event) -> None:
logger.debug( # no cov
f"{Colors.BLUE}[http_event_received]: "
f"{Colors.YELLOW}{event}{Colors.END}",
extra={"verbosity": 2},
)
receiver, created_new = self.get_or_make_receiver(event)
receiver = cast(HTTPReceiver, receiver)
if isinstance(event, HeadersReceived) and created_new:
receiver.future = asyncio.ensure_future(receiver.run())
elif isinstance(event, DataReceived):
try:
receiver.receive_body(event.data)
except Exception as e:
receiver.future.cancel()
receiver.future = asyncio.ensure_future(receiver.run(e))
else:
... # Intentionally here to help out Touchup
logger.debug( # no cov
f"{Colors.RED}DOING NOTHING{Colors.END}",
extra={"verbosity": 2},
)
def get_or_make_receiver(self, event: H3Event) -> Tuple[Receiver, bool]:
if (
isinstance(event, HeadersReceived)
and event.stream_id not in self.receivers
):
request = self._make_request(event)
receiver = HTTPReceiver(self.transmit, self.protocol, request)
request.stream = receiver
self.receivers[event.stream_id] = receiver
return receiver, True
else:
ident = getattr(event, self.HANDLER_PROPERTY_MAPPING[type(event)])
return self.receivers[ident], False
def get_receiver_by_stream_id(self, stream_id: int) -> Receiver:
return self.receivers[stream_id]
def _make_request(self, event: HeadersReceived) -> Request:
headers = Header(((k.decode(), v.decode()) for k, v in event.headers))
method = headers[":method"]
path = headers[":path"]
scheme = headers.pop(":scheme", "")
authority = headers.pop(":authority", "")
if authority:
headers["host"] = authority
transport = HTTP3Transport(self.protocol)
request = self.protocol.request_class(
path.encode(),
headers,
"3",
method,
transport,
self.protocol.app,
b"",
)
request.conn_info = ConnInfo(transport)
request._stream_id = event.stream_id
request._scheme = scheme
return request
class SessionTicketStore:
"""
Simple in-memory store for session tickets.
"""
def __init__(self) -> None:
self.tickets: Dict[bytes, SessionTicket] = {}
def add(self, ticket: SessionTicket) -> None:
self.tickets[ticket.ticket] = ticket
def pop(self, label: bytes) -> Optional[SessionTicket]:
return self.tickets.pop(label, None)
def get_config(
app: Sanic, ssl: Union[SanicSSLContext, CertSelector, SSLContext]
):
# TODO:
# - proper selection needed if servince with multiple certs insted of
# just taking the first
if isinstance(ssl, CertSelector):
ssl = cast(SanicSSLContext, ssl.sanic_select[0])
if app.config.LOCAL_CERT_CREATOR is LocalCertCreator.TRUSTME:
raise SanicException(
"Sorry, you cannot currently use trustme as a local certificate "
"generator for an HTTP/3 server. This is not yet supported. You "
"should be able to use mkcert instead. For more information, see: "
"https://github.com/aiortc/aioquic/issues/295."
)
if not isinstance(ssl, CertSimple):
raise SanicException("SSLContext is not CertSimple")
config = QuicConfiguration(
alpn_protocols=H3_ALPN + H0_ALPN + ["siduck"],
is_client=False,
max_datagram_frame_size=65536,
)
password = app.config.TLS_CERT_PASSWORD or None
config.load_cert_chain(
ssl.sanic["cert"], ssl.sanic["key"], password=password
)
return config

27
sanic/http/stream.py Normal file
View File

@@ -0,0 +1,27 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Optional, Tuple, Union
from sanic.http.constants import Stage
if TYPE_CHECKING:
from sanic.response import BaseHTTPResponse
from sanic.server.protocols.http_protocol import HttpProtocol
class Stream:
stage: Stage
response: Optional[BaseHTTPResponse]
protocol: HttpProtocol
url: Optional[str]
request_body: Optional[bytes]
request_max_size: Union[int, float]
__touchup__: Tuple[str, ...] = tuple()
__slots__ = ()
def respond(
self, response: BaseHTTPResponse
) -> BaseHTTPResponse: # no cov
raise NotImplementedError("Not implemented")

View File

@@ -0,0 +1,5 @@
from .context import process_to_context
from .creators import get_ssl_context
__all__ = ("get_ssl_context", "process_to_context")

View File

@@ -1,7 +1,9 @@
from __future__ import annotations
import os import os
import ssl import ssl
from typing import Iterable, Optional, Union from typing import Any, Dict, Iterable, Optional, Union
from sanic.log import logger from sanic.log import logger
@@ -77,65 +79,6 @@ def load_cert_dir(p: str) -> ssl.SSLContext:
return CertSimple(certfile, keyfile) return CertSimple(certfile, keyfile)
class CertSimple(ssl.SSLContext):
"""A wrapper for creating SSLContext with a sanic attribute."""
def __new__(cls, cert, key, **kw):
# try common aliases, rename to cert/key
certfile = kw["cert"] = kw.pop("certificate", None) or cert
keyfile = kw["key"] = kw.pop("keyfile", None) or key
password = kw.pop("password", None)
if not certfile or not keyfile:
raise ValueError("SSL dict needs filenames for cert and key.")
subject = {}
if "names" not in kw:
cert = ssl._ssl._test_decode_cert(certfile) # type: ignore
kw["names"] = [
name
for t, name in cert["subjectAltName"]
if t in ["DNS", "IP Address"]
]
subject = {k: v for item in cert["subject"] for k, v in item}
self = create_context(certfile, keyfile, password)
self.__class__ = cls
self.sanic = {**subject, **kw}
return self
def __init__(self, cert, key, **kw):
pass # Do not call super().__init__ because it is already initialized
class CertSelector(ssl.SSLContext):
"""Automatically select SSL certificate based on the hostname that the
client is trying to access, via SSL SNI. Paths to certificate folders
with privkey.pem and fullchain.pem in them should be provided, and
will be matched in the order given whenever there is a new connection.
"""
def __new__(cls, ctxs):
return super().__new__(cls)
def __init__(self, ctxs: Iterable[Optional[ssl.SSLContext]]):
super().__init__()
self.sni_callback = selector_sni_callback # type: ignore
self.sanic_select = []
self.sanic_fallback = None
all_names = []
for i, ctx in enumerate(ctxs):
if not ctx:
continue
names = dict(getattr(ctx, "sanic", {})).get("names", [])
all_names += names
self.sanic_select.append(ctx)
if i == 0:
self.sanic_fallback = ctx
if not all_names:
raise ValueError(
"No certificates with SubjectAlternativeNames found."
)
logger.info(f"Certificate vhosts: {', '.join(all_names)}")
def find_cert(self: CertSelector, server_name: str): def find_cert(self: CertSelector, server_name: str):
"""Find the first certificate that matches the given SNI. """Find the first certificate that matches the given SNI.
@@ -194,3 +137,73 @@ def server_name_callback(
) -> None: ) -> None:
"""Store the received SNI as sslobj.sanic_server_name.""" """Store the received SNI as sslobj.sanic_server_name."""
sslobj.sanic_server_name = server_name # type: ignore sslobj.sanic_server_name = server_name # type: ignore
class SanicSSLContext(ssl.SSLContext):
sanic: Dict[str, os.PathLike]
@classmethod
def create_from_ssl_context(cls, context: ssl.SSLContext):
context.__class__ = cls
return context
class CertSimple(SanicSSLContext):
"""A wrapper for creating SSLContext with a sanic attribute."""
sanic: Dict[str, Any]
def __new__(cls, cert, key, **kw):
# try common aliases, rename to cert/key
certfile = kw["cert"] = kw.pop("certificate", None) or cert
keyfile = kw["key"] = kw.pop("keyfile", None) or key
password = kw.pop("password", None)
if not certfile or not keyfile:
raise ValueError("SSL dict needs filenames for cert and key.")
subject = {}
if "names" not in kw:
cert = ssl._ssl._test_decode_cert(certfile) # type: ignore
kw["names"] = [
name
for t, name in cert["subjectAltName"]
if t in ["DNS", "IP Address"]
]
subject = {k: v for item in cert["subject"] for k, v in item}
self = create_context(certfile, keyfile, password)
self.__class__ = cls
self.sanic = {**subject, **kw}
return self
def __init__(self, cert, key, **kw):
pass # Do not call super().__init__ because it is already initialized
class CertSelector(ssl.SSLContext):
"""Automatically select SSL certificate based on the hostname that the
client is trying to access, via SSL SNI. Paths to certificate folders
with privkey.pem and fullchain.pem in them should be provided, and
will be matched in the order given whenever there is a new connection.
"""
def __new__(cls, ctxs):
return super().__new__(cls)
def __init__(self, ctxs: Iterable[Optional[ssl.SSLContext]]):
super().__init__()
self.sni_callback = selector_sni_callback # type: ignore
self.sanic_select = []
self.sanic_fallback = None
all_names = []
for i, ctx in enumerate(ctxs):
if not ctx:
continue
names = dict(getattr(ctx, "sanic", {})).get("names", [])
all_names += names
self.sanic_select.append(ctx)
if i == 0:
self.sanic_fallback = ctx
if not all_names:
raise ValueError(
"No certificates with SubjectAlternativeNames found."
)
logger.info(f"Certificate vhosts: {', '.join(all_names)}")

278
sanic/http/tls/creators.py Normal file
View File

@@ -0,0 +1,278 @@
from __future__ import annotations
import ssl
import subprocess
import sys
from abc import ABC, abstractmethod
from contextlib import suppress
from pathlib import Path
from tempfile import mkdtemp
from types import ModuleType
from typing import TYPE_CHECKING, Optional, Tuple, Type, Union, cast
from sanic.application.constants import Mode
from sanic.application.spinner import loading
from sanic.constants import (
DEFAULT_LOCAL_TLS_CERT,
DEFAULT_LOCAL_TLS_KEY,
LocalCertCreator,
)
from sanic.exceptions import SanicException
from sanic.helpers import Default
from sanic.http.tls.context import CertSimple, SanicSSLContext
try:
import trustme
TRUSTME_INSTALLED = True
except (ImportError, ModuleNotFoundError):
trustme = ModuleType("trustme")
TRUSTME_INSTALLED = False
if TYPE_CHECKING:
from sanic import Sanic
# Only allow secure ciphers, notably leaving out AES-CBC mode
# OpenSSL chooses ECDSA or RSA depending on the cert in use
CIPHERS_TLS12 = [
"ECDHE-ECDSA-CHACHA20-POLY1305",
"ECDHE-ECDSA-AES256-GCM-SHA384",
"ECDHE-ECDSA-AES128-GCM-SHA256",
"ECDHE-RSA-CHACHA20-POLY1305",
"ECDHE-RSA-AES256-GCM-SHA384",
"ECDHE-RSA-AES128-GCM-SHA256",
]
def _make_path(maybe_path: Union[Path, str], tmpdir: Optional[Path]) -> Path:
if isinstance(maybe_path, Path):
return maybe_path
else:
path = Path(maybe_path)
if not path.exists():
if not tmpdir:
raise RuntimeError("Reached an unknown state. No tmpdir.")
return tmpdir / maybe_path
return path
def get_ssl_context(
app: Sanic, ssl: Optional[ssl.SSLContext]
) -> ssl.SSLContext:
if ssl:
return ssl
if app.state.mode is Mode.PRODUCTION:
raise SanicException(
"Cannot run Sanic as an HTTPS server in PRODUCTION mode "
"without passing a TLS certificate. If you are developing "
"locally, please enable DEVELOPMENT mode and Sanic will "
"generate a localhost TLS certificate. For more information "
"please see: ___."
)
creator = CertCreator.select(
app,
cast(LocalCertCreator, app.config.LOCAL_CERT_CREATOR),
app.config.LOCAL_TLS_KEY,
app.config.LOCAL_TLS_CERT,
)
context = creator.generate_cert(app.config.LOCALHOST)
return context
class CertCreator(ABC):
def __init__(self, app, key, cert) -> None:
self.app = app
self.key = key
self.cert = cert
self.tmpdir = None
if isinstance(self.key, Default) or isinstance(self.cert, Default):
self.tmpdir = Path(mkdtemp())
key = (
DEFAULT_LOCAL_TLS_KEY
if isinstance(self.key, Default)
else self.key
)
cert = (
DEFAULT_LOCAL_TLS_CERT
if isinstance(self.cert, Default)
else self.cert
)
self.key_path = _make_path(key, self.tmpdir)
self.cert_path = _make_path(cert, self.tmpdir)
@abstractmethod
def check_supported(self) -> None: # no cov
...
@abstractmethod
def generate_cert(self, localhost: str) -> ssl.SSLContext: # no cov
...
@classmethod
def select(
cls,
app: Sanic,
cert_creator: LocalCertCreator,
local_tls_key,
local_tls_cert,
) -> CertCreator:
creator: Optional[CertCreator] = None
cert_creator_options: Tuple[
Tuple[Type[CertCreator], LocalCertCreator], ...
] = (
(MkcertCreator, LocalCertCreator.MKCERT),
(TrustmeCreator, LocalCertCreator.TRUSTME),
)
for creator_class, local_creator in cert_creator_options:
creator = cls._try_select(
app,
creator,
creator_class,
local_creator,
cert_creator,
local_tls_key,
local_tls_cert,
)
if creator:
break
if not creator:
raise SanicException(
"Sanic could not find package to create a TLS certificate. "
"You must have either mkcert or trustme installed. See "
"_____ for more details."
)
return creator
@staticmethod
def _try_select(
app: Sanic,
creator: Optional[CertCreator],
creator_class: Type[CertCreator],
creator_requirement: LocalCertCreator,
creator_requested: LocalCertCreator,
local_tls_key,
local_tls_cert,
):
if creator or (
creator_requested is not LocalCertCreator.AUTO
and creator_requested is not creator_requirement
):
return creator
instance = creator_class(app, local_tls_key, local_tls_cert)
try:
instance.check_supported()
except SanicException:
if creator_requested is creator_requirement:
raise
else:
return None
return instance
class MkcertCreator(CertCreator):
def check_supported(self) -> None:
try:
subprocess.run( # nosec B603 B607
["mkcert", "-help"],
check=True,
stderr=subprocess.DEVNULL,
stdout=subprocess.DEVNULL,
)
except Exception as e:
raise SanicException(
"Sanic is attempting to use mkcert to generate local TLS "
"certificates since you did not supply a certificate, but "
"one is required. Sanic cannot proceed since mkcert does not "
"appear to be installed. Alternatively, you can use trustme. "
"Please install mkcert, trustme, or supply TLS certificates "
"to proceed. Installation instructions can be found here: "
"https://github.com/FiloSottile/mkcert.\n"
"Find out more information about your options here: "
"_____"
) from e
def generate_cert(self, localhost: str) -> ssl.SSLContext:
try:
if not self.cert_path.exists():
message = "Generating TLS certificate"
# TODO: Validate input for security
with loading(message):
cmd = [
"mkcert",
"-key-file",
str(self.key_path),
"-cert-file",
str(self.cert_path),
localhost,
]
resp = subprocess.run( # nosec B603
cmd,
check=True,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True,
)
sys.stdout.write("\r" + " " * (len(message) + 4))
sys.stdout.flush()
sys.stdout.write(resp.stdout)
finally:
@self.app.main_process_stop
async def cleanup(*_): # no cov
if self.tmpdir:
with suppress(FileNotFoundError):
self.key_path.unlink()
self.cert_path.unlink()
self.tmpdir.rmdir()
return CertSimple(self.cert_path, self.key_path)
class TrustmeCreator(CertCreator):
def check_supported(self) -> None:
if not TRUSTME_INSTALLED:
raise SanicException(
"Sanic is attempting to use trustme to generate local TLS "
"certificates since you did not supply a certificate, but "
"one is required. Sanic cannot proceed since trustme does not "
"appear to be installed. Alternatively, you can use mkcert. "
"Please install mkcert, trustme, or supply TLS certificates "
"to proceed. Installation instructions can be found here: "
"https://github.com/python-trio/trustme.\n"
"Find out more information about your options here: "
"_____"
)
def generate_cert(self, localhost: str) -> ssl.SSLContext:
context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
sanic_context = SanicSSLContext.create_from_ssl_context(context)
sanic_context.sanic = {
"cert": self.cert_path.absolute(),
"key": self.key_path.absolute(),
}
ca = trustme.CA()
server_cert = ca.issue_cert(localhost)
server_cert.configure_cert(sanic_context)
ca.configure_trust(context)
ca.cert_pem.write_to_path(str(self.cert_path.absolute()))
server_cert.private_key_and_cert_chain_pem.write_to_path(
str(self.key_path.absolute())
)
return context

View File

@@ -5,6 +5,8 @@ from enum import Enum
from typing import Any, Dict from typing import Any, Dict
from warnings import warn from warnings import warn
from sanic.compat import is_atty
LOGGING_CONFIG_DEFAULTS: Dict[str, Any] = dict( # no cov LOGGING_CONFIG_DEFAULTS: Dict[str, Any] = dict( # no cov
version=1, version=1,
@@ -55,35 +57,53 @@ LOGGING_CONFIG_DEFAULTS: Dict[str, Any] = dict( # no cov
}, },
}, },
) )
"""
Defult logging configuration
"""
class Colors(str, Enum): # no cov class Colors(str, Enum): # no cov
END = "\033[0m" END = "\033[0m"
BLUE = "\033[01;34m" BLUE = "\033[01;34m"
GREEN = "\033[01;32m" GREEN = "\033[01;32m"
YELLOW = "\033[01;33m" PURPLE = "\033[01;35m"
RED = "\033[01;31m" RED = "\033[01;31m"
SANIC = "\033[38;2;255;13;104m"
YELLOW = "\033[01;33m"
class VerbosityFilter(logging.Filter):
verbosity: int = 0
def filter(self, record: logging.LogRecord) -> bool:
verbosity = getattr(record, "verbosity", 0)
return verbosity <= self.verbosity
_verbosity_filter = VerbosityFilter()
logger = logging.getLogger("sanic.root") # no cov logger = logging.getLogger("sanic.root") # no cov
""" """
General Sanic logger General Sanic logger
""" """
logger.addFilter(_verbosity_filter)
error_logger = logging.getLogger("sanic.error") # no cov error_logger = logging.getLogger("sanic.error") # no cov
""" """
Logger used by Sanic for error logging Logger used by Sanic for error logging
""" """
error_logger.addFilter(_verbosity_filter)
access_logger = logging.getLogger("sanic.access") # no cov access_logger = logging.getLogger("sanic.access") # no cov
""" """
Logger used by Sanic for access logging Logger used by Sanic for access logging
""" """
access_logger.addFilter(_verbosity_filter)
def deprecation(message: str, version: float): # no cov def deprecation(message: str, version: float): # no cov
version_info = f"[DEPRECATION v{version}] " version_info = f"[DEPRECATION v{version}] "
if sys.stdout.isatty(): if is_atty():
version_info = f"{Colors.RED}{version_info}" version_info = f"{Colors.RED}{version_info}"
message = f"{Colors.YELLOW}{message}{Colors.END}" message = f"{Colors.YELLOW}{message}{Colors.END}"
warn(version_info + message, DeprecationWarning) warn(version_info + message, DeprecationWarning)

View File

@@ -3,7 +3,7 @@ from functools import partial
from typing import Callable, List, Optional, Union, overload from typing import Callable, List, Optional, Union, overload
from sanic.base.meta import SanicMeta from sanic.base.meta import SanicMeta
from sanic.exceptions import InvalidUsage from sanic.exceptions import BadRequest
from sanic.models.futures import FutureListener from sanic.models.futures import FutureListener
from sanic.models.handler_types import ListenerType, Sanic from sanic.models.handler_types import ListenerType, Sanic
@@ -86,7 +86,7 @@ class ListenerMixin(metaclass=SanicMeta):
if callable(listener_or_event): if callable(listener_or_event):
if event_or_none is None: if event_or_none is None:
raise InvalidUsage( raise BadRequest(
"Invalid event registration: Missing event name." "Invalid event registration: Missing event name."
) )
return register_listener(listener_or_event, event_or_none) return register_listener(listener_or_event, event_or_none)

View File

@@ -4,27 +4,31 @@ from functools import partial, wraps
from inspect import getsource, signature from inspect import getsource, signature
from mimetypes import guess_type from mimetypes import guess_type
from os import path from os import path
from pathlib import PurePath from pathlib import Path, PurePath
from re import sub
from textwrap import dedent from textwrap import dedent
from time import gmtime, strftime from time import gmtime, strftime
from typing import Any, Callable, Iterable, List, Optional, Set, Tuple, Union from typing import (
Any,
Callable,
Iterable,
List,
Optional,
Set,
Tuple,
Union,
cast,
)
from urllib.parse import unquote from urllib.parse import unquote
from sanic_routing.route import Route # type: ignore from sanic_routing.route import Route
from sanic.base.meta import SanicMeta from sanic.base.meta import SanicMeta
from sanic.compat import stat_async from sanic.compat import stat_async
from sanic.constants import DEFAULT_HTTP_CONTENT_TYPE, HTTP_METHODS from sanic.constants import DEFAULT_HTTP_CONTENT_TYPE, HTTP_METHODS
from sanic.errorpages import RESPONSE_MAPPING from sanic.errorpages import RESPONSE_MAPPING
from sanic.exceptions import ( from sanic.exceptions import FileNotFound, HeaderNotFound, RangeNotSatisfiable
ContentRangeError,
FileNotFound,
HeaderNotFound,
InvalidUsage,
)
from sanic.handlers import ContentRangeHandler from sanic.handlers import ContentRangeHandler
from sanic.log import deprecation, error_logger from sanic.log import error_logger
from sanic.models.futures import FutureRoute, FutureStatic from sanic.models.futures import FutureRoute, FutureStatic
from sanic.models.handler_types import RouteHandler from sanic.models.handler_types import RouteHandler
from sanic.response import HTTPResponse, file, file_stream from sanic.response import HTTPResponse, file, file_stream
@@ -283,7 +287,7 @@ class RouteMixin(metaclass=SanicMeta):
version_prefix: str = "/v", version_prefix: str = "/v",
error_format: Optional[str] = None, error_format: Optional[str] = None,
**ctx_kwargs, **ctx_kwargs,
) -> RouteWrapper: ) -> RouteHandler:
""" """
Add an API URL under the **GET** *HTTP* method Add an API URL under the **GET** *HTTP* method
@@ -299,17 +303,20 @@ class RouteMixin(metaclass=SanicMeta):
will be appended to the route context (``route.ctx``) will be appended to the route context (``route.ctx``)
:return: Object decorated with :func:`route` method :return: Object decorated with :func:`route` method
""" """
return self.route( return cast(
uri, RouteHandler,
methods=frozenset({"GET"}), self.route(
host=host, uri,
strict_slashes=strict_slashes, methods=frozenset({"GET"}),
version=version, host=host,
name=name, strict_slashes=strict_slashes,
ignore_body=ignore_body, version=version,
version_prefix=version_prefix, name=name,
error_format=error_format, ignore_body=ignore_body,
**ctx_kwargs, version_prefix=version_prefix,
error_format=error_format,
**ctx_kwargs,
),
) )
def post( def post(
@@ -323,7 +330,7 @@ class RouteMixin(metaclass=SanicMeta):
version_prefix: str = "/v", version_prefix: str = "/v",
error_format: Optional[str] = None, error_format: Optional[str] = None,
**ctx_kwargs, **ctx_kwargs,
) -> RouteWrapper: ) -> RouteHandler:
""" """
Add an API URL under the **POST** *HTTP* method Add an API URL under the **POST** *HTTP* method
@@ -339,17 +346,20 @@ class RouteMixin(metaclass=SanicMeta):
will be appended to the route context (``route.ctx``) will be appended to the route context (``route.ctx``)
:return: Object decorated with :func:`route` method :return: Object decorated with :func:`route` method
""" """
return self.route( return cast(
uri, RouteHandler,
methods=frozenset({"POST"}), self.route(
host=host, uri,
strict_slashes=strict_slashes, methods=frozenset({"POST"}),
stream=stream, host=host,
version=version, strict_slashes=strict_slashes,
name=name, stream=stream,
version_prefix=version_prefix, version=version,
error_format=error_format, name=name,
**ctx_kwargs, version_prefix=version_prefix,
error_format=error_format,
**ctx_kwargs,
),
) )
def put( def put(
@@ -363,7 +373,7 @@ class RouteMixin(metaclass=SanicMeta):
version_prefix: str = "/v", version_prefix: str = "/v",
error_format: Optional[str] = None, error_format: Optional[str] = None,
**ctx_kwargs, **ctx_kwargs,
) -> RouteWrapper: ) -> RouteHandler:
""" """
Add an API URL under the **PUT** *HTTP* method Add an API URL under the **PUT** *HTTP* method
@@ -379,17 +389,20 @@ class RouteMixin(metaclass=SanicMeta):
will be appended to the route context (``route.ctx``) will be appended to the route context (``route.ctx``)
:return: Object decorated with :func:`route` method :return: Object decorated with :func:`route` method
""" """
return self.route( return cast(
uri, RouteHandler,
methods=frozenset({"PUT"}), self.route(
host=host, uri,
strict_slashes=strict_slashes, methods=frozenset({"PUT"}),
stream=stream, host=host,
version=version, strict_slashes=strict_slashes,
name=name, stream=stream,
version_prefix=version_prefix, version=version,
error_format=error_format, name=name,
**ctx_kwargs, version_prefix=version_prefix,
error_format=error_format,
**ctx_kwargs,
),
) )
def head( def head(
@@ -403,7 +416,7 @@ class RouteMixin(metaclass=SanicMeta):
version_prefix: str = "/v", version_prefix: str = "/v",
error_format: Optional[str] = None, error_format: Optional[str] = None,
**ctx_kwargs, **ctx_kwargs,
) -> RouteWrapper: ) -> RouteHandler:
""" """
Add an API URL under the **HEAD** *HTTP* method Add an API URL under the **HEAD** *HTTP* method
@@ -427,17 +440,20 @@ class RouteMixin(metaclass=SanicMeta):
will be appended to the route context (``route.ctx``) will be appended to the route context (``route.ctx``)
:return: Object decorated with :func:`route` method :return: Object decorated with :func:`route` method
""" """
return self.route( return cast(
uri, RouteHandler,
methods=frozenset({"HEAD"}), self.route(
host=host, uri,
strict_slashes=strict_slashes, methods=frozenset({"HEAD"}),
version=version, host=host,
name=name, strict_slashes=strict_slashes,
ignore_body=ignore_body, version=version,
version_prefix=version_prefix, name=name,
error_format=error_format, ignore_body=ignore_body,
**ctx_kwargs, version_prefix=version_prefix,
error_format=error_format,
**ctx_kwargs,
),
) )
def options( def options(
@@ -451,7 +467,7 @@ class RouteMixin(metaclass=SanicMeta):
version_prefix: str = "/v", version_prefix: str = "/v",
error_format: Optional[str] = None, error_format: Optional[str] = None,
**ctx_kwargs, **ctx_kwargs,
) -> RouteWrapper: ) -> RouteHandler:
""" """
Add an API URL under the **OPTIONS** *HTTP* method Add an API URL under the **OPTIONS** *HTTP* method
@@ -475,17 +491,20 @@ class RouteMixin(metaclass=SanicMeta):
will be appended to the route context (``route.ctx``) will be appended to the route context (``route.ctx``)
:return: Object decorated with :func:`route` method :return: Object decorated with :func:`route` method
""" """
return self.route( return cast(
uri, RouteHandler,
methods=frozenset({"OPTIONS"}), self.route(
host=host, uri,
strict_slashes=strict_slashes, methods=frozenset({"OPTIONS"}),
version=version, host=host,
name=name, strict_slashes=strict_slashes,
ignore_body=ignore_body, version=version,
version_prefix=version_prefix, name=name,
error_format=error_format, ignore_body=ignore_body,
**ctx_kwargs, version_prefix=version_prefix,
error_format=error_format,
**ctx_kwargs,
),
) )
def patch( def patch(
@@ -499,7 +518,7 @@ class RouteMixin(metaclass=SanicMeta):
version_prefix: str = "/v", version_prefix: str = "/v",
error_format: Optional[str] = None, error_format: Optional[str] = None,
**ctx_kwargs, **ctx_kwargs,
) -> RouteWrapper: ) -> RouteHandler:
""" """
Add an API URL under the **PATCH** *HTTP* method Add an API URL under the **PATCH** *HTTP* method
@@ -525,17 +544,20 @@ class RouteMixin(metaclass=SanicMeta):
will be appended to the route context (``route.ctx``) will be appended to the route context (``route.ctx``)
:return: Object decorated with :func:`route` method :return: Object decorated with :func:`route` method
""" """
return self.route( return cast(
uri, RouteHandler,
methods=frozenset({"PATCH"}), self.route(
host=host, uri,
strict_slashes=strict_slashes, methods=frozenset({"PATCH"}),
stream=stream, host=host,
version=version, strict_slashes=strict_slashes,
name=name, stream=stream,
version_prefix=version_prefix, version=version,
error_format=error_format, name=name,
**ctx_kwargs, version_prefix=version_prefix,
error_format=error_format,
**ctx_kwargs,
),
) )
def delete( def delete(
@@ -549,7 +571,7 @@ class RouteMixin(metaclass=SanicMeta):
version_prefix: str = "/v", version_prefix: str = "/v",
error_format: Optional[str] = None, error_format: Optional[str] = None,
**ctx_kwargs, **ctx_kwargs,
) -> RouteWrapper: ) -> RouteHandler:
""" """
Add an API URL under the **DELETE** *HTTP* method Add an API URL under the **DELETE** *HTTP* method
@@ -565,17 +587,20 @@ class RouteMixin(metaclass=SanicMeta):
will be appended to the route context (``route.ctx``) will be appended to the route context (``route.ctx``)
:return: Object decorated with :func:`route` method :return: Object decorated with :func:`route` method
""" """
return self.route( return cast(
uri, RouteHandler,
methods=frozenset({"DELETE"}), self.route(
host=host, uri,
strict_slashes=strict_slashes, methods=frozenset({"DELETE"}),
version=version, host=host,
name=name, strict_slashes=strict_slashes,
ignore_body=ignore_body, version=version,
version_prefix=version_prefix, name=name,
error_format=error_format, ignore_body=ignore_body,
**ctx_kwargs, version_prefix=version_prefix,
error_format=error_format,
**ctx_kwargs,
),
) )
def websocket( def websocket(
@@ -775,32 +800,40 @@ class RouteMixin(metaclass=SanicMeta):
content_type=None, content_type=None,
__file_uri__=None, __file_uri__=None,
): ):
# Using this to determine if the URL is trying to break out of the path
# served. os.path.realpath seems to be very slow
if __file_uri__ and "../" in __file_uri__:
raise InvalidUsage("Invalid URL")
# Merge served directory and requested file if provided # Merge served directory and requested file if provided
# Strip all / that in the beginning of the URL to help prevent python file_path_raw = Path(unquote(file_or_directory))
# from herping a derp and treating the uri as an absolute path root_path = file_path = file_path_raw.resolve()
root_path = file_path = file_or_directory not_found = FileNotFound(
if __file_uri__: "File not found",
file_path = path.join( path=file_or_directory,
file_or_directory, sub("^[/]*", "", __file_uri__) relative_url=__file_uri__,
) )
# URL decode the path sent by the browser otherwise we won't be able to if __file_uri__:
# match filenames which got encoded (filenames with spaces etc) # Strip all / that in the beginning of the URL to help prevent
file_path = path.abspath(unquote(file_path)) # python from herping a derp and treating the uri as an
if not file_path.startswith(path.abspath(unquote(root_path))): # absolute path
error_logger.exception( unquoted_file_uri = unquote(__file_uri__).lstrip("/")
f"File not found: path={file_or_directory}, " file_path_raw = Path(file_or_directory, unquoted_file_uri)
f"relative_url={__file_uri__}" file_path = file_path_raw.resolve()
) if (
raise FileNotFound( file_path < root_path and not file_path_raw.is_symlink()
"File not found", ) or ".." in file_path_raw.parts:
path=file_or_directory, error_logger.exception(
relative_url=__file_uri__, f"File not found: path={file_or_directory}, "
) f"relative_url={__file_uri__}"
)
raise not_found
try:
file_path.relative_to(root_path)
except ValueError:
if not file_path_raw.is_symlink():
error_logger.exception(
f"File not found: path={file_or_directory}, "
f"relative_url={__file_uri__}"
)
raise not_found
try: try:
headers = {} headers = {}
# Check if the client has been sent this file before # Check if the client has been sent this file before
@@ -865,14 +898,10 @@ class RouteMixin(metaclass=SanicMeta):
file_path, headers=headers, _range=_range file_path, headers=headers, _range=_range
) )
return await file(file_path, headers=headers, _range=_range) return await file(file_path, headers=headers, _range=_range)
except ContentRangeError: except RangeNotSatisfiable:
raise raise
except FileNotFoundError: except FileNotFoundError:
raise FileNotFound( raise not_found
"File not found",
path=file_or_directory,
relative_url=__file_uri__,
)
except Exception: except Exception:
error_logger.exception( error_logger.exception(
f"Exception in static request handler: " f"Exception in static request handler: "
@@ -994,17 +1023,6 @@ class RouteMixin(metaclass=SanicMeta):
nonlocal types nonlocal types
with suppress(AttributeError): with suppress(AttributeError):
if node.value.func.id == "stream": # type: ignore
deprecation(
"The sanic.response.stream method has been "
"deprecated and will be removed in v22.6. Please "
"upgrade your application to use the new style "
"streaming pattern. See "
"https://sanicframework.org/en/guide/advanced/"
"streaming.html#response-streaming for more "
"information.",
22.6,
)
checks = [node.value.func.id] # type: ignore checks = [node.value.func.id] # type: ignore
if node.value.keywords: # type: ignore if node.value.keywords: # type: ignore
checks += [ checks += [
@@ -1035,7 +1053,7 @@ class RouteMixin(metaclass=SanicMeta):
raise AttributeError( raise AttributeError(
"Cannot use restricted route context: " "Cannot use restricted route context: "
f"{restricted_arguments}. This limitation is only in place " f"{restricted_arguments}. This limitation is only in place "
"until v22.3 when the restricted names will no longer be in" "until v22.9 when the restricted names will no longer be in"
"conflict. See https://github.com/sanic-org/sanic/issues/2303 " "conflict. See https://github.com/sanic-org/sanic/issues/2303 "
"for more information." "for more information."
) )

View File

@@ -19,16 +19,29 @@ from importlib import import_module
from pathlib import Path from pathlib import Path
from socket import socket from socket import socket
from ssl import SSLContext from ssl import SSLContext
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Type, Union from typing import (
TYPE_CHECKING,
Any,
Dict,
List,
Optional,
Set,
Tuple,
Type,
Union,
cast,
)
from sanic import reloader_helpers from sanic import reloader_helpers
from sanic.application.logo import get_logo from sanic.application.logo import get_logo
from sanic.application.motd import MOTD from sanic.application.motd import MOTD
from sanic.application.state import ApplicationServerInfo, Mode, ServerStage from sanic.application.state import ApplicationServerInfo, Mode, ServerStage
from sanic.base.meta import SanicMeta from sanic.base.meta import SanicMeta
from sanic.compat import OS_IS_WINDOWS from sanic.compat import OS_IS_WINDOWS, is_atty
from sanic.helpers import _default from sanic.helpers import _default
from sanic.log import Colors, error_logger, logger from sanic.http.constants import HTTP
from sanic.http.tls import get_ssl_context, process_to_context
from sanic.log import Colors, deprecation, error_logger, logger
from sanic.models.handler_types import ListenerType from sanic.models.handler_types import ListenerType
from sanic.server import Signal as ServerSignal from sanic.server import Signal as ServerSignal
from sanic.server import try_use_uvloop from sanic.server import try_use_uvloop
@@ -37,19 +50,26 @@ from sanic.server.events import trigger_events
from sanic.server.protocols.http_protocol import HttpProtocol from sanic.server.protocols.http_protocol import HttpProtocol
from sanic.server.protocols.websocket_protocol import WebSocketProtocol from sanic.server.protocols.websocket_protocol import WebSocketProtocol
from sanic.server.runners import serve, serve_multiple, serve_single from sanic.server.runners import serve, serve_multiple, serve_single
from sanic.tls import process_to_context
if TYPE_CHECKING: # no cov if TYPE_CHECKING:
from sanic import Sanic from sanic import Sanic
from sanic.application.state import ApplicationState from sanic.application.state import ApplicationState
from sanic.config import Config from sanic.config import Config
SANIC_PACKAGES = ("sanic-routing", "sanic-testing", "sanic-ext") SANIC_PACKAGES = ("sanic-routing", "sanic-testing", "sanic-ext")
if sys.version_info < (3, 8):
HTTPVersion = Union[HTTP, int]
else:
from typing import Literal
HTTPVersion = Union[HTTP, Literal[1], Literal[3]]
class RunnerMixin(metaclass=SanicMeta): class RunnerMixin(metaclass=SanicMeta):
_app_registry: Dict[str, Sanic] _app_registry: Dict[str, Sanic]
asgi: bool
config: Config config: Config
listeners: Dict[str, List[ListenerType[Any]]] listeners: Dict[str, List[ListenerType[Any]]]
state: ApplicationState state: ApplicationState
@@ -67,6 +87,7 @@ class RunnerMixin(metaclass=SanicMeta):
dev: bool = False, dev: bool = False,
debug: bool = False, debug: bool = False,
auto_reload: Optional[bool] = None, auto_reload: Optional[bool] = None,
version: HTTPVersion = HTTP.VERSION_1,
ssl: Union[None, SSLContext, dict, str, list, tuple] = None, ssl: Union[None, SSLContext, dict, str, list, tuple] = None,
sock: Optional[socket] = None, sock: Optional[socket] = None,
workers: int = 1, workers: int = 1,
@@ -82,6 +103,7 @@ class RunnerMixin(metaclass=SanicMeta):
fast: bool = False, fast: bool = False,
verbosity: int = 0, verbosity: int = 0,
motd_display: Optional[Dict[str, str]] = None, motd_display: Optional[Dict[str, str]] = None,
auto_tls: bool = False,
) -> None: ) -> None:
""" """
Run the HTTP Server and listen until keyboard interrupt or term Run the HTTP Server and listen until keyboard interrupt or term
@@ -125,6 +147,7 @@ class RunnerMixin(metaclass=SanicMeta):
dev=dev, dev=dev,
debug=debug, debug=debug,
auto_reload=auto_reload, auto_reload=auto_reload,
version=version,
ssl=ssl, ssl=ssl,
sock=sock, sock=sock,
workers=workers, workers=workers,
@@ -140,6 +163,7 @@ class RunnerMixin(metaclass=SanicMeta):
fast=fast, fast=fast,
verbosity=verbosity, verbosity=verbosity,
motd_display=motd_display, motd_display=motd_display,
auto_tls=auto_tls,
) )
self.__class__.serve(primary=self) # type: ignore self.__class__.serve(primary=self) # type: ignore
@@ -152,6 +176,7 @@ class RunnerMixin(metaclass=SanicMeta):
dev: bool = False, dev: bool = False,
debug: bool = False, debug: bool = False,
auto_reload: Optional[bool] = None, auto_reload: Optional[bool] = None,
version: HTTPVersion = HTTP.VERSION_1,
ssl: Union[None, SSLContext, dict, str, list, tuple] = None, ssl: Union[None, SSLContext, dict, str, list, tuple] = None,
sock: Optional[socket] = None, sock: Optional[socket] = None,
workers: int = 1, workers: int = 1,
@@ -167,7 +192,15 @@ class RunnerMixin(metaclass=SanicMeta):
fast: bool = False, fast: bool = False,
verbosity: int = 0, verbosity: int = 0,
motd_display: Optional[Dict[str, str]] = None, motd_display: Optional[Dict[str, str]] = None,
auto_tls: bool = False,
) -> None: ) -> None:
if version == 3 and self.state.server_info:
raise RuntimeError(
"Serving HTTP/3 instances as a secondary server is "
"not supported. There can only be a single HTTP/3 worker "
"and it must be the first instance prepared."
)
if dev: if dev:
debug = True debug = True
auto_reload = True auto_reload = True
@@ -209,7 +242,7 @@ class RunnerMixin(metaclass=SanicMeta):
return return
if sock is None: if sock is None:
host, port = host or "127.0.0.1", port or 8000 host, port = self.get_address(host, port, version, auto_tls)
if protocol is None: if protocol is None:
protocol = ( protocol = (
@@ -237,6 +270,7 @@ class RunnerMixin(metaclass=SanicMeta):
host=host, host=host,
port=port, port=port,
debug=debug, debug=debug,
version=version,
ssl=ssl, ssl=ssl,
sock=sock, sock=sock,
unix=unix, unix=unix,
@@ -244,6 +278,7 @@ class RunnerMixin(metaclass=SanicMeta):
protocol=protocol, protocol=protocol,
backlog=backlog, backlog=backlog,
register_sys_signals=register_sys_signals, register_sys_signals=register_sys_signals,
auto_tls=auto_tls,
) )
self.state.server_info.append( self.state.server_info.append(
ApplicationServerInfo(settings=server_settings) ApplicationServerInfo(settings=server_settings)
@@ -313,7 +348,7 @@ class RunnerMixin(metaclass=SanicMeta):
""" """
if sock is None: if sock is None:
host, port = host or "127.0.0.1", port or 8000 host, port = host, port = self.get_address(host, port)
if protocol is None: if protocol is None:
protocol = ( protocol = (
@@ -378,6 +413,7 @@ class RunnerMixin(metaclass=SanicMeta):
host: Optional[str] = None, host: Optional[str] = None,
port: Optional[int] = None, port: Optional[int] = None,
debug: bool = False, debug: bool = False,
version: HTTPVersion = HTTP.VERSION_1,
ssl: Union[None, SSLContext, dict, str, list, tuple] = None, ssl: Union[None, SSLContext, dict, str, list, tuple] = None,
sock: Optional[socket] = None, sock: Optional[socket] = None,
unix: Optional[str] = None, unix: Optional[str] = None,
@@ -387,6 +423,7 @@ class RunnerMixin(metaclass=SanicMeta):
backlog: int = 100, backlog: int = 100,
register_sys_signals: bool = True, register_sys_signals: bool = True,
run_async: bool = False, run_async: bool = False,
auto_tls: bool = False,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
"""Helper function used by `run` and `create_server`.""" """Helper function used by `run` and `create_server`."""
if self.config.PROXIES_COUNT and self.config.PROXIES_COUNT < 0: if self.config.PROXIES_COUNT and self.config.PROXIES_COUNT < 0:
@@ -396,11 +433,18 @@ class RunnerMixin(metaclass=SanicMeta):
"#proxy-configuration" "#proxy-configuration"
) )
ssl = process_to_context(ssl)
if not self.state.is_debug: if not self.state.is_debug:
self.state.mode = Mode.DEBUG if debug else Mode.PRODUCTION self.state.mode = Mode.DEBUG if debug else Mode.PRODUCTION
if isinstance(version, int):
version = HTTP(version)
ssl = process_to_context(ssl)
if version is HTTP.VERSION_3 or auto_tls:
if TYPE_CHECKING:
self = cast(Sanic, self)
ssl = get_ssl_context(self, ssl)
self.state.host = host or "" self.state.host = host or ""
self.state.port = port or 0 self.state.port = port or 0
self.state.workers = workers self.state.workers = workers
@@ -412,6 +456,7 @@ class RunnerMixin(metaclass=SanicMeta):
"protocol": protocol, "protocol": protocol,
"host": host, "host": host,
"port": port, "port": port,
"version": version,
"sock": sock, "sock": sock,
"unix": unix, "unix": unix,
"ssl": ssl, "ssl": ssl,
@@ -422,9 +467,9 @@ class RunnerMixin(metaclass=SanicMeta):
"backlog": backlog, "backlog": backlog,
} }
self.motd(self.serve_location) self.motd(server_settings=server_settings)
if sys.stdout.isatty() and not self.state.is_debug: if is_atty() and not self.state.is_debug:
error_logger.warning( error_logger.warning(
f"{Colors.YELLOW}Sanic is running in PRODUCTION mode. " f"{Colors.YELLOW}Sanic is running in PRODUCTION mode. "
"Consider using '--debug' or '--dev' while actively " "Consider using '--debug' or '--dev' while actively "
@@ -448,7 +493,19 @@ class RunnerMixin(metaclass=SanicMeta):
return server_settings return server_settings
def motd(self, serve_location): def motd(
self,
serve_location: str = "",
server_settings: Optional[Dict[str, Any]] = None,
):
if serve_location:
deprecation(
"Specifying a serve_location in the MOTD is deprecated and "
"will be removed.",
22.9,
)
else:
serve_location = self.get_server_location(server_settings)
if self.config.MOTD: if self.config.MOTD:
mode = [f"{self.state.mode},"] mode = [f"{self.state.mode},"]
if self.state.fast: if self.state.fast:
@@ -461,9 +518,19 @@ class RunnerMixin(metaclass=SanicMeta):
else: else:
mode.append(f"w/ {self.state.workers} workers") mode.append(f"w/ {self.state.workers} workers")
if server_settings:
server = ", ".join(
(
self.state.server,
server_settings["version"].display(), # type: ignore
)
)
else:
server = "ASGI" if self.asgi else "unknown" # type: ignore
display = { display = {
"mode": " ".join(mode), "mode": " ".join(mode),
"server": self.state.server, "server": server,
"python": platform.python_version(), "python": platform.python_version(),
"platform": platform.platform(), "platform": platform.platform(),
} }
@@ -487,7 +554,9 @@ class RunnerMixin(metaclass=SanicMeta):
module_name = package_name.replace("-", "_") module_name = package_name.replace("-", "_")
try: try:
module = import_module(module_name) module = import_module(module_name)
packages.append(f"{package_name}=={module.__version__}") packages.append(
f"{package_name}=={module.__version__}" # type: ignore
)
except ImportError: except ImportError:
... ...
@@ -497,35 +566,60 @@ class RunnerMixin(metaclass=SanicMeta):
if self.config.MOTD_DISPLAY: if self.config.MOTD_DISPLAY:
extra.update(self.config.MOTD_DISPLAY) extra.update(self.config.MOTD_DISPLAY)
logo = ( logo = get_logo(coffee=self.state.coffee)
get_logo(coffee=self.state.coffee)
if self.config.LOGO == "" or self.config.LOGO is True
else self.config.LOGO
)
MOTD.output(logo, serve_location, display, extra) MOTD.output(logo, serve_location, display, extra)
@property @property
def serve_location(self) -> str: def serve_location(self) -> str:
try:
server_settings = self.state.server_info[0].settings
return self.get_server_location(server_settings)
except IndexError:
location = "ASGI" if self.asgi else "unknown" # type: ignore
return f"http://<{location}>"
@staticmethod
def get_server_location(
server_settings: Optional[Dict[str, Any]] = None
) -> str:
serve_location = "" serve_location = ""
proto = "http" proto = "http"
if self.state.ssl is not None: if not server_settings:
return serve_location
if server_settings["ssl"] is not None:
proto = "https" proto = "https"
if self.state.unix: if server_settings["unix"]:
serve_location = f"{self.state.unix} {proto}://..." serve_location = f'{server_settings["unix"]} {proto}://...'
elif self.state.sock: elif server_settings["sock"]:
serve_location = f"{self.state.sock.getsockname()} {proto}://..." serve_location = (
elif self.state.host and self.state.port: f'{server_settings["sock"].getsockname()} {proto}://...'
)
elif server_settings["host"] and server_settings["port"]:
# colon(:) is legal for a host only in an ipv6 address # colon(:) is legal for a host only in an ipv6 address
display_host = ( display_host = (
f"[{self.state.host}]" f'[{server_settings["host"]}]'
if ":" in self.state.host if ":" in server_settings["host"]
else self.state.host else server_settings["host"]
)
serve_location = (
f'{proto}://{display_host}:{server_settings["port"]}'
) )
serve_location = f"{proto}://{display_host}:{self.state.port}"
return serve_location return serve_location
@staticmethod
def get_address(
host: Optional[str],
port: Optional[int],
version: HTTPVersion = HTTP.VERSION_1,
auto_tls: bool = False,
) -> Tuple[str, int]:
host = host or "127.0.0.1"
port = port or (8443 if (version == 3 or auto_tls) else 8000)
return host, port
@classmethod @classmethod
def should_auto_reload(cls) -> bool: def should_auto_reload(cls) -> bool:
return any(app.state.auto_reload for app in cls._app_registry.values()) return any(app.state.auto_reload for app in cls._app_registry.values())
@@ -615,7 +709,7 @@ class RunnerMixin(metaclass=SanicMeta):
f"{app.state.workers} worker(s), which will be ignored " f"{app.state.workers} worker(s), which will be ignored "
"in favor of the primary application." "in favor of the primary application."
) )
if sys.stdout.isatty(): if is_atty():
message = "".join( message = "".join(
[ [
Colors.YELLOW, Colors.YELLOW,
@@ -656,7 +750,7 @@ class RunnerMixin(metaclass=SanicMeta):
"The encountered error was: " "The encountered error was: "
) )
second_message = str(e) second_message = str(e)
if sys.stdout.isatty(): if is_atty():
message_parts = [ message_parts = [
Colors.YELLOW, Colors.YELLOW,
first_message, first_message,

View File

@@ -3,7 +3,8 @@ import sys
from typing import Any, Awaitable, Callable, MutableMapping, Optional, Union from typing import Any, Awaitable, Callable, MutableMapping, Optional, Union
from sanic.exceptions import InvalidUsage from sanic.exceptions import BadRequest
from sanic.models.protocol_types import TransportProtocol
from sanic.server.websockets.connection import WebSocketConnection from sanic.server.websockets.connection import WebSocketConnection
@@ -56,7 +57,7 @@ class MockProtocol: # no cov
await self._not_paused.wait() await self._not_paused.wait()
class MockTransport: # no cov class MockTransport(TransportProtocol): # no cov
_protocol: Optional[MockProtocol] _protocol: Optional[MockProtocol]
def __init__( def __init__(
@@ -68,23 +69,25 @@ class MockTransport: # no cov
self._protocol = None self._protocol = None
self.loop = None self.loop = None
def get_protocol(self) -> MockProtocol: def get_protocol(self) -> MockProtocol: # type: ignore
if not self._protocol: if not self._protocol:
self._protocol = MockProtocol(self, self.loop) self._protocol = MockProtocol(self, self.loop)
return self._protocol return self._protocol
def get_extra_info(self, info: str) -> Union[str, bool, None]: def get_extra_info(
self, info: str, default=None
) -> Optional[Union[str, bool]]:
if info == "peername": if info == "peername":
return self.scope.get("client") return self.scope.get("client")
elif info == "sslcontext": elif info == "sslcontext":
return self.scope.get("scheme") in ["https", "wss"] return self.scope.get("scheme") in ["https", "wss"]
return None return default
def get_websocket_connection(self) -> WebSocketConnection: def get_websocket_connection(self) -> WebSocketConnection:
try: try:
return self._websocket_connection return self._websocket_connection
except AttributeError: except AttributeError:
raise InvalidUsage("Improper websocket connection.") raise BadRequest("Improper websocket connection.")
def create_websocket_connection( def create_websocket_connection(
self, send: ASGISend, receive: ASGIReceive self, send: ASGISend, receive: ASGIReceive

View File

@@ -1,28 +1,22 @@
from __future__ import annotations
import sys import sys
from typing import Any, AnyStr, TypeVar, Union from asyncio import BaseTransport
from typing import TYPE_CHECKING, Any, AnyStr
if TYPE_CHECKING:
from sanic.models.asgi import ASGIScope
if sys.version_info < (3, 8): if sys.version_info < (3, 8):
from asyncio import BaseTransport
# from sanic.models.asgi import MockTransport
MockTransport = TypeVar("MockTransport")
TransportProtocol = Union[MockTransport, BaseTransport]
Range = Any Range = Any
HTMLProtocol = Any HTMLProtocol = Any
else: else:
# Protocol is a 3.8+ feature # Protocol is a 3.8+ feature
from typing import Protocol from typing import Protocol
class TransportProtocol(Protocol):
def get_protocol(self):
...
def get_extra_info(self, info: str) -> Union[str, bool, None]:
...
class HTMLProtocol(Protocol): class HTMLProtocol(Protocol):
def __html__(self) -> AnyStr: def __html__(self) -> AnyStr:
... ...
@@ -42,3 +36,8 @@ else:
def total(self) -> int: def total(self) -> int:
... ...
class TransportProtocol(BaseTransport):
scope: ASGIScope
__slots__ = ()

View File

@@ -1,8 +1,8 @@
from __future__ import annotations from __future__ import annotations
from ssl import SSLObject from ssl import SSLContext, SSLObject
from types import SimpleNamespace from types import SimpleNamespace
from typing import Any, Dict, Optional from typing import Any, Dict, List, Optional
from sanic.models.protocol_types import TransportProtocol from sanic.models.protocol_types import TransportProtocol
@@ -28,6 +28,7 @@ class ConnInfo:
"sockname", "sockname",
"ssl", "ssl",
"cert", "cert",
"network_paths",
) )
def __init__(self, transport: TransportProtocol, unix=None): def __init__(self, transport: TransportProtocol, unix=None):
@@ -40,17 +41,22 @@ class ConnInfo:
self.ssl = False self.ssl = False
self.server_name = "" self.server_name = ""
self.cert: Dict[str, Any] = {} self.cert: Dict[str, Any] = {}
self.network_paths: List[Any] = []
sslobj: Optional[SSLObject] = transport.get_extra_info( sslobj: Optional[SSLObject] = transport.get_extra_info(
"ssl_object" "ssl_object"
) # type: ignore ) # type: ignore
sslctx: Optional[SSLContext] = transport.get_extra_info(
"ssl_context"
) # type: ignore
if sslobj: if sslobj:
self.ssl = True self.ssl = True
self.server_name = getattr(sslobj, "sanic_server_name", None) or "" self.server_name = getattr(sslobj, "sanic_server_name", None) or ""
self.cert = dict(getattr(sslobj.context, "sanic", {})) self.cert = dict(getattr(sslobj.context, "sanic", {}))
if sslctx and not self.cert:
self.cert = dict(getattr(sslctx, "sanic", {}))
if isinstance(addr, str): # UNIX socket if isinstance(addr, str): # UNIX socket
self.server = unix or addr self.server = unix or addr
return return
# IPv4 (ip, port) or IPv6 (ip, port, flowinfo, scopeid) # IPv4 (ip, port) or IPv6 (ip, port, flowinfo, scopeid)
if isinstance(addr, tuple): if isinstance(addr, tuple):
self.server = addr[0] if len(addr) == 2 else f"[{addr[0]}]" self.server = addr[0] if len(addr) == 2 else f"[{addr[0]}]"
@@ -59,6 +65,9 @@ class ConnInfo:
if addr[1] != (443 if self.ssl else 80): if addr[1] != (443 if self.ssl else 80):
self.server = f"{self.server}:{addr[1]}" self.server = f"{self.server}:{addr[1]}"
self.peername = addr = transport.get_extra_info("peername") self.peername = addr = transport.get_extra_info("peername")
self.network_paths = transport.get_extra_info( # type: ignore
"network_paths"
)
if isinstance(addr, tuple): if isinstance(addr, tuple):
self.client = addr[0] if len(addr) == 2 else f"[{addr[0]}]" self.client = addr[0] if len(addr) == 2 else f"[{addr[0]}]"

View File

@@ -1,5 +1,7 @@
from __future__ import annotations from __future__ import annotations
from contextvars import ContextVar
from inspect import isawaitable
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
Any, Any,
@@ -12,12 +14,15 @@ from typing import (
Union, Union,
) )
from sanic_routing.route import Route # type: ignore from sanic_routing.route import Route
from sanic.http.constants import HTTP # type: ignore
from sanic.http.stream import Stream
from sanic.models.asgi import ASGIScope
from sanic.models.http_types import Credentials from sanic.models.http_types import Credentials
if TYPE_CHECKING: # no cov if TYPE_CHECKING:
from sanic.server import ConnInfo from sanic.server import ConnInfo
from sanic.app import Sanic from sanic.app import Sanic
@@ -29,12 +34,12 @@ from http.cookies import SimpleCookie
from types import SimpleNamespace from types import SimpleNamespace
from urllib.parse import parse_qs, parse_qsl, unquote, urlunparse from urllib.parse import parse_qs, parse_qsl, unquote, urlunparse
from httptools import parse_url # type: ignore from httptools import parse_url
from httptools.parser.errors import HttpParserInvalidURLError # type: ignore from httptools.parser.errors import HttpParserInvalidURLError
from sanic.compat import CancelledErrors, Header from sanic.compat import CancelledErrors, Header
from sanic.constants import DEFAULT_HTTP_CONTENT_TYPE from sanic.constants import DEFAULT_HTTP_CONTENT_TYPE
from sanic.exceptions import BadURL, InvalidUsage, ServerError from sanic.exceptions import BadRequest, BadURL, ServerError
from sanic.headers import ( from sanic.headers import (
AcceptContainer, AcceptContainer,
Options, Options,
@@ -45,7 +50,7 @@ from sanic.headers import (
parse_host, parse_host,
parse_xforwarded, parse_xforwarded,
) )
from sanic.http import Http, Stage from sanic.http import Stage
from sanic.log import error_logger, logger from sanic.log import error_logger, logger
from sanic.models.protocol_types import TransportProtocol from sanic.models.protocol_types import TransportProtocol
from sanic.response import BaseHTTPResponse, HTTPResponse from sanic.response import BaseHTTPResponse, HTTPResponse
@@ -81,6 +86,9 @@ class Request:
Properties of an HTTP request such as URL, headers, etc. Properties of an HTTP request such as URL, headers, etc.
""" """
_current: ContextVar[Request] = ContextVar("request")
_loads = json_loads
__slots__ = ( __slots__ = (
"__weakref__", "__weakref__",
"_cookies", "_cookies",
@@ -90,7 +98,9 @@ class Request:
"_port", "_port",
"_protocol", "_protocol",
"_remote_addr", "_remote_addr",
"_scheme",
"_socket", "_socket",
"_stream_id",
"_match_info", "_match_info",
"_name", "_name",
"app", "app",
@@ -127,6 +137,7 @@ class Request:
transport: TransportProtocol, transport: TransportProtocol,
app: Sanic, app: Sanic,
head: bytes = b"", head: bytes = b"",
stream_id: int = 0,
): ):
self.raw_url = url_bytes self.raw_url = url_bytes
@@ -136,6 +147,7 @@ class Request:
raise BadURL(f"Bad URL: {url_bytes.decode()}") raise BadURL(f"Bad URL: {url_bytes.decode()}")
self._id: Optional[Union[uuid.UUID, str, int]] = None self._id: Optional[Union[uuid.UUID, str, int]] = None
self._name: Optional[str] = None self._name: Optional[str] = None
self._stream_id = stream_id
self.app = app self.app = app
self.headers = Header(headers) self.headers = Header(headers)
@@ -152,8 +164,8 @@ class Request:
self.parsed_accept: Optional[AcceptContainer] = None self.parsed_accept: Optional[AcceptContainer] = None
self.parsed_credentials: Optional[Credentials] = None self.parsed_credentials: Optional[Credentials] = None
self.parsed_json = None self.parsed_json = None
self.parsed_form = None self.parsed_form: Optional[RequestParameters] = None
self.parsed_files = None self.parsed_files: Optional[RequestParameters] = None
self.parsed_token: Optional[str] = None self.parsed_token: Optional[str] = None
self.parsed_args: DefaultDict[ self.parsed_args: DefaultDict[
Tuple[bool, bool, str, str], RequestParameters Tuple[bool, bool, str, str], RequestParameters
@@ -162,21 +174,60 @@ class Request:
Tuple[bool, bool, str, str], List[Tuple[str, str]] Tuple[bool, bool, str, str], List[Tuple[str, str]]
] = defaultdict(list) ] = defaultdict(list)
self.request_middleware_started = False self.request_middleware_started = False
self.responded: bool = False
self.route: Optional[Route] = None
self.stream: Optional[Stream] = None
self._cookies: Optional[Dict[str, str]] = None self._cookies: Optional[Dict[str, str]] = None
self._match_info: Dict[str, Any] = {} self._match_info: Dict[str, Any] = {}
self.stream: Optional[Http] = None
self.route: Optional[Route] = None
self._protocol = None self._protocol = None
self.responded: bool = False
def __repr__(self): def __repr__(self):
class_name = self.__class__.__name__ class_name = self.__class__.__name__
return f"<{class_name}: {self.method} {self.path}>" return f"<{class_name}: {self.method} {self.path}>"
@classmethod
def get_current(cls) -> Request:
"""
Retrieve the currrent request object
This implements `Context Variables
<https://docs.python.org/3/library/contextvars.html>`_
to allow for accessing the current request from anywhere.
Raises :exc:`sanic.exceptions.ServerError` if it is outside of
a request lifecycle.
.. code-block:: python
from sanic import Request
current_request = Request.get_current()
:return: the current :class:`sanic.request.Request`
"""
request = cls._current.get(None)
if not request:
raise ServerError("No current request")
return request
@classmethod @classmethod
def generate_id(*_): def generate_id(*_):
return uuid.uuid4() return uuid.uuid4()
@property
def stream_id(self):
"""
Access the HTTP/3 stream ID.
Raises :exc:`sanic.exceptions.ServerError` if it is not an
HTTP/3 request.
"""
if self.protocol.version is not HTTP.VERSION_3:
raise ServerError(
"Stream ID is only a property of a HTTP/3 request"
)
return self._stream_id
def reset_response(self): def reset_response(self):
try: try:
if ( if (
@@ -263,6 +314,9 @@ class Request:
# Connect the response # Connect the response
if isinstance(response, BaseHTTPResponse) and self.stream: if isinstance(response, BaseHTTPResponse) and self.stream:
response = self.stream.respond(response) response = self.stream.respond(response)
if isawaitable(response):
response = await response # type: ignore
# Run response middleware # Run response middleware
try: try:
response = await self.app._run_response_middleware( response = await self.app._run_response_middleware(
@@ -290,7 +344,19 @@ class Request:
self.body = b"".join([data async for data in self.stream]) self.body = b"".join([data async for data in self.stream])
@property @property
def name(self): def name(self) -> Optional[str]:
"""
The route name
In the following pattern:
.. code-block::
<AppName>.[<BlueprintName>.]<HandlerName>
:return: Route name
:rtype: Optional[str]
"""
if self._name: if self._name:
return self._name return self._name
elif self.route: elif self.route:
@@ -298,26 +364,47 @@ class Request:
return None return None
@property @property
def endpoint(self): def endpoint(self) -> Optional[str]:
"""
:return: Alias of :attr:`sanic.request.Request.name`
:rtype: Optional[str]
"""
return self.name return self.name
@property @property
def uri_template(self): def uri_template(self) -> Optional[str]:
return f"/{self.route.path}" """
:return: The defined URI template
:rtype: Optional[str]
"""
if self.route:
return f"/{self.route.path}"
return None
@property @property
def protocol(self): def protocol(self):
"""
:return: The HTTP protocol instance
"""
if not self._protocol: if not self._protocol:
self._protocol = self.transport.get_protocol() self._protocol = self.transport.get_protocol()
return self._protocol return self._protocol
@property @property
def raw_headers(self): def raw_headers(self) -> bytes:
"""
:return: The unparsed HTTP headers
:rtype: bytes
"""
_, headers = self.head.split(b"\r\n", 1) _, headers = self.head.split(b"\r\n", 1)
return bytes(headers) return bytes(headers)
@property @property
def request_line(self): def request_line(self) -> bytes:
"""
:return: The first line of a HTTP request
:rtype: bytes
"""
reqline, _ = self.head.split(b"\r\n", 1) reqline, _ = self.head.split(b"\r\n", 1)
return bytes(reqline) return bytes(reqline)
@@ -366,24 +453,35 @@ class Request:
return self._id # type: ignore return self._id # type: ignore
@property @property
def json(self): def json(self) -> Any:
"""
:return: The request body parsed as JSON
:rtype: Any
"""
if self.parsed_json is None: if self.parsed_json is None:
self.load_json() self.load_json()
return self.parsed_json return self.parsed_json
def load_json(self, loads=json_loads): def load_json(self, loads=None):
try: try:
if not loads:
loads = self.__class__._loads
self.parsed_json = loads(self.body) self.parsed_json = loads(self.body)
except Exception: except Exception:
if not self.body: if not self.body:
return None return None
raise InvalidUsage("Failed when parsing body as json") raise BadRequest("Failed when parsing body as json")
return self.parsed_json return self.parsed_json
@property @property
def accept(self) -> AcceptContainer: def accept(self) -> AcceptContainer:
"""
:return: The ``Accept`` header parsed
:rtype: AcceptContainer
"""
if self.parsed_accept is None: if self.parsed_accept is None:
accept_header = self.headers.getone("accept", "") accept_header = self.headers.getone("accept", "")
self.parsed_accept = parse_accept(accept_header) self.parsed_accept = parse_accept(accept_header)
@@ -426,33 +524,60 @@ class Request:
pass pass
return self.parsed_credentials return self.parsed_credentials
def get_form(
self, keep_blank_values: bool = False
) -> Optional[RequestParameters]:
"""
Method to extract and parse the form data from a request.
:param keep_blank_values:
Whether to discard blank values from the form data
:type keep_blank_values: bool
:return: the parsed form data
:rtype: Optional[RequestParameters]
"""
self.parsed_form = RequestParameters()
self.parsed_files = RequestParameters()
content_type = self.headers.getone(
"content-type", DEFAULT_HTTP_CONTENT_TYPE
)
content_type, parameters = parse_content_header(content_type)
try:
if content_type == "application/x-www-form-urlencoded":
self.parsed_form = RequestParameters(
parse_qs(
self.body.decode("utf-8"),
keep_blank_values=keep_blank_values,
)
)
elif content_type == "multipart/form-data":
# TODO: Stream this instead of reading to/from memory
boundary = parameters["boundary"].encode( # type: ignore
"utf-8"
) # type: ignore
self.parsed_form, self.parsed_files = parse_multipart_form(
self.body, boundary
)
except Exception:
error_logger.exception("Failed when parsing form")
return self.parsed_form
@property @property
def form(self): def form(self):
"""
:return: The request body parsed as form data
"""
if self.parsed_form is None: if self.parsed_form is None:
self.parsed_form = RequestParameters() self.get_form()
self.parsed_files = RequestParameters()
content_type = self.headers.getone(
"content-type", DEFAULT_HTTP_CONTENT_TYPE
)
content_type, parameters = parse_content_header(content_type)
try:
if content_type == "application/x-www-form-urlencoded":
self.parsed_form = RequestParameters(
parse_qs(self.body.decode("utf-8"))
)
elif content_type == "multipart/form-data":
# TODO: Stream this instead of reading to/from memory
boundary = parameters["boundary"].encode("utf-8")
self.parsed_form, self.parsed_files = parse_multipart_form(
self.body, boundary
)
except Exception:
error_logger.exception("Failed when parsing form")
return self.parsed_form return self.parsed_form
@property @property
def files(self): def files(self):
"""
:return: The request body parsed as uploaded files
"""
if self.parsed_files is None: if self.parsed_files is None:
self.form # compute form to get files self.form # compute form to get files
@@ -466,8 +591,8 @@ class Request:
errors: str = "replace", errors: str = "replace",
) -> RequestParameters: ) -> RequestParameters:
""" """
Method to parse `query_string` using `urllib.parse.parse_qs`. Method to parse ``query_string`` using ``urllib.parse.parse_qs``.
This methods is used by `args` property. This methods is used by ``args`` property.
Can be used directly if you need to change default parameters. Can be used directly if you need to change default parameters.
:param keep_blank_values: :param keep_blank_values:
@@ -516,6 +641,10 @@ class Request:
] ]
args = property(get_args) args = property(get_args)
"""
Convenience property to access :meth:`Request.get_args` with
default values.
"""
def get_query_args( def get_query_args(
self, self,
@@ -635,6 +764,9 @@ class Request:
@property @property
def socket(self): def socket(self):
"""
:return: Information about the connected socket if available
"""
return self.conn_info.peername if self.conn_info else (None, None) return self.conn_info.peername if self.conn_info else (None, None)
@property @property
@@ -645,6 +777,13 @@ class Request:
""" """
return self._parsed_url.path.decode("utf-8") return self._parsed_url.path.decode("utf-8")
@property
def network_paths(self):
"""
Access the network paths if available
"""
return self.conn_info.network_paths
# Proxy properties (using SERVER_NAME/forwarded/request/transport info) # Proxy properties (using SERVER_NAME/forwarded/request/transport info)
@property @property
@@ -698,23 +837,25 @@ class Request:
:return: http|https|ws|wss or arbitrary value given by the headers. :return: http|https|ws|wss or arbitrary value given by the headers.
:rtype: str :rtype: str
""" """
if "//" in self.app.config.get("SERVER_NAME", ""): if not hasattr(self, "_scheme"):
return self.app.config.SERVER_NAME.split("//")[0] if "//" in self.app.config.get("SERVER_NAME", ""):
if "proto" in self.forwarded: return self.app.config.SERVER_NAME.split("//")[0]
return str(self.forwarded["proto"]) if "proto" in self.forwarded:
return str(self.forwarded["proto"])
if ( if (
self.app.websocket_enabled self.app.websocket_enabled
and self.headers.getone("upgrade", "").lower() == "websocket" and self.headers.getone("upgrade", "").lower() == "websocket"
): ):
scheme = "ws" scheme = "ws"
else: else:
scheme = "http" scheme = "http"
if self.transport.get_extra_info("sslcontext"): if self.transport.get_extra_info("sslcontext"):
scheme += "s" scheme += "s"
self._scheme = scheme
return scheme return self._scheme
@property @property
def host(self) -> str: def host(self) -> str:
@@ -819,6 +960,21 @@ class Request:
view_name, _external=True, _scheme=scheme, _server=netloc, **kwargs view_name, _external=True, _scheme=scheme, _server=netloc, **kwargs
) )
@property
def scope(self) -> ASGIScope:
"""
:return: The ASGI scope of the request.
If the app isn't an ASGI app, then raises an exception.
:rtype: Optional[ASGIScope]
"""
if not self.app.asgi:
raise NotImplementedError(
"App isn't running in ASGI mode. "
"Scope is only available for ASGI apps."
)
return self.transport.scope
class File(NamedTuple): class File(NamedTuple):
""" """

View File

@@ -1,9 +1,12 @@
from __future__ import annotations from __future__ import annotations
from datetime import datetime
from email.utils import formatdate
from functools import partial from functools import partial
from mimetypes import guess_type from mimetypes import guess_type
from os import path from os import path
from pathlib import PurePath from pathlib import PurePath
from time import time
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
Any, Any,
@@ -19,17 +22,23 @@ from typing import (
) )
from urllib.parse import quote_plus from urllib.parse import quote_plus
from sanic.compat import Header, open_async from sanic.compat import Header, open_async, stat_async
from sanic.constants import DEFAULT_HTTP_CONTENT_TYPE from sanic.constants import DEFAULT_HTTP_CONTENT_TYPE
from sanic.cookies import CookieJar from sanic.cookies import CookieJar
from sanic.exceptions import SanicException, ServerError from sanic.exceptions import SanicException, ServerError
from sanic.helpers import has_message_body, remove_entity_headers from sanic.helpers import (
Default,
_default,
has_message_body,
remove_entity_headers,
)
from sanic.http import Http from sanic.http import Http
from sanic.models.protocol_types import HTMLProtocol, Range from sanic.models.protocol_types import HTMLProtocol, Range
if TYPE_CHECKING: if TYPE_CHECKING:
from sanic.asgi import ASGIApp from sanic.asgi import ASGIApp
from sanic.http.http3 import HTTPReceiver
from sanic.request import Request from sanic.request import Request
else: else:
Request = TypeVar("Request") Request = TypeVar("Request")
@@ -66,11 +75,15 @@ class BaseHTTPResponse:
self.asgi: bool = False self.asgi: bool = False
self.body: Optional[bytes] = None self.body: Optional[bytes] = None
self.content_type: Optional[str] = None self.content_type: Optional[str] = None
self.stream: Optional[Union[Http, ASGIApp]] = None self.stream: Optional[Union[Http, ASGIApp, HTTPReceiver]] = None
self.status: int = None self.status: int = None
self.headers = Header({}) self.headers = Header({})
self._cookies: Optional[CookieJar] = None self._cookies: Optional[CookieJar] = None
def __repr__(self):
class_name = self.__class__.__name__
return f"<{class_name}: {self.status} {self.content_type}>"
def _encode_body(self, data: Optional[AnyStr]): def _encode_body(self, data: Optional[AnyStr]):
if data is None: if data is None:
return b"" return b""
@@ -149,7 +162,10 @@ class BaseHTTPResponse:
if hasattr(data, "encode") if hasattr(data, "encode")
else data or b"" else data or b""
) )
await self.stream.send(data, end_stream=end_stream) await self.stream.send(
data, # type: ignore
end_stream=end_stream or False,
)
class HTTPResponse(BaseHTTPResponse): class HTTPResponse(BaseHTTPResponse):
@@ -309,6 +325,9 @@ async def file(
mime_type: Optional[str] = None, mime_type: Optional[str] = None,
headers: Optional[Dict[str, str]] = None, headers: Optional[Dict[str, str]] = None,
filename: Optional[str] = None, filename: Optional[str] = None,
last_modified: Optional[Union[datetime, float, int, Default]] = _default,
max_age: Optional[Union[float, int]] = None,
no_store: Optional[bool] = None,
_range: Optional[Range] = None, _range: Optional[Range] = None,
) -> HTTPResponse: ) -> HTTPResponse:
"""Return a response object with file data. """Return a response object with file data.
@@ -317,6 +336,9 @@ async def file(
:param mime_type: Specific mime_type. :param mime_type: Specific mime_type.
:param headers: Custom Headers. :param headers: Custom Headers.
:param filename: Override filename. :param filename: Override filename.
:param last_modified: The last modified date and time of the file.
:param max_age: Max age for cache control.
:param no_store: Any cache should not store this response.
:param _range: :param _range:
""" """
headers = headers or {} headers = headers or {}
@@ -324,6 +346,34 @@ async def file(
headers.setdefault( headers.setdefault(
"Content-Disposition", f'attachment; filename="{filename}"' "Content-Disposition", f'attachment; filename="{filename}"'
) )
if isinstance(last_modified, datetime):
last_modified = last_modified.replace(microsecond=0).timestamp()
elif isinstance(last_modified, Default):
stat = await stat_async(location)
last_modified = stat.st_mtime
if last_modified:
headers.setdefault(
"last-modified", formatdate(last_modified, usegmt=True)
)
if no_store:
cache_control = "no-store"
elif max_age:
cache_control = f"public, max-age={max_age}"
headers.setdefault(
"expires",
formatdate(
time() + max_age,
usegmt=True,
),
)
else:
cache_control = "no-cache"
headers.setdefault("cache-control", cache_control)
filename = filename or path.split(location)[-1] filename = filename or path.split(location)[-1]
async with await open_async(location, mode="rb") as f: async with await open_async(location, mode="rb") as f:
@@ -377,8 +427,7 @@ def redirect(
class ResponseStream: class ResponseStream:
""" """
ResponseStream is a compat layer to bridge the gap after the deprecation ResponseStream is a compat layer to bridge the gap after the deprecation
of StreamingHTTPResponse. In v22.6 it will be removed when: of StreamingHTTPResponse. It will be removed when:
- stream is removed
- file_stream is moved to new style streaming - file_stream is moved to new style streaming
- file and file_stream are combined into a single API - file and file_stream are combined into a single API
""" """
@@ -506,38 +555,3 @@ async def file_stream(
headers=headers, headers=headers,
content_type=mime_type, content_type=mime_type,
) )
def stream(
streaming_fn: Callable[
[Union[BaseHTTPResponse, ResponseStream]], Coroutine[Any, Any, None]
],
status: int = 200,
headers: Optional[Dict[str, str]] = None,
content_type: str = "text/plain; charset=utf-8",
) -> ResponseStream:
"""Accepts a coroutine `streaming_fn` which can be used to
write chunks to a streaming response. Returns a `ResponseStream`.
Example usage::
@app.route("/")
async def index(request):
async def streaming_fn(response):
await response.write('foo')
await response.write('bar')
return stream(streaming_fn, content_type='text/plain')
:param streaming_fn: A coroutine accepts a response and
writes content to that response.
:param status: HTTP status.
:param content_type: Specific content_type.
:param headers: Custom Headers.
"""
return ResponseStream(
streaming_fn,
headers=headers,
content_type=content_type,
status=status,
)

View File

@@ -5,16 +5,14 @@ from inspect import signature
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
from uuid import UUID from uuid import UUID
from sanic_routing import BaseRouter # type: ignore from sanic_routing import BaseRouter
from sanic_routing.exceptions import NoMethod # type: ignore from sanic_routing.exceptions import NoMethod
from sanic_routing.exceptions import ( from sanic_routing.exceptions import NotFound as RoutingNotFound
NotFound as RoutingNotFound, # type: ignore from sanic_routing.route import Route
)
from sanic_routing.route import Route # type: ignore
from sanic.constants import HTTP_METHODS from sanic.constants import HTTP_METHODS
from sanic.errorpages import check_error_format from sanic.errorpages import check_error_format
from sanic.exceptions import MethodNotSupported, NotFound, SanicException from sanic.exceptions import MethodNotAllowed, NotFound, SanicException
from sanic.models.handler_types import RouteHandler from sanic.models.handler_types import RouteHandler
@@ -43,7 +41,7 @@ class Router(BaseRouter):
except RoutingNotFound as e: except RoutingNotFound as e:
raise NotFound("Requested URL {} not found".format(e.path)) raise NotFound("Requested URL {} not found".format(e.path))
except NoMethod as e: except NoMethod as e:
raise MethodNotSupported( raise MethodNotAllowed(
"Method {} not allowed for URL {}".format(method, path), "Method {} not allowed for URL {}".format(method, path),
method=method, method=method,
allowed_methods=e.allowed_methods, allowed_methods=e.allowed_methods,

View File

@@ -5,7 +5,6 @@ import asyncio
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from sanic.exceptions import SanicException from sanic.exceptions import SanicException
from sanic.log import deprecation
if TYPE_CHECKING: if TYPE_CHECKING:
@@ -35,15 +34,6 @@ class AsyncioServer:
self.serve_coro = serve_coro self.serve_coro = serve_coro
self.server = None self.server = None
@property
def init(self):
deprecation(
"AsyncioServer.init has been deprecated and will be removed "
"in v22.6. Use Sanic.state.is_started instead.",
22.6,
)
return self.app.state.is_started
def startup(self): def startup(self):
""" """
Trigger "before_server_start" events Trigger "before_server_start" events

View File

@@ -4,7 +4,7 @@ from inspect import isawaitable
from typing import TYPE_CHECKING, Any, Callable, Iterable, Optional from typing import TYPE_CHECKING, Any, Callable, Iterable, Optional
if TYPE_CHECKING: # no cov if TYPE_CHECKING:
from sanic import Sanic from sanic import Sanic

View File

@@ -3,7 +3,7 @@ from __future__ import annotations
from typing import TYPE_CHECKING, Optional from typing import TYPE_CHECKING, Optional
if TYPE_CHECKING: # no cov if TYPE_CHECKING:
from sanic.app import Sanic from sanic.app import Sanic
import asyncio import asyncio

View File

@@ -2,33 +2,89 @@ from __future__ import annotations
from typing import TYPE_CHECKING, Optional from typing import TYPE_CHECKING, Optional
from sanic.http.constants import HTTP
from sanic.http.http3 import Http3
from sanic.touchup.meta import TouchUpMeta from sanic.touchup.meta import TouchUpMeta
if TYPE_CHECKING: # no cov if TYPE_CHECKING:
from sanic.app import Sanic from sanic.app import Sanic
import sys
from asyncio import CancelledError from asyncio import CancelledError
from time import monotonic as current_time from time import monotonic as current_time
from sanic.exceptions import RequestTimeout, ServiceUnavailable from sanic.exceptions import RequestTimeout, ServiceUnavailable
from sanic.http import Http, Stage from sanic.http import Http, Stage
from sanic.log import error_logger, logger from sanic.log import Colors, error_logger, logger
from sanic.models.server_types import ConnInfo from sanic.models.server_types import ConnInfo
from sanic.request import Request from sanic.request import Request
from sanic.server.protocols.base_protocol import SanicProtocol from sanic.server.protocols.base_protocol import SanicProtocol
class HttpProtocol(SanicProtocol, metaclass=TouchUpMeta): ConnectionProtocol = type("ConnectionProtocol", (), {})
try:
from aioquic.asyncio import QuicConnectionProtocol
from aioquic.h3.connection import H3_ALPN, H3Connection
from aioquic.quic.events import (
DatagramFrameReceived,
ProtocolNegotiated,
QuicEvent,
)
ConnectionProtocol = QuicConnectionProtocol
except ModuleNotFoundError: # no cov
...
class HttpProtocolMixin:
__slots__ = ()
__version__: HTTP
def _setup_connection(self, *args, **kwargs):
self._http = self.HTTP_CLASS(self, *args, **kwargs)
self._time = current_time()
try:
self.check_timeouts()
except AttributeError:
...
def _setup(self):
self.request: Optional[Request] = None
self.access_log = self.app.config.ACCESS_LOG
self.request_handler = self.app.handle_request
self.error_handler = self.app.error_handler
self.request_timeout = self.app.config.REQUEST_TIMEOUT
self.response_timeout = self.app.config.RESPONSE_TIMEOUT
self.keep_alive_timeout = self.app.config.KEEP_ALIVE_TIMEOUT
self.request_max_size = self.app.config.REQUEST_MAX_SIZE
self.request_class = self.app.request_class or Request
@property
def http(self):
if not hasattr(self, "_http"):
return None
return self._http
@property
def version(self):
return self.__class__.__version__
class HttpProtocol(HttpProtocolMixin, SanicProtocol, metaclass=TouchUpMeta):
""" """
This class provides implements the HTTP 1.1 protocol on top of our This class provides implements the HTTP 1.1 protocol on top of our
Sanic Server transport Sanic Server transport
""" """
HTTP_CLASS = Http
__touchup__ = ( __touchup__ = (
"send", "send",
"connection_task", "connection_task",
) )
__version__ = HTTP.VERSION_1
__slots__ = ( __slots__ = (
# request params # request params
"request", "request",
@@ -70,25 +126,12 @@ class HttpProtocol(SanicProtocol, metaclass=TouchUpMeta):
unix=unix, unix=unix,
) )
self.url = None self.url = None
self.request: Optional[Request] = None
self.access_log = self.app.config.ACCESS_LOG
self.request_handler = self.app.handle_request
self.error_handler = self.app.error_handler
self.request_timeout = self.app.config.REQUEST_TIMEOUT
self.response_timeout = self.app.config.RESPONSE_TIMEOUT
self.keep_alive_timeout = self.app.config.KEEP_ALIVE_TIMEOUT
self.request_max_size = self.app.config.REQUEST_MAX_SIZE
self.request_class = self.app.request_class or Request
self.state = state if state else {} self.state = state if state else {}
self._setup()
if "requests_count" not in self.state: if "requests_count" not in self.state:
self.state["requests_count"] = 0 self.state["requests_count"] = 0
self._exception = None self._exception = None
def _setup_connection(self):
self._http = Http(self)
self._time = current_time()
self.check_timeouts()
async def connection_task(self): # no cov async def connection_task(self): # no cov
""" """
Run a HTTP connection. Run a HTTP connection.
@@ -169,7 +212,10 @@ class HttpProtocol(SanicProtocol, metaclass=TouchUpMeta):
) )
self.loop.call_later(max(0.1, interval), self.check_timeouts) self.loop.call_later(max(0.1, interval), self.check_timeouts)
return return
self._task.cancel() cancel_msg_args = ()
if sys.version_info >= (3, 9):
cancel_msg_args = ("Cancel connection task with a timeout",)
self._task.cancel(*cancel_msg_args)
except Exception: except Exception:
error_logger.exception("protocol.check_timeouts") error_logger.exception("protocol.check_timeouts")
@@ -236,3 +282,39 @@ class HttpProtocol(SanicProtocol, metaclass=TouchUpMeta):
self._data_received.set() self._data_received.set()
except Exception: except Exception:
error_logger.exception("protocol.data_received") error_logger.exception("protocol.data_received")
class Http3Protocol(HttpProtocolMixin, ConnectionProtocol): # type: ignore
HTTP_CLASS = Http3
__version__ = HTTP.VERSION_3
def __init__(self, *args, app: Sanic, **kwargs) -> None:
self.app = app
super().__init__(*args, **kwargs)
self._setup()
self._connection: Optional[H3Connection] = None
def quic_event_received(self, event: QuicEvent) -> None:
logger.debug(
f"{Colors.BLUE}[quic_event_received]: "
f"{Colors.PURPLE}{event}{Colors.END}",
extra={"verbosity": 2},
)
if isinstance(event, ProtocolNegotiated):
self._setup_connection(transmit=self.transmit)
if event.alpn_protocol in H3_ALPN:
self._connection = H3Connection(
self._quic, enable_webtransport=True
)
elif isinstance(event, DatagramFrameReceived):
if event.data == b"quack":
self._quic.send_datagram_frame(b"quack-ack")
# pass event to the HTTP layer
if self._connection is not None:
for http_event in self._connection.handle_event(event):
self._http.http_event_received(http_event)
@property
def connection(self) -> Optional[H3Connection]:
return self._connection

View File

@@ -11,7 +11,7 @@ from sanic.server import HttpProtocol
from ..websockets.impl import WebsocketImplProtocol from ..websockets.impl import WebsocketImplProtocol
if TYPE_CHECKING: # no cov if TYPE_CHECKING:
from websockets import http11 from websockets import http11

View File

@@ -6,6 +6,9 @@ from ssl import SSLContext
from typing import TYPE_CHECKING, Dict, Optional, Type, Union from typing import TYPE_CHECKING, Dict, Optional, Type, Union
from sanic.config import Config from sanic.config import Config
from sanic.exceptions import ServerError
from sanic.http.constants import HTTP
from sanic.http.tls import get_ssl_context
from sanic.server.events import trigger_events from sanic.server.events import trigger_events
@@ -23,10 +26,11 @@ from signal import signal as signal_func
from sanic.application.ext import setup_ext from sanic.application.ext import setup_ext
from sanic.compat import OS_IS_WINDOWS, ctrlc_workaround_for_windows from sanic.compat import OS_IS_WINDOWS, ctrlc_workaround_for_windows
from sanic.http.http3 import SessionTicketStore, get_config
from sanic.log import error_logger, logger from sanic.log import error_logger, logger
from sanic.models.server_types import Signal from sanic.models.server_types import Signal
from sanic.server.async_server import AsyncioServer from sanic.server.async_server import AsyncioServer
from sanic.server.protocols.http_protocol import HttpProtocol from sanic.server.protocols.http_protocol import Http3Protocol, HttpProtocol
from sanic.server.socket import ( from sanic.server.socket import (
bind_socket, bind_socket,
bind_unix_socket, bind_unix_socket,
@@ -34,6 +38,14 @@ from sanic.server.socket import (
) )
try:
from aioquic.asyncio import serve as quic_serve
HTTP3_AVAILABLE = True
except ModuleNotFoundError: # no cov
HTTP3_AVAILABLE = False
def serve( def serve(
host, host,
port, port,
@@ -52,6 +64,7 @@ def serve(
signal=Signal(), signal=Signal(),
state=None, state=None,
asyncio_server_kwargs=None, asyncio_server_kwargs=None,
version=HTTP.VERSION_1,
): ):
"""Start asynchronous HTTP Server on an individual process. """Start asynchronous HTTP Server on an individual process.
@@ -88,6 +101,87 @@ def serve(
app.asgi = False app.asgi = False
if version is HTTP.VERSION_3:
return _serve_http_3(host, port, app, loop, ssl)
return _serve_http_1(
host,
port,
app,
ssl,
sock,
unix,
reuse_port,
loop,
protocol,
backlog,
register_sys_signals,
run_multiple,
run_async,
connections,
signal,
state,
asyncio_server_kwargs,
)
def _setup_system_signals(
app: Sanic,
run_multiple: bool,
register_sys_signals: bool,
loop: asyncio.AbstractEventLoop,
) -> None:
# Ignore SIGINT when run_multiple
if run_multiple:
signal_func(SIGINT, SIG_IGN)
os.environ["SANIC_WORKER_PROCESS"] = "true"
# Register signals for graceful termination
if register_sys_signals:
if OS_IS_WINDOWS:
ctrlc_workaround_for_windows(app)
else:
for _signal in [SIGTERM] if run_multiple else [SIGINT, SIGTERM]:
loop.add_signal_handler(_signal, app.stop)
def _run_server_forever(loop, before_stop, after_stop, cleanup, unix):
pid = os.getpid()
try:
logger.info("Starting worker [%s]", pid)
loop.run_forever()
except KeyboardInterrupt:
pass
finally:
logger.info("Stopping worker [%s]", pid)
loop.run_until_complete(before_stop())
if cleanup:
cleanup()
loop.run_until_complete(after_stop())
remove_unix_socket(unix)
def _serve_http_1(
host,
port,
app,
ssl,
sock,
unix,
reuse_port,
loop,
protocol,
backlog,
register_sys_signals,
run_multiple,
run_async,
connections,
signal,
state,
asyncio_server_kwargs,
):
connections = connections if connections is not None else set() connections = connections if connections is not None else set()
protocol_kwargs = _build_protocol_kwargs(protocol, app.config) protocol_kwargs = _build_protocol_kwargs(protocol, app.config)
server = partial( server = partial(
@@ -135,30 +229,7 @@ def serve(
error_logger.exception("Unable to start server", exc_info=True) error_logger.exception("Unable to start server", exc_info=True)
return return
# Ignore SIGINT when run_multiple def _cleanup():
if run_multiple:
signal_func(SIGINT, SIG_IGN)
os.environ["SANIC_WORKER_PROCESS"] = "true"
# Register signals for graceful termination
if register_sys_signals:
if OS_IS_WINDOWS:
ctrlc_workaround_for_windows(app)
else:
for _signal in [SIGTERM] if run_multiple else [SIGINT, SIGTERM]:
loop.add_signal_handler(_signal, app.stop)
loop.run_until_complete(app._server_event("init", "after"))
pid = os.getpid()
try:
logger.info("Starting worker [%s]", pid)
loop.run_forever()
finally:
logger.info("Stopping worker [%s]", pid)
# Run the on_stop function if provided
loop.run_until_complete(app._server_event("shutdown", "before"))
# Wait for event loop to finish and all connections to drain # Wait for event loop to finish and all connections to drain
http_server.close() http_server.close()
loop.run_until_complete(http_server.wait_closed()) loop.run_until_complete(http_server.wait_closed())
@@ -188,8 +259,55 @@ def serve(
conn.websocket.fail_connection(code=1001) conn.websocket.fail_connection(code=1001)
else: else:
conn.abort() conn.abort()
loop.run_until_complete(app._server_event("shutdown", "after"))
remove_unix_socket(unix) _setup_system_signals(app, run_multiple, register_sys_signals, loop)
loop.run_until_complete(app._server_event("init", "after"))
_run_server_forever(
loop,
partial(app._server_event, "shutdown", "before"),
partial(app._server_event, "shutdown", "after"),
_cleanup,
unix,
)
def _serve_http_3(
host,
port,
app,
loop,
ssl,
register_sys_signals: bool = True,
run_multiple: bool = False,
):
if not HTTP3_AVAILABLE:
raise ServerError(
"Cannot run HTTP/3 server without aioquic installed. "
)
protocol = partial(Http3Protocol, app=app)
ticket_store = SessionTicketStore()
ssl_context = get_ssl_context(app, ssl)
config = get_config(app, ssl_context)
coro = quic_serve(
host,
port,
configuration=config,
create_protocol=protocol,
session_ticket_fetcher=ticket_store.pop,
session_ticket_handler=ticket_store.add,
)
server = AsyncioServer(app, loop, coro, [])
loop.run_until_complete(server.startup())
loop.run_until_complete(server.before_start())
loop.run_until_complete(server)
_setup_system_signals(app, run_multiple, register_sys_signals, loop)
loop.run_until_complete(server.after_start())
# TODO: Create connection cleanup and graceful shutdown
cleanup = None
_run_server_forever(
loop, server.before_stop, server.after_stop, cleanup, None
)
def serve_single(server_settings): def serve_single(server_settings):

View File

@@ -9,7 +9,7 @@ from websockets.typing import Data
from sanic.exceptions import ServerError from sanic.exceptions import ServerError
if TYPE_CHECKING: # no cov if TYPE_CHECKING:
from .impl import WebsocketImplProtocol from .impl import WebsocketImplProtocol
UTF8Decoder = codecs.getincrementaldecoder("utf-8") UTF8Decoder = codecs.getincrementaldecoder("utf-8")
@@ -37,7 +37,7 @@ class WebsocketFrameAssembler:
"get_id", "get_id",
"put_id", "put_id",
) )
if TYPE_CHECKING: # no cov if TYPE_CHECKING:
protocol: "WebsocketImplProtocol" protocol: "WebsocketImplProtocol"
read_mutex: asyncio.Lock read_mutex: asyncio.Lock
write_mutex: asyncio.Lock write_mutex: asyncio.Lock

View File

@@ -6,9 +6,9 @@ from enum import Enum
from inspect import isawaitable from inspect import isawaitable
from typing import Any, Dict, List, Optional, Tuple, Union, cast from typing import Any, Dict, List, Optional, Tuple, Union, cast
from sanic_routing import BaseRouter, Route, RouteGroup # type: ignore from sanic_routing import BaseRouter, Route, RouteGroup
from sanic_routing.exceptions import NotFound # type: ignore from sanic_routing.exceptions import NotFound
from sanic_routing.utils import path_to_parts # type: ignore from sanic_routing.utils import path_to_parts
from sanic.exceptions import InvalidSignal from sanic.exceptions import InvalidSignal
from sanic.log import error_logger, logger from sanic.log import error_logger, logger

View File

@@ -1,3 +1,4 @@
from .altsvc import AltSvcCheck # noqa
from .base import BaseScheme from .base import BaseScheme
from .ode import OptionalDispatchEvent # noqa from .ode import OptionalDispatchEvent # noqa

View File

@@ -0,0 +1,56 @@
from __future__ import annotations
from ast import Assign, Constant, NodeTransformer, Subscript
from typing import TYPE_CHECKING, Any, List
from sanic.http.constants import HTTP
from .base import BaseScheme
if TYPE_CHECKING:
from sanic import Sanic
class AltSvcCheck(BaseScheme):
ident = "ALTSVC"
def visitors(self) -> List[NodeTransformer]:
return [RemoveAltSvc(self.app, self.app.state.verbosity)]
class RemoveAltSvc(NodeTransformer):
def __init__(self, app: Sanic, verbosity: int = 0) -> None:
self._app = app
self._verbosity = verbosity
self._versions = {
info.settings["version"] for info in app.state.server_info
}
def visit_Assign(self, node: Assign) -> Any:
if any(self._matches(target) for target in node.targets):
if self._should_remove():
return None
assert isinstance(node.value, Constant)
node.value.value = self.value()
return node
def _should_remove(self) -> bool:
return len(self._versions) == 1
@staticmethod
def _matches(node) -> bool:
return (
isinstance(node, Subscript)
and isinstance(node.slice, Constant)
and node.slice.value == "alt-svc"
)
def value(self):
values = []
for info in self._app.state.server_info:
port = info.settings["port"]
version = info.settings["version"]
if version is HTTP.VERSION_3:
values.append(f'h3=":{port}"')
return ", ".join(values)

View File

@@ -1,5 +1,8 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Set, Type from ast import NodeTransformer, parse
from inspect import getsource
from textwrap import dedent
from typing import Any, Dict, List, Set, Type
class BaseScheme(ABC): class BaseScheme(ABC):
@@ -10,11 +13,26 @@ class BaseScheme(ABC):
self.app = app self.app = app
@abstractmethod @abstractmethod
def run(self, method, module_globals) -> None: def visitors(self) -> List[NodeTransformer]:
... ...
def __init_subclass__(cls): def __init_subclass__(cls):
BaseScheme._registry.add(cls) BaseScheme._registry.add(cls)
def __call__(self, method, module_globals): def __call__(self):
return self.run(method, module_globals) return self.visitors()
@classmethod
def build(cls, method, module_globals, app):
raw_source = getsource(method)
src = dedent(raw_source)
node = parse(src)
for scheme in cls._registry:
for visitor in scheme(app)():
node = visitor.visit(node)
compiled_src = compile(node, method.__name__, "exec")
exec_locals: Dict[str, Any] = {}
exec(compiled_src, module_globals, exec_locals) # nosec
return exec_locals[method.__name__]

View File

@@ -1,7 +1,5 @@
from ast import Attribute, Await, Dict, Expr, NodeTransformer, parse from ast import Attribute, Await, Expr, NodeTransformer
from inspect import getsource from typing import Any, List
from textwrap import dedent
from typing import Any
from sanic.log import logger from sanic.log import logger
@@ -20,18 +18,8 @@ class OptionalDispatchEvent(BaseScheme):
signal.name for signal in app.signal_router.routes signal.name for signal in app.signal_router.routes
] ]
def run(self, method, module_globals): def visitors(self) -> List[NodeTransformer]:
raw_source = getsource(method) return [RemoveDispatch(self._registered_events)]
src = dedent(raw_source)
tree = parse(src)
node = RemoveDispatch(
self._registered_events, self.app.state.verbosity
).visit(tree)
compiled_src = compile(node, method.__name__, "exec")
exec_locals: Dict[str, Any] = {}
exec(compiled_src, module_globals, exec_locals) # nosec
return exec_locals[method.__name__]
def _sync_events(self): def _sync_events(self):
all_events = set() all_events = set()
@@ -64,9 +52,8 @@ class OptionalDispatchEvent(BaseScheme):
class RemoveDispatch(NodeTransformer): class RemoveDispatch(NodeTransformer):
def __init__(self, registered_events, verbosity: int = 0) -> None: def __init__(self, registered_events) -> None:
self._registered_events = registered_events self._registered_events = registered_events
self._verbosity = verbosity
def visit_Expr(self, node: Expr) -> Any: def visit_Expr(self, node: Expr) -> Any:
call = node.value call = node.value
@@ -83,8 +70,10 @@ class RemoveDispatch(NodeTransformer):
if hasattr(event, "s"): if hasattr(event, "s"):
event_name = getattr(event, "value", event.s) event_name = getattr(event, "value", event.s)
if self._not_registered(event_name): if self._not_registered(event_name):
if self._verbosity >= 2: logger.debug(
logger.debug(f"Disabling event: {event_name}") f"Disabling event: {event_name}",
extra={"verbosity": 2},
)
return None return None
return node return node

View File

@@ -21,10 +21,8 @@ class TouchUp:
module = getmodule(target) module = getmodule(target)
module_globals = dict(getmembers(module)) module_globals = dict(getmembers(module))
modified = BaseScheme.build(method, module_globals, app)
for scheme in BaseScheme._registry: setattr(target, method_name, modified)
modified = scheme(app)(method, module_globals)
setattr(target, method_name, modified)
target.__touched__ = True target.__touched__ = True

View File

@@ -13,7 +13,7 @@ from typing import (
from sanic.models.handler_types import RouteHandler from sanic.models.handler_types import RouteHandler
if TYPE_CHECKING: # no cov if TYPE_CHECKING:
from sanic import Sanic from sanic import Sanic
from sanic.blueprints import Blueprint from sanic.blueprints import Blueprint

View File

@@ -1,13 +1,2 @@
[flake8] [flake8]
ignore = E203, W503 ignore = E203, W503
[isort]
atomic = true
default_section = THIRDPARTY
include_trailing_comma = true
known_first_party = sanic
known_third_party = pytest
line_length = 79
lines_after_imports = 2
lines_between_types = 1
multi_line_output = 3

View File

@@ -122,6 +122,7 @@ docs_require = [
"docutils", "docutils",
"pygments", "pygments",
"m2r2", "m2r2",
"enum-tools[sphinx]",
"mistune<2.0.0", "mistune<2.0.0",
] ]
@@ -149,6 +150,7 @@ extras_require = {
"docs": docs_require, "docs": docs_require,
"all": all_require, "all": all_require,
"ext": ["sanic-ext"], "ext": ["sanic-ext"],
"http3": ["aioquic"],
} }
setup_kwargs["install_requires"] = requirements setup_kwargs["install_requires"] = requirements

0
tests/__init__.py Normal file
View File

View File

@@ -25,6 +25,10 @@ class AsyncMock(Mock):
def __await__(self): def __await__(self):
return self().__await__() return self().__await__()
def reset_mock(self, *args, **kwargs):
super().reset_mock(*args, **kwargs)
self.await_count = 0
def assert_awaited_once(self): def assert_awaited_once(self):
if not self.await_count == 1: if not self.await_count == 1:
msg = ( msg = (
@@ -32,3 +36,13 @@ class AsyncMock(Mock):
f" Awaited {self.await_count} times." f" Awaited {self.await_count} times."
) )
raise AssertionError(msg) raise AssertionError(msg)
def assert_awaited_once_with(self, *args, **kwargs):
if not self.await_count == 1:
msg = (
f"Expected to have been awaited once."
f" Awaited {self.await_count} times."
)
raise AssertionError(msg)
self.assert_awaited_once()
return self.assert_called_with(*args, **kwargs)

47
tests/client.py Normal file
View File

@@ -0,0 +1,47 @@
import asyncio
from textwrap import dedent
from typing import AnyStr
class RawClient:
CRLF = b"\r\n"
def __init__(self, host: str, port: int):
self.reader = None
self.writer = None
self.host = host
self.port = port
async def connect(self):
self.reader, self.writer = await asyncio.open_connection(
self.host, self.port
)
async def close(self):
self.writer.close()
await self.writer.wait_closed()
async def send(self, message: AnyStr):
if isinstance(message, str):
msg = self._clean(message).encode("utf-8")
else:
msg = message
await self._send(msg)
async def _send(self, message: bytes):
if not self.writer:
raise Exception("No open write stream")
self.writer.write(message)
async def recv(self, nbytes: int = -1) -> bytes:
if not self.reader:
raise Exception("No open read stream")
return await self.reader.read(nbytes)
def _clean(self, message: str) -> str:
return (
dedent(message)
.lstrip("\n")
.replace("\n", self.CRLF.decode("utf-8"))
)

View File

@@ -150,6 +150,7 @@ def app(request):
yield app yield app
for target, method_name in TouchUp._registry: for target, method_name in TouchUp._registry:
setattr(target, method_name, CACHE[method_name]) setattr(target, method_name, CACHE[method_name])
Sanic._app_registry.clear()
@pytest.fixture(scope="function") @pytest.fixture(scope="function")

0
tests/http3/__init__.py Normal file
View File

View File

@@ -0,0 +1,294 @@
from unittest.mock import Mock
import pytest
from aioquic.h3.connection import H3Connection
from aioquic.h3.events import DataReceived, HeadersReceived
from aioquic.quic.configuration import QuicConfiguration
from aioquic.quic.connection import QuicConnection
from aioquic.quic.events import ProtocolNegotiated
from sanic import Request, Sanic
from sanic.compat import Header
from sanic.config import DEFAULT_CONFIG
from sanic.exceptions import PayloadTooLarge
from sanic.http.constants import Stage
from sanic.http.http3 import Http3, HTTPReceiver
from sanic.models.server_types import ConnInfo
from sanic.response import empty, json
from sanic.server.protocols.http_protocol import Http3Protocol
try:
from unittest.mock import AsyncMock
except ImportError:
from tests.asyncmock import AsyncMock # type: ignore
pytestmark = pytest.mark.asyncio
@pytest.fixture(autouse=True)
async def setup(app: Sanic):
@app.get("/")
async def handler(*_):
return empty()
app.router.finalize()
app.signal_router.finalize()
app.signal_router.allow_fail_builtin = False
@pytest.fixture
def http_request(app):
return Request(b"/", Header({}), "3", "GET", Mock(), app)
def generate_protocol(app):
connection = QuicConnection(configuration=QuicConfiguration())
connection._ack_delay = 0
connection._loss = Mock()
connection._loss.spaces = []
connection._loss.get_loss_detection_time = lambda: None
connection.datagrams_to_send = Mock(return_value=[]) # type: ignore
return Http3Protocol(
connection,
app=app,
stream_handler=None,
)
def generate_http_receiver(app, http_request) -> HTTPReceiver:
protocol = generate_protocol(app)
receiver = HTTPReceiver(
protocol.transmit,
protocol,
http_request,
)
http_request.stream = receiver
return receiver
def test_http_receiver_init(app: Sanic, http_request: Request):
receiver = generate_http_receiver(app, http_request)
assert receiver.request_body is None
assert receiver.stage is Stage.IDLE
assert receiver.headers_sent is False
assert receiver.response is None
assert receiver.request_max_size == DEFAULT_CONFIG["REQUEST_MAX_SIZE"]
assert receiver.request_bytes == 0
async def test_http_receiver_run_request(app: Sanic, http_request: Request):
handler = AsyncMock()
class mock_handle(Sanic):
handle_request = handler
app.__class__ = mock_handle
receiver = generate_http_receiver(app, http_request)
receiver.protocol.quic_event_received(
ProtocolNegotiated(alpn_protocol="h3")
)
await receiver.run()
handler.assert_awaited_once_with(receiver.request)
async def test_http_receiver_run_exception(app: Sanic, http_request: Request):
handler = AsyncMock()
class mock_handle(Sanic):
handle_exception = handler
app.__class__ = mock_handle
receiver = generate_http_receiver(app, http_request)
receiver.protocol.quic_event_received(
ProtocolNegotiated(alpn_protocol="h3")
)
exception = Exception("Oof")
await receiver.run(exception)
handler.assert_awaited_once_with(receiver.request, exception)
handler.reset_mock()
receiver.stage = Stage.REQUEST
await receiver.run(exception)
handler.assert_awaited_once_with(receiver.request, exception)
def test_http_receiver_respond(app: Sanic, http_request: Request):
receiver = generate_http_receiver(app, http_request)
response = empty()
receiver.stage = Stage.RESPONSE
with pytest.raises(RuntimeError, match="Response already started"):
receiver.respond(response)
receiver.stage = Stage.HANDLER
receiver.response = Mock()
resp = receiver.respond(response)
assert receiver.response is resp
assert resp is response
assert response.stream is receiver
def test_http_receiver_receive_body(app: Sanic, http_request: Request):
receiver = generate_http_receiver(app, http_request)
receiver.request_max_size = 4
receiver.receive_body(b"..")
assert receiver.request.body == b".."
receiver.receive_body(b"..")
assert receiver.request.body == b"...."
with pytest.raises(
PayloadTooLarge, match="Request body exceeds the size limit"
):
receiver.receive_body(b"..")
def test_http3_events(app):
protocol = generate_protocol(app)
http3 = Http3(protocol, protocol.transmit)
http3.http_event_received(
HeadersReceived(
[
(b":method", b"GET"),
(b":path", b"/location"),
(b":scheme", b"https"),
(b":authority", b"localhost:8443"),
(b"foo", b"bar"),
],
1,
False,
)
)
http3.http_event_received(DataReceived(b"foobar", 1, False))
receiver = http3.receivers[1]
assert len(http3.receivers) == 1
assert receiver.request.stream_id == 1
assert receiver.request.path == "/location"
assert receiver.request.method == "GET"
assert receiver.request.headers["foo"] == "bar"
assert receiver.request.body == b"foobar"
async def test_send_headers(app: Sanic, http_request: Request):
send_headers_mock = Mock()
existing_send_headers = H3Connection.send_headers
receiver = generate_http_receiver(app, http_request)
receiver.protocol.quic_event_received(
ProtocolNegotiated(alpn_protocol="h3")
)
http_request._protocol = receiver.protocol
def send_headers(*args, **kwargs):
send_headers_mock(*args, **kwargs)
return existing_send_headers(
receiver.protocol.connection, *args, **kwargs
)
receiver.protocol.connection.send_headers = send_headers
receiver.head_only = False
response = json({}, status=201, headers={"foo": "bar"})
with pytest.raises(RuntimeError, match="no response"):
receiver.send_headers()
receiver.response = response
receiver.send_headers()
assert receiver.headers_sent
assert receiver.stage is Stage.RESPONSE
send_headers_mock.assert_called_once_with(
stream_id=0,
headers=[
(b":status", b"201"),
(b"foo", b"bar"),
(b"content-length", b"2"),
(b"content-type", b"application/json"),
],
)
def test_multiple_streams(app):
protocol = generate_protocol(app)
http3 = Http3(protocol, protocol.transmit)
http3.http_event_received(
HeadersReceived(
[
(b":method", b"GET"),
(b":path", b"/location"),
(b":scheme", b"https"),
(b":authority", b"localhost:8443"),
(b"foo", b"bar"),
],
1,
False,
)
)
http3.http_event_received(
HeadersReceived(
[
(b":method", b"GET"),
(b":path", b"/location"),
(b":scheme", b"https"),
(b":authority", b"localhost:8443"),
(b"foo", b"bar"),
],
2,
False,
)
)
receiver1 = http3.get_receiver_by_stream_id(1)
receiver2 = http3.get_receiver_by_stream_id(2)
assert len(http3.receivers) == 2
assert isinstance(receiver1, HTTPReceiver)
assert isinstance(receiver2, HTTPReceiver)
assert receiver1 is not receiver2
def test_request_stream_id(app):
protocol = generate_protocol(app)
http3 = Http3(protocol, protocol.transmit)
http3.http_event_received(
HeadersReceived(
[
(b":method", b"GET"),
(b":path", b"/location"),
(b":scheme", b"https"),
(b":authority", b"localhost:8443"),
(b"foo", b"bar"),
],
1,
False,
)
)
receiver = http3.get_receiver_by_stream_id(1)
assert isinstance(receiver.request, Request)
assert receiver.request.stream_id == 1
def test_request_conn_info(app):
protocol = generate_protocol(app)
http3 = Http3(protocol, protocol.transmit)
http3.http_event_received(
HeadersReceived(
[
(b":method", b"GET"),
(b":path", b"/location"),
(b":scheme", b"https"),
(b":authority", b"localhost:8443"),
(b"foo", b"bar"),
],
1,
False,
)
)
receiver = http3.get_receiver_by_stream_id(1)
assert isinstance(receiver.request.conn_info, ConnInfo)

114
tests/http3/test_server.py Normal file
View File

@@ -0,0 +1,114 @@
import logging
import sys
from asyncio import Event
from pathlib import Path
import pytest
from sanic import Sanic
from sanic.compat import UVLOOP_INSTALLED
from sanic.http.constants import HTTP
parent_dir = Path(__file__).parent.parent
localhost_dir = parent_dir / "certs/localhost"
@pytest.mark.parametrize("version", (3, HTTP.VERSION_3))
@pytest.mark.skipif(
sys.version_info < (3, 8) and not UVLOOP_INSTALLED,
reason="In 3.7 w/o uvloop the port is not always released",
)
def test_server_starts_http3(app: Sanic, version, caplog):
ev = Event()
@app.after_server_start
def shutdown(*_):
ev.set()
app.stop()
with caplog.at_level(logging.INFO):
app.run(
version=version,
ssl={
"cert": localhost_dir / "fullchain.pem",
"key": localhost_dir / "privkey.pem",
},
)
assert ev.is_set()
assert (
"sanic.root",
logging.INFO,
"server: sanic, HTTP/3",
) in caplog.record_tuples
@pytest.mark.skipif(
sys.version_info < (3, 8) and not UVLOOP_INSTALLED,
reason="In 3.7 w/o uvloop the port is not always released",
)
def test_server_starts_http1_and_http3(app: Sanic, caplog):
@app.after_server_start
def shutdown(*_):
app.stop()
app.prepare(
version=3,
ssl={
"cert": localhost_dir / "fullchain.pem",
"key": localhost_dir / "privkey.pem",
},
)
app.prepare(
version=1,
ssl={
"cert": localhost_dir / "fullchain.pem",
"key": localhost_dir / "privkey.pem",
},
)
with caplog.at_level(logging.INFO):
Sanic.serve()
assert (
"sanic.root",
logging.INFO,
"server: sanic, HTTP/1.1",
) in caplog.record_tuples
assert (
"sanic.root",
logging.INFO,
"server: sanic, HTTP/3",
) in caplog.record_tuples
@pytest.mark.skipif(
sys.version_info < (3, 8) and not UVLOOP_INSTALLED,
reason="In 3.7 w/o uvloop the port is not always released",
)
def test_server_starts_http1_and_http3_bad_order(app: Sanic, caplog):
@app.after_server_start
def shutdown(*_):
app.stop()
app.prepare(
version=1,
ssl={
"cert": localhost_dir / "fullchain.pem",
"key": localhost_dir / "privkey.pem",
},
)
message = (
"Serving HTTP/3 instances as a secondary server is not supported. "
"There can only be a single HTTP/3 worker and it must be the first "
"instance prepared."
)
with pytest.raises(RuntimeError, match=message):
app.prepare(
version=3,
ssl={
"cert": localhost_dir / "fullchain.pem",
"key": localhost_dir / "privkey.pem",
},
)

View File

@@ -0,0 +1,46 @@
from datetime import datetime
from aioquic.tls import CipherSuite, SessionTicket
from sanic.http.http3 import SessionTicketStore
def _generate_ticket(label):
return SessionTicket(
1,
CipherSuite.AES_128_GCM_SHA256,
datetime.now(),
datetime.now(),
label,
label.decode(),
label,
None,
[],
)
def test_session_ticket_store():
store = SessionTicketStore()
assert len(store.tickets) == 0
ticket1 = _generate_ticket(b"foo")
store.add(ticket1)
assert len(store.tickets) == 1
ticket2 = _generate_ticket(b"bar")
store.add(ticket2)
assert len(store.tickets) == 2
assert len(store.tickets) == 2
popped2 = store.pop(ticket2.ticket)
assert len(store.tickets) == 1
assert popped2 is ticket2
popped1 = store.pop(ticket1.ticket)
assert len(store.tickets) == 0
assert popped1 is ticket1

View File

@@ -4,7 +4,6 @@ import re
from collections import Counter from collections import Counter
from inspect import isawaitable from inspect import isawaitable
from os import environ
from unittest.mock import Mock, patch from unittest.mock import Mock, patch
import pytest import pytest
@@ -113,19 +112,6 @@ def test_create_server_main_convenience(app, caplog):
) in caplog.record_tuples ) in caplog.record_tuples
def test_create_server_init(app, caplog):
loop = asyncio.get_event_loop()
asyncio_srv_coro = app.create_server(return_asyncio_server=True)
server = loop.run_until_complete(asyncio_srv_coro)
message = (
"AsyncioServer.init has been deprecated and will be removed in v22.6. "
"Use Sanic.state.is_started instead."
)
with pytest.warns(DeprecationWarning, match=message):
server.init
def test_app_loop_not_running(app): def test_app_loop_not_running(app):
with pytest.raises(SanicException) as excinfo: with pytest.raises(SanicException) as excinfo:
app.loop app.loop
@@ -385,40 +371,6 @@ def test_get_app_default_ambiguous():
Sanic.get_app() Sanic.get_app()
def test_app_no_registry():
Sanic("no-register", register=False)
with pytest.raises(
SanicException, match='Sanic app name "no-register" not found.'
):
Sanic.get_app("no-register")
def test_app_no_registry_deprecation_message():
with pytest.warns(DeprecationWarning) as records:
Sanic("no-register", register=False)
Sanic("yes-register", register=True)
message = (
"[DEPRECATION v22.6] The register argument is deprecated and will "
"stop working in v22.6. After v22.6 all apps will be added to the "
"Sanic app registry."
)
assert len(records) == 2
for record in records:
assert record.message.args[0] == message
def test_app_no_registry_env():
environ["SANIC_REGISTER"] = "False"
Sanic("no-register")
with pytest.raises(
SanicException, match='Sanic app name "no-register" not found.'
):
Sanic.get_app("no-register")
del environ["SANIC_REGISTER"]
def test_app_set_attribute_warning(app): def test_app_set_attribute_warning(app):
message = ( message = (
"Setting variables on Sanic instances is not allowed. You should " "Setting variables on Sanic instances is not allowed. You should "

View File

@@ -9,10 +9,11 @@ import uvicorn
from sanic import Sanic from sanic import Sanic
from sanic.application.state import Mode from sanic.application.state import Mode
from sanic.asgi import MockTransport from sanic.asgi import MockTransport
from sanic.exceptions import Forbidden, InvalidUsage, ServiceUnavailable from sanic.exceptions import BadRequest, Forbidden, ServiceUnavailable
from sanic.request import Request from sanic.request import Request
from sanic.response import json, text from sanic.response import json, text
from sanic.server.websockets.connection import WebSocketConnection from sanic.server.websockets.connection import WebSocketConnection
from sanic.signals import RESERVED_NAMESPACES
@pytest.fixture @pytest.fixture
@@ -221,6 +222,7 @@ def test_listeners_triggered_async(app, caplog):
assert after_server_stop assert after_server_stop
app.state.mode = Mode.DEBUG app.state.mode = Mode.DEBUG
app.state.verbosity = 0
with caplog.at_level(logging.DEBUG): with caplog.at_level(logging.DEBUG):
server.run() server.run()
@@ -392,7 +394,7 @@ async def test_websocket_accept_with_multiple_subprotocols(
def test_improper_websocket_connection(transport, send, receive): def test_improper_websocket_connection(transport, send, receive):
with pytest.raises(InvalidUsage): with pytest.raises(BadRequest):
transport.get_websocket_connection() transport.get_websocket_connection()
transport.create_websocket_connection(send, receive) transport.create_websocket_connection(send, receive)
@@ -415,7 +417,7 @@ async def test_request_class_custom():
class MyCustomRequest(Request): class MyCustomRequest(Request):
pass pass
app = Sanic(name=__name__, request_class=MyCustomRequest) app = Sanic(name="Test", request_class=MyCustomRequest)
@app.get("/custom") @app.get("/custom")
def custom_request(request): def custom_request(request):
@@ -513,3 +515,44 @@ async def test_request_exception_suppressed_by_middleware(app):
_, response = await app.asgi_client.get("/error-prone") _, response = await app.asgi_client.get("/error-prone")
assert response.status_code == 403 assert response.status_code == 403
@pytest.mark.asyncio
async def test_signals_triggered(app):
@app.get("/test_signals_triggered")
async def _request(request):
return text("test_signals_triggered")
signals_triggered = []
signals_expected = [
# "http.lifecycle.begin",
# "http.lifecycle.read_head",
"http.lifecycle.request",
"http.lifecycle.handle",
"http.routing.before",
"http.routing.after",
"http.lifecycle.response",
# "http.lifecycle.send",
# "http.lifecycle.complete",
]
def signal_handler(signal):
return lambda *a, **kw: signals_triggered.append(signal)
for signal in RESERVED_NAMESPACES["http"]:
app.signal(signal)(signal_handler(signal))
_, response = await app.asgi_client.get("/test_signals_triggered")
assert response.status_code == 200
assert response.text == "test_signals_triggered"
assert signals_triggered == signals_expected
@pytest.mark.asyncio
async def test_asgi_serve_location(app):
@app.get("/")
def _request(request: Request):
return text(request.app.serve_location)
_, response = await app.asgi_client.get("/")
assert response.text == "http://<ASGI>"

View File

@@ -3,12 +3,7 @@ from pytest import raises
from sanic.app import Sanic from sanic.app import Sanic
from sanic.blueprint_group import BlueprintGroup from sanic.blueprint_group import BlueprintGroup
from sanic.blueprints import Blueprint from sanic.blueprints import Blueprint
from sanic.exceptions import ( from sanic.exceptions import BadRequest, Forbidden, SanicException, ServerError
Forbidden,
InvalidUsage,
SanicException,
ServerError,
)
from sanic.request import Request from sanic.request import Request
from sanic.response import HTTPResponse, text from sanic.response import HTTPResponse, text
@@ -104,7 +99,7 @@ def test_bp_group(app: Sanic):
@blueprint_1.route("/invalid") @blueprint_1.route("/invalid")
def blueprint_1_error(request: Request): def blueprint_1_error(request: Request):
raise InvalidUsage("Invalid") raise BadRequest("Invalid")
@blueprint_2.route("/") @blueprint_2.route("/")
def blueprint_2_default_route(request): def blueprint_2_default_route(request):
@@ -120,7 +115,7 @@ def test_bp_group(app: Sanic):
blueprint_3 = Blueprint("blueprint_3", url_prefix="/bp3") blueprint_3 = Blueprint("blueprint_3", url_prefix="/bp3")
@blueprint_group_1.exception(InvalidUsage) @blueprint_group_1.exception(BadRequest)
def handle_group_exception(request, exception): def handle_group_exception(request, exception):
return text("BP1_ERR_OK") return text("BP1_ERR_OK")

View File

@@ -7,12 +7,7 @@ import pytest
from sanic.app import Sanic from sanic.app import Sanic
from sanic.blueprints import Blueprint from sanic.blueprints import Blueprint
from sanic.constants import HTTP_METHODS from sanic.constants import HTTP_METHODS
from sanic.exceptions import ( from sanic.exceptions import BadRequest, NotFound, SanicException, ServerError
InvalidUsage,
NotFound,
SanicException,
ServerError,
)
from sanic.request import Request from sanic.request import Request
from sanic.response import json, text from sanic.response import json, text
@@ -448,7 +443,7 @@ def test_bp_exception_handler(app):
@blueprint.route("/1") @blueprint.route("/1")
def handler_1(request): def handler_1(request):
raise InvalidUsage("OK") raise BadRequest("OK")
@blueprint.route("/2") @blueprint.route("/2")
def handler_2(request): def handler_2(request):

View File

@@ -2,6 +2,7 @@ import json
import subprocess import subprocess
from pathlib import Path from pathlib import Path
from typing import List, Optional, Tuple
import pytest import pytest
@@ -10,7 +11,7 @@ from sanic_routing import __version__ as __routing_version__
from sanic import __version__ from sanic import __version__
def capture(command): def capture(command: List[str]):
proc = subprocess.Popen( proc = subprocess.Popen(
command, command,
stdout=subprocess.PIPE, stdout=subprocess.PIPE,
@@ -18,21 +19,21 @@ def capture(command):
cwd=Path(__file__).parent, cwd=Path(__file__).parent,
) )
try: try:
out, err = proc.communicate(timeout=1) out, err = proc.communicate(timeout=10)
except subprocess.TimeoutExpired: except subprocess.TimeoutExpired:
proc.kill() proc.kill()
out, err = proc.communicate() out, err = proc.communicate()
return out, err, proc.returncode return out, err, proc.returncode
def starting_line(lines): def starting_line(lines: List[str]):
for idx, line in enumerate(lines): for idx, line in enumerate(lines):
if line.strip().startswith(b"Sanic v"): if line.strip().startswith(b"Sanic v"):
return idx return idx
return 0 return 0
def read_app_info(lines): def read_app_info(lines: List[str]):
for line in lines: for line in lines:
if line.startswith(b"{") and line.endswith(b"}"): if line.startswith(b"{") and line.endswith(b"}"):
return json.loads(line) return json.loads(line)
@@ -46,7 +47,7 @@ def read_app_info(lines):
("fake.server.create_app()", None), ("fake.server.create_app()", None),
), ),
) )
def test_server_run(appname, extra): def test_server_run(appname: str, extra: Optional[str]):
command = ["sanic", appname] command = ["sanic", appname]
if extra: if extra:
command.append(extra) command.append(extra)
@@ -119,7 +120,7 @@ def test_error_with_path_as_instance_without_simple_arg():
), ),
), ),
) )
def test_tls_options(cmd): def test_tls_options(cmd: Tuple[str]):
command = ["sanic", "fake.server.app", *cmd, "-p=9999", "--debug"] command = ["sanic", "fake.server.app", *cmd, "-p=9999", "--debug"]
out, err, exitcode = capture(command) out, err, exitcode = capture(command)
assert exitcode != 1 assert exitcode != 1
@@ -140,15 +141,14 @@ def test_tls_options(cmd):
("--tls-strict-host",), ("--tls-strict-host",),
), ),
) )
def test_tls_wrong_options(cmd): def test_tls_wrong_options(cmd: Tuple[str]):
command = ["sanic", "fake.server.app", *cmd, "-p=9999", "--debug"] command = ["sanic", "fake.server.app", *cmd, "-p=9999", "--debug"]
out, err, exitcode = capture(command) out, err, exitcode = capture(command)
assert exitcode == 1 assert exitcode == 1
assert not out assert not out
lines = err.decode().split("\n") lines = err.decode().split("\n")
errmsg = lines[6] assert "TLS certificates must be specified by either of:" in lines
assert errmsg == "TLS certificates must be specified by either of:"
@pytest.mark.parametrize( @pytest.mark.parametrize(
@@ -158,7 +158,7 @@ def test_tls_wrong_options(cmd):
("-H", "localhost", "-p", "9999"), ("-H", "localhost", "-p", "9999"),
), ),
) )
def test_host_port_localhost(cmd): def test_host_port_localhost(cmd: Tuple[str]):
command = ["sanic", "fake.server.app", *cmd] command = ["sanic", "fake.server.app", *cmd]
out, err, exitcode = capture(command) out, err, exitcode = capture(command)
lines = out.split(b"\n") lines = out.split(b"\n")
@@ -175,7 +175,7 @@ def test_host_port_localhost(cmd):
("-H", "127.0.0.127", "-p", "9999"), ("-H", "127.0.0.127", "-p", "9999"),
), ),
) )
def test_host_port_ipv4(cmd): def test_host_port_ipv4(cmd: Tuple[str]):
command = ["sanic", "fake.server.app", *cmd] command = ["sanic", "fake.server.app", *cmd]
out, err, exitcode = capture(command) out, err, exitcode = capture(command)
lines = out.split(b"\n") lines = out.split(b"\n")
@@ -192,7 +192,7 @@ def test_host_port_ipv4(cmd):
("-H", "::", "-p", "9999"), ("-H", "::", "-p", "9999"),
), ),
) )
def test_host_port_ipv6_any(cmd): def test_host_port_ipv6_any(cmd: Tuple[str]):
command = ["sanic", "fake.server.app", *cmd] command = ["sanic", "fake.server.app", *cmd]
out, err, exitcode = capture(command) out, err, exitcode = capture(command)
lines = out.split(b"\n") lines = out.split(b"\n")
@@ -209,7 +209,7 @@ def test_host_port_ipv6_any(cmd):
("-H", "::1", "-p", "9999"), ("-H", "::1", "-p", "9999"),
), ),
) )
def test_host_port_ipv6_loopback(cmd): def test_host_port_ipv6_loopback(cmd: Tuple[str]):
command = ["sanic", "fake.server.app", *cmd] command = ["sanic", "fake.server.app", *cmd]
out, err, exitcode = capture(command) out, err, exitcode = capture(command)
lines = out.split(b"\n") lines = out.split(b"\n")
@@ -230,7 +230,7 @@ def test_host_port_ipv6_loopback(cmd):
(4, ("-w", "4")), (4, ("-w", "4")),
), ),
) )
def test_num_workers(num, cmd): def test_num_workers(num: int, cmd: Tuple[str]):
command = ["sanic", "fake.server.app", *cmd] command = ["sanic", "fake.server.app", *cmd]
out, err, exitcode = capture(command) out, err, exitcode = capture(command)
lines = out.split(b"\n") lines = out.split(b"\n")
@@ -245,7 +245,7 @@ def test_num_workers(num, cmd):
@pytest.mark.parametrize("cmd", ("--debug",)) @pytest.mark.parametrize("cmd", ("--debug",))
def test_debug(cmd): def test_debug(cmd: str):
command = ["sanic", "fake.server.app", cmd] command = ["sanic", "fake.server.app", cmd]
out, err, exitcode = capture(command) out, err, exitcode = capture(command)
lines = out.split(b"\n") lines = out.split(b"\n")
@@ -259,7 +259,7 @@ def test_debug(cmd):
@pytest.mark.parametrize("cmd", ("--dev", "-d")) @pytest.mark.parametrize("cmd", ("--dev", "-d"))
def test_dev(cmd): def test_dev(cmd: str):
command = ["sanic", "fake.server.app", cmd] command = ["sanic", "fake.server.app", cmd]
out, err, exitcode = capture(command) out, err, exitcode = capture(command)
lines = out.split(b"\n") lines = out.split(b"\n")
@@ -272,7 +272,7 @@ def test_dev(cmd):
@pytest.mark.parametrize("cmd", ("--auto-reload", "-r")) @pytest.mark.parametrize("cmd", ("--auto-reload", "-r"))
def test_auto_reload(cmd): def test_auto_reload(cmd: str):
command = ["sanic", "fake.server.app", cmd] command = ["sanic", "fake.server.app", cmd]
out, err, exitcode = capture(command) out, err, exitcode = capture(command)
lines = out.split(b"\n") lines = out.split(b"\n")
@@ -288,7 +288,7 @@ def test_auto_reload(cmd):
@pytest.mark.parametrize( @pytest.mark.parametrize(
"cmd,expected", (("--access-log", True), ("--no-access-log", False)) "cmd,expected", (("--access-log", True), ("--no-access-log", False))
) )
def test_access_logs(cmd, expected): def test_access_logs(cmd: str, expected: bool):
command = ["sanic", "fake.server.app", cmd] command = ["sanic", "fake.server.app", cmd]
out, err, exitcode = capture(command) out, err, exitcode = capture(command)
lines = out.split(b"\n") lines = out.split(b"\n")
@@ -300,7 +300,7 @@ def test_access_logs(cmd, expected):
@pytest.mark.parametrize("cmd", ("--version", "-v")) @pytest.mark.parametrize("cmd", ("--version", "-v"))
def test_version(cmd): def test_version(cmd: str):
command = ["sanic", cmd] command = ["sanic", cmd]
out, err, exitcode = capture(command) out, err, exitcode = capture(command)
version_string = f"Sanic {__version__}; Routing {__routing_version__}\n" version_string = f"Sanic {__version__}; Routing {__routing_version__}\n"
@@ -315,7 +315,7 @@ def test_version(cmd):
("--no-noisy-exceptions", False), ("--no-noisy-exceptions", False),
), ),
) )
def test_noisy_exceptions(cmd, expected): def test_noisy_exceptions(cmd: str, expected: bool):
command = ["sanic", "fake.server.app", cmd] command = ["sanic", "fake.server.app", cmd]
out, err, exitcode = capture(command) out, err, exitcode = capture(command)
lines = out.split(b"\n") lines = out.split(b"\n")

View File

@@ -1,4 +1,5 @@
import logging import logging
import os
from contextlib import contextmanager from contextlib import contextmanager
from os import environ from os import environ
@@ -13,6 +14,7 @@ from pytest import MonkeyPatch
from sanic import Sanic from sanic import Sanic
from sanic.config import DEFAULT_CONFIG, Config from sanic.config import DEFAULT_CONFIG, Config
from sanic.constants import LocalCertCreator
from sanic.exceptions import PyFileError from sanic.exceptions import PyFileError
@@ -49,7 +51,7 @@ def test_load_from_object(app: Sanic):
def test_load_from_object_string(app: Sanic): def test_load_from_object_string(app: Sanic):
app.config.load("test_config.ConfigTest") app.config.load("tests.test_config.ConfigTest")
assert "CONFIG_VALUE" in app.config assert "CONFIG_VALUE" in app.config
assert app.config.CONFIG_VALUE == "should be used" assert app.config.CONFIG_VALUE == "should be used"
assert "not_for_config" not in app.config assert "not_for_config" not in app.config
@@ -71,14 +73,14 @@ def test_load_from_object_string_exception(app: Sanic):
def test_auto_env_prefix(): def test_auto_env_prefix():
environ["SANIC_TEST_ANSWER"] = "42" environ["SANIC_TEST_ANSWER"] = "42"
app = Sanic(name=__name__) app = Sanic(name="Test")
assert app.config.TEST_ANSWER == 42 assert app.config.TEST_ANSWER == 42
del environ["SANIC_TEST_ANSWER"] del environ["SANIC_TEST_ANSWER"]
def test_auto_bool_env_prefix(): def test_auto_bool_env_prefix():
environ["SANIC_TEST_ANSWER"] = "True" environ["SANIC_TEST_ANSWER"] = "True"
app = Sanic(name=__name__) app = Sanic(name="Test")
assert app.config.TEST_ANSWER is True assert app.config.TEST_ANSWER is True
del environ["SANIC_TEST_ANSWER"] del environ["SANIC_TEST_ANSWER"]
@@ -86,28 +88,28 @@ def test_auto_bool_env_prefix():
@pytest.mark.parametrize("env_prefix", [None, ""]) @pytest.mark.parametrize("env_prefix", [None, ""])
def test_empty_load_env_prefix(env_prefix): def test_empty_load_env_prefix(env_prefix):
environ["SANIC_TEST_ANSWER"] = "42" environ["SANIC_TEST_ANSWER"] = "42"
app = Sanic(name=__name__, env_prefix=env_prefix) app = Sanic(name="Test", env_prefix=env_prefix)
assert getattr(app.config, "TEST_ANSWER", None) is None assert getattr(app.config, "TEST_ANSWER", None) is None
del environ["SANIC_TEST_ANSWER"] del environ["SANIC_TEST_ANSWER"]
def test_env_prefix(): def test_env_prefix():
environ["MYAPP_TEST_ANSWER"] = "42" environ["MYAPP_TEST_ANSWER"] = "42"
app = Sanic(name=__name__, env_prefix="MYAPP_") app = Sanic(name="Test", env_prefix="MYAPP_")
assert app.config.TEST_ANSWER == 42 assert app.config.TEST_ANSWER == 42
del environ["MYAPP_TEST_ANSWER"] del environ["MYAPP_TEST_ANSWER"]
def test_env_prefix_float_values(): def test_env_prefix_float_values():
environ["MYAPP_TEST_ROI"] = "2.3" environ["MYAPP_TEST_ROI"] = "2.3"
app = Sanic(name=__name__, env_prefix="MYAPP_") app = Sanic(name="Test", env_prefix="MYAPP_")
assert app.config.TEST_ROI == 2.3 assert app.config.TEST_ROI == 2.3
del environ["MYAPP_TEST_ROI"] del environ["MYAPP_TEST_ROI"]
def test_env_prefix_string_value(): def test_env_prefix_string_value():
environ["MYAPP_TEST_TOKEN"] = "somerandomtesttoken" environ["MYAPP_TEST_TOKEN"] = "somerandomtesttoken"
app = Sanic(name=__name__, env_prefix="MYAPP_") app = Sanic(name="Test", env_prefix="MYAPP_")
assert app.config.TEST_TOKEN == "somerandomtesttoken" assert app.config.TEST_TOKEN == "somerandomtesttoken"
del environ["MYAPP_TEST_TOKEN"] del environ["MYAPP_TEST_TOKEN"]
@@ -116,7 +118,7 @@ def test_env_w_custom_converter():
environ["SANIC_TEST_ANSWER"] = "42" environ["SANIC_TEST_ANSWER"] = "42"
config = Config(converters=[UltimateAnswer]) config = Config(converters=[UltimateAnswer])
app = Sanic(name=__name__, config=config) app = Sanic(name="Test", config=config)
assert isinstance(app.config.TEST_ANSWER, UltimateAnswer) assert isinstance(app.config.TEST_ANSWER, UltimateAnswer)
assert app.config.TEST_ANSWER.answer == 42 assert app.config.TEST_ANSWER.answer == 42
del environ["SANIC_TEST_ANSWER"] del environ["SANIC_TEST_ANSWER"]
@@ -125,7 +127,7 @@ def test_env_w_custom_converter():
def test_env_lowercase(): def test_env_lowercase():
with pytest.warns(None) as record: with pytest.warns(None) as record:
environ["SANIC_test_answer"] = "42" environ["SANIC_test_answer"] = "42"
app = Sanic(name=__name__) app = Sanic(name="Test")
assert app.config.test_answer == 42 assert app.config.test_answer == 42
assert str(record[0].message) == ( assert str(record[0].message) == (
"[DEPRECATION v22.9] Lowercase environment variables will not be " "[DEPRECATION v22.9] Lowercase environment variables will not be "
@@ -369,15 +371,6 @@ def test_update_from_lowercase_key(app: Sanic):
assert "test_setting_value" not in app.config assert "test_setting_value" not in app.config
def test_deprecation_notice_when_setting_logo(app: Sanic):
message = (
"Setting the config.LOGO is deprecated and will no longer be "
"supported starting in v22.6."
)
with pytest.warns(DeprecationWarning, match=message):
app.config.LOGO = "My Custom Logo"
def test_config_set_methods(app: Sanic, monkeypatch: MonkeyPatch): def test_config_set_methods(app: Sanic, monkeypatch: MonkeyPatch):
post_set = Mock() post_set = Mock()
monkeypatch.setattr(Config, "_post_set", post_set) monkeypatch.setattr(Config, "_post_set", post_set)
@@ -435,3 +428,21 @@ def test_negative_proxy_count(app: Sanic):
) )
with pytest.raises(ValueError, match=message): with pytest.raises(ValueError, match=message):
app.prepare() app.prepare()
@pytest.mark.parametrize(
"passed,expected",
(
("auto", LocalCertCreator.AUTO),
("mkcert", LocalCertCreator.MKCERT),
("trustme", LocalCertCreator.TRUSTME),
("AUTO", LocalCertCreator.AUTO),
("MKCERT", LocalCertCreator.MKCERT),
("TRUSTME", LocalCertCreator.TRUSTME),
),
)
def test_convert_local_cert_creator(passed, expected):
os.environ["SANIC_LOCAL_CERT_CREATOR"] = passed
app = Sanic("Test")
assert app.config.LOCAL_CERT_CREATOR is expected
del os.environ["SANIC_LOCAL_CERT_CREATOR"]

View File

@@ -3,6 +3,7 @@ from http.cookies import SimpleCookie
import pytest import pytest
from sanic import Sanic
from sanic.cookies import Cookie from sanic.cookies import Cookie
from sanic.response import text from sanic.response import text
@@ -221,30 +222,29 @@ def test_cookie_bad_max_age(app, max_age):
assert response.status == 500 assert response.status == 500
@pytest.mark.parametrize( @pytest.mark.parametrize("expires", [timedelta(seconds=60)])
"expires", [datetime.utcnow() + timedelta(seconds=60)] def test_cookie_expires(app: Sanic, expires: timedelta):
) expires_time = datetime.utcnow().replace(microsecond=0) + expires
def test_cookie_expires(app, expires):
expires = expires.replace(microsecond=0)
cookies = {"test": "wait"} cookies = {"test": "wait"}
@app.get("/") @app.get("/")
def handler(request): def handler(request):
response = text("pass") response = text("pass")
response.cookies["test"] = "pass" response.cookies["test"] = "pass"
response.cookies["test"]["expires"] = expires response.cookies["test"]["expires"] = expires_time
return response return response
request, response = app.test_client.get( request, response = app.test_client.get(
"/", cookies=cookies, raw_cookies=True "/", cookies=cookies, raw_cookies=True
) )
cookie_expires = datetime.utcfromtimestamp( cookie_expires = datetime.utcfromtimestamp(
response.raw_cookies["test"].expires response.raw_cookies["test"].expires
).replace(microsecond=0) ).replace(microsecond=0)
assert response.status == 200 assert response.status == 200
assert response.cookies["test"] == "pass" assert response.cookies["test"] == "pass"
assert cookie_expires == expires assert cookie_expires == expires_time
@pytest.mark.parametrize("expires", ["Fri, 21-Dec-2018 15:30:00 GMT"]) @pytest.mark.parametrize("expires", ["Fri, 21-Dec-2018 15:30:00 GMT"])

View File

@@ -17,7 +17,7 @@ class CustomRequest(Request):
def test_custom_request(): def test_custom_request():
app = Sanic(name=__name__, request_class=CustomRequest) app = Sanic(name="Test", request_class=CustomRequest)
@app.route("/post", methods=["POST"]) @app.route("/post", methods=["POST"])
async def post_handler(request): async def post_handler(request):

View File

@@ -6,9 +6,16 @@ from bs4 import BeautifulSoup
from sanic import Sanic from sanic import Sanic
from sanic.exceptions import ( from sanic.exceptions import (
BadRequest,
ContentRangeError,
ExpectationFailed,
Forbidden, Forbidden,
HeaderExpectationFailed,
InvalidUsage, InvalidUsage,
MethodNotAllowed,
MethodNotSupported,
NotFound, NotFound,
RangeNotSatisfiable,
SanicException, SanicException,
ServerError, ServerError,
Unauthorized, Unauthorized,
@@ -77,7 +84,7 @@ def exception_app():
@app.route("/invalid") @app.route("/invalid")
def handler_invalid(request): def handler_invalid(request):
raise InvalidUsage("OK") raise BadRequest("OK")
@app.route("/abort/401") @app.route("/abort/401")
def handler_401_error(request): def handler_401_error(request):
@@ -136,7 +143,7 @@ def test_server_error_exception(exception_app):
def test_invalid_usage_exception(exception_app): def test_invalid_usage_exception(exception_app):
"""Test the built-in InvalidUsage exception works""" """Test the built-in BadRequest exception works"""
request, response = exception_app.test_client.get("/invalid") request, response = exception_app.test_client.get("/invalid")
assert response.status == 400 assert response.status == 400
@@ -252,7 +259,7 @@ def test_custom_exception_default_message(exception_app):
def test_exception_in_ws_logged(caplog): def test_exception_in_ws_logged(caplog):
app = Sanic(__name__) app = Sanic("Test")
@app.websocket("/feed") @app.websocket("/feed")
async def feed(request, ws): async def feed(request, ws):
@@ -272,7 +279,7 @@ def test_exception_in_ws_logged(caplog):
@pytest.mark.parametrize("debug", (True, False)) @pytest.mark.parametrize("debug", (True, False))
def test_contextual_exception_context(debug): def test_contextual_exception_context(debug):
app = Sanic(__name__) app = Sanic("Test")
class TeapotError(SanicException): class TeapotError(SanicException):
status_code = 418 status_code = 418
@@ -307,7 +314,7 @@ def test_contextual_exception_context(debug):
@pytest.mark.parametrize("debug", (True, False)) @pytest.mark.parametrize("debug", (True, False))
def test_contextual_exception_extra(debug): def test_contextual_exception_extra(debug):
app = Sanic(__name__) app = Sanic("Test")
class TeapotError(SanicException): class TeapotError(SanicException):
status_code = 418 status_code = 418
@@ -354,7 +361,7 @@ def test_contextual_exception_extra(debug):
@pytest.mark.parametrize("override", (True, False)) @pytest.mark.parametrize("override", (True, False))
def test_contextual_exception_functional_message(override): def test_contextual_exception_functional_message(override):
app = Sanic(__name__) app = Sanic("Test")
class TeapotError(SanicException): class TeapotError(SanicException):
status_code = 418 status_code = 418
@@ -375,3 +382,10 @@ def test_contextual_exception_functional_message(override):
assert response.status == 418 assert response.status == 418
assert response.json["message"] == error_message assert response.json["message"] == error_message
assert response.json["context"] == {"foo": "bar"} assert response.json["context"] == {"foo": "bar"}
def test_exception_aliases():
assert InvalidUsage is BadRequest
assert MethodNotSupported is MethodNotAllowed
assert ContentRangeError is RangeNotSatisfiable
assert HeaderExpectationFailed is ExpectationFailed

View File

@@ -10,16 +10,10 @@ from bs4 import BeautifulSoup
from pytest import LogCaptureFixture, MonkeyPatch from pytest import LogCaptureFixture, MonkeyPatch
from sanic import Sanic, handlers from sanic import Sanic, handlers
from sanic.exceptions import Forbidden, InvalidUsage, NotFound, ServerError from sanic.exceptions import BadRequest, Forbidden, NotFound, ServerError
from sanic.handlers import ErrorHandler from sanic.handlers import ErrorHandler
from sanic.request import Request from sanic.request import Request
from sanic.response import stream, text from sanic.response import text
async def sample_streaming_fn(response):
await response.write("foo,")
await asyncio.sleep(0.001)
await response.write("bar")
class ErrorWithRequestCtx(ServerError): class ErrorWithRequestCtx(ServerError):
@@ -32,7 +26,7 @@ def exception_handler_app():
@exception_handler_app.route("/1", error_format="html") @exception_handler_app.route("/1", error_format="html")
def handler_1(request): def handler_1(request):
raise InvalidUsage("OK") raise BadRequest("OK")
@exception_handler_app.route("/2", error_format="html") @exception_handler_app.route("/2", error_format="html")
def handler_2(request): def handler_2(request):
@@ -81,10 +75,10 @@ def exception_handler_app():
@exception_handler_app.exception(Forbidden) @exception_handler_app.exception(Forbidden)
async def async_handler_exception(request, exception): async def async_handler_exception(request, exception):
return stream( response = await request.respond(content_type="text/csv")
sample_streaming_fn, await response.send("foo,")
content_type="text/csv", await asyncio.sleep(0.001)
) await response.send("bar")
@exception_handler_app.middleware @exception_handler_app.middleware
async def some_request_middleware(request): async def some_request_middleware(request):
@@ -183,7 +177,7 @@ def test_exception_handler_lookup(exception_handler_app: Sanic):
class ModuleNotFoundError(ImportError): class ModuleNotFoundError(ImportError):
pass pass
handler = ErrorHandler("auto") handler = ErrorHandler()
handler.add(ImportError, import_error_handler) handler.add(ImportError, import_error_handler)
handler.add(CustomError, custom_error_handler) handler.add(CustomError, custom_error_handler)
handler.add(ServerError, server_error_handler) handler.add(ServerError, server_error_handler)
@@ -261,7 +255,6 @@ def test_exception_handler_response_was_sent(
_, response = app.test_client.get("/1") _, response = app.test_client.get("/1")
assert "some text" in response.text assert "some text" in response.text
# Change to assert warning not in the records in the future version.
message_in_records( message_in_records(
caplog.records, caplog.records,
( (

View File

@@ -1,9 +1,8 @@
import asyncio
import json as stdjson import json as stdjson
from collections import namedtuple from collections import namedtuple
from textwrap import dedent from pathlib import Path
from typing import AnyStr from sys import version_info
import pytest import pytest
@@ -11,54 +10,15 @@ from sanic_testing.reusable import ReusableClient
from sanic import json, text from sanic import json, text
from sanic.app import Sanic from sanic.app import Sanic
from tests.client import RawClient
parent_dir = Path(__file__).parent
localhost_dir = parent_dir / "certs/localhost"
PORT = 1234 PORT = 1234
class RawClient:
CRLF = b"\r\n"
def __init__(self, host: str, port: int):
self.reader = None
self.writer = None
self.host = host
self.port = port
async def connect(self):
self.reader, self.writer = await asyncio.open_connection(
self.host, self.port
)
async def close(self):
self.writer.close()
await self.writer.wait_closed()
async def send(self, message: AnyStr):
if isinstance(message, str):
msg = self._clean(message).encode("utf-8")
else:
msg = message
await self._send(msg)
async def _send(self, message: bytes):
if not self.writer:
raise Exception("No open write stream")
self.writer.write(message)
async def recv(self, nbytes: int = -1) -> bytes:
if not self.reader:
raise Exception("No open read stream")
return await self.reader.read(nbytes)
def _clean(self, message: str) -> str:
return (
dedent(message)
.lstrip("\n")
.replace("\n", self.CRLF.decode("utf-8"))
)
@pytest.fixture @pytest.fixture
def test_app(app: Sanic): def test_app(app: Sanic):
app.config.KEEP_ALIVE_TIMEOUT = 1 app.config.KEEP_ALIVE_TIMEOUT = 1
@@ -115,7 +75,10 @@ def test_full_message(client):
""" """
) )
response = client.recv() response = client.recv()
assert len(response) == 140
# AltSvcCheck touchup removes the Alt-Svc header from the
# response in the Python 3.9+ in this case
assert len(response) == (151 if version_info < (3, 9) else 140)
assert b"200 OK" in response assert b"200 OK" in response

View File

@@ -0,0 +1,66 @@
import sys
from pathlib import Path
import pytest
from sanic.app import Sanic
from sanic.response import empty
from tests.client import RawClient
parent_dir = Path(__file__).parent
localhost_dir = parent_dir / "certs/localhost"
PORT = 12344
@pytest.mark.skipif(sys.version_info < (3, 9), reason="Not supported in 3.7")
def test_http1_response_has_alt_svc():
Sanic._app_registry.clear()
app = Sanic("TestAltSvc")
app.config.TOUCHUP = True
response = b""
@app.get("/")
async def handler(*_):
return empty()
@app.after_server_start
async def do_request(*_):
nonlocal response
app.router.reset()
app.router.finalize()
client = RawClient(app.state.host, app.state.port)
await client.connect()
await client.send(
"""
GET / HTTP/1.1
host: localhost:7777
"""
)
response = await client.recv()
await client.close()
@app.after_server_start
def shutdown(*_):
app.stop()
app.prepare(
version=3,
ssl={
"cert": localhost_dir / "fullchain.pem",
"key": localhost_dir / "privkey.pem",
},
port=PORT,
)
app.prepare(
version=1,
port=PORT,
)
Sanic.serve()
assert f'alt-svc: h3=":{PORT}"\r\n'.encode() in response

View File

@@ -0,0 +1,52 @@
from json import loads as sloads
import pytest
try:
from ujson import loads as uloads
NO_UJSON = False
DEFAULT_LOADS = uloads
except ModuleNotFoundError:
NO_UJSON = True
DEFAULT_LOADS = sloads
from sanic import Request, Sanic, json
@pytest.fixture(autouse=True)
def default_back_to_ujson():
yield
Request._loads = DEFAULT_LOADS
def test_change_decoder():
Sanic("Test", loads=sloads)
assert Request._loads == sloads
def test_change_decoder_to_some_custom():
def my_custom_decoder(some_str: str):
dict = sloads(some_str)
dict["some_key"] = "new_value"
return dict
app = Sanic("Test", loads=my_custom_decoder)
assert Request._loads == my_custom_decoder
req_body = {"some_key": "some_value"}
@app.post("/test")
def handler(request):
new_json = request.json
return json(new_json)
req, res = app.test_client.post("/test", json=req_body)
assert sloads(res.body) == {"some_key": "new_value"}
@pytest.mark.skipif(NO_UJSON is True, reason="ujson not installed")
def test_default_decoder():
Sanic("Test")
assert Request._loads == uloads

View File

@@ -136,7 +136,7 @@ def test_log_connection_lost(app, debug, monkeypatch):
async def test_logger(caplog): async def test_logger(caplog):
rand_string = str(uuid.uuid4()) rand_string = str(uuid.uuid4())
app = Sanic(name=__name__) app = Sanic(name="Test")
@app.get("/") @app.get("/")
def log_info(request): def log_info(request):
@@ -163,7 +163,7 @@ def test_logging_modified_root_logger_config():
def test_access_log_client_ip_remote_addr(monkeypatch): def test_access_log_client_ip_remote_addr(monkeypatch):
access = Mock() access = Mock()
monkeypatch.setattr(sanic.http, "access_logger", access) monkeypatch.setattr(sanic.http.http1, "access_logger", access)
app = Sanic("test_logging") app = Sanic("test_logging")
app.config.PROXIES_COUNT = 2 app.config.PROXIES_COUNT = 2
@@ -190,7 +190,7 @@ def test_access_log_client_ip_remote_addr(monkeypatch):
def test_access_log_client_ip_reqip(monkeypatch): def test_access_log_client_ip_reqip(monkeypatch):
access = Mock() access = Mock()
monkeypatch.setattr(sanic.http, "access_logger", access) monkeypatch.setattr(sanic.http.http1, "access_logger", access)
app = Sanic("test_logging") app = Sanic("test_logging")
@@ -209,3 +209,42 @@ def test_access_log_client_ip_reqip(monkeypatch):
"request": f"GET {request.scheme}://{request.host}/", "request": f"GET {request.scheme}://{request.host}/",
}, },
) )
@pytest.mark.parametrize(
"app_verbosity,log_verbosity,exists",
(
(0, 0, True),
(0, 1, False),
(0, 2, False),
(1, 0, True),
(1, 1, True),
(1, 2, False),
(2, 0, True),
(2, 1, True),
(2, 2, True),
),
)
def test_verbosity(app, caplog, app_verbosity, log_verbosity, exists):
rand_string = str(uuid.uuid4())
@app.get("/")
def log_info(request):
logger.info("DEFAULT")
logger.info(rand_string, extra={"verbosity": log_verbosity})
return text("hello")
with caplog.at_level(logging.INFO):
_ = app.test_client.get(
"/", server_kwargs={"verbosity": app_verbosity}
)
record = ("sanic.root", logging.INFO, rand_string)
if exists:
assert record in caplog.record_tuples
else:
assert record not in caplog.record_tuples
if app_verbosity == 0:
assert ("sanic.root", logging.INFO, "DEFAULT") in caplog.record_tuples

View File

@@ -19,41 +19,12 @@ def test_logo_base(app, run_startup):
assert logs[0][2] == BASE_LOGO assert logs[0][2] == BASE_LOGO
def test_logo_false(app, run_startup):
app.config.LOGO = False
logs = run_startup(app)
banner, port = logs[1][2].rsplit(":", 1)
assert logs[0][1] == logging.INFO
assert banner == "Goin' Fast @ http://127.0.0.1"
assert int(port) > 0
def test_logo_true(app, run_startup):
app.config.LOGO = True
logs = run_startup(app)
assert logs[0][1] == logging.DEBUG
assert logs[0][2] == BASE_LOGO
def test_logo_custom(app, run_startup):
app.config.LOGO = "My Custom Logo"
logs = run_startup(app)
assert logs[0][1] == logging.DEBUG
assert logs[0][2] == "My Custom Logo"
def test_motd_with_expected_info(app, run_startup): def test_motd_with_expected_info(app, run_startup):
logs = run_startup(app) logs = run_startup(app)
assert logs[1][2] == f"Sanic v{__version__}" assert logs[1][2] == f"Sanic v{__version__}"
assert logs[3][2] == "mode: debug, single worker" assert logs[3][2] == "mode: debug, single worker"
assert logs[4][2] == "server: sanic" assert logs[4][2] == "server: sanic, HTTP/1.1"
assert logs[5][2] == f"python: {platform.python_version()}" assert logs[5][2] == f"python: {platform.python_version()}"
assert logs[6][2] == f"platform: {platform.platform()}" assert logs[6][2] == f"platform: {platform.platform()}"

View File

@@ -14,7 +14,7 @@ from sanic.touchup.schemes.ode import OptionalDispatchEvent
try: try:
from unittest.mock import AsyncMock from unittest.mock import AsyncMock
except ImportError: except ImportError:
from asyncmock import AsyncMock # type: ignore from tests.asyncmock import AsyncMock # type: ignore
@pytest.fixture @pytest.fixture

View File

@@ -4,7 +4,7 @@ from uuid import UUID, uuid4
import pytest import pytest
from sanic import Sanic, response from sanic import Sanic, response
from sanic.exceptions import BadURL from sanic.exceptions import BadURL, SanicException
from sanic.request import Request, uuid from sanic.request import Request, uuid
from sanic.server import HttpProtocol from sanic.server import HttpProtocol
@@ -191,3 +191,55 @@ def test_bad_url_parse():
Mock(), Mock(),
Mock(), Mock(),
) )
def test_request_scope_raises_exception_when_no_asgi():
app = Sanic("no_asgi")
@app.get("/")
async def get(request):
return request.scope
request, response = app.test_client.get("/")
assert response.status == 500
with pytest.raises(NotImplementedError):
_ = request.scope
@pytest.mark.asyncio
async def test_request_scope_is_not_none_when_running_in_asgi(app):
@app.get("/")
async def get(request):
return response.empty()
request, _ = await app.asgi_client.get("/")
assert request.scope is not None
assert request.scope["method"].lower() == "get"
assert request.scope["path"].lower() == "/"
def test_cannot_get_request_outside_of_cycle():
with pytest.raises(SanicException, match="No current request"):
Request.get_current()
def test_get_current_request(app):
@app.get("/")
async def get(request):
return response.json({"same": request is Request.get_current()})
_, resp = app.test_client.get("/")
assert resp.json["same"]
def test_request_stream_id(app):
@app.get("/")
async def get(request):
try:
request.stream_id
except Exception as e:
return response.text(str(e))
_, resp = app.test_client.get("/")
assert resp.text == "Stream ID is only a property of a HTTP/3 request"

View File

@@ -3,7 +3,7 @@ import contextlib
import pytest import pytest
from sanic.response import stream, text from sanic.response import text
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -43,18 +43,16 @@ async def test_stream_request_cancel_when_conn_lost(app):
async def post(request, id): async def post(request, id):
assert isinstance(request.stream, asyncio.Queue) assert isinstance(request.stream, asyncio.Queue)
async def streaming(response): response = await request.respond()
while True:
body = await request.stream.get()
if body is None:
break
await response.write(body.decode("utf-8"))
await asyncio.sleep(1.0) await asyncio.sleep(1.0)
# at this point client is already disconnected # at this point client is already disconnected
app.ctx.still_serving_cancelled_request = True app.ctx.still_serving_cancelled_request = True
while True:
return stream(streaming) body = await request.stream.get()
if body is None:
break
await response.send(body.decode("utf-8"))
# schedule client call # schedule client call
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()

View File

@@ -552,7 +552,7 @@ def test_streaming_new_api(app):
def test_streaming_echo(): def test_streaming_echo():
"""2-way streaming chat between server and client.""" """2-way streaming chat between server and client."""
app = Sanic(name=__name__) app = Sanic(name="Test")
@app.post("/echo", stream=True) @app.post("/echo", stream=True)
async def handler(request): async def handler(request):

View File

@@ -1016,6 +1016,72 @@ async def test_post_form_urlencoded_asgi(app):
assert request.form.get("test") == "OK" # For request.parsed_form assert request.form.get("test") == "OK" # For request.parsed_form
def test_post_form_urlencoded_keep_blanks(app):
@app.route("/", methods=["POST"])
async def handler(request):
request.get_form(keep_blank_values=True)
return text("OK")
payload = "test="
headers = {"content-type": "application/x-www-form-urlencoded"}
request, response = app.test_client.post(
"/", data=payload, headers=headers
)
assert request.form.get("test") == ""
assert request.form.get("test") == "" # For request.parsed_form
@pytest.mark.asyncio
async def test_post_form_urlencoded_keep_blanks_asgi(app):
@app.route("/", methods=["POST"])
async def handler(request):
request.get_form(keep_blank_values=True)
return text("OK")
payload = "test="
headers = {"content-type": "application/x-www-form-urlencoded"}
request, response = await app.asgi_client.post(
"/", data=payload, headers=headers
)
assert request.form.get("test") == ""
assert request.form.get("test") == "" # For request.parsed_form
def test_post_form_urlencoded_drop_blanks(app):
@app.route("/", methods=["POST"])
async def handler(request):
return text("OK")
payload = "test="
headers = {"content-type": "application/x-www-form-urlencoded"}
request, response = app.test_client.post(
"/", data=payload, headers=headers
)
assert "test" not in request.form.keys()
@pytest.mark.asyncio
async def test_post_form_urlencoded_drop_blanks_asgi(app):
@app.route("/", methods=["POST"])
async def handler(request):
return text("OK")
payload = "test="
headers = {"content-type": "application/x-www-form-urlencoded"}
request, response = await app.asgi_client.post(
"/", data=payload, headers=headers
)
assert "test" not in request.form.keys()
@pytest.mark.parametrize( @pytest.mark.parametrize(
"payload", "payload",
[ [
@@ -1984,7 +2050,7 @@ async def test_request_form_invalid_content_type_asgi(app):
def test_endpoint_basic(): def test_endpoint_basic():
app = Sanic(name=__name__) app = Sanic(name="Test")
@app.route("/") @app.route("/")
def my_unique_handler(request): def my_unique_handler(request):
@@ -1992,12 +2058,12 @@ def test_endpoint_basic():
request, response = app.test_client.get("/") request, response = app.test_client.get("/")
assert request.endpoint == "test_requests.my_unique_handler" assert request.endpoint == "Test.my_unique_handler"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_endpoint_basic_asgi(): async def test_endpoint_basic_asgi():
app = Sanic(name=__name__) app = Sanic(name="Test")
@app.route("/") @app.route("/")
def my_unique_handler(request): def my_unique_handler(request):
@@ -2005,7 +2071,7 @@ async def test_endpoint_basic_asgi():
request, response = await app.asgi_client.get("/") request, response = await app.asgi_client.get("/")
assert request.endpoint == "test_requests.my_unique_handler" assert request.endpoint == "Test.my_unique_handler"
def test_endpoint_named_app(): def test_endpoint_named_app():

View File

@@ -3,10 +3,13 @@ import inspect
import os import os
from collections import namedtuple from collections import namedtuple
from datetime import datetime
from email.utils import formatdate
from logging import ERROR, LogRecord from logging import ERROR, LogRecord
from mimetypes import guess_type from mimetypes import guess_type
from pathlib import Path
from random import choice from random import choice
from typing import Callable, List from typing import Callable, List, Union
from urllib.parse import unquote from urllib.parse import unquote
import pytest import pytest
@@ -24,7 +27,6 @@ from sanic.response import (
file_stream, file_stream,
json, json,
raw, raw,
stream,
text, text,
) )
@@ -46,10 +48,13 @@ def test_response_body_not_a_string(app):
assert b"Internal Server Error" in response.body assert b"Internal Server Error" in response.body
async def sample_streaming_fn(response): async def sample_streaming_fn(request, response=None):
await response.write("foo,") if not response:
response = await request.respond(content_type="text/csv")
await response.send("foo,")
await asyncio.sleep(0.001) await asyncio.sleep(0.001)
await response.write("bar") await response.send("bar")
await response.eof()
def test_method_not_allowed(): def test_method_not_allowed():
@@ -98,11 +103,12 @@ def test_response_header(app):
return json({"ok": True}, headers={"CONTENT-TYPE": "application/json"}) return json({"ok": True}, headers={"CONTENT-TYPE": "application/json"})
request, response = app.test_client.get("/") request, response = app.test_client.get("/")
assert dict(response.headers) == { for key, value in {
"connection": "keep-alive", "connection": "keep-alive",
"content-length": "11", "content-length": "11",
"content-type": "application/json", "content-type": "application/json",
} }.items():
assert response.headers[key] == value
def test_response_content_length(app): def test_response_content_length(app):
@@ -213,10 +219,7 @@ def test_no_content(json_app):
def streaming_app(app): def streaming_app(app):
@app.route("/") @app.route("/")
async def test(request: Request): async def test(request: Request):
return stream( await sample_streaming_fn(request)
sample_streaming_fn,
content_type="text/csv",
)
return app return app
@@ -225,11 +228,11 @@ def streaming_app(app):
def non_chunked_streaming_app(app): def non_chunked_streaming_app(app):
@app.route("/") @app.route("/")
async def test(request: Request): async def test(request: Request):
return stream( response = await request.respond(
sample_streaming_fn,
headers={"Content-Length": "7"}, headers={"Content-Length": "7"},
content_type="text/csv", content_type="text/csv",
) )
await sample_streaming_fn(request, response)
return app return app
@@ -279,18 +282,6 @@ def test_non_chunked_streaming_returns_correct_content(
assert response.text == "foo,bar" assert response.text == "foo,bar"
def test_stream_response_with_cookies_legacy(app):
@app.route("/")
async def test(request: Request):
response = stream(sample_streaming_fn, content_type="text/csv")
response.cookies["test"] = "modified"
response.cookies["test"] = "pass"
return response
request, response = app.test_client.get("/")
assert response.cookies["test"] == "pass"
def test_stream_response_with_cookies(app): def test_stream_response_with_cookies(app):
@app.route("/") @app.route("/")
async def test(request: Request): async def test(request: Request):
@@ -313,7 +304,7 @@ def test_stream_response_with_cookies(app):
def test_stream_response_without_cookies(app): def test_stream_response_without_cookies(app):
@app.route("/") @app.route("/")
async def test(request: Request): async def test(request: Request):
return stream(sample_streaming_fn, content_type="text/csv") await sample_streaming_fn(request)
request, response = app.test_client.get("/") request, response = app.test_client.get("/")
assert response.cookies == {} assert response.cookies == {}
@@ -328,12 +319,27 @@ def static_file_directory():
return static_directory return static_directory
def get_file_content(static_file_directory, file_name): def path_str_to_path_obj(static_file_directory: Union[Path, str]):
if isinstance(static_file_directory, str):
static_file_directory = Path(static_file_directory)
return static_file_directory
def get_file_content(static_file_directory: Union[Path, str], file_name: str):
"""The content of the static file to check""" """The content of the static file to check"""
with open(os.path.join(static_file_directory, file_name), "rb") as file: static_file_directory = path_str_to_path_obj(static_file_directory)
with open(static_file_directory / file_name, "rb") as file:
return file.read() return file.read()
def get_file_last_modified_timestamp(
static_file_directory: Union[Path, str], file_name: str
):
"""The content of the static file to check"""
static_file_directory = path_str_to_path_obj(static_file_directory)
return (static_file_directory / file_name).stat().st_mtime
@pytest.mark.parametrize( @pytest.mark.parametrize(
"file_name", ["test.file", "decode me.txt", "python.png"] "file_name", ["test.file", "decode me.txt", "python.png"]
) )
@@ -711,3 +717,84 @@ def send_response_after_eof_should_fail(
assert "foo, " in response.text assert "foo, " in response.text
assert message_in_records(caplog.records, error_msg1) assert message_in_records(caplog.records, error_msg1)
assert message_in_records(caplog.records, error_msg2) assert message_in_records(caplog.records, error_msg2)
@pytest.mark.parametrize(
"file_name", ["test.file", "decode me.txt", "python.png"]
)
def test_file_response_headers(
app: Sanic, file_name: str, static_file_directory: str
):
test_last_modified = datetime.now()
test_max_age = 10
test_expires = test_last_modified.timestamp() + test_max_age
@app.route("/files/cached/<filename>", methods=["GET"])
def file_route_cache(request, filename):
file_path = (Path(static_file_directory) / file_name).absolute()
return file(
file_path, max_age=test_max_age, last_modified=test_last_modified
)
@app.route(
"/files/cached_default_last_modified/<filename>", methods=["GET"]
)
def file_route_cache_default_last_modified(request, filename):
file_path = (Path(static_file_directory) / file_name).absolute()
return file(file_path, max_age=test_max_age)
@app.route("/files/no_cache/<filename>", methods=["GET"])
def file_route_no_cache(request, filename):
file_path = (Path(static_file_directory) / file_name).absolute()
return file(file_path)
@app.route("/files/no_store/<filename>", methods=["GET"])
def file_route_no_store(request, filename):
file_path = (Path(static_file_directory) / file_name).absolute()
return file(file_path, no_store=True)
_, response = app.test_client.get(f"/files/cached/{file_name}")
assert response.body == get_file_content(static_file_directory, file_name)
headers = response.headers
assert (
"cache-control" in headers
and f"max-age={test_max_age}" in headers.get("cache-control")
and f"public" in headers.get("cache-control")
)
assert (
"expires" in headers
and headers.get("expires")[:-6]
== formatdate(test_expires, usegmt=True)[:-6]
# [:-6] to allow at most 1 min difference
# It's minimal for cases like:
# Thu, 26 May 2022 05:36:49 GMT
# AND
# Thu, 26 May 2022 05:36:50 GMT
)
assert "last-modified" in headers and headers.get(
"last-modified"
) == formatdate(test_last_modified.timestamp(), usegmt=True)
_, response = app.test_client.get(
f"/files/cached_default_last_modified/{file_name}"
)
file_last_modified = get_file_last_modified_timestamp(
static_file_directory, file_name
)
headers = response.headers
assert "last-modified" in headers and headers.get(
"last-modified"
) == formatdate(file_last_modified, usegmt=True)
_, response = app.test_client.get(f"/files/no_cache/{file_name}")
headers = response.headers
assert "cache-control" in headers and f"no-cache" == headers.get(
"cache-control"
)
_, response = app.test_client.get(f"/files/no_store/{file_name}")
headers = response.headers
assert "cache-control" in headers and f"no-store" == headers.get(
"cache-control"
)

View File

@@ -8,7 +8,7 @@ import pytest
from sanic_testing.testing import HOST, PORT from sanic_testing.testing import HOST, PORT
from sanic.exceptions import InvalidUsage, SanicException from sanic.exceptions import BadRequest, SanicException
AVAILABLE_LISTENERS = [ AVAILABLE_LISTENERS = [
@@ -137,7 +137,7 @@ async def test_trigger_before_events_create_server_missing_event(app):
class MySanicDb: class MySanicDb:
pass pass
with pytest.raises(InvalidUsage): with pytest.raises(BadRequest):
@app.listener @app.listener
async def init_db(app, loop): async def init_db(app, loop):

View File

@@ -11,7 +11,7 @@ import pytest
from sanic_testing.testing import HOST, PORT from sanic_testing.testing import HOST, PORT
from sanic.compat import ctrlc_workaround_for_windows from sanic.compat import ctrlc_workaround_for_windows
from sanic.exceptions import InvalidUsage from sanic.exceptions import BadRequest
from sanic.response import HTTPResponse from sanic.response import HTTPResponse
@@ -122,6 +122,6 @@ def test_signals_with_invalid_invocation(app):
return HTTPResponse() return HTTPResponse()
with pytest.raises( with pytest.raises(
InvalidUsage, match="Invalid event registration: Missing event name" BadRequest, match="Invalid event registration: Missing event name"
): ):
app.listener(stop) app.listener(stop)

Some files were not shown because too many files have changed in this diff Show More