2019-09-22 21:55:36 +01:00
|
|
|
from typing import (
|
|
|
|
Any,
|
|
|
|
Awaitable,
|
|
|
|
Callable,
|
|
|
|
Dict,
|
|
|
|
MutableMapping,
|
|
|
|
Optional,
|
|
|
|
Union,
|
|
|
|
)
|
|
|
|
|
|
|
|
from httptools import HttpParserUpgrade # type: ignore
|
|
|
|
from websockets import ( # type: ignore
|
|
|
|
ConnectionClosed,
|
|
|
|
InvalidHandshake,
|
|
|
|
WebSocketCommonProtocol,
|
|
|
|
handshake,
|
|
|
|
)
|
2018-10-18 05:20:16 +01:00
|
|
|
|
|
|
|
from sanic.exceptions import InvalidUsage
|
|
|
|
from sanic.server import HttpProtocol
|
2017-02-20 20:25:44 +00:00
|
|
|
|
|
|
|
|
2019-09-22 21:55:36 +01:00
|
|
|
__all__ = ["ConnectionClosed", "WebSocketProtocol", "WebSocketConnection"]
|
|
|
|
|
2019-05-21 23:42:19 +01:00
|
|
|
ASIMessage = MutableMapping[str, Any]
|
|
|
|
|
|
|
|
|
2017-02-20 20:25:44 +00:00
|
|
|
class WebSocketProtocol(HttpProtocol):
|
2018-10-14 01:55:33 +01:00
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
*args,
|
|
|
|
websocket_timeout=10,
|
|
|
|
websocket_max_size=None,
|
|
|
|
websocket_max_queue=None,
|
|
|
|
websocket_read_limit=2 ** 16,
|
|
|
|
websocket_write_limit=2 ** 16,
|
|
|
|
**kwargs
|
|
|
|
):
|
2017-02-20 20:25:44 +00:00
|
|
|
super().__init__(*args, **kwargs)
|
2017-02-20 22:32:14 +00:00
|
|
|
self.websocket = None
|
2019-05-26 22:57:50 +01:00
|
|
|
# self.app = None
|
2018-02-01 15:23:10 +00:00
|
|
|
self.websocket_timeout = websocket_timeout
|
2017-05-26 04:11:26 +01:00
|
|
|
self.websocket_max_size = websocket_max_size
|
|
|
|
self.websocket_max_queue = websocket_max_queue
|
2018-02-01 15:23:10 +00:00
|
|
|
self.websocket_read_limit = websocket_read_limit
|
|
|
|
self.websocket_write_limit = websocket_write_limit
|
2017-02-20 20:25:44 +00:00
|
|
|
|
2017-10-16 02:05:01 +01:00
|
|
|
# timeouts make no sense for websocket routes
|
|
|
|
def request_timeout_callback(self):
|
2017-02-25 02:09:32 +00:00
|
|
|
if self.websocket is None:
|
2017-10-16 02:05:01 +01:00
|
|
|
super().request_timeout_callback()
|
|
|
|
|
|
|
|
def response_timeout_callback(self):
|
|
|
|
if self.websocket is None:
|
|
|
|
super().response_timeout_callback()
|
|
|
|
|
|
|
|
def keep_alive_timeout_callback(self):
|
|
|
|
if self.websocket is None:
|
|
|
|
super().keep_alive_timeout_callback()
|
2017-02-25 02:09:32 +00:00
|
|
|
|
|
|
|
def connection_lost(self, exc):
|
|
|
|
if self.websocket is not None:
|
|
|
|
self.websocket.connection_lost(exc)
|
|
|
|
super().connection_lost(exc)
|
|
|
|
|
2017-02-20 20:25:44 +00:00
|
|
|
def data_received(self, data):
|
2017-02-20 22:32:14 +00:00
|
|
|
if self.websocket is not None:
|
2017-02-20 20:25:44 +00:00
|
|
|
# pass the data to the websocket protocol
|
2017-02-20 22:32:14 +00:00
|
|
|
self.websocket.data_received(data)
|
2017-02-20 20:25:44 +00:00
|
|
|
else:
|
|
|
|
try:
|
|
|
|
super().data_received(data)
|
|
|
|
except HttpParserUpgrade:
|
|
|
|
# this is okay, it just indicates we've got an upgrade request
|
|
|
|
pass
|
|
|
|
|
|
|
|
def write_response(self, response):
|
2017-02-20 22:32:14 +00:00
|
|
|
if self.websocket is not None:
|
2017-02-20 20:25:44 +00:00
|
|
|
# websocket requests do not write a response
|
|
|
|
self.transport.close()
|
|
|
|
else:
|
|
|
|
super().write_response(response)
|
|
|
|
|
2017-08-08 19:21:52 +01:00
|
|
|
async def websocket_handshake(self, request, subprotocols=None):
|
2017-02-20 20:25:44 +00:00
|
|
|
# let the websockets package do the handshake with the client
|
2018-09-02 08:19:19 +01:00
|
|
|
headers = {}
|
2017-02-20 20:25:44 +00:00
|
|
|
|
2017-02-20 22:32:14 +00:00
|
|
|
try:
|
2018-09-02 08:19:19 +01:00
|
|
|
key = handshake.check_request(request.headers)
|
|
|
|
handshake.build_response(headers, key)
|
2017-02-20 22:32:14 +00:00
|
|
|
except InvalidHandshake:
|
2018-10-14 01:55:33 +01:00
|
|
|
raise InvalidUsage("Invalid websocket request")
|
2017-02-20 20:25:44 +00:00
|
|
|
|
2017-08-08 19:21:52 +01:00
|
|
|
subprotocol = None
|
2018-10-14 01:55:33 +01:00
|
|
|
if subprotocols and "Sec-Websocket-Protocol" in request.headers:
|
2017-08-08 19:21:52 +01:00
|
|
|
# select a subprotocol
|
2018-10-14 01:55:33 +01:00
|
|
|
client_subprotocols = [
|
|
|
|
p.strip()
|
|
|
|
for p in request.headers["Sec-Websocket-Protocol"].split(",")
|
|
|
|
]
|
2017-08-08 19:21:52 +01:00
|
|
|
for p in client_subprotocols:
|
|
|
|
if p in subprotocols:
|
|
|
|
subprotocol = p
|
2018-10-14 01:55:33 +01:00
|
|
|
headers["Sec-Websocket-Protocol"] = subprotocol
|
2017-08-08 19:21:52 +01:00
|
|
|
break
|
|
|
|
|
2017-02-20 20:25:44 +00:00
|
|
|
# write the 101 response back to the client
|
2018-10-14 01:55:33 +01:00
|
|
|
rv = b"HTTP/1.1 101 Switching Protocols\r\n"
|
2018-09-02 08:19:19 +01:00
|
|
|
for k, v in headers.items():
|
2018-10-14 01:55:33 +01:00
|
|
|
rv += k.encode("utf-8") + b": " + v.encode("utf-8") + b"\r\n"
|
|
|
|
rv += b"\r\n"
|
2017-02-20 20:25:44 +00:00
|
|
|
request.transport.write(rv)
|
|
|
|
|
|
|
|
# hook up the websocket protocol
|
2017-05-26 04:11:26 +01:00
|
|
|
self.websocket = WebSocketCommonProtocol(
|
2018-02-01 15:23:10 +00:00
|
|
|
timeout=self.websocket_timeout,
|
2017-05-26 04:11:26 +01:00
|
|
|
max_size=self.websocket_max_size,
|
2018-02-01 15:23:10 +00:00
|
|
|
max_queue=self.websocket_max_queue,
|
|
|
|
read_limit=self.websocket_read_limit,
|
2018-10-14 01:55:33 +01:00
|
|
|
write_limit=self.websocket_write_limit,
|
2017-05-26 04:11:26 +01:00
|
|
|
)
|
2019-10-02 07:03:10 +01:00
|
|
|
# Following two lines are required for websockets 8.x
|
|
|
|
self.websocket.is_client = False
|
|
|
|
self.websocket.side = "server"
|
2017-08-08 19:21:52 +01:00
|
|
|
self.websocket.subprotocol = subprotocol
|
2017-02-20 22:32:14 +00:00
|
|
|
self.websocket.connection_made(request.transport)
|
2017-11-04 11:04:59 +00:00
|
|
|
self.websocket.connection_open()
|
2017-02-20 22:32:14 +00:00
|
|
|
return self.websocket
|
2019-05-21 23:42:19 +01:00
|
|
|
|
|
|
|
|
|
|
|
class WebSocketConnection:
|
|
|
|
|
|
|
|
# TODO
|
|
|
|
# - Implement ping/pong
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
send: Callable[[ASIMessage], Awaitable[None]],
|
|
|
|
receive: Callable[[], Awaitable[ASIMessage]],
|
|
|
|
) -> None:
|
|
|
|
self._send = send
|
|
|
|
self._receive = receive
|
|
|
|
|
|
|
|
async def send(self, data: Union[str, bytes], *args, **kwargs) -> None:
|
2019-09-22 21:55:36 +01:00
|
|
|
message: Dict[str, Union[str, bytes]] = {"type": "websocket.send"}
|
2019-05-21 23:42:19 +01:00
|
|
|
|
2019-09-22 21:55:36 +01:00
|
|
|
if isinstance(data, bytes):
|
2019-05-21 23:42:19 +01:00
|
|
|
message.update({"bytes": data})
|
2019-09-22 21:55:36 +01:00
|
|
|
else:
|
|
|
|
message.update({"text": str(data)})
|
2019-05-21 23:42:19 +01:00
|
|
|
|
|
|
|
await self._send(message)
|
|
|
|
|
|
|
|
async def recv(self, *args, **kwargs) -> Optional[str]:
|
|
|
|
message = await self._receive()
|
|
|
|
|
|
|
|
if message["type"] == "websocket.receive":
|
|
|
|
return message["text"]
|
|
|
|
elif message["type"] == "websocket.disconnect":
|
|
|
|
pass
|
2019-06-18 22:15:41 +01:00
|
|
|
|
2019-09-22 21:55:36 +01:00
|
|
|
return None
|
|
|
|
|
2019-06-18 22:15:41 +01:00
|
|
|
receive = recv
|
2019-05-21 23:42:19 +01:00
|
|
|
|
|
|
|
async def accept(self) -> None:
|
|
|
|
await self._send({"type": "websocket.accept", "subprotocol": ""})
|
|
|
|
|
|
|
|
async def close(self) -> None:
|
|
|
|
pass
|