diff --git a/sanic/asgi.py b/sanic/asgi.py index 8213fdca..a49ae633 100644 --- a/sanic/asgi.py +++ b/sanic/asgi.py @@ -20,7 +20,7 @@ import sanic.app # noqa from sanic.compat import Header from sanic.exceptions import InvalidUsage, ServerError from sanic.log import logger -from sanic.request import Request, StreamBuffer +from sanic.request import Request from sanic.response import HTTPResponse, StreamingHTTPResponse from sanic.websocket import WebSocketConnection @@ -250,11 +250,11 @@ class ASGIApp: instance.transport, sanic_app, ) - instance.request.stream = StreamBuffer(protocol=instance) + instance.request.stream = instance return instance - async def stream_body(self) -> None: + async def read(self) -> None: """ Read and stream the body in chunks from an incoming ASGI message. """ @@ -263,6 +263,13 @@ class ASGIApp: return None return message.get("body", b"") + async def __aiter__(self): + while True: + data = await self.read() + if not data: + return + yield data + def respond(self, response): headers: List[Tuple[bytes, bytes]] = [] cookies: Dict[str, str] = {} diff --git a/sanic/http.py b/sanic/http.py index 8c0ecae9..04e3320e 100644 --- a/sanic/http.py +++ b/sanic/http.py @@ -17,7 +17,7 @@ from sanic.response import HTTPResponse from sanic.compat import Header -class Lifespan(Enum): +class Stage(Enum): IDLE = 0 # Waiting for request REQUEST = 1 # Request headers being received HANDLER = 3 # Headers done, handler running @@ -29,20 +29,21 @@ HTTP_CONTINUE = b"HTTP/1.1 100 Continue\r\n\r\n" class Http: def __init__(self, protocol): - self._send = protocol.push_data + self._send = protocol.send self._receive_more = protocol.receive_more + self.recv_buffer = protocol.recv_buffer self.protocol = protocol - self.recv_buffer = bytearray() self.expecting_continue = False # Note: connections are initially in request mode and do not obey # keep-alive timeout like with some other servers. - self.lifespan = Lifespan.REQUEST + self.stage = Stage.REQUEST + self.keep_alive = True + self.head_only = None async def http1(self): """HTTP 1.1 connection handler""" buf = self.recv_buffer - self.keep_alive = True - url = None + self.url = None while self.keep_alive: # Read request header pos = 0 @@ -53,17 +54,17 @@ class Http: break pos = max(0, len(buf) - 3) await self._receive_more() - if self.lifespan is Lifespan.IDLE: - self.lifespan = Lifespan.REQUEST + if self.stage is Stage.IDLE: + self.stage = Stage.REQUEST else: - self.lifespan = Lifespan.HANDLER + self.stage = Stage.HANDLER raise PayloadTooLarge("Payload Too Large") self.protocol._total_request_size = pos + 4 try: reqline, *headers = buf[:pos].decode().split("\r\n") - method, url, protocol = reqline.split(" ") + method, self.url, protocol = reqline.split(" ") if protocol not in ("HTTP/1.0", "HTTP/1.1"): raise Exception self.head_only = method.upper() == "HEAD" @@ -72,12 +73,12 @@ class Http: for name, value in (h.split(":", 1) for h in headers) ) except: - self.lifespan = Lifespan.HANDLER + self.stage = Stage.HANDLER raise InvalidUsage("Bad Request") # Prepare a request object from the header received request = self.protocol.request_class( - url_bytes=url.encode(), + url_bytes=self.url.encode(), headers=headers, version=protocol[-3:], method=method, @@ -86,8 +87,6 @@ class Http: ) request.stream = self self.protocol.state["requests_count"] += 1 - self.protocol.url = url - self.protocol.request = request self.keep_alive = ( protocol == "HTTP/1.1" or headers.get("connection", "").lower() == "keep-alive" @@ -98,7 +97,7 @@ class Http: ) self.request_chunked = False self.request_bytes_left = 0 - self.lifespan = Lifespan.HANDLER + self.stage = Stage.HANDLER if body: expect = headers.get("expect") if expect: @@ -121,14 +120,14 @@ class Http: except Exception: logger.exception("Uncaught from app/handler") await self.write_error(ServerError("Internal Server Error")) - if self.lifespan is Lifespan.IDLE: + if self.stage is Stage.IDLE: continue - if self.lifespan is Lifespan.HANDLER: + if self.stage is Stage.HANDLER: await self.respond(HTTPResponse(status=204)).send(end_stream=True) # Finish sending a response (if no error) - elif self.lifespan is Lifespan.RESPONSE: + elif self.stage is Stage.RESPONSE: await self.send(end_stream=True) # Consume any remaining request body @@ -139,10 +138,10 @@ class Http: while await self.read(): pass - self.lifespan = Lifespan.IDLE + self.stage = Stage.IDLE async def write_error(self, e): - if self.lifespan is Lifespan.HANDLER: + if self.stage is Stage.HANDLER: try: response = HTTPResponse(f"{e}", e.status_code, content_type="text/plain") await self.respond(response).send(end_stream=True) @@ -210,8 +209,8 @@ class Http: Nothing is sent until the first send() call on the returned object, and calling this function multiple times will just alter the response to be given.""" - if self.lifespan is not Lifespan.HANDLER: - self.lifespan = Lifespan.FAILED + if self.stage is not Stage.HANDLER: + self.stage = Stage.FAILED raise RuntimeError("Response already started") if not isinstance(response.status, int) or response.status < 200: raise RuntimeError(f"Invalid response status {response.status!r}") @@ -240,7 +239,7 @@ class Http: size = len(data) if data is not None else 0 # Headers not yet sent? - if self.lifespan is Lifespan.HANDLER: + if self.stage is Stage.HANDLER: if self.response.body: data = self.response.body + data if data else self.response.body size = len(data) @@ -297,14 +296,14 @@ class Http: if status != 417: ret = HTTP_CONTINUE + ret # Send response - self.lifespan = Lifespan.IDLE if end_stream else Lifespan.RESPONSE + self.stage = Stage.IDLE if end_stream else Stage.RESPONSE return ret # HEAD request: don't send body if self.head_only: return None - if self.lifespan is not Lifespan.RESPONSE: + if self.stage is not Stage.RESPONSE: if size: raise RuntimeError("Cannot send data to a closed stream") return @@ -313,7 +312,7 @@ class Http: if self.response_bytes_left is True: if end_stream: self.response_bytes_left = None - self.lifespan = Lifespan.IDLE + self.stage = Stage.IDLE if size: return b"%x\r\n%b\r\n0\r\n\r\n" % (size, data) return b"0\r\n\r\n" @@ -325,5 +324,5 @@ class Http: if self.response_bytes_left <= 0: if self.response_bytes_left < 0: raise ServerError("Response was bigger than content-length") - self.lifespan = Lifespan.IDLE + self.stage = Stage.IDLE return data if size else None diff --git a/sanic/request.py b/sanic/request.py index 097808eb..7b1b4939 100644 --- a/sanic/request.py +++ b/sanic/request.py @@ -47,19 +47,6 @@ class RequestParameters(dict): return super().get(name, default) -class StreamBuffer: - def __init__(self, protocol): - self.read = protocol.stream_body - self.respond = protocol.respond - - async def __aiter__(self): - while True: - data = await self.read() - if not data: - return - yield data - - class Request: """Properties of an HTTP request such as URL, headers, etc.""" diff --git a/sanic/server.py b/sanic/server.py index ccf70ea8..789ae67d 100644 --- a/sanic/server.py +++ b/sanic/server.py @@ -23,7 +23,7 @@ from sanic.exceptions import ( ServerError, ServiceUnavailable, ) -from sanic.http import Http, Lifespan +from sanic.http import Http, Stage from sanic.log import access_logger, logger from sanic.request import Request @@ -71,7 +71,6 @@ class HttpProtocol(asyncio.Protocol): # connection management "_total_request_size", "_last_response_time", - "keep_alive", "state", "url", "_debug", @@ -82,6 +81,7 @@ class HttpProtocol(asyncio.Protocol): "_task", "_http", "_exception", + "recv_buffer", ) def __init__( @@ -126,7 +126,6 @@ class HttpProtocol(asyncio.Protocol): self.request_max_size = request_max_size self.request_class = request_class or Request self._total_request_size = 0 - self.keep_alive = keep_alive self.state = state if state else {} if "requests_count" not in self.state: self.state["requests_count"] = 0 @@ -136,30 +135,40 @@ class HttpProtocol(asyncio.Protocol): self._can_write.set() self._exception = None - # -------------------------------------------- # - # Connection - # -------------------------------------------- # + async def connection_task(self): + """Run a HTTP connection. - def connection_made(self, transport): - self.connections.add(self) - self.transport = transport - #self.check_timeouts() - self._http = Http(self) - self._task = self.loop.create_task(self.connection_task()) - self._time = current_time() - self.check_timeouts() - - def connection_lost(self, exc): - self.connections.discard(self) - if self._task: - self._task.cancel() + Timeouts and some additional error handling occur here, while most of + everything else happens in class Http or in code called from there. + """ + try: + self._http = Http(self) + self._time = current_time() + self.check_timeouts() + try: + await self._http.http1() + except asyncio.CancelledError: + await self._http.write_error( + self._exception + or ServiceUnavailable("Request handler cancelled") + ) + except SanicException as e: + await self._http.write_error(e) + except BaseException as e: + logger.exception( + f"Uncaught exception while handling URL {self._http.url}" + ) + except asyncio.CancelledError: + pass + except: + logger.exception("protocol.connection_task uncaught") + finally: + self._http = None self._task = None - - def pause_writing(self): - self._can_write.clear() - - def resume_writing(self): - self._can_write.set() + try: + self.close() + except: + logger.exception("Closing failed") async def receive_more(self): """Wait until more data is received into self._buffer.""" @@ -167,51 +176,18 @@ class HttpProtocol(asyncio.Protocol): self._data_received.clear() await self._data_received.wait() - # -------------------------------------------- # - # Parsing - # -------------------------------------------- # - - def data_received(self, data): - self._time = current_time() - if not data: - return self.close() - self._http.recv_buffer += data - if len(self._http.recv_buffer) > self.request_max_size: - self.transport.pause_reading() - - if self._data_received: - self._data_received.set() - - async def connection_task(self): - try: - await self._http.http1() - except asyncio.CancelledError: - await self._http.write_error( - self._exception - or ServiceUnavailable("Request handler cancelled") - ) - except SanicException as e: - await self._http.write_error(e) - except BaseException as e: - logger.exception(f"Uncaught exception while handling URL {url}") - finally: - try: - self.close() - except: - logger.exception("Closing failed") - def check_timeouts(self): """Runs itself once a second to enforce any expired timeouts.""" if not self._task: return duration = current_time() - self._time - lifespan = self._http.lifespan - if lifespan == Lifespan.IDLE and duration > self.keep_alive_timeout: + stage = self._http.stage + if stage is Stage.IDLE and duration > self.keep_alive_timeout: logger.debug("KeepAlive Timeout. Closing connection.") - elif lifespan == Lifespan.REQUEST and duration > self.request_timeout: + elif stage is Stage.REQUEST and duration > self.request_timeout: self._exception = RequestTimeout("Request Timeout") elif ( - lifespan.value > Lifespan.REQUEST.value + stage in (Stage.REQUEST, Stage.FAILED) and duration > self.response_timeout ): self._exception = ServiceUnavailable("Response Timeout") @@ -220,20 +196,18 @@ class HttpProtocol(asyncio.Protocol): return self._task.cancel() - async def drain(self): + async def send(self, data): + """Writes data with backpressure control.""" await self._can_write.wait() - - async def push_data(self, data): - self._time = current_time() - await self.drain() self.transport.write(data) + self._time = current_time() def close_if_idle(self): """Close the connection if a request is not being sent or received :return: boolean - True if closed, false if staying open """ - if self._http.lifespan == Lifespan.IDLE: + if self._http.stage is Stage.IDLE: self.close() return True return False @@ -242,16 +216,56 @@ class HttpProtocol(asyncio.Protocol): """ Force close the connection. """ - if self.transport is not None: - try: - if self._task: - self._task.cancel() - self._task = None - self.transport.close() - self.resume_writing() - finally: - self.transport = None + # Cause a call to connection_lost where further cleanup occurs + if self.transport: + self.transport.close() + self.transport = None + # -------------------------------------------- # + # Only asyncio.Protocol callbacks below this + # -------------------------------------------- # + + def connection_made(self, transport): + try: + # TODO: Benchmark to find suitable write buffer limits + transport.set_write_buffer_limits(low=16384, high=65536) + self.connections.add(self) + self.transport = transport + self._task = self.loop.create_task(self.connection_task()) + self.recv_buffer = bytearray() + except: + logger.exception("protocol.connect_made") + + def connection_lost(self, exc): + try: + self.connections.discard(self) + self.resume_writing() + if self._task: + self._task.cancel() + except: + logger.exception("protocol.connection_lost") + + def pause_writing(self): + self._can_write.clear() + + def resume_writing(self): + self._can_write.set() + + def data_received(self, data): + try: + self._time = current_time() + if not data: + return self.close() + self.recv_buffer += data + + # Buffer up to request max size + if len(self.recv_buffer) > self.request_max_size: + self.transport.pause_reading() + + if self._data_received: + self._data_received.set() + except: + logger.exception("protocol.data_received") def trigger_events(events, loop): """Trigger event callbacks (functions or async)