Better request cancel handling (#2513)

This commit is contained in:
Adam Hopkins 2022-09-19 16:04:09 +03:00 committed by GitHub
parent 7f894c45b3
commit 389363ab71
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 46 additions and 8 deletions

View File

@ -1,8 +1,13 @@
from asyncio import CancelledError
from typing import Any, Dict, Optional, Union from typing import Any, Dict, Optional, Union
from sanic.helpers import STATUS_CODES from sanic.helpers import STATUS_CODES
class RequestCancelled(CancelledError):
quiet = True
class SanicException(Exception): class SanicException(Exception):
message: str = "" message: str = ""

View File

@ -14,8 +14,8 @@ from sanic.exceptions import (
BadRequest, BadRequest,
ExpectationFailed, ExpectationFailed,
PayloadTooLarge, PayloadTooLarge,
RequestCancelled,
ServerError, ServerError,
ServiceUnavailable,
) )
from sanic.headers import format_http1_response from sanic.headers import format_http1_response
from sanic.helpers import has_message_body from sanic.helpers import has_message_body
@ -132,7 +132,7 @@ class Http(Stream, metaclass=TouchUpMeta):
if self.stage is Stage.RESPONSE: if self.stage is Stage.RESPONSE:
await self.response.send(end_stream=True) await self.response.send(end_stream=True)
except CancelledError: except CancelledError as exc:
# Write an appropriate response before exiting # Write an appropriate response before exiting
if not self.protocol.transport: if not self.protocol.transport:
logger.info( logger.info(
@ -140,7 +140,11 @@ class Http(Stream, metaclass=TouchUpMeta):
"stopped. Transport is closed." "stopped. Transport is closed."
) )
return return
e = self.exception or ServiceUnavailable("Cancelled") e = (
RequestCancelled()
if self.protocol.conn_info.lost
else (self.exception or exc)
)
self.exception = None self.exception = None
self.keep_alive = False self.keep_alive = False
await self.error_response(e) await self.error_response(e)

View File

@ -21,6 +21,7 @@ class ConnInfo:
"client", "client",
"client_ip", "client_ip",
"ctx", "ctx",
"lost",
"peername", "peername",
"server_port", "server_port",
"server", "server",
@ -33,6 +34,7 @@ class ConnInfo:
def __init__(self, transport: TransportProtocol, unix=None): def __init__(self, transport: TransportProtocol, unix=None):
self.ctx = SimpleNamespace() self.ctx = SimpleNamespace()
self.lost = False
self.peername = None self.peername = None
self.server = self.client = "" self.server = self.client = ""
self.server_port = self.client_port = 0 self.server_port = self.client_port = 0

View File

@ -2,13 +2,14 @@ from __future__ import annotations
from typing import TYPE_CHECKING, Optional from typing import TYPE_CHECKING, Optional
from sanic.exceptions import RequestCancelled
if TYPE_CHECKING: if TYPE_CHECKING:
from sanic.app import Sanic from sanic.app import Sanic
import asyncio import asyncio
from asyncio import CancelledError
from asyncio.transports import Transport from asyncio.transports import Transport
from time import monotonic as current_time from time import monotonic as current_time
@ -69,7 +70,7 @@ class SanicProtocol(asyncio.Protocol):
""" """
await self._can_write.wait() await self._can_write.wait()
if self.transport.is_closing(): if self.transport.is_closing():
raise CancelledError raise RequestCancelled
self.transport.write(data) self.transport.write(data)
self._time = current_time() self._time = current_time()
@ -120,6 +121,7 @@ class SanicProtocol(asyncio.Protocol):
try: try:
self.connections.discard(self) self.connections.discard(self)
self.resume_writing() self.resume_writing()
self.conn_info.lost = True
if self._task: if self._task:
self._task.cancel() self._task.cancel()
except BaseException: except BaseException:

View File

@ -15,7 +15,11 @@ import sys
from asyncio import CancelledError from asyncio import CancelledError
from time import monotonic as current_time from time import monotonic as current_time
from sanic.exceptions import RequestTimeout, ServiceUnavailable from sanic.exceptions import (
RequestCancelled,
RequestTimeout,
ServiceUnavailable,
)
from sanic.http import Http, Stage from sanic.http import Http, Stage
from sanic.log import Colors, error_logger, logger from sanic.log import Colors, error_logger, logger
from sanic.models.server_types import ConnInfo from sanic.models.server_types import ConnInfo
@ -225,7 +229,7 @@ class HttpProtocol(HttpProtocolMixin, SanicProtocol, metaclass=TouchUpMeta):
""" """
await self._can_write.wait() await self._can_write.wait()
if self.transport.is_closing(): if self.transport.is_closing():
raise CancelledError raise RequestCancelled
await self.app.dispatch( await self.app.dispatch(
"http.lifecycle.send", "http.lifecycle.send",
inline=True, inline=True,

View File

@ -0,0 +1,21 @@
import asyncio
from asyncio import CancelledError
import pytest
from sanic import Request, Sanic, json
def test_can_raise_in_handler(app: Sanic):
@app.get("/")
async def handler(request: Request):
raise CancelledError("STOP!!")
@app.exception(CancelledError)
async def handle_cancel(request: Request, exc: CancelledError):
return json({"message": exc.args[0]}, status=418)
_, response = app.test_client.get("/")
assert response.status == 418
assert response.json["message"] == "STOP!!"

View File

@ -166,7 +166,7 @@ def test_middleware_response_raise_cancelled_error(app, caplog):
with caplog.at_level(logging.ERROR): with caplog.at_level(logging.ERROR):
reqrequest, response = app.test_client.get("/") reqrequest, response = app.test_client.get("/")
assert response.status == 503 assert response.status == 500
assert ( assert (
"sanic.error", "sanic.error",
logging.ERROR, logging.ERROR,