sanic/sanic/server/protocols/base_protocol.py
2023-10-04 14:30:10 +03:00

147 lines
4.0 KiB
Python

from __future__ import annotations
from typing import TYPE_CHECKING, Optional
from sanic.exceptions import RequestCancelled
if TYPE_CHECKING:
from sanic.app import Sanic
import asyncio
from asyncio.transports import Transport
from time import monotonic as current_time
from sanic.log import error_logger
from sanic.models.server_types import ConnInfo, Signal
class SanicProtocol(asyncio.Protocol):
__slots__ = (
"app",
# event loop, connection
"loop",
"transport",
"connections",
"conn_info",
"signal",
"_can_write",
"_time",
"_task",
"_unix",
"_data_received",
)
def __init__(
self,
*,
loop,
app: Sanic,
signal=None,
connections=None,
unix=None,
**kwargs,
):
asyncio.set_event_loop(loop)
self.loop = loop
self.app: Sanic = app
self.signal = signal or Signal()
self.transport: Optional[Transport] = None
self.connections = connections if connections is not None else set()
self.conn_info: Optional[ConnInfo] = None
self._can_write = asyncio.Event()
self._can_write.set()
self._unix = unix
self._time = 0.0 # type: float
self._task = None # type: Optional[asyncio.Task]
self._data_received = asyncio.Event()
@property
def ctx(self):
if self.conn_info is not None:
return self.conn_info.ctx
else:
return None
async def send(self, data):
"""
Generic data write implementation with backpressure control.
"""
await self._can_write.wait()
if self.transport.is_closing():
raise RequestCancelled
self.transport.write(data)
self._time = current_time()
async def receive_more(self):
"""
Wait until more data is received into the Server protocol's buffer
"""
self.transport.resume_reading()
self._data_received.clear()
await self._data_received.wait()
def close(self, timeout: Optional[float] = None):
"""
Attempt close the connection.
"""
# Cause a call to connection_lost where further cleanup occurs
if self.transport:
self.transport.close()
if timeout is None:
self.abort()
else:
self.loop.call_later(timeout, self.abort)
def abort(self):
"""
Force close the connection.
"""
# Cause a call to connection_lost where further cleanup occurs
if self.transport:
self.transport.abort()
self.transport = None
# asyncio.Protocol API Callbacks #
# ------------------------------ #
def connection_made(self, transport):
"""
Generic connection-made, with no connection_task, and no recv_buffer.
Override this for protocol-specific connection implementations.
"""
try:
transport.set_write_buffer_limits(low=16384, high=65536)
self.connections.add(self)
self.transport = transport
self.conn_info = ConnInfo(self.transport, unix=self._unix)
except Exception:
error_logger.exception("protocol.connect_made")
def connection_lost(self, exc):
try:
self.connections.discard(self)
self.resume_writing()
self.conn_info.lost = True
if self._task:
self._task.cancel()
except BaseException:
error_logger.exception("protocol.connection_lost")
def pause_writing(self):
self._can_write.clear()
def resume_writing(self):
self._can_write.set()
def data_received(self, data: bytes):
try:
self._time = current_time()
if not data:
return self.close()
if self._data_received:
self._data_received.set()
except BaseException:
error_logger.exception("protocol.data_received")