diff --git a/sanic/server/protocols/websocket_protocol.py b/sanic/server/protocols/websocket_protocol.py index 457f1cd0..2321e949 100644 --- a/sanic/server/protocols/websocket_protocol.py +++ b/sanic/server/protocols/websocket_protocol.py @@ -1,7 +1,8 @@ -from typing import TYPE_CHECKING, Optional, Sequence +from typing import TYPE_CHECKING, Optional, Sequence, cast from websockets.connection import CLOSED, CLOSING, OPEN from websockets.server import ServerConnection +from websockets.typing import Subprotocol from sanic.exceptions import ServerError from sanic.log import error_logger @@ -15,13 +16,6 @@ if TYPE_CHECKING: class WebSocketProtocol(HttpProtocol): - - websocket: Optional[WebsocketImplProtocol] - websocket_timeout: float - websocket_max_size = Optional[int] - websocket_ping_interval = Optional[float] - websocket_ping_timeout = Optional[float] - def __init__( self, *args, @@ -35,7 +29,7 @@ class WebSocketProtocol(HttpProtocol): **kwargs, ): super().__init__(*args, **kwargs) - self.websocket = None + self.websocket: Optional[WebsocketImplProtocol] = None self.websocket_timeout = websocket_timeout self.websocket_max_size = websocket_max_size if websocket_max_queue is not None and websocket_max_queue > 0: @@ -109,14 +103,22 @@ class WebSocketProtocol(HttpProtocol): return super().close_if_idle() async def websocket_handshake( - self, request, subprotocols=Optional[Sequence[str]] + self, request, subprotocols: Optional[Sequence[str]] = None ): # let the websockets package do the handshake with the client try: if subprotocols is not None: # subprotocols can be a set or frozenset, # but ServerConnection needs a list - subprotocols = list(subprotocols) + subprotocols = cast( + Optional[Sequence[Subprotocol]], + list( + [ + Subprotocol(subprotocol) + for subprotocol in subprotocols + ] + ), + ) ws_conn = ServerConnection( max_size=self.websocket_max_size, subprotocols=subprotocols, @@ -131,21 +133,18 @@ class WebSocketProtocol(HttpProtocol): ) raise ServerError(msg, status_code=500) if 100 <= resp.status_code <= 299: - rbody = "".join( - [ - "HTTP/1.1 ", - str(resp.status_code), - " ", - resp.reason_phrase, - "\r\n", - ] - ) - rbody += "".join(f"{k}: {v}\r\n" for k, v in resp.headers.items()) + first_line = ( + f"HTTP/1.1 {resp.status_code} {resp.reason_phrase}\r\n" + ).encode() + rbody = bytearray(first_line) + rbody += ( + "".join(f"{k}: {v}\r\n" for k, v in resp.headers.items()) + ).encode() + rbody += b"\r\n" if resp.body is not None: - rbody += f"\r\n{resp.body}\r\n\r\n" - else: - rbody += "\r\n" - await super().send(rbody.encode()) + rbody += resp.body + rbody += b"\r\n\r\n" + await super().send(rbody) else: raise ServerError(resp.body, resp.status_code) self.websocket = WebsocketImplProtocol(