From e4be70bae8c090575a2ad6d16ccfba8ce0dec270 Mon Sep 17 00:00:00 2001 From: Michael Azimov Date: Wed, 29 Jun 2022 23:39:21 +0300 Subject: [PATCH] Add custom loads function (#2445) Co-authored-by: Zhiwei Co-authored-by: Adam Hopkins --- sanic/app.py | 3 +++ sanic/request.py | 6 ++++- tests/test_json_decoding.py | 52 +++++++++++++++++++++++++++++++++++++ 3 files changed, 60 insertions(+), 1 deletion(-) create mode 100644 tests/test_json_decoding.py diff --git a/sanic/app.py b/sanic/app.py index 62228726..47a600cb 100644 --- a/sanic/app.py +++ b/sanic/app.py @@ -170,6 +170,7 @@ class Sanic(BaseSanic, RunnerMixin, metaclass=TouchUpMeta): log_config: Optional[Dict[str, Any]] = None, configure_logging: bool = True, dumps: Optional[Callable[..., AnyStr]] = None, + loads: Optional[Callable[..., Any]] = None, ) -> None: super().__init__(name=name) @@ -223,6 +224,8 @@ class Sanic(BaseSanic, RunnerMixin, metaclass=TouchUpMeta): if dumps: BaseHTTPResponse._dumps = dumps # type: ignore + if loads: + Request._loads = loads # type: ignore @property def loop(self): diff --git a/sanic/request.py b/sanic/request.py index 9c6336fe..24a85765 100644 --- a/sanic/request.py +++ b/sanic/request.py @@ -87,6 +87,7 @@ class Request: """ _current: ContextVar[Request] = ContextVar("request") + _loads = json_loads __slots__ = ( "__weakref__", @@ -462,8 +463,11 @@ class Request: return self.parsed_json - def load_json(self, loads=json_loads): + def load_json(self, loads=None): try: + if not loads: + loads = self.__class__._loads + self.parsed_json = loads(self.body) except Exception: if not self.body: diff --git a/tests/test_json_decoding.py b/tests/test_json_decoding.py new file mode 100644 index 00000000..57fdbc6b --- /dev/null +++ b/tests/test_json_decoding.py @@ -0,0 +1,52 @@ +from json import loads as sloads + +import pytest + + +try: + from ujson import loads as uloads + + NO_UJSON = False + DEFAULT_LOADS = uloads +except ModuleNotFoundError: + NO_UJSON = True + DEFAULT_LOADS = sloads + +from sanic import Request, Sanic, json + + +@pytest.fixture(autouse=True) +def default_back_to_ujson(): + yield + Request._loads = DEFAULT_LOADS + + +def test_change_decoder(): + Sanic("Test", loads=sloads) + assert Request._loads == sloads + + +def test_change_decoder_to_some_custom(): + def my_custom_decoder(some_str: str): + dict = sloads(some_str) + dict["some_key"] = "new_value" + return dict + + app = Sanic("Test", loads=my_custom_decoder) + assert Request._loads == my_custom_decoder + + req_body = {"some_key": "some_value"} + + @app.post("/test") + def handler(request): + new_json = request.json + return json(new_json) + + req, res = app.test_client.post("/test", json=req_body) + assert sloads(res.body) == {"some_key": "new_value"} + + +@pytest.mark.skipif(NO_UJSON is True, reason="ujson not installed") +def test_default_decoder(): + Sanic("Test") + assert Request._loads == uloads