Merge pull request #469 from miguelgrinberg/websocket-support
websocket support
This commit is contained in:
56
sanic/app.py
56
sanic/app.py
@@ -19,6 +19,7 @@ from sanic.server import serve, serve_multiple, HttpProtocol
|
||||
from sanic.static import register as static_register
|
||||
from sanic.testing import TestClient
|
||||
from sanic.views import CompositionView
|
||||
from sanic.websocket import WebSocketProtocol, ConnectionClosed
|
||||
|
||||
|
||||
class Sanic:
|
||||
@@ -51,6 +52,7 @@ class Sanic:
|
||||
self.sock = None
|
||||
self.listeners = defaultdict(list)
|
||||
self.is_running = False
|
||||
self.websocket_enabled = False
|
||||
|
||||
# Register alternative method names
|
||||
self.go_fast = self.run
|
||||
@@ -168,6 +170,50 @@ class Sanic:
|
||||
self.route(uri=uri, methods=methods, host=host)(handler)
|
||||
return handler
|
||||
|
||||
# Decorator
|
||||
def websocket(self, uri, host=None):
|
||||
"""Decorate a function to be registered as a websocket route
|
||||
:param uri: path of the URL
|
||||
:param host:
|
||||
:return: decorated function
|
||||
"""
|
||||
self.websocket_enabled = True
|
||||
|
||||
# Fix case where the user did not prefix the URL with a /
|
||||
# and will probably get confused as to why it's not working
|
||||
if not uri.startswith('/'):
|
||||
uri = '/' + uri
|
||||
|
||||
def response(handler):
|
||||
async def websocket_handler(request, *args, **kwargs):
|
||||
request.app = self
|
||||
protocol = request.transport.get_protocol()
|
||||
ws = await protocol.websocket_handshake(request)
|
||||
try:
|
||||
# invoke the application handler
|
||||
await handler(request, ws, *args, **kwargs)
|
||||
except ConnectionClosed:
|
||||
pass
|
||||
await ws.close()
|
||||
|
||||
self.router.add(uri=uri, handler=websocket_handler,
|
||||
methods=frozenset({'GET'}), host=host)
|
||||
return handler
|
||||
|
||||
return response
|
||||
|
||||
def add_websocket_route(self, handler, uri, host=None):
|
||||
"""A helper method to register a function as a websocket route."""
|
||||
return self.websocket(uri, host=host)(handler)
|
||||
|
||||
def enable_websocket(self, enable=True):
|
||||
"""Enable or disable the support for websocket.
|
||||
|
||||
Websocket is enabled automatically if websocket routes are
|
||||
added to the application.
|
||||
"""
|
||||
self.websocket_enabled = enable
|
||||
|
||||
def remove_route(self, uri, clean_cache=True, host=None):
|
||||
self.router.remove(uri, clean_cache, host)
|
||||
|
||||
@@ -437,7 +483,7 @@ class Sanic:
|
||||
|
||||
def run(self, host="127.0.0.1", port=8000, debug=False, before_start=None,
|
||||
after_start=None, before_stop=None, after_stop=None, ssl=None,
|
||||
sock=None, workers=1, loop=None, protocol=HttpProtocol,
|
||||
sock=None, workers=1, loop=None, protocol=None,
|
||||
backlog=100, stop_event=None, register_sys_signals=True):
|
||||
"""Run the HTTP Server and listen until keyboard interrupt or term
|
||||
signal. On termination, drain connections before closing.
|
||||
@@ -464,6 +510,9 @@ class Sanic:
|
||||
:param protocol: Subclass of asyncio protocol class
|
||||
:return: Nothing
|
||||
"""
|
||||
if protocol is None:
|
||||
protocol = (WebSocketProtocol if self.websocket_enabled
|
||||
else HttpProtocol)
|
||||
server_settings = self._helper(
|
||||
host=host, port=port, debug=debug, before_start=before_start,
|
||||
after_start=after_start, before_stop=before_stop,
|
||||
@@ -491,13 +540,16 @@ class Sanic:
|
||||
async def create_server(self, host="127.0.0.1", port=8000, debug=False,
|
||||
before_start=None, after_start=None,
|
||||
before_stop=None, after_stop=None, ssl=None,
|
||||
sock=None, loop=None, protocol=HttpProtocol,
|
||||
sock=None, loop=None, protocol=None,
|
||||
backlog=100, stop_event=None):
|
||||
"""Asynchronous version of `run`.
|
||||
|
||||
NOTE: This does not support multiprocessing and is not the preferred
|
||||
way to run a Sanic application.
|
||||
"""
|
||||
if protocol is None:
|
||||
protocol = (WebSocketProtocol if self.websocket_enabled
|
||||
else HttpProtocol)
|
||||
server_settings = self._helper(
|
||||
host=host, port=port, debug=debug, before_start=before_start,
|
||||
after_start=after_start, before_stop=before_stop,
|
||||
|
||||
@@ -23,6 +23,7 @@ class Blueprint:
|
||||
self.host = host
|
||||
|
||||
self.routes = []
|
||||
self.websocket_routes = []
|
||||
self.exceptions = []
|
||||
self.listeners = defaultdict(list)
|
||||
self.middlewares = []
|
||||
@@ -46,6 +47,17 @@ class Blueprint:
|
||||
host=future.host or self.host
|
||||
)(future.handler)
|
||||
|
||||
for future in self.websocket_routes:
|
||||
# attach the blueprint name to the handler so that it can be
|
||||
# prefixed properly in the router
|
||||
future.handler.__blueprintname__ = self.name
|
||||
# Prepend the blueprint URI prefix if available
|
||||
uri = url_prefix + future.uri if url_prefix else future.uri
|
||||
app.websocket(
|
||||
uri=uri,
|
||||
host=future.host or self.host
|
||||
)(future.handler)
|
||||
|
||||
# Middleware
|
||||
for future in self.middlewares:
|
||||
if future.args or future.kwargs:
|
||||
@@ -106,6 +118,28 @@ class Blueprint:
|
||||
self.route(uri=uri, methods=methods, host=host)(handler)
|
||||
return handler
|
||||
|
||||
def websocket(self, uri, host=None):
|
||||
"""Create a blueprint websocket route from a decorated function.
|
||||
|
||||
:param uri: endpoint at which the route will be accessible.
|
||||
"""
|
||||
def decorator(handler):
|
||||
route = FutureRoute(handler, uri, [], host)
|
||||
self.websocket_routes.append(route)
|
||||
return handler
|
||||
return decorator
|
||||
|
||||
def add_websocket_route(self, handler, uri, host=None):
|
||||
"""Create a blueprint websocket route from a function.
|
||||
|
||||
:param handler: function for handling uri requests. Accepts function,
|
||||
or class instance with a view_class method.
|
||||
:param uri: endpoint at which the route will be accessible.
|
||||
:return: function or class instance
|
||||
"""
|
||||
self.websocket(uri=uri, host=host)(handler)
|
||||
return handler
|
||||
|
||||
def listener(self, event):
|
||||
"""Create a listener from a decorated function.
|
||||
|
||||
|
||||
67
sanic/websocket.py
Normal file
67
sanic/websocket.py
Normal file
@@ -0,0 +1,67 @@
|
||||
from sanic.exceptions import InvalidUsage
|
||||
from sanic.server import HttpProtocol
|
||||
from httptools import HttpParserUpgrade
|
||||
from websockets import handshake, WebSocketCommonProtocol, InvalidHandshake
|
||||
from websockets import ConnectionClosed # noqa
|
||||
|
||||
|
||||
class WebSocketProtocol(HttpProtocol):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.websocket = None
|
||||
|
||||
def connection_timeout(self):
|
||||
# timeouts make no sense for websocket routes
|
||||
if self.websocket is None:
|
||||
super().connection_timeout()
|
||||
|
||||
def connection_lost(self, exc):
|
||||
if self.websocket is not None:
|
||||
self.websocket.connection_lost(exc)
|
||||
super().connection_lost(exc)
|
||||
|
||||
def data_received(self, data):
|
||||
if self.websocket is not None:
|
||||
# pass the data to the websocket protocol
|
||||
self.websocket.data_received(data)
|
||||
else:
|
||||
try:
|
||||
super().data_received(data)
|
||||
except HttpParserUpgrade:
|
||||
# this is okay, it just indicates we've got an upgrade request
|
||||
pass
|
||||
|
||||
def write_response(self, response):
|
||||
if self.websocket is not None:
|
||||
# websocket requests do not write a response
|
||||
self.transport.close()
|
||||
else:
|
||||
super().write_response(response)
|
||||
|
||||
async def websocket_handshake(self, request):
|
||||
# let the websockets package do the handshake with the client
|
||||
headers = []
|
||||
|
||||
def get_header(k):
|
||||
return request.headers.get(k, '')
|
||||
|
||||
def set_header(k, v):
|
||||
headers.append((k, v))
|
||||
|
||||
try:
|
||||
key = handshake.check_request(get_header)
|
||||
handshake.build_response(set_header, key)
|
||||
except InvalidHandshake:
|
||||
raise InvalidUsage('Invalid websocket request')
|
||||
|
||||
# write the 101 response back to the client
|
||||
rv = b'HTTP/1.1 101 Switching Protocols\r\n'
|
||||
for k, v in headers:
|
||||
rv += k.encode('utf-8') + b': ' + v.encode('utf-8') + b'\r\n'
|
||||
rv += b'\r\n'
|
||||
request.transport.write(rv)
|
||||
|
||||
# hook up the websocket protocol
|
||||
self.websocket = WebSocketCommonProtocol()
|
||||
self.websocket.connection_made(request.transport)
|
||||
return self.websocket
|
||||
Reference in New Issue
Block a user