addressed feedback
This commit is contained in:
		| @@ -9,7 +9,7 @@ async def index(request): | |||||||
|     return await file('websocket.html') |     return await file('websocket.html') | ||||||
|  |  | ||||||
|  |  | ||||||
| @app.ws('/feed') | @app.websocket('/feed') | ||||||
| async def feed(request, ws): | async def feed(request, ws): | ||||||
|     while True: |     while True: | ||||||
|         data = 'hello!' |         data = 'hello!' | ||||||
|   | |||||||
							
								
								
									
										29
									
								
								sanic/app.py
									
									
									
									
									
								
							
							
						
						
									
										29
									
								
								sanic/app.py
									
									
									
									
									
								
							| @@ -19,7 +19,7 @@ from sanic.server import serve, serve_multiple, HttpProtocol | |||||||
| from sanic.static import register as static_register | from sanic.static import register as static_register | ||||||
| from sanic.testing import TestClient | from sanic.testing import TestClient | ||||||
| from sanic.views import CompositionView | from sanic.views import CompositionView | ||||||
| from sanic.ws import WebSocketProtocol, ConnectionClosed | from sanic.websocket import WebSocketProtocol, ConnectionClosed | ||||||
|  |  | ||||||
|  |  | ||||||
| class Sanic: | class Sanic: | ||||||
| @@ -52,7 +52,7 @@ class Sanic: | |||||||
|         self.sock = None |         self.sock = None | ||||||
|         self.listeners = defaultdict(list) |         self.listeners = defaultdict(list) | ||||||
|         self.is_running = False |         self.is_running = False | ||||||
|         self.needs_websocket = False |         self.websocket_enabled = False | ||||||
|  |  | ||||||
|         # Register alternative method names |         # Register alternative method names | ||||||
|         self.go_fast = self.run |         self.go_fast = self.run | ||||||
| @@ -171,13 +171,13 @@ class Sanic: | |||||||
|         return handler |         return handler | ||||||
|  |  | ||||||
|     # Decorator |     # Decorator | ||||||
|     def ws(self, uri, host=None): |     def websocket(self, uri, host=None): | ||||||
|         """Decorate a function to be registered as a websocket route |         """Decorate a function to be registered as a websocket route | ||||||
|         :param uri: path of the URL |         :param uri: path of the URL | ||||||
|         :param host: |         :param host: | ||||||
|         :return: decorated function |         :return: decorated function | ||||||
|         """ |         """ | ||||||
|         self.needs_websocket = True |         self.websocket_enabled = True | ||||||
|  |  | ||||||
|         # Fix case where the user did not prefix the URL with a / |         # Fix case where the user did not prefix the URL with a / | ||||||
|         # and will probably get confused as to why it's not working |         # and will probably get confused as to why it's not working | ||||||
| @@ -201,8 +201,17 @@ class Sanic: | |||||||
|  |  | ||||||
|         return response |         return response | ||||||
|  |  | ||||||
|     def add_ws_route(self, handler, uri, host=None): |     def add_websocket_route(self, handler, uri, host=None): | ||||||
|         return self.ws(uri, host=host)(handler) |         """A helper method to register a function as a websocket route.""" | ||||||
|  |         return self.websocket(uri, host=host)(handler) | ||||||
|  |  | ||||||
|  |     def enable_websocket(self, enable=True): | ||||||
|  |         """Enable or disable the support for websocket. | ||||||
|  |  | ||||||
|  |         Websocket is enabled automatically if websocket routes are | ||||||
|  |         added to the application. | ||||||
|  |         """ | ||||||
|  |         self.websocket_enabled = enable | ||||||
|  |  | ||||||
|     def remove_route(self, uri, clean_cache=True, host=None): |     def remove_route(self, uri, clean_cache=True, host=None): | ||||||
|         self.router.remove(uri, clean_cache, host) |         self.router.remove(uri, clean_cache, host) | ||||||
| @@ -501,8 +510,8 @@ class Sanic: | |||||||
|         :return: Nothing |         :return: Nothing | ||||||
|         """ |         """ | ||||||
|         if protocol is None: |         if protocol is None: | ||||||
|             protocol = WebSocketProtocol if self.needs_websocket \ |             protocol = (WebSocketProtocol if self.websocket_enabled | ||||||
|                 else HttpProtocol |                         else HttpProtocol) | ||||||
|         server_settings = self._helper( |         server_settings = self._helper( | ||||||
|             host=host, port=port, debug=debug, before_start=before_start, |             host=host, port=port, debug=debug, before_start=before_start, | ||||||
|             after_start=after_start, before_stop=before_stop, |             after_start=after_start, before_stop=before_stop, | ||||||
| @@ -538,8 +547,8 @@ class Sanic: | |||||||
|               way to run a Sanic application. |               way to run a Sanic application. | ||||||
|         """ |         """ | ||||||
|         if protocol is None: |         if protocol is None: | ||||||
|             protocol = WebSocketProtocol if self.needs_websocket \ |             protocol = (WebSocketProtocol if self.websocket_enabled | ||||||
|                 else HttpProtocol |                         else HttpProtocol) | ||||||
|         server_settings = self._helper( |         server_settings = self._helper( | ||||||
|             host=host, port=port, debug=debug, before_start=before_start, |             host=host, port=port, debug=debug, before_start=before_start, | ||||||
|             after_start=after_start, before_stop=before_stop, |             after_start=after_start, before_stop=before_stop, | ||||||
|   | |||||||
| @@ -1,18 +1,19 @@ | |||||||
|  | from sanic.exceptions import InvalidUsage | ||||||
| from sanic.server import HttpProtocol | from sanic.server import HttpProtocol | ||||||
| from httptools import HttpParserUpgrade | from httptools import HttpParserUpgrade | ||||||
| from websockets import handshake, WebSocketCommonProtocol | from websockets import handshake, WebSocketCommonProtocol, InvalidHandshake | ||||||
| from websockets import ConnectionClosed  # noqa | from websockets import ConnectionClosed  # noqa | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class WebSocketProtocol(HttpProtocol): | class WebSocketProtocol(HttpProtocol): | ||||||
|     def __init__(self, *args, **kwargs): |     def __init__(self, *args, **kwargs): | ||||||
|         super().__init__(*args, **kwargs) |         super().__init__(*args, **kwargs) | ||||||
|         self.ws = None |         self.websocket = None | ||||||
| 
 | 
 | ||||||
|     def data_received(self, data): |     def data_received(self, data): | ||||||
|         if self.ws is not None: |         if self.websocket is not None: | ||||||
|             # pass the data to the websocket protocol |             # pass the data to the websocket protocol | ||||||
|             self.ws.data_received(data) |             self.websocket.data_received(data) | ||||||
|         else: |         else: | ||||||
|             try: |             try: | ||||||
|                 super().data_received(data) |                 super().data_received(data) | ||||||
| @@ -21,7 +22,7 @@ class WebSocketProtocol(HttpProtocol): | |||||||
|                 pass |                 pass | ||||||
| 
 | 
 | ||||||
|     def write_response(self, response): |     def write_response(self, response): | ||||||
|         if self.ws is not None: |         if self.websocket is not None: | ||||||
|             # websocket requests do not write a response |             # websocket requests do not write a response | ||||||
|             self.transport.close() |             self.transport.close() | ||||||
|         else: |         else: | ||||||
| @@ -37,8 +38,11 @@ class WebSocketProtocol(HttpProtocol): | |||||||
|         def set_header(k, v): |         def set_header(k, v): | ||||||
|             headers.append((k, v)) |             headers.append((k, v)) | ||||||
| 
 | 
 | ||||||
|  |         try: | ||||||
|             key = handshake.check_request(get_header) |             key = handshake.check_request(get_header) | ||||||
|             handshake.build_response(set_header, key) |             handshake.build_response(set_header, key) | ||||||
|  |         except InvalidHandshake: | ||||||
|  |             raise InvalidUsage('Invalid websocket request') | ||||||
| 
 | 
 | ||||||
|         # write the 101 response back to the client |         # write the 101 response back to the client | ||||||
|         rv = b'HTTP/1.1 101 Switching Protocols\r\n' |         rv = b'HTTP/1.1 101 Switching Protocols\r\n' | ||||||
| @@ -48,6 +52,6 @@ class WebSocketProtocol(HttpProtocol): | |||||||
|         request.transport.write(rv) |         request.transport.write(rv) | ||||||
| 
 | 
 | ||||||
|         # hook up the websocket protocol |         # hook up the websocket protocol | ||||||
|         self.ws = WebSocketCommonProtocol() |         self.websocket = WebSocketCommonProtocol() | ||||||
|         self.ws.connection_made(request.transport) |         self.websocket.connection_made(request.transport) | ||||||
|         return self.ws |         return self.websocket | ||||||
		Reference in New Issue
	
	Block a user
	 Miguel Grinberg
					Miguel Grinberg