sanic/sanic/websocket.py

171 lines
5.1 KiB
Python
Raw Normal View History

from typing import (
Any,
Awaitable,
Callable,
Dict,
MutableMapping,
Optional,
Union,
)
from httptools import HttpParserUpgrade # type: ignore
from websockets import ( # type: ignore
ConnectionClosed,
InvalidHandshake,
WebSocketCommonProtocol,
handshake,
)
2018-10-18 05:20:16 +01:00
from sanic.exceptions import InvalidUsage
from sanic.server import HttpProtocol
__all__ = ["ConnectionClosed", "WebSocketProtocol", "WebSocketConnection"]
2019-05-21 23:42:19 +01:00
ASIMessage = MutableMapping[str, Any]
class WebSocketProtocol(HttpProtocol):
2018-10-14 01:55:33 +01:00
def __init__(
self,
*args,
websocket_timeout=10,
websocket_max_size=None,
websocket_max_queue=None,
websocket_read_limit=2 ** 16,
websocket_write_limit=2 ** 16,
**kwargs
):
super().__init__(*args, **kwargs)
2017-02-20 22:32:14 +00:00
self.websocket = None
2019-05-26 22:57:50 +01:00
# self.app = None
self.websocket_timeout = websocket_timeout
self.websocket_max_size = websocket_max_size
self.websocket_max_queue = websocket_max_queue
self.websocket_read_limit = websocket_read_limit
self.websocket_write_limit = websocket_write_limit
# timeouts make no sense for websocket routes
def request_timeout_callback(self):
if self.websocket is None:
super().request_timeout_callback()
def response_timeout_callback(self):
if self.websocket is None:
super().response_timeout_callback()
def keep_alive_timeout_callback(self):
if self.websocket is None:
super().keep_alive_timeout_callback()
def connection_lost(self, exc):
if self.websocket is not None:
self.websocket.connection_lost(exc)
super().connection_lost(exc)
def data_received(self, data):
2017-02-20 22:32:14 +00:00
if self.websocket is not None:
# pass the data to the websocket protocol
2017-02-20 22:32:14 +00:00
self.websocket.data_received(data)
else:
try:
super().data_received(data)
except HttpParserUpgrade:
# this is okay, it just indicates we've got an upgrade request
pass
def write_response(self, response):
2017-02-20 22:32:14 +00:00
if self.websocket is not None:
# websocket requests do not write a response
self.transport.close()
else:
super().write_response(response)
async def websocket_handshake(self, request, subprotocols=None):
# let the websockets package do the handshake with the client
headers = {}
2017-02-20 22:32:14 +00:00
try:
key = handshake.check_request(request.headers)
handshake.build_response(headers, key)
2017-02-20 22:32:14 +00:00
except InvalidHandshake:
2018-10-14 01:55:33 +01:00
raise InvalidUsage("Invalid websocket request")
subprotocol = None
2018-10-14 01:55:33 +01:00
if subprotocols and "Sec-Websocket-Protocol" in request.headers:
# select a subprotocol
2018-10-14 01:55:33 +01:00
client_subprotocols = [
p.strip()
for p in request.headers["Sec-Websocket-Protocol"].split(",")
]
for p in client_subprotocols:
if p in subprotocols:
subprotocol = p
2018-10-14 01:55:33 +01:00
headers["Sec-Websocket-Protocol"] = subprotocol
break
# write the 101 response back to the client
2018-10-14 01:55:33 +01:00
rv = b"HTTP/1.1 101 Switching Protocols\r\n"
for k, v in headers.items():
2018-10-14 01:55:33 +01:00
rv += k.encode("utf-8") + b": " + v.encode("utf-8") + b"\r\n"
rv += b"\r\n"
request.transport.write(rv)
# hook up the websocket protocol
self.websocket = WebSocketCommonProtocol(
timeout=self.websocket_timeout,
max_size=self.websocket_max_size,
max_queue=self.websocket_max_queue,
read_limit=self.websocket_read_limit,
2018-10-14 01:55:33 +01:00
write_limit=self.websocket_write_limit,
)
# Following two lines are required for websockets 8.x
self.websocket.is_client = False
self.websocket.side = "server"
self.websocket.subprotocol = subprotocol
2017-02-20 22:32:14 +00:00
self.websocket.connection_made(request.transport)
self.websocket.connection_open()
2017-02-20 22:32:14 +00:00
return self.websocket
2019-05-21 23:42:19 +01:00
class WebSocketConnection:
# TODO
# - Implement ping/pong
def __init__(
self,
send: Callable[[ASIMessage], Awaitable[None]],
receive: Callable[[], Awaitable[ASIMessage]],
) -> None:
self._send = send
self._receive = receive
async def send(self, data: Union[str, bytes], *args, **kwargs) -> None:
message: Dict[str, Union[str, bytes]] = {"type": "websocket.send"}
2019-05-21 23:42:19 +01:00
if isinstance(data, bytes):
2019-05-21 23:42:19 +01:00
message.update({"bytes": data})
else:
message.update({"text": str(data)})
2019-05-21 23:42:19 +01:00
await self._send(message)
async def recv(self, *args, **kwargs) -> Optional[str]:
message = await self._receive()
if message["type"] == "websocket.receive":
return message["text"]
elif message["type"] == "websocket.disconnect":
pass
return None
receive = recv
2019-05-21 23:42:19 +01:00
async def accept(self) -> None:
await self._send({"type": "websocket.accept", "subprotocol": ""})
async def close(self) -> None:
pass