From ce926a34f29096bc0de1dc99cf0aaa0ec6da2b5b Mon Sep 17 00:00:00 2001 From: Adam Hopkins Date: Thu, 16 Jun 2022 22:57:02 +0300 Subject: [PATCH] Add Request contextvars (#2475) * Add Request contextvars * Add missing contextvar setter * Move location of context setter --- sanic/http.py | 1 + sanic/request.py | 12 +++++++++++- tests/test_request.py | 16 +++++++++++++++- 3 files changed, 27 insertions(+), 2 deletions(-) diff --git a/sanic/http.py b/sanic/http.py index 330732b2..b63e243d 100644 --- a/sanic/http.py +++ b/sanic/http.py @@ -265,6 +265,7 @@ class Http(metaclass=TouchUpMeta): transport=self.protocol.transport, app=self.protocol.app, ) + self.protocol.request_class._current.set(request) await self.dispatch( "http.lifecycle.request", inline=True, diff --git a/sanic/request.py b/sanic/request.py index 5405de0b..f55283c3 100644 --- a/sanic/request.py +++ b/sanic/request.py @@ -1,5 +1,6 @@ from __future__ import annotations +from contextvars import ContextVar from typing import ( TYPE_CHECKING, Any, @@ -35,7 +36,7 @@ from httptools.parser.errors import HttpParserInvalidURLError # type: ignore from sanic.compat import CancelledErrors, Header from sanic.constants import DEFAULT_HTTP_CONTENT_TYPE -from sanic.exceptions import BadRequest, BadURL, ServerError +from sanic.exceptions import BadRequest, BadURL, SanicException, ServerError from sanic.headers import ( AcceptContainer, Options, @@ -82,6 +83,8 @@ class Request: Properties of an HTTP request such as URL, headers, etc. """ + _current: ContextVar[Request] = ContextVar("request") + __slots__ = ( "__weakref__", "_cookies", @@ -174,6 +177,13 @@ class Request: class_name = self.__class__.__name__ return f"<{class_name}: {self.method} {self.path}>" + @classmethod + def get_current(cls) -> Request: + request = cls._current.get(None) + if not request: + raise SanicException("No current request") + return request + @classmethod def generate_id(*_): return uuid.uuid4() diff --git a/tests/test_request.py b/tests/test_request.py index 83e2f8e6..cb68325f 100644 --- a/tests/test_request.py +++ b/tests/test_request.py @@ -4,7 +4,7 @@ from uuid import UUID, uuid4 import pytest from sanic import Sanic, response -from sanic.exceptions import BadURL +from sanic.exceptions import BadURL, SanicException from sanic.request import Request, uuid from sanic.server import HttpProtocol @@ -217,3 +217,17 @@ async def test_request_scope_is_not_none_when_running_in_asgi(app): assert request.scope is not None assert request.scope["method"].lower() == "get" assert request.scope["path"].lower() == "/" + + +def test_cannot_get_request_outside_of_cycle(): + with pytest.raises(SanicException, match="No current request"): + Request.get_current() + + +def test_get_current_request(app): + @app.get("/") + async def get(request): + return response.json({"same": request is Request.get_current()}) + + _, resp = app.test_client.get("/") + assert resp.json["same"]