diff --git a/sanic/errorpages.py b/sanic/errorpages.py index 19ccc69b..7b243ddb 100644 --- a/sanic/errorpages.py +++ b/sanic/errorpages.py @@ -92,8 +92,10 @@ class BaseRenderer: self.full if self.debug and not getattr(self.exception, "quiet", False) else self.minimal - ) - return output() + )() + output.status = self.status + output.headers.update(self.headers) + return output def minimal(self) -> HTTPResponse: # noqa """ @@ -125,7 +127,7 @@ class HTMLRenderer(BaseRenderer): request=self.request, exc=self.exception, ) - return html(page.render(), status=self.status, headers=self.headers) + return html(page.render()) def minimal(self) -> HTTPResponse: return self.full() @@ -146,8 +148,7 @@ class TextRenderer(BaseRenderer): text=self.text, bar=("=" * len(self.title)), body=self._generate_body(full=True), - ), - status=self.status, + ) ) def minimal(self) -> HTTPResponse: @@ -157,9 +158,7 @@ class TextRenderer(BaseRenderer): text=self.text, bar=("=" * len(self.title)), body=self._generate_body(full=False), - ), - status=self.status, - headers=self.headers, + ) ) @property @@ -218,11 +217,11 @@ class JSONRenderer(BaseRenderer): def full(self) -> HTTPResponse: output = self._generate_output(full=True) - return json(output, status=self.status, dumps=self.dumps) + return json(output, dumps=self.dumps) def minimal(self) -> HTTPResponse: output = self._generate_output(full=False) - return json(output, status=self.status, dumps=self.dumps) + return json(output, dumps=self.dumps) def _generate_output(self, *, full): output = { diff --git a/tests/test_errorpages.py b/tests/test_errorpages.py index d2c1fc7b..a2df90b3 100644 --- a/tests/test_errorpages.py +++ b/tests/test_errorpages.py @@ -527,3 +527,26 @@ def test_guess_mime_logging( ] assert logmsg == expected + + +@pytest.mark.parametrize( + "format,expected", + ( + ("html", "text/html; charset=utf-8"), + ("text", "text/plain; charset=utf-8"), + ("json", "application/json"), + ), +) +def test_exception_header_on_renderers(app: Sanic, format, expected): + app.config.FALLBACK_ERROR_FORMAT = format + + @app.get("/test") + def test(request): + raise SanicException( + "test", status_code=400, headers={"exception": "test"} + ) + + _, response = app.test_client.get("/test") + assert response.status == 400 + assert response.headers.get("exception") == "test" + assert response.content_type == expected