Merge pull request #883 from miguelgrinberg/websocket-subprotocols
Weboscket subprotocol negotiation
This commit is contained in:
commit
c797c3f22d
|
@ -214,9 +214,12 @@ class Sanic:
|
||||||
return handler
|
return handler
|
||||||
|
|
||||||
# Decorator
|
# Decorator
|
||||||
def websocket(self, uri, host=None, strict_slashes=False):
|
def websocket(self, uri, host=None, strict_slashes=False,
|
||||||
|
subprotocols=None):
|
||||||
"""Decorate a function to be registered as a websocket route
|
"""Decorate a function to be registered as a websocket route
|
||||||
:param uri: path of the URL
|
:param uri: path of the URL
|
||||||
|
:param subprotocols: optional list of strings with the supported
|
||||||
|
subprotocols
|
||||||
:param host:
|
:param host:
|
||||||
:return: decorated function
|
:return: decorated function
|
||||||
"""
|
"""
|
||||||
|
@ -236,7 +239,7 @@ class Sanic:
|
||||||
# On Python3.5 the Transport classes in asyncio do not
|
# On Python3.5 the Transport classes in asyncio do not
|
||||||
# have a get_protocol() method as in uvloop
|
# have a get_protocol() method as in uvloop
|
||||||
protocol = request.transport._protocol
|
protocol = request.transport._protocol
|
||||||
ws = await protocol.websocket_handshake(request)
|
ws = await protocol.websocket_handshake(request, subprotocols)
|
||||||
|
|
||||||
# schedule the application handler
|
# schedule the application handler
|
||||||
# its future is kept in self.websocket_tasks in case it
|
# its future is kept in self.websocket_tasks in case it
|
||||||
|
|
|
@ -41,7 +41,7 @@ class WebSocketProtocol(HttpProtocol):
|
||||||
else:
|
else:
|
||||||
super().write_response(response)
|
super().write_response(response)
|
||||||
|
|
||||||
async def websocket_handshake(self, request):
|
async def websocket_handshake(self, request, subprotocols=None):
|
||||||
# let the websockets package do the handshake with the client
|
# let the websockets package do the handshake with the client
|
||||||
headers = []
|
headers = []
|
||||||
|
|
||||||
|
@ -57,6 +57,17 @@ class WebSocketProtocol(HttpProtocol):
|
||||||
except InvalidHandshake:
|
except InvalidHandshake:
|
||||||
raise InvalidUsage('Invalid websocket request')
|
raise InvalidUsage('Invalid websocket request')
|
||||||
|
|
||||||
|
subprotocol = None
|
||||||
|
if subprotocols and 'Sec-Websocket-Protocol' in request.headers:
|
||||||
|
# select a subprotocol
|
||||||
|
client_subprotocols = [p.strip() for p in request.headers[
|
||||||
|
'Sec-Websocket-Protocol'].split(',')]
|
||||||
|
for p in client_subprotocols:
|
||||||
|
if p in subprotocols:
|
||||||
|
subprotocol = p
|
||||||
|
set_header('Sec-Websocket-Protocol', subprotocol)
|
||||||
|
break
|
||||||
|
|
||||||
# write the 101 response back to the client
|
# write the 101 response back to the client
|
||||||
rv = b'HTTP/1.1 101 Switching Protocols\r\n'
|
rv = b'HTTP/1.1 101 Switching Protocols\r\n'
|
||||||
for k, v in headers:
|
for k, v in headers:
|
||||||
|
@ -69,5 +80,6 @@ class WebSocketProtocol(HttpProtocol):
|
||||||
max_size=self.websocket_max_size,
|
max_size=self.websocket_max_size,
|
||||||
max_queue=self.websocket_max_queue
|
max_queue=self.websocket_max_queue
|
||||||
)
|
)
|
||||||
|
self.websocket.subprotocol = subprotocol
|
||||||
self.websocket.connection_made(request.transport)
|
self.websocket.connection_made(request.transport)
|
||||||
return self.websocket
|
return self.websocket
|
||||||
|
|
|
@ -341,6 +341,7 @@ def test_websocket_route():
|
||||||
|
|
||||||
@app.websocket('/ws')
|
@app.websocket('/ws')
|
||||||
async def handler(request, ws):
|
async def handler(request, ws):
|
||||||
|
assert ws.subprotocol is None
|
||||||
ev.set()
|
ev.set()
|
||||||
|
|
||||||
request, response = app.test_client.get('/ws', headers={
|
request, response = app.test_client.get('/ws', headers={
|
||||||
|
@ -352,6 +353,48 @@ def test_websocket_route():
|
||||||
assert ev.is_set()
|
assert ev.is_set()
|
||||||
|
|
||||||
|
|
||||||
|
def test_websocket_route_with_subprotocols():
|
||||||
|
app = Sanic('test_websocket_route')
|
||||||
|
results = []
|
||||||
|
|
||||||
|
@app.websocket('/ws', subprotocols=['foo', 'bar'])
|
||||||
|
async def handler(request, ws):
|
||||||
|
results.append(ws.subprotocol)
|
||||||
|
|
||||||
|
request, response = app.test_client.get('/ws', headers={
|
||||||
|
'Upgrade': 'websocket',
|
||||||
|
'Connection': 'upgrade',
|
||||||
|
'Sec-WebSocket-Key': 'dGhlIHNhbXBsZSBub25jZQ==',
|
||||||
|
'Sec-WebSocket-Version': '13',
|
||||||
|
'Sec-WebSocket-Protocol': 'bar'})
|
||||||
|
assert response.status == 101
|
||||||
|
|
||||||
|
request, response = app.test_client.get('/ws', headers={
|
||||||
|
'Upgrade': 'websocket',
|
||||||
|
'Connection': 'upgrade',
|
||||||
|
'Sec-WebSocket-Key': 'dGhlIHNhbXBsZSBub25jZQ==',
|
||||||
|
'Sec-WebSocket-Version': '13',
|
||||||
|
'Sec-WebSocket-Protocol': 'bar, foo'})
|
||||||
|
assert response.status == 101
|
||||||
|
|
||||||
|
request, response = app.test_client.get('/ws', headers={
|
||||||
|
'Upgrade': 'websocket',
|
||||||
|
'Connection': 'upgrade',
|
||||||
|
'Sec-WebSocket-Key': 'dGhlIHNhbXBsZSBub25jZQ==',
|
||||||
|
'Sec-WebSocket-Version': '13',
|
||||||
|
'Sec-WebSocket-Protocol': 'baz'})
|
||||||
|
assert response.status == 101
|
||||||
|
|
||||||
|
request, response = app.test_client.get('/ws', headers={
|
||||||
|
'Upgrade': 'websocket',
|
||||||
|
'Connection': 'upgrade',
|
||||||
|
'Sec-WebSocket-Key': 'dGhlIHNhbXBsZSBub25jZQ==',
|
||||||
|
'Sec-WebSocket-Version': '13'})
|
||||||
|
assert response.status == 101
|
||||||
|
|
||||||
|
assert results == ['bar', 'bar', None, None]
|
||||||
|
|
||||||
|
|
||||||
def test_route_duplicate():
|
def test_route_duplicate():
|
||||||
app = Sanic('test_route_duplicate')
|
app = Sanic('test_route_duplicate')
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user