diff --git a/sanic/app.py b/sanic/app.py index f0ccad86..a3acbecf 100644 --- a/sanic/app.py +++ b/sanic/app.py @@ -214,9 +214,12 @@ class Sanic: return handler # Decorator - def websocket(self, uri, host=None, strict_slashes=False): + def websocket(self, uri, host=None, strict_slashes=False, + subprotocols=None): """Decorate a function to be registered as a websocket route :param uri: path of the URL + :param subprotocols: optional list of strings with the supported + subprotocols :param host: :return: decorated function """ @@ -236,7 +239,7 @@ class Sanic: # On Python3.5 the Transport classes in asyncio do not # have a get_protocol() method as in uvloop protocol = request.transport._protocol - ws = await protocol.websocket_handshake(request) + ws = await protocol.websocket_handshake(request, subprotocols) # schedule the application handler # its future is kept in self.websocket_tasks in case it diff --git a/sanic/websocket.py b/sanic/websocket.py index 94320a5e..e8e9922f 100644 --- a/sanic/websocket.py +++ b/sanic/websocket.py @@ -41,7 +41,7 @@ class WebSocketProtocol(HttpProtocol): else: super().write_response(response) - async def websocket_handshake(self, request): + async def websocket_handshake(self, request, subprotocols=None): # let the websockets package do the handshake with the client headers = [] @@ -57,6 +57,17 @@ class WebSocketProtocol(HttpProtocol): except InvalidHandshake: raise InvalidUsage('Invalid websocket request') + subprotocol = None + if subprotocols and 'Sec-Websocket-Protocol' in request.headers: + # select a subprotocol + client_subprotocols = [p.strip() for p in request.headers[ + 'Sec-Websocket-Protocol'].split(',')] + for p in client_subprotocols: + if p in subprotocols: + subprotocol = p + set_header('Sec-Websocket-Protocol', subprotocol) + break + # write the 101 response back to the client rv = b'HTTP/1.1 101 Switching Protocols\r\n' for k, v in headers: @@ -69,5 +80,6 @@ class WebSocketProtocol(HttpProtocol): max_size=self.websocket_max_size, max_queue=self.websocket_max_queue ) + self.websocket.subprotocol = subprotocol self.websocket.connection_made(request.transport) return self.websocket diff --git a/tests/test_routes.py b/tests/test_routes.py index 04a682a0..b356c2d5 100644 --- a/tests/test_routes.py +++ b/tests/test_routes.py @@ -341,6 +341,7 @@ def test_websocket_route(): @app.websocket('/ws') async def handler(request, ws): + assert ws.subprotocol is None ev.set() request, response = app.test_client.get('/ws', headers={ @@ -352,6 +353,48 @@ def test_websocket_route(): assert ev.is_set() +def test_websocket_route_with_subprotocols(): + app = Sanic('test_websocket_route') + results = [] + + @app.websocket('/ws', subprotocols=['foo', 'bar']) + async def handler(request, ws): + results.append(ws.subprotocol) + + request, response = app.test_client.get('/ws', headers={ + 'Upgrade': 'websocket', + 'Connection': 'upgrade', + 'Sec-WebSocket-Key': 'dGhlIHNhbXBsZSBub25jZQ==', + 'Sec-WebSocket-Version': '13', + 'Sec-WebSocket-Protocol': 'bar'}) + assert response.status == 101 + + request, response = app.test_client.get('/ws', headers={ + 'Upgrade': 'websocket', + 'Connection': 'upgrade', + 'Sec-WebSocket-Key': 'dGhlIHNhbXBsZSBub25jZQ==', + 'Sec-WebSocket-Version': '13', + 'Sec-WebSocket-Protocol': 'bar, foo'}) + assert response.status == 101 + + request, response = app.test_client.get('/ws', headers={ + 'Upgrade': 'websocket', + 'Connection': 'upgrade', + 'Sec-WebSocket-Key': 'dGhlIHNhbXBsZSBub25jZQ==', + 'Sec-WebSocket-Version': '13', + 'Sec-WebSocket-Protocol': 'baz'}) + assert response.status == 101 + + request, response = app.test_client.get('/ws', headers={ + 'Upgrade': 'websocket', + 'Connection': 'upgrade', + 'Sec-WebSocket-Key': 'dGhlIHNhbXBsZSBub25jZQ==', + 'Sec-WebSocket-Version': '13'}) + assert response.status == 101 + + assert results == ['bar', 'bar', None, None] + + def test_route_duplicate(): app = Sanic('test_route_duplicate')