addressed feedback

This commit is contained in:
Miguel Grinberg 2017-02-20 14:32:14 -08:00
parent 6e903ee7d5
commit 1d6e11ca10
4 changed files with 35 additions and 21 deletions

View File

@ -9,7 +9,7 @@ async def index(request):
return await file('websocket.html') return await file('websocket.html')
@app.ws('/feed') @app.websocket('/feed')
async def feed(request, ws): async def feed(request, ws):
while True: while True:
data = 'hello!' data = 'hello!'

View File

@ -19,7 +19,7 @@ from sanic.server import serve, serve_multiple, HttpProtocol
from sanic.static import register as static_register from sanic.static import register as static_register
from sanic.testing import TestClient from sanic.testing import TestClient
from sanic.views import CompositionView from sanic.views import CompositionView
from sanic.ws import WebSocketProtocol, ConnectionClosed from sanic.websocket import WebSocketProtocol, ConnectionClosed
class Sanic: class Sanic:
@ -52,7 +52,7 @@ class Sanic:
self.sock = None self.sock = None
self.listeners = defaultdict(list) self.listeners = defaultdict(list)
self.is_running = False self.is_running = False
self.needs_websocket = False self.websocket_enabled = False
# Register alternative method names # Register alternative method names
self.go_fast = self.run self.go_fast = self.run
@ -171,13 +171,13 @@ class Sanic:
return handler return handler
# Decorator # Decorator
def ws(self, uri, host=None): def websocket(self, uri, host=None):
"""Decorate a function to be registered as a websocket route """Decorate a function to be registered as a websocket route
:param uri: path of the URL :param uri: path of the URL
:param host: :param host:
:return: decorated function :return: decorated function
""" """
self.needs_websocket = True self.websocket_enabled = True
# Fix case where the user did not prefix the URL with a / # Fix case where the user did not prefix the URL with a /
# and will probably get confused as to why it's not working # and will probably get confused as to why it's not working
@ -201,8 +201,17 @@ class Sanic:
return response return response
def add_ws_route(self, handler, uri, host=None): def add_websocket_route(self, handler, uri, host=None):
return self.ws(uri, host=host)(handler) """A helper method to register a function as a websocket route."""
return self.websocket(uri, host=host)(handler)
def enable_websocket(self, enable=True):
"""Enable or disable the support for websocket.
Websocket is enabled automatically if websocket routes are
added to the application.
"""
self.websocket_enabled = enable
def remove_route(self, uri, clean_cache=True, host=None): def remove_route(self, uri, clean_cache=True, host=None):
self.router.remove(uri, clean_cache, host) self.router.remove(uri, clean_cache, host)
@ -501,8 +510,8 @@ class Sanic:
:return: Nothing :return: Nothing
""" """
if protocol is None: if protocol is None:
protocol = WebSocketProtocol if self.needs_websocket \ protocol = (WebSocketProtocol if self.websocket_enabled
else HttpProtocol else HttpProtocol)
server_settings = self._helper( server_settings = self._helper(
host=host, port=port, debug=debug, before_start=before_start, host=host, port=port, debug=debug, before_start=before_start,
after_start=after_start, before_stop=before_stop, after_start=after_start, before_stop=before_stop,
@ -538,8 +547,8 @@ class Sanic:
way to run a Sanic application. way to run a Sanic application.
""" """
if protocol is None: if protocol is None:
protocol = WebSocketProtocol if self.needs_websocket \ protocol = (WebSocketProtocol if self.websocket_enabled
else HttpProtocol else HttpProtocol)
server_settings = self._helper( server_settings = self._helper(
host=host, port=port, debug=debug, before_start=before_start, host=host, port=port, debug=debug, before_start=before_start,
after_start=after_start, before_stop=before_stop, after_start=after_start, before_stop=before_stop,

View File

@ -1,18 +1,19 @@
from sanic.exceptions import InvalidUsage
from sanic.server import HttpProtocol from sanic.server import HttpProtocol
from httptools import HttpParserUpgrade from httptools import HttpParserUpgrade
from websockets import handshake, WebSocketCommonProtocol from websockets import handshake, WebSocketCommonProtocol, InvalidHandshake
from websockets import ConnectionClosed # noqa from websockets import ConnectionClosed # noqa
class WebSocketProtocol(HttpProtocol): class WebSocketProtocol(HttpProtocol):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.ws = None self.websocket = None
def data_received(self, data): def data_received(self, data):
if self.ws is not None: if self.websocket is not None:
# pass the data to the websocket protocol # pass the data to the websocket protocol
self.ws.data_received(data) self.websocket.data_received(data)
else: else:
try: try:
super().data_received(data) super().data_received(data)
@ -21,7 +22,7 @@ class WebSocketProtocol(HttpProtocol):
pass pass
def write_response(self, response): def write_response(self, response):
if self.ws is not None: if self.websocket is not None:
# websocket requests do not write a response # websocket requests do not write a response
self.transport.close() self.transport.close()
else: else:
@ -37,8 +38,11 @@ class WebSocketProtocol(HttpProtocol):
def set_header(k, v): def set_header(k, v):
headers.append((k, v)) headers.append((k, v))
key = handshake.check_request(get_header) try:
handshake.build_response(set_header, key) key = handshake.check_request(get_header)
handshake.build_response(set_header, key)
except InvalidHandshake:
raise InvalidUsage('Invalid websocket request')
# write the 101 response back to the client # write the 101 response back to the client
rv = b'HTTP/1.1 101 Switching Protocols\r\n' rv = b'HTTP/1.1 101 Switching Protocols\r\n'
@ -48,6 +52,6 @@ class WebSocketProtocol(HttpProtocol):
request.transport.write(rv) request.transport.write(rv)
# hook up the websocket protocol # hook up the websocket protocol
self.ws = WebSocketCommonProtocol() self.websocket = WebSocketCommonProtocol()
self.ws.connection_made(request.transport) self.websocket.connection_made(request.transport)
return self.ws return self.websocket

View File

@ -12,6 +12,7 @@ python =
deps = deps =
aiofiles aiofiles
aiohttp aiohttp
websockets
pytest pytest
beautifulsoup4 beautifulsoup4
coverage coverage