RFC/1684 Context objects (#2063)

* Initial setup

* connection context

* Add tests

* move ctx to conn_info

* Move __setattr__ for __fake_slots__ check into base calss
This commit is contained in:
Adam Hopkins 2021-03-17 20:55:52 +02:00 committed by GitHub
parent 01f238de79
commit 8a2ea626c6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 226 additions and 47 deletions

View File

@ -17,6 +17,7 @@ from inspect import isawaitable
from socket import socket from socket import socket
from ssl import Purpose, SSLContext, create_default_context from ssl import Purpose, SSLContext, create_default_context
from traceback import format_exc from traceback import format_exc
from types import SimpleNamespace
from typing import ( from typing import (
Any, Any,
Awaitable, Awaitable,
@ -76,6 +77,44 @@ class Sanic(BaseSanic):
The main application instance The main application instance
""" """
__fake_slots__ = (
"_app_registry",
"_asgi_client",
"_blueprint_order",
"_future_routes",
"_future_statics",
"_future_middleware",
"_future_listeners",
"_future_exceptions",
"_future_signals",
"_test_client",
"_test_manager",
"asgi",
"blueprints",
"config",
"configure_logging",
"ctx",
"debug",
"error_handler",
"go_fast",
"is_running",
"is_stopping",
"listeners",
"name",
"named_request_middleware",
"named_response_middleware",
"request_class",
"request_middleware",
"response_middleware",
"router",
"signal_router",
"sock",
"strict_slashes",
"test_mode",
"websocket_enabled",
"websocket_tasks",
)
_app_registry: Dict[str, "Sanic"] = {} _app_registry: Dict[str, "Sanic"] = {}
test_mode = False test_mode = False
@ -104,31 +143,33 @@ class Sanic(BaseSanic):
if configure_logging: if configure_logging:
logging.config.dictConfig(log_config or LOGGING_CONFIG_DEFAULTS) logging.config.dictConfig(log_config or LOGGING_CONFIG_DEFAULTS)
self.name = name self._asgi_client = None
self.asgi = False
self.router = router or Router()
self.signal_router = signal_router or SignalRouter()
self.request_class = request_class
self.error_handler = error_handler or ErrorHandler()
self.config = Config(load_env=load_env)
self.request_middleware: Deque[MiddlewareType] = deque()
self.response_middleware: Deque[MiddlewareType] = deque()
self.blueprints: Dict[str, Blueprint] = {}
self._blueprint_order: List[Blueprint] = [] self._blueprint_order: List[Blueprint] = []
self._test_client = None
self._test_manager = None
self.asgi = False
self.blueprints: Dict[str, Blueprint] = {}
self.config = Config(load_env=load_env)
self.configure_logging = configure_logging self.configure_logging = configure_logging
self.ctx = SimpleNamespace()
self.debug = None self.debug = None
self.sock = None self.error_handler = error_handler or ErrorHandler()
self.strict_slashes = strict_slashes
self.listeners: Dict[str, List[ListenerType]] = defaultdict(list)
self.is_stopping = False
self.is_running = False self.is_running = False
self.websocket_enabled = False self.is_stopping = False
self.websocket_tasks: Set[Future] = set() self.listeners: Dict[str, List[ListenerType]] = defaultdict(list)
self.name = name
self.named_request_middleware: Dict[str, Deque[MiddlewareType]] = {} self.named_request_middleware: Dict[str, Deque[MiddlewareType]] = {}
self.named_response_middleware: Dict[str, Deque[MiddlewareType]] = {} self.named_response_middleware: Dict[str, Deque[MiddlewareType]] = {}
self._test_manager = None self.request_class = request_class
self._test_client = None self.request_middleware: Deque[MiddlewareType] = deque()
self._asgi_client = None self.response_middleware: Deque[MiddlewareType] = deque()
self.router = router or Router()
self.signal_router = signal_router or SignalRouter()
self.sock = None
self.strict_slashes = strict_slashes
self.websocket_enabled = False
self.websocket_tasks: Set[Future] = set()
# Register alternative method names # Register alternative method names
self.go_fast = self.run self.go_fast = self.run

View File

@ -1,3 +1,6 @@
from typing import Any, Tuple
from warnings import warn
from sanic.mixins.exceptions import ExceptionMixin from sanic.mixins.exceptions import ExceptionMixin
from sanic.mixins.listeners import ListenerMixin from sanic.mixins.listeners import ListenerMixin
from sanic.mixins.middleware import MiddlewareMixin from sanic.mixins.middleware import MiddlewareMixin
@ -35,8 +38,23 @@ class BaseSanic(
SignalMixin, SignalMixin,
metaclass=Base, metaclass=Base,
): ):
__fake_slots__: Tuple[str, ...]
def __str__(self) -> str: def __str__(self) -> str:
return f"<{self.__class__.__name__} {self.name}>" return f"<{self.__class__.__name__} {self.name}>"
def __repr__(self) -> str: def __repr__(self) -> str:
return f'{self.__class__.__name__}(name="{self.name}")' return f'{self.__class__.__name__}(name="{self.name}")'
def __setattr__(self, name: str, value: Any) -> None:
# This is a temporary compat layer so we can raise a warning until
# setting attributes on the app instance can be removed and deprecated
# with a proper implementation of __slots__
if name not in self.__fake_slots__:
warn(
f"Setting variables on {self.__class__.__name__} instances is "
"deprecated and will be removed in version 21.9. You should "
f"change your {self.__class__.__name__} instance to use "
f"instance.ctx.{name} instead."
)
super().__setattr__(name, value)

View File

@ -3,6 +3,7 @@ from __future__ import annotations
import asyncio import asyncio
from collections import defaultdict from collections import defaultdict
from types import SimpleNamespace
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Union from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Union
from sanic_routing.exceptions import NotFound # type: ignore from sanic_routing.exceptions import NotFound # type: ignore
@ -42,6 +43,28 @@ class Blueprint(BaseSanic):
training */* training */*
""" """
__fake_slots__ = (
"_apps",
"_future_routes",
"_future_statics",
"_future_middleware",
"_future_listeners",
"_future_exceptions",
"_future_signals",
"ctx",
"exceptions",
"host",
"listeners",
"middlewares",
"name",
"routes",
"statics",
"strict_slashes",
"url_prefix",
"version",
"websocket_routes",
)
def __init__( def __init__(
self, self,
name: str, name: str,
@ -50,19 +73,20 @@ class Blueprint(BaseSanic):
version: Optional[int] = None, version: Optional[int] = None,
strict_slashes: Optional[bool] = None, strict_slashes: Optional[bool] = None,
): ):
self._apps: Set[Sanic] = set()
self.name = name
self.url_prefix = url_prefix
self.host = host
self.routes: List[Route] = [] self._apps: Set[Sanic] = set()
self.websocket_routes: List[Route] = [] self.ctx = SimpleNamespace()
self.exceptions: List[RouteHandler] = [] self.exceptions: List[RouteHandler] = []
self.host = host
self.listeners: Dict[str, List[ListenerType]] = {} self.listeners: Dict[str, List[ListenerType]] = {}
self.middlewares: List[MiddlewareType] = [] self.middlewares: List[MiddlewareType] = []
self.name = name
self.routes: List[Route] = []
self.statics: List[RouteHandler] = [] self.statics: List[RouteHandler] = []
self.version = version
self.strict_slashes = strict_slashes self.strict_slashes = strict_slashes
self.url_prefix = url_prefix
self.version = version
self.websocket_routes: List[Route] = []
def __repr__(self) -> str: def __repr__(self) -> str:
args = ", ".join( args = ", ".join(

View File

@ -82,6 +82,7 @@ class Request:
"_ip", "_ip",
"_parsed_url", "_parsed_url",
"_port", "_port",
"_protocol",
"_remote_addr", "_remote_addr",
"_socket", "_socket",
"_match_info", "_match_info",
@ -153,6 +154,7 @@ class Request:
self.stream: Optional[Http] = None self.stream: Optional[Http] = None
self.endpoint: Optional[str] = None self.endpoint: Optional[str] = None
self.route: Optional[Route] = None self.route: Optional[Route] = None
self._protocol = None
def __repr__(self): def __repr__(self):
class_name = self.__class__.__name__ class_name = self.__class__.__name__
@ -205,6 +207,12 @@ class Request:
if not self.body: if not self.body:
self.body = b"".join([data async for data in self.stream]) self.body = b"".join([data async for data in self.stream])
@property
def protocol(self):
if not self._protocol:
self._protocol = self.transport.get_protocol()
return self._protocol
@property @property
def raw_headers(self): def raw_headers(self):
_, headers = self.head.split(b"\r\n", 1) _, headers = self.head.split(b"\r\n", 1)

View File

@ -1,6 +1,7 @@
from __future__ import annotations from __future__ import annotations
from ssl import SSLContext from ssl import SSLContext
from types import SimpleNamespace
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
Any, Any,
@ -62,24 +63,28 @@ class ConnInfo:
""" """
__slots__ = ( __slots__ = (
"sockname",
"peername",
"server",
"server_port",
"client",
"client_port", "client_port",
"client",
"ctx",
"peername",
"server_port",
"server",
"sockname",
"ssl", "ssl",
) )
def __init__(self, transport: TransportProtocol, unix=None): def __init__(self, transport: TransportProtocol, unix=None):
self.ssl: bool = bool(transport.get_extra_info("sslcontext")) self.ctx = SimpleNamespace()
self.peername = None
self.server = self.client = "" self.server = self.client = ""
self.server_port = self.client_port = 0 self.server_port = self.client_port = 0
self.peername = None
self.sockname = addr = transport.get_extra_info("sockname") self.sockname = addr = transport.get_extra_info("sockname")
self.ssl: bool = bool(transport.get_extra_info("sslcontext"))
if isinstance(addr, str): # UNIX socket if isinstance(addr, str): # UNIX socket
self.server = unix or addr self.server = unix or addr
return return
# IPv4 (ip, port) or IPv6 (ip, port, flowinfo, scopeid) # IPv4 (ip, port) or IPv6 (ip, port, flowinfo, scopeid)
if isinstance(addr, tuple): if isinstance(addr, tuple):
self.server = addr[0] if len(addr) == 2 else f"[{addr[0]}]" self.server = addr[0] if len(addr) == 2 else f"[{addr[0]}]"
@ -88,6 +93,7 @@ class ConnInfo:
if addr[1] != (443 if self.ssl else 80): if addr[1] != (443 if self.ssl else 80):
self.server = f"{self.server}:{addr[1]}" self.server = f"{self.server}:{addr[1]}"
self.peername = addr = transport.get_extra_info("peername") self.peername = addr = transport.get_extra_info("peername")
if isinstance(addr, tuple): if isinstance(addr, tuple):
self.client = addr[0] if len(addr) == 2 else f"[{addr[0]}]" self.client = addr[0] if len(addr) == 2 else f"[{addr[0]}]"
self.client_port = addr[1] self.client_port = addr[1]
@ -107,6 +113,7 @@ class HttpProtocol(asyncio.Protocol):
"connections", "connections",
"signal", "signal",
"conn_info", "conn_info",
"ctx",
# request params # request params
"request", "request",
# request config # request config

View File

@ -386,3 +386,22 @@ def test_app_no_registry_env():
): ):
Sanic.get_app("no-register") Sanic.get_app("no-register")
del environ["SANIC_REGISTER"] del environ["SANIC_REGISTER"]
def test_app_set_attribute_warning(app):
with pytest.warns(UserWarning) as record:
app.foo = 1
assert len(record) == 1
assert record[0].message.args[0] == (
"Setting variables on Sanic instances is deprecated "
"and will be removed in version 21.9. You should change your "
"Sanic instance to use instance.ctx.foo instead."
)
def test_app_set_context(app):
app.ctx.foo = 1
retrieved = Sanic.get_app(app.name)
assert retrieved.ctx.foo == 1

View File

@ -1024,3 +1024,16 @@ def test_blueprint_registered_multiple_apps():
for app in (app1, app2): for app in (app1, app2):
_, response = app.test_client.get("/") _, response = app.test_client.get("/")
assert response.text == f"{app.name}.bp.handler" assert response.text == f"{app.name}.bp.handler"
def test_bp_set_attribute_warning():
bp = Blueprint("bp")
with pytest.warns(UserWarning) as record:
bp.foo = 1
assert len(record) == 1
assert record[0].message.args[0] == (
"Setting variables on Blueprint instances is deprecated "
"and will be removed in version 21.9. You should change your "
"Blueprint instance to use instance.ctx.foo instead."
)

View File

@ -19,10 +19,6 @@ CONFIG_FOR_TESTS = {"KEEP_ALIVE_TIMEOUT": 2, "KEEP_ALIVE": True}
PORT = 42101 # test_keep_alive_timeout_reuse doesn't work with random port PORT = 42101 # test_keep_alive_timeout_reuse doesn't work with random port
from httpcore._async.base import ConnectionState
from httpcore._async.connection import AsyncHTTPConnection
from httpcore._types import Origin
class ReusableSanicConnectionPool(httpcore.AsyncConnectionPool): class ReusableSanicConnectionPool(httpcore.AsyncConnectionPool):
last_reused_connection = None last_reused_connection = None
@ -185,10 +181,12 @@ class ReuseableSanicTestClient(SanicTestClient):
keep_alive_timeout_app_reuse = Sanic("test_ka_timeout_reuse") keep_alive_timeout_app_reuse = Sanic("test_ka_timeout_reuse")
keep_alive_app_client_timeout = Sanic("test_ka_client_timeout") keep_alive_app_client_timeout = Sanic("test_ka_client_timeout")
keep_alive_app_server_timeout = Sanic("test_ka_server_timeout") keep_alive_app_server_timeout = Sanic("test_ka_server_timeout")
keep_alive_app_context = Sanic("keep_alive_app_context")
keep_alive_timeout_app_reuse.config.update(CONFIG_FOR_TESTS) keep_alive_timeout_app_reuse.config.update(CONFIG_FOR_TESTS)
keep_alive_app_client_timeout.config.update(CONFIG_FOR_TESTS) keep_alive_app_client_timeout.config.update(CONFIG_FOR_TESTS)
keep_alive_app_server_timeout.config.update(CONFIG_FOR_TESTS) keep_alive_app_server_timeout.config.update(CONFIG_FOR_TESTS)
keep_alive_app_context.config.update(CONFIG_FOR_TESTS)
@keep_alive_timeout_app_reuse.route("/1") @keep_alive_timeout_app_reuse.route("/1")
@ -206,6 +204,17 @@ async def handler3(request):
return text("OK") return text("OK")
@keep_alive_app_context.post("/ctx")
def set_ctx(request):
request.conn_info.ctx.foo = "hello"
return text("OK")
@keep_alive_app_context.get("/ctx")
def get_ctx(request):
return text(request.conn_info.ctx.foo)
@pytest.mark.skipif( @pytest.mark.skipif(
bool(environ.get("SANIC_NO_UVLOOP")) or OS_IS_WINDOWS, bool(environ.get("SANIC_NO_UVLOOP")) or OS_IS_WINDOWS,
reason="Not testable with current client", reason="Not testable with current client",
@ -243,14 +252,14 @@ def test_keep_alive_client_timeout():
asyncio.set_event_loop(loop) asyncio.set_event_loop(loop)
client = ReuseableSanicTestClient(keep_alive_app_client_timeout, loop) client = ReuseableSanicTestClient(keep_alive_app_client_timeout, loop)
headers = {"Connection": "keep-alive"} headers = {"Connection": "keep-alive"}
request, response = client.get( _, response = client.get("/1", headers=headers, request_keepalive=1)
"/1", headers=headers, request_keepalive=1
)
assert response.status == 200 assert response.status == 200
assert response.text == "OK" assert response.text == "OK"
loop.run_until_complete(aio_sleep(2)) loop.run_until_complete(aio_sleep(2))
exception = None _, response = client.get("/1", request_keepalive=1)
request, response = client.get("/1", request_keepalive=1)
assert ReusableSanicConnectionPool.last_reused_connection is None assert ReusableSanicConnectionPool.last_reused_connection is None
finally: finally:
client.kill_server() client.kill_server()
@ -270,14 +279,38 @@ def test_keep_alive_server_timeout():
asyncio.set_event_loop(loop) asyncio.set_event_loop(loop)
client = ReuseableSanicTestClient(keep_alive_app_server_timeout, loop) client = ReuseableSanicTestClient(keep_alive_app_server_timeout, loop)
headers = {"Connection": "keep-alive"} headers = {"Connection": "keep-alive"}
request, response = client.get( _, response = client.get("/1", headers=headers, request_keepalive=60)
"/1", headers=headers, request_keepalive=60
)
assert response.status == 200 assert response.status == 200
assert response.text == "OK" assert response.text == "OK"
loop.run_until_complete(aio_sleep(3)) loop.run_until_complete(aio_sleep(3))
exception = None _, response = client.get("/1", request_keepalive=60)
request, response = client.get("/1", request_keepalive=60)
assert ReusableSanicConnectionPool.last_reused_connection is None assert ReusableSanicConnectionPool.last_reused_connection is None
finally: finally:
client.kill_server() client.kill_server()
@pytest.mark.skipif(
bool(environ.get("SANIC_NO_UVLOOP")) or OS_IS_WINDOWS,
reason="Not testable with current client",
)
def test_keep_alive_connection_context():
try:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
client = ReuseableSanicTestClient(keep_alive_app_context, loop)
headers = {"Connection": "keep-alive"}
request1, _ = client.post("/ctx", headers=headers)
loop.run_until_complete(aio_sleep(1))
request2, response = client.get("/ctx")
assert response.text == "hello"
assert id(request1.conn_info.ctx) == id(request2.conn_info.ctx)
assert (
request1.conn_info.ctx.foo == request2.conn_info.ctx.foo == "hello"
)
finally:
client.kill_server()

View File

@ -5,6 +5,7 @@ import pytest
from sanic import Sanic, response from sanic import Sanic, response
from sanic.request import Request, uuid from sanic.request import Request, uuid
from sanic.server import HttpProtocol
def test_no_request_id_not_called(monkeypatch): def test_no_request_id_not_called(monkeypatch):
@ -83,3 +84,18 @@ def test_route_assigned_to_request(app):
request, _ = app.test_client.get("/") request, _ = app.test_client.get("/")
assert request.route is list(app.router.routes.values())[0] assert request.route is list(app.router.routes.values())[0]
def test_protocol_attribute(app):
retrieved = None
@app.get("/")
async def get(request):
nonlocal retrieved
retrieved = request.protocol
return response.empty()
headers = {"Connection": "keep-alive"}
_ = app.test_client.get("/", headers=headers)
assert isinstance(retrieved, HttpProtocol)