Replacing assignation by typing for websocket_handshake (#2273)

* Replacing assignation by typing for `websocket_handshake`

Related to #2272

* Fix some type hinting issues

* Cleanup websocket handchake response concat

* Optimize concat encoding

Co-authored-by: Adam Hopkins <adam@amhopkins.com>
This commit is contained in:
Cyril Nicodème 2021-10-27 09:00:04 +02:00 committed by GitHub
parent 645310cff6
commit 71cc30e5cd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1,7 +1,8 @@
from typing import TYPE_CHECKING, Optional, Sequence from typing import TYPE_CHECKING, Optional, Sequence, cast
from websockets.connection import CLOSED, CLOSING, OPEN from websockets.connection import CLOSED, CLOSING, OPEN
from websockets.server import ServerConnection from websockets.server import ServerConnection
from websockets.typing import Subprotocol
from sanic.exceptions import ServerError from sanic.exceptions import ServerError
from sanic.log import error_logger from sanic.log import error_logger
@ -15,13 +16,6 @@ if TYPE_CHECKING:
class WebSocketProtocol(HttpProtocol): class WebSocketProtocol(HttpProtocol):
websocket: Optional[WebsocketImplProtocol]
websocket_timeout: float
websocket_max_size = Optional[int]
websocket_ping_interval = Optional[float]
websocket_ping_timeout = Optional[float]
def __init__( def __init__(
self, self,
*args, *args,
@ -35,7 +29,7 @@ class WebSocketProtocol(HttpProtocol):
**kwargs, **kwargs,
): ):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.websocket = None self.websocket: Optional[WebsocketImplProtocol] = None
self.websocket_timeout = websocket_timeout self.websocket_timeout = websocket_timeout
self.websocket_max_size = websocket_max_size self.websocket_max_size = websocket_max_size
if websocket_max_queue is not None and websocket_max_queue > 0: if websocket_max_queue is not None and websocket_max_queue > 0:
@ -109,14 +103,22 @@ class WebSocketProtocol(HttpProtocol):
return super().close_if_idle() return super().close_if_idle()
async def websocket_handshake( async def websocket_handshake(
self, request, subprotocols=Optional[Sequence[str]] self, request, subprotocols: Optional[Sequence[str]] = None
): ):
# let the websockets package do the handshake with the client # let the websockets package do the handshake with the client
try: try:
if subprotocols is not None: if subprotocols is not None:
# subprotocols can be a set or frozenset, # subprotocols can be a set or frozenset,
# but ServerConnection needs a list # but ServerConnection needs a list
subprotocols = list(subprotocols) subprotocols = cast(
Optional[Sequence[Subprotocol]],
list(
[
Subprotocol(subprotocol)
for subprotocol in subprotocols
]
),
)
ws_conn = ServerConnection( ws_conn = ServerConnection(
max_size=self.websocket_max_size, max_size=self.websocket_max_size,
subprotocols=subprotocols, subprotocols=subprotocols,
@ -131,21 +133,18 @@ class WebSocketProtocol(HttpProtocol):
) )
raise ServerError(msg, status_code=500) raise ServerError(msg, status_code=500)
if 100 <= resp.status_code <= 299: if 100 <= resp.status_code <= 299:
rbody = "".join( first_line = (
[ f"HTTP/1.1 {resp.status_code} {resp.reason_phrase}\r\n"
"HTTP/1.1 ", ).encode()
str(resp.status_code), rbody = bytearray(first_line)
" ", rbody += (
resp.reason_phrase, "".join(f"{k}: {v}\r\n" for k, v in resp.headers.items())
"\r\n", ).encode()
] rbody += b"\r\n"
)
rbody += "".join(f"{k}: {v}\r\n" for k, v in resp.headers.items())
if resp.body is not None: if resp.body is not None:
rbody += f"\r\n{resp.body}\r\n\r\n" rbody += resp.body
else: rbody += b"\r\n\r\n"
rbody += "\r\n" await super().send(rbody)
await super().send(rbody.encode())
else: else:
raise ServerError(resp.body, resp.status_code) raise ServerError(resp.body, resp.status_code)
self.websocket = WebsocketImplProtocol( self.websocket = WebsocketImplProtocol(