diff --git a/sanic/server.py b/sanic/server.py index 18eb79d2..eea0f030 100644 --- a/sanic/server.py +++ b/sanic/server.py @@ -102,9 +102,126 @@ class ConnInfo: self.client_port = addr[1] -class HttpProtocol(asyncio.Protocol, metaclass=TouchUpMeta): +class SanicProtocol(asyncio.Protocol): + __slots__ = ( + "app", + # event loop, connection + "loop", + "transport", + "connections", + "conn_info", + "signal", + "_can_write", + "_time", + "_task", + "_unix", + "_data_received", + ) + + def __init__( + self, + *, + loop, + app: Sanic, + signal=None, + connections=None, + unix=None, + **kwargs, + ): + asyncio.set_event_loop(loop) + self.loop = loop + self.app: Sanic = app + self.signal = signal or Signal() + self.transport: Optional[Transport] = None + self.connections = connections if connections is not None else set() + self.conn_info: Optional[ConnInfo] = None + self._can_write = asyncio.Event() + self._can_write.set() + self._unix = unix + self._time = 0.0 # type: float + self._task = None # type: Optional[asyncio.Task] + self._data_received = asyncio.Event() + + @property + def ctx(self): + if self.conn_info is not None: + return self.conn_info.ctx + else: + return None + + async def send(self, data): + """ + Generic data write implementation with backpressure control. + """ + await self._can_write.wait() + if self.transport.is_closing(): + raise CancelledError + self.transport.write(data) + self._time = current_time() + + async def receive_more(self): + """ + Wait until more data is received into the Server protocol's buffer + """ + self.transport.resume_reading() + self._data_received.clear() + await self._data_received.wait() + + def close(self): + """ + Force close the connection. + """ + # Cause a call to connection_lost where further cleanup occurs + if self.transport: + self.transport.close() + self.transport = None + + # asyncio.Protocol API Callbacks # + # ------------------------------ # + def connection_made(self, transport): + """ + Generic connection-made, with no connection_task, and no recv_buffer. + Override this for protocol-specific connection implementations. + """ + try: + transport.set_write_buffer_limits(low=16384, high=65536) + self.connections.add(self) + self.transport = transport + self.conn_info = ConnInfo(self.transport, unix=self._unix) + except Exception: + error_logger.exception("protocol.connect_made") + + def connection_lost(self, exc): + try: + self.connections.discard(self) + self.resume_writing() + if self._task: + self._task.cancel() + except BaseException: + error_logger.exception("protocol.connection_lost") + + def pause_writing(self): + self._can_write.clear() + + def resume_writing(self): + self._can_write.set() + + def data_received(self, data: bytes): + try: + self._time = current_time() + if not data: + return self.close() + + if self._data_received: + self._data_received.set() + except BaseException: + error_logger.exception("protocol.data_received") + + +class HttpProtocol(SanicProtocol, metaclass=TouchUpMeta): """ - This class provides a basic HTTP implementation of the sanic framework. + This class provides implements the HTTP 1.1 protocol on top of our + Sanic Server transport """ __touchup__ = ( @@ -112,15 +229,6 @@ class HttpProtocol(asyncio.Protocol, metaclass=TouchUpMeta): "connection_task", ) __slots__ = ( - # app - "app", - # event loop, connection - "loop", - "transport", - "connections", - "signal", - "conn_info", - "ctx", # request params "request", # request config @@ -137,14 +245,9 @@ class HttpProtocol(asyncio.Protocol, metaclass=TouchUpMeta): "state", "url", "_handler_task", - "_can_write", - "_data_received", - "_time", - "_task", "_http", "_exception", "recv_buffer", - "_unix", ) def __init__( @@ -158,16 +261,16 @@ class HttpProtocol(asyncio.Protocol, metaclass=TouchUpMeta): unix=None, **kwargs, ): - asyncio.set_event_loop(loop) - self.loop = loop - self.app: Sanic = app + super().__init__( + loop=loop, + app=app, + signal=signal, + connections=connections, + unix=unix, + ) self.url = None - self.transport: Optional[Transport] = None - self.conn_info: Optional[ConnInfo] = None self.request: Optional[Request] = None - self.signal = signal or Signal() self.access_log = self.app.config.ACCESS_LOG - self.connections = connections if connections is not None else set() self.request_handler = self.app.handle_request self.error_handler = self.app.error_handler self.request_timeout = self.app.config.REQUEST_TIMEOUT @@ -178,11 +281,7 @@ class HttpProtocol(asyncio.Protocol, metaclass=TouchUpMeta): self.state = state if state else {} if "requests_count" not in self.state: self.state["requests_count"] = 0 - self._data_received = asyncio.Event() - self._can_write = asyncio.Event() - self._can_write.set() self._exception = None - self._unix = unix def _setup_connection(self): self._http = Http(self) @@ -229,14 +328,6 @@ class HttpProtocol(asyncio.Protocol, metaclass=TouchUpMeta): ) ... - async def receive_more(self): - """ - Wait until more data is received into the Server protocol's buffer - """ - self.transport.resume_reading() - self._data_received.clear() - await self._data_received.wait() - def check_timeouts(self): """ Runs itself periodically to enforce any expired timeouts. @@ -277,7 +368,7 @@ class HttpProtocol(asyncio.Protocol, metaclass=TouchUpMeta): async def send(self, data): # no cov """ - Writes data with backpressure control. + Writes HTTP data with backpressure control. """ await self._can_write.wait() if self.transport.is_closing(): @@ -301,20 +392,14 @@ class HttpProtocol(asyncio.Protocol, metaclass=TouchUpMeta): return True return False - def close(self): - """ - Force close the connection. - """ - # Cause a call to connection_lost where further cleanup occurs - if self.transport: - self.transport.close() - self.transport = None - # -------------------------------------------- # # Only asyncio.Protocol callbacks below this # -------------------------------------------- # def connection_made(self, transport): + """ + HTTP-protocol-specific new connection handler + """ try: # TODO: Benchmark to find suitable write buffer limits transport.set_write_buffer_limits(low=16384, high=65536) @@ -326,22 +411,8 @@ class HttpProtocol(asyncio.Protocol, metaclass=TouchUpMeta): except Exception: error_logger.exception("protocol.connect_made") - def connection_lost(self, exc): - try: - self.connections.discard(self) - self.resume_writing() - if self._task: - self._task.cancel() - except Exception: - error_logger.exception("protocol.connection_lost") - - def pause_writing(self): - self._can_write.clear() - - def resume_writing(self): - self._can_write.set() - def data_received(self, data: bytes): + try: self._time = current_time() if not data: @@ -349,7 +420,7 @@ class HttpProtocol(asyncio.Protocol, metaclass=TouchUpMeta): self.recv_buffer += data if ( - len(self.recv_buffer) > self.app.config.REQUEST_BUFFER_SIZE + len(self.recv_buffer) >= self.app.config.REQUEST_BUFFER_SIZE and self.transport ): self.transport.pause_reading()