Add custom loads function (#2445)

Co-authored-by: Zhiwei <chihwei.public@outlook.com>
Co-authored-by: Adam Hopkins <adam@amhopkins.com>
This commit is contained in:
Michael Azimov 2022-06-29 23:39:21 +03:00 committed by GitHub
parent 13d5a44278
commit e4be70bae8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 60 additions and 1 deletions

View File

@ -170,6 +170,7 @@ class Sanic(BaseSanic, RunnerMixin, metaclass=TouchUpMeta):
log_config: Optional[Dict[str, Any]] = None, log_config: Optional[Dict[str, Any]] = None,
configure_logging: bool = True, configure_logging: bool = True,
dumps: Optional[Callable[..., AnyStr]] = None, dumps: Optional[Callable[..., AnyStr]] = None,
loads: Optional[Callable[..., Any]] = None,
) -> None: ) -> None:
super().__init__(name=name) super().__init__(name=name)
@ -223,6 +224,8 @@ class Sanic(BaseSanic, RunnerMixin, metaclass=TouchUpMeta):
if dumps: if dumps:
BaseHTTPResponse._dumps = dumps # type: ignore BaseHTTPResponse._dumps = dumps # type: ignore
if loads:
Request._loads = loads # type: ignore
@property @property
def loop(self): def loop(self):

View File

@ -87,6 +87,7 @@ class Request:
""" """
_current: ContextVar[Request] = ContextVar("request") _current: ContextVar[Request] = ContextVar("request")
_loads = json_loads
__slots__ = ( __slots__ = (
"__weakref__", "__weakref__",
@ -462,8 +463,11 @@ class Request:
return self.parsed_json return self.parsed_json
def load_json(self, loads=json_loads): def load_json(self, loads=None):
try: try:
if not loads:
loads = self.__class__._loads
self.parsed_json = loads(self.body) self.parsed_json = loads(self.body)
except Exception: except Exception:
if not self.body: if not self.body:

View File

@ -0,0 +1,52 @@
from json import loads as sloads
import pytest
try:
from ujson import loads as uloads
NO_UJSON = False
DEFAULT_LOADS = uloads
except ModuleNotFoundError:
NO_UJSON = True
DEFAULT_LOADS = sloads
from sanic import Request, Sanic, json
@pytest.fixture(autouse=True)
def default_back_to_ujson():
yield
Request._loads = DEFAULT_LOADS
def test_change_decoder():
Sanic("Test", loads=sloads)
assert Request._loads == sloads
def test_change_decoder_to_some_custom():
def my_custom_decoder(some_str: str):
dict = sloads(some_str)
dict["some_key"] = "new_value"
return dict
app = Sanic("Test", loads=my_custom_decoder)
assert Request._loads == my_custom_decoder
req_body = {"some_key": "some_value"}
@app.post("/test")
def handler(request):
new_json = request.json
return json(new_json)
req, res = app.test_client.post("/test", json=req_body)
assert sloads(res.body) == {"some_key": "new_value"}
@pytest.mark.skipif(NO_UJSON is True, reason="ujson not installed")
def test_default_decoder():
Sanic("Test")
assert Request._loads == uloads