Add Request contextvars (#2475)

* Add Request contextvars

* Add missing contextvar setter

* Move location of context setter
This commit is contained in:
Adam Hopkins 2022-06-16 22:57:02 +03:00 committed by GitHub
parent a744041e38
commit ce926a34f2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 27 additions and 2 deletions

View File

@ -265,6 +265,7 @@ class Http(metaclass=TouchUpMeta):
transport=self.protocol.transport, transport=self.protocol.transport,
app=self.protocol.app, app=self.protocol.app,
) )
self.protocol.request_class._current.set(request)
await self.dispatch( await self.dispatch(
"http.lifecycle.request", "http.lifecycle.request",
inline=True, inline=True,

View File

@ -1,5 +1,6 @@
from __future__ import annotations from __future__ import annotations
from contextvars import ContextVar
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
Any, Any,
@ -35,7 +36,7 @@ from httptools.parser.errors import HttpParserInvalidURLError # type: ignore
from sanic.compat import CancelledErrors, Header from sanic.compat import CancelledErrors, Header
from sanic.constants import DEFAULT_HTTP_CONTENT_TYPE from sanic.constants import DEFAULT_HTTP_CONTENT_TYPE
from sanic.exceptions import BadRequest, BadURL, ServerError from sanic.exceptions import BadRequest, BadURL, SanicException, ServerError
from sanic.headers import ( from sanic.headers import (
AcceptContainer, AcceptContainer,
Options, Options,
@ -82,6 +83,8 @@ class Request:
Properties of an HTTP request such as URL, headers, etc. Properties of an HTTP request such as URL, headers, etc.
""" """
_current: ContextVar[Request] = ContextVar("request")
__slots__ = ( __slots__ = (
"__weakref__", "__weakref__",
"_cookies", "_cookies",
@ -174,6 +177,13 @@ class Request:
class_name = self.__class__.__name__ class_name = self.__class__.__name__
return f"<{class_name}: {self.method} {self.path}>" return f"<{class_name}: {self.method} {self.path}>"
@classmethod
def get_current(cls) -> Request:
request = cls._current.get(None)
if not request:
raise SanicException("No current request")
return request
@classmethod @classmethod
def generate_id(*_): def generate_id(*_):
return uuid.uuid4() return uuid.uuid4()

View File

@ -4,7 +4,7 @@ from uuid import UUID, uuid4
import pytest import pytest
from sanic import Sanic, response from sanic import Sanic, response
from sanic.exceptions import BadURL from sanic.exceptions import BadURL, SanicException
from sanic.request import Request, uuid from sanic.request import Request, uuid
from sanic.server import HttpProtocol from sanic.server import HttpProtocol
@ -217,3 +217,17 @@ async def test_request_scope_is_not_none_when_running_in_asgi(app):
assert request.scope is not None assert request.scope is not None
assert request.scope["method"].lower() == "get" assert request.scope["method"].lower() == "get"
assert request.scope["path"].lower() == "/" assert request.scope["path"].lower() == "/"
def test_cannot_get_request_outside_of_cycle():
with pytest.raises(SanicException, match="No current request"):
Request.get_current()
def test_get_current_request(app):
@app.get("/")
async def get(request):
return response.json({"same": request is Request.get_current()})
_, resp = app.test_client.get("/")
assert resp.json["same"]