addressed feedback

This commit is contained in:
Miguel Grinberg 2017-02-20 14:32:14 -08:00
parent 6e903ee7d5
commit 1d6e11ca10
4 changed files with 35 additions and 21 deletions

View File

@ -9,7 +9,7 @@ async def index(request):
return await file('websocket.html')
@app.ws('/feed')
@app.websocket('/feed')
async def feed(request, ws):
while True:
data = 'hello!'

View File

@ -19,7 +19,7 @@ from sanic.server import serve, serve_multiple, HttpProtocol
from sanic.static import register as static_register
from sanic.testing import TestClient
from sanic.views import CompositionView
from sanic.ws import WebSocketProtocol, ConnectionClosed
from sanic.websocket import WebSocketProtocol, ConnectionClosed
class Sanic:
@ -52,7 +52,7 @@ class Sanic:
self.sock = None
self.listeners = defaultdict(list)
self.is_running = False
self.needs_websocket = False
self.websocket_enabled = False
# Register alternative method names
self.go_fast = self.run
@ -171,13 +171,13 @@ class Sanic:
return handler
# Decorator
def ws(self, uri, host=None):
def websocket(self, uri, host=None):
"""Decorate a function to be registered as a websocket route
:param uri: path of the URL
:param host:
:return: decorated function
"""
self.needs_websocket = True
self.websocket_enabled = True
# Fix case where the user did not prefix the URL with a /
# and will probably get confused as to why it's not working
@ -201,8 +201,17 @@ class Sanic:
return response
def add_ws_route(self, handler, uri, host=None):
return self.ws(uri, host=host)(handler)
def add_websocket_route(self, handler, uri, host=None):
"""A helper method to register a function as a websocket route."""
return self.websocket(uri, host=host)(handler)
def enable_websocket(self, enable=True):
"""Enable or disable the support for websocket.
Websocket is enabled automatically if websocket routes are
added to the application.
"""
self.websocket_enabled = enable
def remove_route(self, uri, clean_cache=True, host=None):
self.router.remove(uri, clean_cache, host)
@ -501,8 +510,8 @@ class Sanic:
:return: Nothing
"""
if protocol is None:
protocol = WebSocketProtocol if self.needs_websocket \
else HttpProtocol
protocol = (WebSocketProtocol if self.websocket_enabled
else HttpProtocol)
server_settings = self._helper(
host=host, port=port, debug=debug, before_start=before_start,
after_start=after_start, before_stop=before_stop,
@ -538,8 +547,8 @@ class Sanic:
way to run a Sanic application.
"""
if protocol is None:
protocol = WebSocketProtocol if self.needs_websocket \
else HttpProtocol
protocol = (WebSocketProtocol if self.websocket_enabled
else HttpProtocol)
server_settings = self._helper(
host=host, port=port, debug=debug, before_start=before_start,
after_start=after_start, before_stop=before_stop,

View File

@ -1,18 +1,19 @@
from sanic.exceptions import InvalidUsage
from sanic.server import HttpProtocol
from httptools import HttpParserUpgrade
from websockets import handshake, WebSocketCommonProtocol
from websockets import handshake, WebSocketCommonProtocol, InvalidHandshake
from websockets import ConnectionClosed # noqa
class WebSocketProtocol(HttpProtocol):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.ws = None
self.websocket = None
def data_received(self, data):
if self.ws is not None:
if self.websocket is not None:
# pass the data to the websocket protocol
self.ws.data_received(data)
self.websocket.data_received(data)
else:
try:
super().data_received(data)
@ -21,7 +22,7 @@ class WebSocketProtocol(HttpProtocol):
pass
def write_response(self, response):
if self.ws is not None:
if self.websocket is not None:
# websocket requests do not write a response
self.transport.close()
else:
@ -37,8 +38,11 @@ class WebSocketProtocol(HttpProtocol):
def set_header(k, v):
headers.append((k, v))
key = handshake.check_request(get_header)
handshake.build_response(set_header, key)
try:
key = handshake.check_request(get_header)
handshake.build_response(set_header, key)
except InvalidHandshake:
raise InvalidUsage('Invalid websocket request')
# write the 101 response back to the client
rv = b'HTTP/1.1 101 Switching Protocols\r\n'
@ -48,6 +52,6 @@ class WebSocketProtocol(HttpProtocol):
request.transport.write(rv)
# hook up the websocket protocol
self.ws = WebSocketCommonProtocol()
self.ws.connection_made(request.transport)
return self.ws
self.websocket = WebSocketCommonProtocol()
self.websocket.connection_made(request.transport)
return self.websocket

View File

@ -12,6 +12,7 @@ python =
deps =
aiofiles
aiohttp
websockets
pytest
beautifulsoup4
coverage