| @@ -214,9 +214,12 @@ class Sanic: | |||||||
|         return handler |         return handler | ||||||
|  |  | ||||||
|     # Decorator |     # Decorator | ||||||
|     def websocket(self, uri, host=None, strict_slashes=False): |     def websocket(self, uri, host=None, strict_slashes=False, | ||||||
|  |                   subprotocols=None): | ||||||
|         """Decorate a function to be registered as a websocket route |         """Decorate a function to be registered as a websocket route | ||||||
|         :param uri: path of the URL |         :param uri: path of the URL | ||||||
|  |         :param subprotocols: optional list of strings with the supported | ||||||
|  |                              subprotocols | ||||||
|         :param host: |         :param host: | ||||||
|         :return: decorated function |         :return: decorated function | ||||||
|         """ |         """ | ||||||
| @@ -236,7 +239,7 @@ class Sanic: | |||||||
|                     # On Python3.5 the Transport classes in asyncio do not |                     # On Python3.5 the Transport classes in asyncio do not | ||||||
|                     # have a get_protocol() method as in uvloop |                     # have a get_protocol() method as in uvloop | ||||||
|                     protocol = request.transport._protocol |                     protocol = request.transport._protocol | ||||||
|                 ws = await protocol.websocket_handshake(request) |                 ws = await protocol.websocket_handshake(request, subprotocols) | ||||||
|  |  | ||||||
|                 # schedule the application handler |                 # schedule the application handler | ||||||
|                 # its future is kept in self.websocket_tasks in case it |                 # its future is kept in self.websocket_tasks in case it | ||||||
|   | |||||||
| @@ -41,7 +41,7 @@ class WebSocketProtocol(HttpProtocol): | |||||||
|         else: |         else: | ||||||
|             super().write_response(response) |             super().write_response(response) | ||||||
|  |  | ||||||
|     async def websocket_handshake(self, request): |     async def websocket_handshake(self, request, subprotocols=None): | ||||||
|         # let the websockets package do the handshake with the client |         # let the websockets package do the handshake with the client | ||||||
|         headers = [] |         headers = [] | ||||||
|  |  | ||||||
| @@ -57,6 +57,17 @@ class WebSocketProtocol(HttpProtocol): | |||||||
|         except InvalidHandshake: |         except InvalidHandshake: | ||||||
|             raise InvalidUsage('Invalid websocket request') |             raise InvalidUsage('Invalid websocket request') | ||||||
|  |  | ||||||
|  |         subprotocol = None | ||||||
|  |         if subprotocols and 'Sec-Websocket-Protocol' in request.headers: | ||||||
|  |             # select a subprotocol | ||||||
|  |             client_subprotocols = [p.strip() for p in request.headers[ | ||||||
|  |                 'Sec-Websocket-Protocol'].split(',')] | ||||||
|  |             for p in client_subprotocols: | ||||||
|  |                 if p in subprotocols: | ||||||
|  |                     subprotocol = p | ||||||
|  |                     set_header('Sec-Websocket-Protocol', subprotocol) | ||||||
|  |                     break | ||||||
|  |  | ||||||
|         # write the 101 response back to the client |         # write the 101 response back to the client | ||||||
|         rv = b'HTTP/1.1 101 Switching Protocols\r\n' |         rv = b'HTTP/1.1 101 Switching Protocols\r\n' | ||||||
|         for k, v in headers: |         for k, v in headers: | ||||||
| @@ -69,5 +80,6 @@ class WebSocketProtocol(HttpProtocol): | |||||||
|             max_size=self.websocket_max_size, |             max_size=self.websocket_max_size, | ||||||
|             max_queue=self.websocket_max_queue |             max_queue=self.websocket_max_queue | ||||||
|         ) |         ) | ||||||
|  |         self.websocket.subprotocol = subprotocol | ||||||
|         self.websocket.connection_made(request.transport) |         self.websocket.connection_made(request.transport) | ||||||
|         return self.websocket |         return self.websocket | ||||||
|   | |||||||
| @@ -341,6 +341,7 @@ def test_websocket_route(): | |||||||
|  |  | ||||||
|     @app.websocket('/ws') |     @app.websocket('/ws') | ||||||
|     async def handler(request, ws): |     async def handler(request, ws): | ||||||
|  |         assert ws.subprotocol is None | ||||||
|         ev.set() |         ev.set() | ||||||
|  |  | ||||||
|     request, response = app.test_client.get('/ws', headers={ |     request, response = app.test_client.get('/ws', headers={ | ||||||
| @@ -352,6 +353,48 @@ def test_websocket_route(): | |||||||
|     assert ev.is_set() |     assert ev.is_set() | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def test_websocket_route_with_subprotocols(): | ||||||
|  |     app = Sanic('test_websocket_route') | ||||||
|  |     results = [] | ||||||
|  |  | ||||||
|  |     @app.websocket('/ws', subprotocols=['foo', 'bar']) | ||||||
|  |     async def handler(request, ws): | ||||||
|  |         results.append(ws.subprotocol) | ||||||
|  |  | ||||||
|  |     request, response = app.test_client.get('/ws', headers={ | ||||||
|  |         'Upgrade': 'websocket', | ||||||
|  |         'Connection': 'upgrade', | ||||||
|  |         'Sec-WebSocket-Key': 'dGhlIHNhbXBsZSBub25jZQ==', | ||||||
|  |         'Sec-WebSocket-Version': '13', | ||||||
|  |         'Sec-WebSocket-Protocol': 'bar'}) | ||||||
|  |     assert response.status == 101 | ||||||
|  |  | ||||||
|  |     request, response = app.test_client.get('/ws', headers={ | ||||||
|  |         'Upgrade': 'websocket', | ||||||
|  |         'Connection': 'upgrade', | ||||||
|  |         'Sec-WebSocket-Key': 'dGhlIHNhbXBsZSBub25jZQ==', | ||||||
|  |         'Sec-WebSocket-Version': '13', | ||||||
|  |         'Sec-WebSocket-Protocol': 'bar, foo'}) | ||||||
|  |     assert response.status == 101 | ||||||
|  |  | ||||||
|  |     request, response = app.test_client.get('/ws', headers={ | ||||||
|  |         'Upgrade': 'websocket', | ||||||
|  |         'Connection': 'upgrade', | ||||||
|  |         'Sec-WebSocket-Key': 'dGhlIHNhbXBsZSBub25jZQ==', | ||||||
|  |         'Sec-WebSocket-Version': '13', | ||||||
|  |         'Sec-WebSocket-Protocol': 'baz'}) | ||||||
|  |     assert response.status == 101 | ||||||
|  |  | ||||||
|  |     request, response = app.test_client.get('/ws', headers={ | ||||||
|  |         'Upgrade': 'websocket', | ||||||
|  |         'Connection': 'upgrade', | ||||||
|  |         'Sec-WebSocket-Key': 'dGhlIHNhbXBsZSBub25jZQ==', | ||||||
|  |         'Sec-WebSocket-Version': '13'}) | ||||||
|  |     assert response.status == 101 | ||||||
|  |      | ||||||
|  |     assert results == ['bar', 'bar', None, None] | ||||||
|  |  | ||||||
|  |  | ||||||
| def test_route_duplicate(): | def test_route_duplicate(): | ||||||
|     app = Sanic('test_route_duplicate') |     app = Sanic('test_route_duplicate') | ||||||
|  |  | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Miguel Grinberg
					Miguel Grinberg