Merge pull request #469 from miguelgrinberg/websocket-support

websocket support
This commit is contained in:
Eli Uriegas 2017-03-05 17:42:49 -08:00 committed by GitHub
commit 8e6678d526
11 changed files with 280 additions and 2 deletions

View File

@ -55,6 +55,11 @@ will look like:
Blueprints have much the same functionality as an application instance. Blueprints have much the same functionality as an application instance.
### WebSocket routes
WebSocket handlers can be registered on a blueprint using the `@bp.route`
decorator or `bp.add_websocket_route` method.
### Middleware ### Middleware
Using blueprints allows you to also register middleware globally. Using blueprints allows you to also register middleware globally.

View File

@ -181,3 +181,37 @@ url = app.url_for('post_handler', post_id=5, arg_one=['one', 'two'], arg_two=2,
# http://another_server:8888/posts/5?arg_one=one&arg_one=two&arg_two=2#anchor # http://another_server:8888/posts/5?arg_one=one&arg_one=two&arg_two=2#anchor
``` ```
- All valid parameters must be passed to `url_for` to build a URL. If a parameter is not supplied, or if a parameter does not match the specified type, a `URLBuildError` will be thrown. - All valid parameters must be passed to `url_for` to build a URL. If a parameter is not supplied, or if a parameter does not match the specified type, a `URLBuildError` will be thrown.
## WebSocket routes
Routes for the WebSocket protocol can be defined with the `@app.websocket`
decorator:
```python
@app.websocket('/feed')
async def feed(request, ws):
while True:
data = 'hello!'
print('Sending: ' + data)
await ws.send(data)
data = await ws.recv()
print('Received: ' + data)
```
Alternatively, the `app.add_websocket_route` method can be used instead of the
decorator:
```python
async def feed(request, ws):
pass
app.add_websocket_route(my_websocket_handler, '/feed')
```
Handlers for a WebSocket route are passed the request as first argument, and a
WebSocket protocol object as second argument. The protocol object has `send`
and `recv` methods to send and receive data respectively.
WebSocket support requires the [websockets](https://github.com/aaugustin/websockets)
package by Aymeric Augustin.

29
examples/websocket.html Normal file
View File

@ -0,0 +1,29 @@
<!DOCTYPE html>
<html>
<head>
<title>WebSocket demo</title>
</head>
<body>
<script>
var ws = new WebSocket('ws://' + document.domain + ':' + location.port + '/feed'),
messages = document.createElement('ul');
ws.onmessage = function (event) {
var messages = document.getElementsByTagName('ul')[0],
message = document.createElement('li'),
content = document.createTextNode('Received: ' + event.data);
message.appendChild(content);
messages.appendChild(message);
};
document.body.appendChild(messages);
window.setInterval(function() {
data = 'bye!'
ws.send(data);
var messages = document.getElementsByTagName('ul')[0],
message = document.createElement('li'),
content = document.createTextNode('Sent: ' + data);
message.appendChild(content);
messages.appendChild(message);
}, 1000);
</script>
</body>
</html>

23
examples/websocket.py Normal file
View File

@ -0,0 +1,23 @@
from sanic import Sanic
from sanic.response import file
app = Sanic(__name__)
@app.route('/')
async def index(request):
return await file('websocket.html')
@app.websocket('/feed')
async def feed(request, ws):
while True:
data = 'hello!'
print('Sending: ' + data)
await ws.send(data)
data = await ws.recv()
print('Received: ' + data)
if __name__ == '__main__':
app.run()

View File

@ -19,6 +19,7 @@ from sanic.server import serve, serve_multiple, HttpProtocol
from sanic.static import register as static_register from sanic.static import register as static_register
from sanic.testing import TestClient from sanic.testing import TestClient
from sanic.views import CompositionView from sanic.views import CompositionView
from sanic.websocket import WebSocketProtocol, ConnectionClosed
class Sanic: class Sanic:
@ -51,6 +52,7 @@ class Sanic:
self.sock = None self.sock = None
self.listeners = defaultdict(list) self.listeners = defaultdict(list)
self.is_running = False self.is_running = False
self.websocket_enabled = False
# Register alternative method names # Register alternative method names
self.go_fast = self.run self.go_fast = self.run
@ -168,6 +170,50 @@ class Sanic:
self.route(uri=uri, methods=methods, host=host)(handler) self.route(uri=uri, methods=methods, host=host)(handler)
return handler return handler
# Decorator
def websocket(self, uri, host=None):
"""Decorate a function to be registered as a websocket route
:param uri: path of the URL
:param host:
:return: decorated function
"""
self.websocket_enabled = True
# Fix case where the user did not prefix the URL with a /
# and will probably get confused as to why it's not working
if not uri.startswith('/'):
uri = '/' + uri
def response(handler):
async def websocket_handler(request, *args, **kwargs):
request.app = self
protocol = request.transport.get_protocol()
ws = await protocol.websocket_handshake(request)
try:
# invoke the application handler
await handler(request, ws, *args, **kwargs)
except ConnectionClosed:
pass
await ws.close()
self.router.add(uri=uri, handler=websocket_handler,
methods=frozenset({'GET'}), host=host)
return handler
return response
def add_websocket_route(self, handler, uri, host=None):
"""A helper method to register a function as a websocket route."""
return self.websocket(uri, host=host)(handler)
def enable_websocket(self, enable=True):
"""Enable or disable the support for websocket.
Websocket is enabled automatically if websocket routes are
added to the application.
"""
self.websocket_enabled = enable
def remove_route(self, uri, clean_cache=True, host=None): def remove_route(self, uri, clean_cache=True, host=None):
self.router.remove(uri, clean_cache, host) self.router.remove(uri, clean_cache, host)
@ -437,7 +483,7 @@ class Sanic:
def run(self, host="127.0.0.1", port=8000, debug=False, before_start=None, def run(self, host="127.0.0.1", port=8000, debug=False, before_start=None,
after_start=None, before_stop=None, after_stop=None, ssl=None, after_start=None, before_stop=None, after_stop=None, ssl=None,
sock=None, workers=1, loop=None, protocol=HttpProtocol, sock=None, workers=1, loop=None, protocol=None,
backlog=100, stop_event=None, register_sys_signals=True): backlog=100, stop_event=None, register_sys_signals=True):
"""Run the HTTP Server and listen until keyboard interrupt or term """Run the HTTP Server and listen until keyboard interrupt or term
signal. On termination, drain connections before closing. signal. On termination, drain connections before closing.
@ -464,6 +510,9 @@ class Sanic:
:param protocol: Subclass of asyncio protocol class :param protocol: Subclass of asyncio protocol class
:return: Nothing :return: Nothing
""" """
if protocol is None:
protocol = (WebSocketProtocol if self.websocket_enabled
else HttpProtocol)
server_settings = self._helper( server_settings = self._helper(
host=host, port=port, debug=debug, before_start=before_start, host=host, port=port, debug=debug, before_start=before_start,
after_start=after_start, before_stop=before_stop, after_start=after_start, before_stop=before_stop,
@ -491,13 +540,16 @@ class Sanic:
async def create_server(self, host="127.0.0.1", port=8000, debug=False, async def create_server(self, host="127.0.0.1", port=8000, debug=False,
before_start=None, after_start=None, before_start=None, after_start=None,
before_stop=None, after_stop=None, ssl=None, before_stop=None, after_stop=None, ssl=None,
sock=None, loop=None, protocol=HttpProtocol, sock=None, loop=None, protocol=None,
backlog=100, stop_event=None): backlog=100, stop_event=None):
"""Asynchronous version of `run`. """Asynchronous version of `run`.
NOTE: This does not support multiprocessing and is not the preferred NOTE: This does not support multiprocessing and is not the preferred
way to run a Sanic application. way to run a Sanic application.
""" """
if protocol is None:
protocol = (WebSocketProtocol if self.websocket_enabled
else HttpProtocol)
server_settings = self._helper( server_settings = self._helper(
host=host, port=port, debug=debug, before_start=before_start, host=host, port=port, debug=debug, before_start=before_start,
after_start=after_start, before_stop=before_stop, after_start=after_start, before_stop=before_stop,

View File

@ -23,6 +23,7 @@ class Blueprint:
self.host = host self.host = host
self.routes = [] self.routes = []
self.websocket_routes = []
self.exceptions = [] self.exceptions = []
self.listeners = defaultdict(list) self.listeners = defaultdict(list)
self.middlewares = [] self.middlewares = []
@ -46,6 +47,17 @@ class Blueprint:
host=future.host or self.host host=future.host or self.host
)(future.handler) )(future.handler)
for future in self.websocket_routes:
# attach the blueprint name to the handler so that it can be
# prefixed properly in the router
future.handler.__blueprintname__ = self.name
# Prepend the blueprint URI prefix if available
uri = url_prefix + future.uri if url_prefix else future.uri
app.websocket(
uri=uri,
host=future.host or self.host
)(future.handler)
# Middleware # Middleware
for future in self.middlewares: for future in self.middlewares:
if future.args or future.kwargs: if future.args or future.kwargs:
@ -106,6 +118,28 @@ class Blueprint:
self.route(uri=uri, methods=methods, host=host)(handler) self.route(uri=uri, methods=methods, host=host)(handler)
return handler return handler
def websocket(self, uri, host=None):
"""Create a blueprint websocket route from a decorated function.
:param uri: endpoint at which the route will be accessible.
"""
def decorator(handler):
route = FutureRoute(handler, uri, [], host)
self.websocket_routes.append(route)
return handler
return decorator
def add_websocket_route(self, handler, uri, host=None):
"""Create a blueprint websocket route from a function.
:param handler: function for handling uri requests. Accepts function,
or class instance with a view_class method.
:param uri: endpoint at which the route will be accessible.
:return: function or class instance
"""
self.websocket(uri=uri, host=host)(handler)
return handler
def listener(self, event): def listener(self, event):
"""Create a listener from a decorated function. """Create a listener from a decorated function.

67
sanic/websocket.py Normal file
View File

@ -0,0 +1,67 @@
from sanic.exceptions import InvalidUsage
from sanic.server import HttpProtocol
from httptools import HttpParserUpgrade
from websockets import handshake, WebSocketCommonProtocol, InvalidHandshake
from websockets import ConnectionClosed # noqa
class WebSocketProtocol(HttpProtocol):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.websocket = None
def connection_timeout(self):
# timeouts make no sense for websocket routes
if self.websocket is None:
super().connection_timeout()
def connection_lost(self, exc):
if self.websocket is not None:
self.websocket.connection_lost(exc)
super().connection_lost(exc)
def data_received(self, data):
if self.websocket is not None:
# pass the data to the websocket protocol
self.websocket.data_received(data)
else:
try:
super().data_received(data)
except HttpParserUpgrade:
# this is okay, it just indicates we've got an upgrade request
pass
def write_response(self, response):
if self.websocket is not None:
# websocket requests do not write a response
self.transport.close()
else:
super().write_response(response)
async def websocket_handshake(self, request):
# let the websockets package do the handshake with the client
headers = []
def get_header(k):
return request.headers.get(k, '')
def set_header(k, v):
headers.append((k, v))
try:
key = handshake.check_request(get_header)
handshake.build_response(set_header, key)
except InvalidHandshake:
raise InvalidUsage('Invalid websocket request')
# write the 101 response back to the client
rv = b'HTTP/1.1 101 Switching Protocols\r\n'
for k, v in headers:
rv += k.encode('utf-8') + b': ' + v.encode('utf-8') + b'\r\n'
rv += b'\r\n'
request.transport.write(rv)
# hook up the websocket protocol
self.websocket = WebSocketCommonProtocol()
self.websocket.connection_made(request.transport)
return self.websocket

View File

@ -19,6 +19,7 @@ install_requires = [
'httptools>=0.0.9', 'httptools>=0.0.9',
'ujson>=1.35', 'ujson>=1.35',
'aiofiles>=0.3.0', 'aiofiles>=0.3.0',
'websockets>=3.2',
] ]
if os.name != 'nt': if os.name != 'nt':

View File

@ -1,3 +1,4 @@
import asyncio
import inspect import inspect
from sanic import Sanic from sanic import Sanic
@ -236,6 +237,7 @@ def test_bp_static():
def test_bp_shorthand(): def test_bp_shorthand():
app = Sanic('test_shorhand_routes') app = Sanic('test_shorhand_routes')
blueprint = Blueprint('test_shorhand_routes') blueprint = Blueprint('test_shorhand_routes')
ev = asyncio.Event()
@blueprint.get('/get') @blueprint.get('/get')
def handler(request): def handler(request):
@ -265,6 +267,10 @@ def test_bp_shorthand():
def handler(request): def handler(request):
return text('OK') return text('OK')
@blueprint.websocket('/ws')
async def handler(request, ws):
ev.set()
app.blueprint(blueprint) app.blueprint(blueprint)
request, response = app.test_client.get('/get') request, response = app.test_client.get('/get')
@ -308,3 +314,11 @@ def test_bp_shorthand():
request, response = app.test_client.get('/delete') request, response = app.test_client.get('/delete')
assert response.status == 405 assert response.status == 405
request, response = app.test_client.get('/ws', headers={
'Upgrade': 'websocket',
'Connection': 'upgrade',
'Sec-WebSocket-Key': 'dGhlIHNhbXBsZSBub25jZQ==',
'Sec-WebSocket-Version': '13'})
assert response.status == 101
assert ev.is_set()

View File

@ -1,3 +1,4 @@
import asyncio
import pytest import pytest
from sanic import Sanic from sanic import Sanic
@ -234,6 +235,23 @@ def test_dynamic_route_unhashable():
assert response.status == 404 assert response.status == 404
def test_websocket_route():
app = Sanic('test_websocket_route')
ev = asyncio.Event()
@app.websocket('/ws')
async def handler(request, ws):
ev.set()
request, response = app.test_client.get('/ws', headers={
'Upgrade': 'websocket',
'Connection': 'upgrade',
'Sec-WebSocket-Key': 'dGhlIHNhbXBsZSBub25jZQ==',
'Sec-WebSocket-Version': '13'})
assert response.status == 101
assert ev.is_set()
def test_route_duplicate(): def test_route_duplicate():
app = Sanic('test_route_duplicate') app = Sanic('test_route_duplicate')

View File

@ -12,6 +12,7 @@ python =
deps = deps =
aiofiles aiofiles
aiohttp aiohttp
websockets
pytest pytest
beautifulsoup4 beautifulsoup4
coverage coverage