Replacing assignation by typing for websocket_handshake (#2273)

* Replacing assignation by typing for `websocket_handshake`

Related to #2272

* Fix some type hinting issues

* Cleanup websocket handchake response concat

* Optimize concat encoding

Co-authored-by: Adam Hopkins <adam@amhopkins.com>
This commit is contained in:
Cyril Nicodème 2021-10-27 09:00:04 +02:00 committed by GitHub
parent 645310cff6
commit 71cc30e5cd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1,7 +1,8 @@
from typing import TYPE_CHECKING, Optional, Sequence
from typing import TYPE_CHECKING, Optional, Sequence, cast
from websockets.connection import CLOSED, CLOSING, OPEN
from websockets.server import ServerConnection
from websockets.typing import Subprotocol
from sanic.exceptions import ServerError
from sanic.log import error_logger
@ -15,13 +16,6 @@ if TYPE_CHECKING:
class WebSocketProtocol(HttpProtocol):
websocket: Optional[WebsocketImplProtocol]
websocket_timeout: float
websocket_max_size = Optional[int]
websocket_ping_interval = Optional[float]
websocket_ping_timeout = Optional[float]
def __init__(
self,
*args,
@ -35,7 +29,7 @@ class WebSocketProtocol(HttpProtocol):
**kwargs,
):
super().__init__(*args, **kwargs)
self.websocket = None
self.websocket: Optional[WebsocketImplProtocol] = None
self.websocket_timeout = websocket_timeout
self.websocket_max_size = websocket_max_size
if websocket_max_queue is not None and websocket_max_queue > 0:
@ -109,14 +103,22 @@ class WebSocketProtocol(HttpProtocol):
return super().close_if_idle()
async def websocket_handshake(
self, request, subprotocols=Optional[Sequence[str]]
self, request, subprotocols: Optional[Sequence[str]] = None
):
# let the websockets package do the handshake with the client
try:
if subprotocols is not None:
# subprotocols can be a set or frozenset,
# but ServerConnection needs a list
subprotocols = list(subprotocols)
subprotocols = cast(
Optional[Sequence[Subprotocol]],
list(
[
Subprotocol(subprotocol)
for subprotocol in subprotocols
]
),
)
ws_conn = ServerConnection(
max_size=self.websocket_max_size,
subprotocols=subprotocols,
@ -131,21 +133,18 @@ class WebSocketProtocol(HttpProtocol):
)
raise ServerError(msg, status_code=500)
if 100 <= resp.status_code <= 299:
rbody = "".join(
[
"HTTP/1.1 ",
str(resp.status_code),
" ",
resp.reason_phrase,
"\r\n",
]
)
rbody += "".join(f"{k}: {v}\r\n" for k, v in resp.headers.items())
first_line = (
f"HTTP/1.1 {resp.status_code} {resp.reason_phrase}\r\n"
).encode()
rbody = bytearray(first_line)
rbody += (
"".join(f"{k}: {v}\r\n" for k, v in resp.headers.items())
).encode()
rbody += b"\r\n"
if resp.body is not None:
rbody += f"\r\n{resp.body}\r\n\r\n"
else:
rbody += "\r\n"
await super().send(rbody.encode())
rbody += resp.body
rbody += b"\r\n\r\n"
await super().send(rbody)
else:
raise ServerError(resp.body, resp.status_code)
self.websocket = WebsocketImplProtocol(