Split HttpProtocol parts into base SanicProtocol and HTTPProtocol subclass (#2229)

* Split HttpProtocol parts into base SanicProtocol and HTTPProtocol subclass.

* lint fixes

* re-black server.py
This commit is contained in:
Ashley Sommer 2021-08-31 17:49:58 +10:00 committed by GitHub
parent 2e5c288fea
commit 9d0b54c90d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -102,9 +102,126 @@ class ConnInfo:
self.client_port = addr[1]
class HttpProtocol(asyncio.Protocol, metaclass=TouchUpMeta):
class SanicProtocol(asyncio.Protocol):
__slots__ = (
"app",
# event loop, connection
"loop",
"transport",
"connections",
"conn_info",
"signal",
"_can_write",
"_time",
"_task",
"_unix",
"_data_received",
)
def __init__(
self,
*,
loop,
app: Sanic,
signal=None,
connections=None,
unix=None,
**kwargs,
):
asyncio.set_event_loop(loop)
self.loop = loop
self.app: Sanic = app
self.signal = signal or Signal()
self.transport: Optional[Transport] = None
self.connections = connections if connections is not None else set()
self.conn_info: Optional[ConnInfo] = None
self._can_write = asyncio.Event()
self._can_write.set()
self._unix = unix
self._time = 0.0 # type: float
self._task = None # type: Optional[asyncio.Task]
self._data_received = asyncio.Event()
@property
def ctx(self):
if self.conn_info is not None:
return self.conn_info.ctx
else:
return None
async def send(self, data):
"""
Generic data write implementation with backpressure control.
"""
await self._can_write.wait()
if self.transport.is_closing():
raise CancelledError
self.transport.write(data)
self._time = current_time()
async def receive_more(self):
"""
Wait until more data is received into the Server protocol's buffer
"""
self.transport.resume_reading()
self._data_received.clear()
await self._data_received.wait()
def close(self):
"""
Force close the connection.
"""
# Cause a call to connection_lost where further cleanup occurs
if self.transport:
self.transport.close()
self.transport = None
# asyncio.Protocol API Callbacks #
# ------------------------------ #
def connection_made(self, transport):
"""
Generic connection-made, with no connection_task, and no recv_buffer.
Override this for protocol-specific connection implementations.
"""
try:
transport.set_write_buffer_limits(low=16384, high=65536)
self.connections.add(self)
self.transport = transport
self.conn_info = ConnInfo(self.transport, unix=self._unix)
except Exception:
error_logger.exception("protocol.connect_made")
def connection_lost(self, exc):
try:
self.connections.discard(self)
self.resume_writing()
if self._task:
self._task.cancel()
except BaseException:
error_logger.exception("protocol.connection_lost")
def pause_writing(self):
self._can_write.clear()
def resume_writing(self):
self._can_write.set()
def data_received(self, data: bytes):
try:
self._time = current_time()
if not data:
return self.close()
if self._data_received:
self._data_received.set()
except BaseException:
error_logger.exception("protocol.data_received")
class HttpProtocol(SanicProtocol, metaclass=TouchUpMeta):
"""
This class provides a basic HTTP implementation of the sanic framework.
This class provides implements the HTTP 1.1 protocol on top of our
Sanic Server transport
"""
__touchup__ = (
@ -112,15 +229,6 @@ class HttpProtocol(asyncio.Protocol, metaclass=TouchUpMeta):
"connection_task",
)
__slots__ = (
# app
"app",
# event loop, connection
"loop",
"transport",
"connections",
"signal",
"conn_info",
"ctx",
# request params
"request",
# request config
@ -137,14 +245,9 @@ class HttpProtocol(asyncio.Protocol, metaclass=TouchUpMeta):
"state",
"url",
"_handler_task",
"_can_write",
"_data_received",
"_time",
"_task",
"_http",
"_exception",
"recv_buffer",
"_unix",
)
def __init__(
@ -158,16 +261,16 @@ class HttpProtocol(asyncio.Protocol, metaclass=TouchUpMeta):
unix=None,
**kwargs,
):
asyncio.set_event_loop(loop)
self.loop = loop
self.app: Sanic = app
super().__init__(
loop=loop,
app=app,
signal=signal,
connections=connections,
unix=unix,
)
self.url = None
self.transport: Optional[Transport] = None
self.conn_info: Optional[ConnInfo] = None
self.request: Optional[Request] = None
self.signal = signal or Signal()
self.access_log = self.app.config.ACCESS_LOG
self.connections = connections if connections is not None else set()
self.request_handler = self.app.handle_request
self.error_handler = self.app.error_handler
self.request_timeout = self.app.config.REQUEST_TIMEOUT
@ -178,11 +281,7 @@ class HttpProtocol(asyncio.Protocol, metaclass=TouchUpMeta):
self.state = state if state else {}
if "requests_count" not in self.state:
self.state["requests_count"] = 0
self._data_received = asyncio.Event()
self._can_write = asyncio.Event()
self._can_write.set()
self._exception = None
self._unix = unix
def _setup_connection(self):
self._http = Http(self)
@ -229,14 +328,6 @@ class HttpProtocol(asyncio.Protocol, metaclass=TouchUpMeta):
)
...
async def receive_more(self):
"""
Wait until more data is received into the Server protocol's buffer
"""
self.transport.resume_reading()
self._data_received.clear()
await self._data_received.wait()
def check_timeouts(self):
"""
Runs itself periodically to enforce any expired timeouts.
@ -277,7 +368,7 @@ class HttpProtocol(asyncio.Protocol, metaclass=TouchUpMeta):
async def send(self, data): # no cov
"""
Writes data with backpressure control.
Writes HTTP data with backpressure control.
"""
await self._can_write.wait()
if self.transport.is_closing():
@ -301,20 +392,14 @@ class HttpProtocol(asyncio.Protocol, metaclass=TouchUpMeta):
return True
return False
def close(self):
"""
Force close the connection.
"""
# Cause a call to connection_lost where further cleanup occurs
if self.transport:
self.transport.close()
self.transport = None
# -------------------------------------------- #
# Only asyncio.Protocol callbacks below this
# -------------------------------------------- #
def connection_made(self, transport):
"""
HTTP-protocol-specific new connection handler
"""
try:
# TODO: Benchmark to find suitable write buffer limits
transport.set_write_buffer_limits(low=16384, high=65536)
@ -326,22 +411,8 @@ class HttpProtocol(asyncio.Protocol, metaclass=TouchUpMeta):
except Exception:
error_logger.exception("protocol.connect_made")
def connection_lost(self, exc):
try:
self.connections.discard(self)
self.resume_writing()
if self._task:
self._task.cancel()
except Exception:
error_logger.exception("protocol.connection_lost")
def pause_writing(self):
self._can_write.clear()
def resume_writing(self):
self._can_write.set()
def data_received(self, data: bytes):
try:
self._time = current_time()
if not data:
@ -349,7 +420,7 @@ class HttpProtocol(asyncio.Protocol, metaclass=TouchUpMeta):
self.recv_buffer += data
if (
len(self.recv_buffer) > self.app.config.REQUEST_BUFFER_SIZE
len(self.recv_buffer) >= self.app.config.REQUEST_BUFFER_SIZE
and self.transport
):
self.transport.pause_reading()