Weboscket subprotocol negotiation

Fixes #874
This commit is contained in:
Miguel Grinberg 2017-08-08 11:21:52 -07:00
parent 7b66a56cad
commit 375ed23216
No known key found for this signature in database
GPG Key ID: 36848B262DF5F06C
3 changed files with 61 additions and 3 deletions

View File

@ -214,9 +214,12 @@ class Sanic:
return handler return handler
# Decorator # Decorator
def websocket(self, uri, host=None, strict_slashes=False): def websocket(self, uri, host=None, strict_slashes=False,
subprotocols=None):
"""Decorate a function to be registered as a websocket route """Decorate a function to be registered as a websocket route
:param uri: path of the URL :param uri: path of the URL
:param subprotocols: optional list of strings with the supported
subprotocols
:param host: :param host:
:return: decorated function :return: decorated function
""" """
@ -236,7 +239,7 @@ class Sanic:
# On Python3.5 the Transport classes in asyncio do not # On Python3.5 the Transport classes in asyncio do not
# have a get_protocol() method as in uvloop # have a get_protocol() method as in uvloop
protocol = request.transport._protocol protocol = request.transport._protocol
ws = await protocol.websocket_handshake(request) ws = await protocol.websocket_handshake(request, subprotocols)
# schedule the application handler # schedule the application handler
# its future is kept in self.websocket_tasks in case it # its future is kept in self.websocket_tasks in case it

View File

@ -41,7 +41,7 @@ class WebSocketProtocol(HttpProtocol):
else: else:
super().write_response(response) super().write_response(response)
async def websocket_handshake(self, request): async def websocket_handshake(self, request, subprotocols=None):
# let the websockets package do the handshake with the client # let the websockets package do the handshake with the client
headers = [] headers = []
@ -57,6 +57,17 @@ class WebSocketProtocol(HttpProtocol):
except InvalidHandshake: except InvalidHandshake:
raise InvalidUsage('Invalid websocket request') raise InvalidUsage('Invalid websocket request')
subprotocol = None
if subprotocols and 'Sec-Websocket-Protocol' in request.headers:
# select a subprotocol
client_subprotocols = [p.strip() for p in request.headers[
'Sec-Websocket-Protocol'].split(',')]
for p in client_subprotocols:
if p in subprotocols:
subprotocol = p
set_header('Sec-Websocket-Protocol', subprotocol)
break
# write the 101 response back to the client # write the 101 response back to the client
rv = b'HTTP/1.1 101 Switching Protocols\r\n' rv = b'HTTP/1.1 101 Switching Protocols\r\n'
for k, v in headers: for k, v in headers:
@ -69,5 +80,6 @@ class WebSocketProtocol(HttpProtocol):
max_size=self.websocket_max_size, max_size=self.websocket_max_size,
max_queue=self.websocket_max_queue max_queue=self.websocket_max_queue
) )
self.websocket.subprotocol = subprotocol
self.websocket.connection_made(request.transport) self.websocket.connection_made(request.transport)
return self.websocket return self.websocket

View File

@ -341,6 +341,7 @@ def test_websocket_route():
@app.websocket('/ws') @app.websocket('/ws')
async def handler(request, ws): async def handler(request, ws):
assert ws.subprotocol is None
ev.set() ev.set()
request, response = app.test_client.get('/ws', headers={ request, response = app.test_client.get('/ws', headers={
@ -352,6 +353,48 @@ def test_websocket_route():
assert ev.is_set() assert ev.is_set()
def test_websocket_route_with_subprotocols():
app = Sanic('test_websocket_route')
results = []
@app.websocket('/ws', subprotocols=['foo', 'bar'])
async def handler(request, ws):
results.append(ws.subprotocol)
request, response = app.test_client.get('/ws', headers={
'Upgrade': 'websocket',
'Connection': 'upgrade',
'Sec-WebSocket-Key': 'dGhlIHNhbXBsZSBub25jZQ==',
'Sec-WebSocket-Version': '13',
'Sec-WebSocket-Protocol': 'bar'})
assert response.status == 101
request, response = app.test_client.get('/ws', headers={
'Upgrade': 'websocket',
'Connection': 'upgrade',
'Sec-WebSocket-Key': 'dGhlIHNhbXBsZSBub25jZQ==',
'Sec-WebSocket-Version': '13',
'Sec-WebSocket-Protocol': 'bar, foo'})
assert response.status == 101
request, response = app.test_client.get('/ws', headers={
'Upgrade': 'websocket',
'Connection': 'upgrade',
'Sec-WebSocket-Key': 'dGhlIHNhbXBsZSBub25jZQ==',
'Sec-WebSocket-Version': '13',
'Sec-WebSocket-Protocol': 'baz'})
assert response.status == 101
request, response = app.test_client.get('/ws', headers={
'Upgrade': 'websocket',
'Connection': 'upgrade',
'Sec-WebSocket-Key': 'dGhlIHNhbXBsZSBub25jZQ==',
'Sec-WebSocket-Version': '13'})
assert response.status == 101
assert results == ['bar', 'bar', None, None]
def test_route_duplicate(): def test_route_duplicate():
app = Sanic('test_route_duplicate') app = Sanic('test_route_duplicate')