diff --git a/examples/websocket.py b/examples/websocket.py index ecc14f8d..16cc4015 100644 --- a/examples/websocket.py +++ b/examples/websocket.py @@ -9,7 +9,7 @@ async def index(request): return await file('websocket.html') -@app.ws('/feed') +@app.websocket('/feed') async def feed(request, ws): while True: data = 'hello!' diff --git a/sanic/app.py b/sanic/app.py index b4f7fcb7..633fef17 100644 --- a/sanic/app.py +++ b/sanic/app.py @@ -19,7 +19,7 @@ from sanic.server import serve, serve_multiple, HttpProtocol from sanic.static import register as static_register from sanic.testing import TestClient from sanic.views import CompositionView -from sanic.ws import WebSocketProtocol, ConnectionClosed +from sanic.websocket import WebSocketProtocol, ConnectionClosed class Sanic: @@ -52,7 +52,7 @@ class Sanic: self.sock = None self.listeners = defaultdict(list) self.is_running = False - self.needs_websocket = False + self.websocket_enabled = False # Register alternative method names self.go_fast = self.run @@ -171,13 +171,13 @@ class Sanic: return handler # Decorator - def ws(self, uri, host=None): + def websocket(self, uri, host=None): """Decorate a function to be registered as a websocket route :param uri: path of the URL :param host: :return: decorated function """ - self.needs_websocket = True + self.websocket_enabled = True # Fix case where the user did not prefix the URL with a / # and will probably get confused as to why it's not working @@ -201,8 +201,17 @@ class Sanic: return response - def add_ws_route(self, handler, uri, host=None): - return self.ws(uri, host=host)(handler) + def add_websocket_route(self, handler, uri, host=None): + """A helper method to register a function as a websocket route.""" + return self.websocket(uri, host=host)(handler) + + def enable_websocket(self, enable=True): + """Enable or disable the support for websocket. + + Websocket is enabled automatically if websocket routes are + added to the application. + """ + self.websocket_enabled = enable def remove_route(self, uri, clean_cache=True, host=None): self.router.remove(uri, clean_cache, host) @@ -501,8 +510,8 @@ class Sanic: :return: Nothing """ if protocol is None: - protocol = WebSocketProtocol if self.needs_websocket \ - else HttpProtocol + protocol = (WebSocketProtocol if self.websocket_enabled + else HttpProtocol) server_settings = self._helper( host=host, port=port, debug=debug, before_start=before_start, after_start=after_start, before_stop=before_stop, @@ -538,8 +547,8 @@ class Sanic: way to run a Sanic application. """ if protocol is None: - protocol = WebSocketProtocol if self.needs_websocket \ - else HttpProtocol + protocol = (WebSocketProtocol if self.websocket_enabled + else HttpProtocol) server_settings = self._helper( host=host, port=port, debug=debug, before_start=before_start, after_start=after_start, before_stop=before_stop, diff --git a/sanic/ws.py b/sanic/websocket.py similarity index 68% rename from sanic/ws.py rename to sanic/websocket.py index 8d0ddc70..3fc0a092 100644 --- a/sanic/ws.py +++ b/sanic/websocket.py @@ -1,18 +1,19 @@ +from sanic.exceptions import InvalidUsage from sanic.server import HttpProtocol from httptools import HttpParserUpgrade -from websockets import handshake, WebSocketCommonProtocol +from websockets import handshake, WebSocketCommonProtocol, InvalidHandshake from websockets import ConnectionClosed # noqa class WebSocketProtocol(HttpProtocol): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self.ws = None + self.websocket = None def data_received(self, data): - if self.ws is not None: + if self.websocket is not None: # pass the data to the websocket protocol - self.ws.data_received(data) + self.websocket.data_received(data) else: try: super().data_received(data) @@ -21,7 +22,7 @@ class WebSocketProtocol(HttpProtocol): pass def write_response(self, response): - if self.ws is not None: + if self.websocket is not None: # websocket requests do not write a response self.transport.close() else: @@ -37,8 +38,11 @@ class WebSocketProtocol(HttpProtocol): def set_header(k, v): headers.append((k, v)) - key = handshake.check_request(get_header) - handshake.build_response(set_header, key) + try: + key = handshake.check_request(get_header) + handshake.build_response(set_header, key) + except InvalidHandshake: + raise InvalidUsage('Invalid websocket request') # write the 101 response back to the client rv = b'HTTP/1.1 101 Switching Protocols\r\n' @@ -48,6 +52,6 @@ class WebSocketProtocol(HttpProtocol): request.transport.write(rv) # hook up the websocket protocol - self.ws = WebSocketCommonProtocol() - self.ws.connection_made(request.transport) - return self.ws + self.websocket = WebSocketCommonProtocol() + self.websocket.connection_made(request.transport) + return self.websocket diff --git a/tox.ini b/tox.ini index 05276a43..33e4298f 100644 --- a/tox.ini +++ b/tox.ini @@ -12,6 +12,7 @@ python = deps = aiofiles aiohttp + websockets pytest beautifulsoup4 coverage