Merge pull request #1475 from tomchristie/asgi-refactor-attempt
ASGI refactoring attempt
This commit is contained in:
		| @@ -5,3 +5,11 @@ omit = site-packages, sanic/utils.py, sanic/__main__.py | |||||||
|  |  | ||||||
| [html] | [html] | ||||||
| directory = coverage | directory = coverage | ||||||
|  |  | ||||||
|  | [report] | ||||||
|  | exclude_lines = | ||||||
|  |     no cov | ||||||
|  |     no qa | ||||||
|  |     noqa | ||||||
|  |     NOQA | ||||||
|  |     pragma: no cover | ||||||
|   | |||||||
| @@ -1,7 +1,12 @@ | |||||||
| # Deploying | # Deploying | ||||||
|  |  | ||||||
| Deploying Sanic is made simple by the inbuilt webserver. After defining an | Deploying Sanic is very simple using one of three options: the inbuilt webserver, | ||||||
| instance of `sanic.Sanic`, we can call the `run` method with the following | an [ASGI webserver](https://asgi.readthedocs.io/en/latest/implementations.html), or `gunicorn`. | ||||||
|  | It is also very common to place Sanic behind a reverse proxy, like `nginx`.  | ||||||
|  |  | ||||||
|  | ## Running via Sanic webserver | ||||||
|  |  | ||||||
|  | After defining an instance of `sanic.Sanic`, we can call the `run` method with the following | ||||||
| keyword arguments: | keyword arguments: | ||||||
|  |  | ||||||
| - `host` *(default `"127.0.0.1"`)*: Address to host the server on. | - `host` *(default `"127.0.0.1"`)*: Address to host the server on. | ||||||
| @@ -17,7 +22,13 @@ keyword arguments: | |||||||
|   [asyncio.protocol](https://docs.python.org/3/library/asyncio-protocol.html#protocol-classes). |   [asyncio.protocol](https://docs.python.org/3/library/asyncio-protocol.html#protocol-classes). | ||||||
| - `access_log` *(default `True`)*: Enables log on handling requests (significantly slows server). | - `access_log` *(default `True`)*: Enables log on handling requests (significantly slows server). | ||||||
|  |  | ||||||
| ## Workers | ```python | ||||||
|  | app.run(host='0.0.0.0', port=1337, access_log=False) | ||||||
|  | ``` | ||||||
|  |  | ||||||
|  | In the above example, we decided to turn off the access log in order to increase performance. | ||||||
|  |  | ||||||
|  | ### Workers | ||||||
|  |  | ||||||
| By default, Sanic listens in the main process using only one CPU core. To crank | By default, Sanic listens in the main process using only one CPU core. To crank | ||||||
| up the juice, just specify the number of workers in the `run` arguments. | up the juice, just specify the number of workers in the `run` arguments. | ||||||
| @@ -29,9 +40,9 @@ app.run(host='0.0.0.0', port=1337, workers=4) | |||||||
| Sanic will automatically spin up multiple processes and route traffic between | Sanic will automatically spin up multiple processes and route traffic between | ||||||
| them. We recommend as many workers as you have available cores. | them. We recommend as many workers as you have available cores. | ||||||
|  |  | ||||||
| ## Running via command | ### Running via command | ||||||
|  |  | ||||||
| If you like using command line arguments, you can launch a Sanic server by | If you like using command line arguments, you can launch a Sanic webserver by | ||||||
| executing the module. For example, if you initialized Sanic as `app` in a file | executing the module. For example, if you initialized Sanic as `app` in a file | ||||||
| named `server.py`, you could run the server like so: | named `server.py`, you could run the server like so: | ||||||
|  |  | ||||||
| @@ -46,6 +57,33 @@ if __name__ == '__main__': | |||||||
|     app.run(host='0.0.0.0', port=1337, workers=4) |     app.run(host='0.0.0.0', port=1337, workers=4) | ||||||
| ``` | ``` | ||||||
|  |  | ||||||
|  | ## Running via ASGI | ||||||
|  |  | ||||||
|  | Sanic is also ASGI-compliant. This means you can use your preferred ASGI webserver | ||||||
|  | to run Sanic. The three main implementations of ASGI are | ||||||
|  | [Daphne](http://github.com/django/daphne), [Uvicorn](https://www.uvicorn.org/), | ||||||
|  | and [Hypercorn](https://pgjones.gitlab.io/hypercorn/index.html). | ||||||
|  |  | ||||||
|  | Follow their documentation for the proper way to run them, but it should look | ||||||
|  | something like: | ||||||
|  |  | ||||||
|  | ``` | ||||||
|  | daphne myapp:app | ||||||
|  | uvicorn myapp:app | ||||||
|  | hypercorn myapp:app | ||||||
|  | ``` | ||||||
|  |  | ||||||
|  | A couple things to note when using ASGI: | ||||||
|  |  | ||||||
|  | 1. When using the Sanic webserver, websockets will run using the [`websockets`](https://websockets.readthedocs.io/) package. In ASGI mode, there is no need for this package since websockets are managed in the ASGI server. | ||||||
|  | 1. The ASGI [lifespan protocol](https://asgi.readthedocs.io/en/latest/specs/lifespan.html) supports | ||||||
|  | only two server events: startup and shutdown. Sanic has four: before startup, after startup,  | ||||||
|  | before shutdown, and after shutdown. Therefore, in ASGI mode, the startup and shutdown events will  | ||||||
|  | run consecutively and not actually around the server process beginning and ending (since that  | ||||||
|  | is now controlled by the ASGI server). Therefore, it is best to use `after_server_start` and  | ||||||
|  | `before_server_stop`. | ||||||
|  | 1. ASGI mode is still in "beta" as of Sanic v19.6. | ||||||
|  |  | ||||||
| ## Running via Gunicorn | ## Running via Gunicorn | ||||||
|  |  | ||||||
| [Gunicorn](http://gunicorn.org/) ‘Green Unicorn’ is a WSGI HTTP Server for UNIX. | [Gunicorn](http://gunicorn.org/) ‘Green Unicorn’ is a WSGI HTTP Server for UNIX. | ||||||
| @@ -64,7 +102,9 @@ of the memory leak. | |||||||
|  |  | ||||||
| See the [Gunicorn Docs](http://docs.gunicorn.org/en/latest/settings.html#max-requests) for more information. | See the [Gunicorn Docs](http://docs.gunicorn.org/en/latest/settings.html#max-requests) for more information. | ||||||
|  |  | ||||||
| ## Running behind a reverse proxy | ## Other deployment considerations | ||||||
|  |  | ||||||
|  | ### Running behind a reverse proxy | ||||||
|  |  | ||||||
| Sanic can be used with a reverse proxy (e.g. nginx). There's a simple example of nginx configuration: | Sanic can be used with a reverse proxy (e.g. nginx). There's a simple example of nginx configuration: | ||||||
|  |  | ||||||
| @@ -84,7 +124,7 @@ server { | |||||||
|  |  | ||||||
| If you want to get real client ip, you should configure `X-Real-IP` and `X-Forwarded-For` HTTP headers and set `app.config.PROXIES_COUNT` to `1`; see the configuration page for more information. | If you want to get real client ip, you should configure `X-Real-IP` and `X-Forwarded-For` HTTP headers and set `app.config.PROXIES_COUNT` to `1`; see the configuration page for more information. | ||||||
|  |  | ||||||
| ## Disable debug logging | ### Disable debug logging for performance | ||||||
|  |  | ||||||
| To improve the performance add `debug=False` and `access_log=False` in the `run` arguments. | To improve the performance add `debug=False` and `access_log=False` in the `run` arguments. | ||||||
|  |  | ||||||
| @@ -104,9 +144,10 @@ Or you can rewrite app config directly | |||||||
| app.config.ACCESS_LOG = False | app.config.ACCESS_LOG = False | ||||||
| ``` | ``` | ||||||
|  |  | ||||||
| ## Asynchronous support | ### Asynchronous support and sharing the loop | ||||||
| This is suitable if you *need* to share the sanic process with other applications, in particular the `loop`. |  | ||||||
| However be advised that this method does not support using multiple processes, and is not the preferred way | This is suitable if you *need* to share the Sanic process with other applications, in particular the `loop`. | ||||||
|  | However, be advised that this method does not support using multiple processes, and is not the preferred way | ||||||
| to run the app in general. | to run the app in general. | ||||||
|  |  | ||||||
| Here is an incomplete example (please see `run_async.py` in examples for something more practical): | Here is an incomplete example (please see `run_async.py` in examples for something more practical): | ||||||
|   | |||||||
							
								
								
									
										88
									
								
								examples/run_asgi.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										88
									
								
								examples/run_asgi.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,88 @@ | |||||||
|  | """ | ||||||
|  | 1. Create a simple Sanic app | ||||||
|  | 0. Run with an ASGI server: | ||||||
|  |     $ uvicorn run_asgi:app | ||||||
|  |     or | ||||||
|  |     $ hypercorn run_asgi:app | ||||||
|  | """ | ||||||
|  |  | ||||||
|  | from pathlib import Path | ||||||
|  | from sanic import Sanic, response | ||||||
|  |  | ||||||
|  |  | ||||||
|  | app = Sanic(__name__) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @app.route("/text") | ||||||
|  | def handler_text(request): | ||||||
|  |     return response.text("Hello") | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @app.route("/json") | ||||||
|  | def handler_json(request): | ||||||
|  |     return response.json({"foo": "bar"}) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @app.websocket("/ws") | ||||||
|  | async def handler_ws(request, ws): | ||||||
|  |     name = "<someone>" | ||||||
|  |     while True: | ||||||
|  |         data = f"Hello {name}" | ||||||
|  |         await ws.send(data) | ||||||
|  |         name = await ws.recv() | ||||||
|  |  | ||||||
|  |         if not name: | ||||||
|  |             break | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @app.route("/file") | ||||||
|  | async def handler_file(request): | ||||||
|  |     return await response.file(Path("../") / "setup.py") | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @app.route("/file_stream") | ||||||
|  | async def handler_file_stream(request): | ||||||
|  |     return await response.file_stream( | ||||||
|  |         Path("../") / "setup.py", chunk_size=1024 | ||||||
|  |     ) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @app.route("/stream", stream=True) | ||||||
|  | async def handler_stream(request): | ||||||
|  |     while True: | ||||||
|  |         body = await request.stream.read() | ||||||
|  |         if body is None: | ||||||
|  |             break | ||||||
|  |         body = body.decode("utf-8").replace("1", "A") | ||||||
|  |         # await response.write(body) | ||||||
|  |     return response.stream(body) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @app.listener("before_server_start") | ||||||
|  | async def listener_before_server_start(*args, **kwargs): | ||||||
|  |     print("before_server_start") | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @app.listener("after_server_start") | ||||||
|  | async def listener_after_server_start(*args, **kwargs): | ||||||
|  |     print("after_server_start") | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @app.listener("before_server_stop") | ||||||
|  | async def listener_before_server_stop(*args, **kwargs): | ||||||
|  |     print("before_server_stop") | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @app.listener("after_server_stop") | ||||||
|  | async def listener_after_server_stop(*args, **kwargs): | ||||||
|  |     print("after_server_stop") | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @app.middleware("request") | ||||||
|  | async def print_on_request(request): | ||||||
|  |     print("print_on_request") | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @app.middleware("response") | ||||||
|  | async def print_on_response(request, response): | ||||||
|  |     print("print_on_response") | ||||||
							
								
								
									
										62
									
								
								sanic/app.py
									
									
									
									
									
								
							
							
						
						
									
										62
									
								
								sanic/app.py
									
									
									
									
									
								
							| @@ -15,6 +15,7 @@ from typing import Any, Optional, Type, Union | |||||||
| from urllib.parse import urlencode, urlunparse | from urllib.parse import urlencode, urlunparse | ||||||
|  |  | ||||||
| from sanic import reloader_helpers | from sanic import reloader_helpers | ||||||
|  | from sanic.asgi import ASGIApp | ||||||
| from sanic.blueprint_group import BlueprintGroup | from sanic.blueprint_group import BlueprintGroup | ||||||
| from sanic.config import BASE_LOGO, Config | from sanic.config import BASE_LOGO, Config | ||||||
| from sanic.constants import HTTP_METHODS | from sanic.constants import HTTP_METHODS | ||||||
| @@ -25,7 +26,7 @@ from sanic.response import HTTPResponse, StreamingHTTPResponse | |||||||
| from sanic.router import Router | from sanic.router import Router | ||||||
| from sanic.server import HttpProtocol, Signal, serve, serve_multiple | from sanic.server import HttpProtocol, Signal, serve, serve_multiple | ||||||
| from sanic.static import register as static_register | from sanic.static import register as static_register | ||||||
| from sanic.testing import SanicTestClient | from sanic.testing import SanicASGITestClient, SanicTestClient | ||||||
| from sanic.views import CompositionView | from sanic.views import CompositionView | ||||||
| from sanic.websocket import ConnectionClosed, WebSocketProtocol | from sanic.websocket import ConnectionClosed, WebSocketProtocol | ||||||
|  |  | ||||||
| @@ -53,6 +54,7 @@ class Sanic: | |||||||
|             logging.config.dictConfig(log_config or LOGGING_CONFIG_DEFAULTS) |             logging.config.dictConfig(log_config or LOGGING_CONFIG_DEFAULTS) | ||||||
|  |  | ||||||
|         self.name = name |         self.name = name | ||||||
|  |         self.asgi = False | ||||||
|         self.router = router or Router() |         self.router = router or Router() | ||||||
|         self.request_class = request_class |         self.request_class = request_class | ||||||
|         self.error_handler = error_handler or ErrorHandler() |         self.error_handler = error_handler or ErrorHandler() | ||||||
| @@ -80,7 +82,7 @@ class Sanic: | |||||||
|  |  | ||||||
|         Only supported when using the `app.run` method. |         Only supported when using the `app.run` method. | ||||||
|         """ |         """ | ||||||
|         if not self.is_running: |         if not self.is_running and self.asgi is False: | ||||||
|             raise SanicException( |             raise SanicException( | ||||||
|                 "Loop can only be retrieved after the app has started " |                 "Loop can only be retrieved after the app has started " | ||||||
|                 "running. Not supported with `create_server` function" |                 "running. Not supported with `create_server` function" | ||||||
| @@ -469,13 +471,23 @@ class Sanic: | |||||||
|                         getattr(handler, "__blueprintname__", "") |                         getattr(handler, "__blueprintname__", "") | ||||||
|                         + handler.__name__ |                         + handler.__name__ | ||||||
|                     ) |                     ) | ||||||
|                 try: |  | ||||||
|                     protocol = request.transport.get_protocol() |                     pass | ||||||
|                 except AttributeError: |  | ||||||
|                     # On Python3.5 the Transport classes in asyncio do not |                 if self.asgi: | ||||||
|                     # have a get_protocol() method as in uvloop |                     ws = request.transport.get_websocket_connection() | ||||||
|                     protocol = request.transport._protocol |                 else: | ||||||
|                 ws = await protocol.websocket_handshake(request, subprotocols) |                     try: | ||||||
|  |                         protocol = request.transport.get_protocol() | ||||||
|  |                     except AttributeError: | ||||||
|  |                         # On Python3.5 the Transport classes in asyncio do not | ||||||
|  |                         # have a get_protocol() method as in uvloop | ||||||
|  |                         protocol = request.transport._protocol | ||||||
|  |                     protocol.app = self | ||||||
|  |  | ||||||
|  |                     ws = await protocol.websocket_handshake( | ||||||
|  |                         request, subprotocols | ||||||
|  |                     ) | ||||||
|  |  | ||||||
|                 # schedule the application handler |                 # schedule the application handler | ||||||
|                 # its future is kept in self.websocket_tasks in case it |                 # its future is kept in self.websocket_tasks in case it | ||||||
| @@ -983,8 +995,16 @@ class Sanic: | |||||||
|                 raise CancelledError() |                 raise CancelledError() | ||||||
|  |  | ||||||
|         # pass the response to the correct callback |         # pass the response to the correct callback | ||||||
|         if isinstance(response, StreamingHTTPResponse): |         if write_callback is None or isinstance( | ||||||
|             await stream_callback(response) |             response, StreamingHTTPResponse | ||||||
|  |         ): | ||||||
|  |             if stream_callback: | ||||||
|  |                 await stream_callback(response) | ||||||
|  |             else: | ||||||
|  |                 # Should only end here IF it is an ASGI websocket. | ||||||
|  |                 # TODO: | ||||||
|  |                 # - Add exception handling | ||||||
|  |                 pass | ||||||
|         else: |         else: | ||||||
|             write_callback(response) |             write_callback(response) | ||||||
|  |  | ||||||
| @@ -996,6 +1016,10 @@ class Sanic: | |||||||
|     def test_client(self): |     def test_client(self): | ||||||
|         return SanicTestClient(self) |         return SanicTestClient(self) | ||||||
|  |  | ||||||
|  |     @property | ||||||
|  |     def asgi_client(self): | ||||||
|  |         return SanicASGITestClient(self) | ||||||
|  |  | ||||||
|     # -------------------------------------------------------------------- # |     # -------------------------------------------------------------------- # | ||||||
|     # Execution |     # Execution | ||||||
|     # -------------------------------------------------------------------- # |     # -------------------------------------------------------------------- # | ||||||
| @@ -1122,10 +1146,6 @@ class Sanic: | |||||||
|         """This kills the Sanic""" |         """This kills the Sanic""" | ||||||
|         get_event_loop().stop() |         get_event_loop().stop() | ||||||
|  |  | ||||||
|     def __call__(self): |  | ||||||
|         """gunicorn compatibility""" |  | ||||||
|         return self |  | ||||||
|  |  | ||||||
|     async def create_server( |     async def create_server( | ||||||
|         self, |         self, | ||||||
|         host: Optional[str] = None, |         host: Optional[str] = None, | ||||||
| @@ -1367,3 +1387,15 @@ class Sanic: | |||||||
|     def _build_endpoint_name(self, *parts): |     def _build_endpoint_name(self, *parts): | ||||||
|         parts = [self.name, *parts] |         parts = [self.name, *parts] | ||||||
|         return ".".join(parts) |         return ".".join(parts) | ||||||
|  |  | ||||||
|  |     # -------------------------------------------------------------------- # | ||||||
|  |     # ASGI | ||||||
|  |     # -------------------------------------------------------------------- # | ||||||
|  |  | ||||||
|  |     async def __call__(self, scope, receive, send): | ||||||
|  |         """To be ASGI compliant, our instance must be a callable that accepts | ||||||
|  |         three arguments: scope, receive, send. See the ASGI reference for more | ||||||
|  |         details: https://asgi.readthedocs.io/en/latest/""" | ||||||
|  |         self.asgi = True | ||||||
|  |         asgi_app = await ASGIApp.create(self, scope, receive, send) | ||||||
|  |         await asgi_app() | ||||||
|   | |||||||
							
								
								
									
										350
									
								
								sanic/asgi.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										350
									
								
								sanic/asgi.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,350 @@ | |||||||
|  | import asyncio | ||||||
|  | import warnings | ||||||
|  |  | ||||||
|  | from http.cookies import SimpleCookie | ||||||
|  | from inspect import isawaitable | ||||||
|  | from typing import Any, Awaitable, Callable, MutableMapping, Union | ||||||
|  | from urllib.parse import quote | ||||||
|  |  | ||||||
|  | from multidict import CIMultiDict | ||||||
|  |  | ||||||
|  | from sanic.exceptions import InvalidUsage, ServerError | ||||||
|  | from sanic.log import logger | ||||||
|  | from sanic.request import Request | ||||||
|  | from sanic.response import HTTPResponse, StreamingHTTPResponse | ||||||
|  | from sanic.server import StreamBuffer | ||||||
|  | from sanic.websocket import WebSocketConnection | ||||||
|  |  | ||||||
|  |  | ||||||
|  | ASGIScope = MutableMapping[str, Any] | ||||||
|  | ASGIMessage = MutableMapping[str, Any] | ||||||
|  | ASGISend = Callable[[ASGIMessage], Awaitable[None]] | ||||||
|  | ASGIReceive = Callable[[], Awaitable[ASGIMessage]] | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class MockProtocol: | ||||||
|  |     def __init__(self, transport: "MockTransport", loop): | ||||||
|  |         self.transport = transport | ||||||
|  |         self._not_paused = asyncio.Event(loop=loop) | ||||||
|  |         self._not_paused.set() | ||||||
|  |         self._complete = asyncio.Event(loop=loop) | ||||||
|  |  | ||||||
|  |     def pause_writing(self) -> None: | ||||||
|  |         self._not_paused.clear() | ||||||
|  |  | ||||||
|  |     def resume_writing(self) -> None: | ||||||
|  |         self._not_paused.set() | ||||||
|  |  | ||||||
|  |     async def complete(self) -> None: | ||||||
|  |         self._not_paused.set() | ||||||
|  |         await self.transport.send( | ||||||
|  |             {"type": "http.response.body", "body": b"", "more_body": False} | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |     @property | ||||||
|  |     def is_complete(self) -> bool: | ||||||
|  |         return self._complete.is_set() | ||||||
|  |  | ||||||
|  |     async def push_data(self, data: bytes) -> None: | ||||||
|  |         if not self.is_complete: | ||||||
|  |             await self.transport.send( | ||||||
|  |                 {"type": "http.response.body", "body": data, "more_body": True} | ||||||
|  |             ) | ||||||
|  |  | ||||||
|  |     async def drain(self) -> None: | ||||||
|  |         await self._not_paused.wait() | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class MockTransport: | ||||||
|  |     def __init__( | ||||||
|  |         self, scope: ASGIScope, receive: ASGIReceive, send: ASGISend | ||||||
|  |     ) -> None: | ||||||
|  |         self.scope = scope | ||||||
|  |         self._receive = receive | ||||||
|  |         self._send = send | ||||||
|  |         self._protocol = None | ||||||
|  |         self.loop = None | ||||||
|  |  | ||||||
|  |     def get_protocol(self) -> MockProtocol: | ||||||
|  |         if not self._protocol: | ||||||
|  |             self._protocol = MockProtocol(self, self.loop) | ||||||
|  |         return self._protocol | ||||||
|  |  | ||||||
|  |     def get_extra_info(self, info: str) -> Union[str, bool]: | ||||||
|  |         if info == "peername": | ||||||
|  |             return self.scope.get("server") | ||||||
|  |         elif info == "sslcontext": | ||||||
|  |             return self.scope.get("scheme") in ["https", "wss"] | ||||||
|  |  | ||||||
|  |     def get_websocket_connection(self) -> WebSocketConnection: | ||||||
|  |         try: | ||||||
|  |             return self._websocket_connection | ||||||
|  |         except AttributeError: | ||||||
|  |             raise InvalidUsage("Improper websocket connection.") | ||||||
|  |  | ||||||
|  |     def create_websocket_connection( | ||||||
|  |         self, send: ASGISend, receive: ASGIReceive | ||||||
|  |     ) -> WebSocketConnection: | ||||||
|  |         self._websocket_connection = WebSocketConnection(send, receive) | ||||||
|  |         return self._websocket_connection | ||||||
|  |  | ||||||
|  |     def add_task(self) -> None:  # noqa | ||||||
|  |         raise NotImplementedError | ||||||
|  |  | ||||||
|  |     async def send(self, data) -> None: | ||||||
|  |         # TODO: | ||||||
|  |         # - Validation on data and that it is formatted properly and is valid | ||||||
|  |         await self._send(data) | ||||||
|  |  | ||||||
|  |     async def receive(self) -> ASGIMessage: | ||||||
|  |         return await self._receive() | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class Lifespan: | ||||||
|  |     def __init__(self, asgi_app: "ASGIApp") -> None: | ||||||
|  |         self.asgi_app = asgi_app | ||||||
|  |  | ||||||
|  |         if "before_server_start" in self.asgi_app.sanic_app.listeners: | ||||||
|  |             warnings.warn( | ||||||
|  |                 'You have set a listener for "before_server_start" ' | ||||||
|  |                 "in ASGI mode. " | ||||||
|  |                 "It will be executed as early as possible, but not before " | ||||||
|  |                 "the ASGI server is started." | ||||||
|  |             ) | ||||||
|  |         if "after_server_stop" in self.asgi_app.sanic_app.listeners: | ||||||
|  |             warnings.warn( | ||||||
|  |                 'You have set a listener for "after_server_stop" ' | ||||||
|  |                 "in ASGI mode. " | ||||||
|  |                 "It will be executed as late as possible, but not after " | ||||||
|  |                 "the ASGI server is stopped." | ||||||
|  |             ) | ||||||
|  |  | ||||||
|  |     # async def pre_startup(self) -> None: | ||||||
|  |     #     for handler in self.asgi_app.sanic_app.listeners[ | ||||||
|  |     #         "before_server_start" | ||||||
|  |     #     ]: | ||||||
|  |     #         response = handler( | ||||||
|  |     #             self.asgi_app.sanic_app, self.asgi_app.sanic_app.loop | ||||||
|  |     #         ) | ||||||
|  |     #         if isawaitable(response): | ||||||
|  |     #             await response | ||||||
|  |  | ||||||
|  |     async def startup(self) -> None: | ||||||
|  |         for handler in self.asgi_app.sanic_app.listeners[ | ||||||
|  |             "before_server_start" | ||||||
|  |         ]: | ||||||
|  |             response = handler( | ||||||
|  |                 self.asgi_app.sanic_app, self.asgi_app.sanic_app.loop | ||||||
|  |             ) | ||||||
|  |             if isawaitable(response): | ||||||
|  |                 await response | ||||||
|  |  | ||||||
|  |         for handler in self.asgi_app.sanic_app.listeners["after_server_start"]: | ||||||
|  |             response = handler( | ||||||
|  |                 self.asgi_app.sanic_app, self.asgi_app.sanic_app.loop | ||||||
|  |             ) | ||||||
|  |             if isawaitable(response): | ||||||
|  |                 await response | ||||||
|  |  | ||||||
|  |     async def shutdown(self) -> None: | ||||||
|  |         for handler in self.asgi_app.sanic_app.listeners["before_server_stop"]: | ||||||
|  |             response = handler( | ||||||
|  |                 self.asgi_app.sanic_app, self.asgi_app.sanic_app.loop | ||||||
|  |             ) | ||||||
|  |             if isawaitable(response): | ||||||
|  |                 await response | ||||||
|  |  | ||||||
|  |         for handler in self.asgi_app.sanic_app.listeners["after_server_stop"]: | ||||||
|  |             response = handler( | ||||||
|  |                 self.asgi_app.sanic_app, self.asgi_app.sanic_app.loop | ||||||
|  |             ) | ||||||
|  |             if isawaitable(response): | ||||||
|  |                 await response | ||||||
|  |  | ||||||
|  |     async def __call__( | ||||||
|  |         self, scope: ASGIScope, receive: ASGIReceive, send: ASGISend | ||||||
|  |     ) -> None: | ||||||
|  |         message = await receive() | ||||||
|  |         if message["type"] == "lifespan.startup": | ||||||
|  |             await self.startup() | ||||||
|  |             await send({"type": "lifespan.startup.complete"}) | ||||||
|  |  | ||||||
|  |         message = await receive() | ||||||
|  |         if message["type"] == "lifespan.shutdown": | ||||||
|  |             await self.shutdown() | ||||||
|  |             await send({"type": "lifespan.shutdown.complete"}) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class ASGIApp: | ||||||
|  |     def __init__(self) -> None: | ||||||
|  |         self.ws = None | ||||||
|  |  | ||||||
|  |     @classmethod | ||||||
|  |     async def create( | ||||||
|  |         cls, sanic_app, scope: ASGIScope, receive: ASGIReceive, send: ASGISend | ||||||
|  |     ) -> "ASGIApp": | ||||||
|  |         instance = cls() | ||||||
|  |         instance.sanic_app = sanic_app | ||||||
|  |         instance.transport = MockTransport(scope, receive, send) | ||||||
|  |         instance.transport.add_task = sanic_app.loop.create_task | ||||||
|  |         instance.transport.loop = sanic_app.loop | ||||||
|  |  | ||||||
|  |         headers = CIMultiDict( | ||||||
|  |             [ | ||||||
|  |                 (key.decode("latin-1"), value.decode("latin-1")) | ||||||
|  |                 for key, value in scope.get("headers", []) | ||||||
|  |             ] | ||||||
|  |         ) | ||||||
|  |         instance.do_stream = ( | ||||||
|  |             True if headers.get("expect") == "100-continue" else False | ||||||
|  |         ) | ||||||
|  |         instance.lifespan = Lifespan(instance) | ||||||
|  |  | ||||||
|  |         if scope["type"] == "lifespan": | ||||||
|  |             await instance.lifespan(scope, receive, send) | ||||||
|  |         else: | ||||||
|  |             url_bytes = scope.get("root_path", "") + quote(scope["path"]) | ||||||
|  |             url_bytes = url_bytes.encode("latin-1") | ||||||
|  |             url_bytes += b"?" + scope["query_string"] | ||||||
|  |  | ||||||
|  |             if scope["type"] == "http": | ||||||
|  |                 version = scope["http_version"] | ||||||
|  |                 method = scope["method"] | ||||||
|  |             elif scope["type"] == "websocket": | ||||||
|  |                 version = "1.1" | ||||||
|  |                 method = "GET" | ||||||
|  |  | ||||||
|  |                 instance.ws = instance.transport.create_websocket_connection( | ||||||
|  |                     send, receive | ||||||
|  |                 ) | ||||||
|  |                 await instance.ws.accept() | ||||||
|  |             else: | ||||||
|  |                 pass | ||||||
|  |                 # TODO: | ||||||
|  |                 # - close connection | ||||||
|  |  | ||||||
|  |             instance.request = Request( | ||||||
|  |                 url_bytes, | ||||||
|  |                 headers, | ||||||
|  |                 version, | ||||||
|  |                 method, | ||||||
|  |                 instance.transport, | ||||||
|  |                 sanic_app, | ||||||
|  |             ) | ||||||
|  |  | ||||||
|  |             if sanic_app.is_request_stream: | ||||||
|  |                 is_stream_handler = sanic_app.router.is_stream_handler( | ||||||
|  |                     instance.request | ||||||
|  |                 ) | ||||||
|  |                 if is_stream_handler: | ||||||
|  |                     instance.request.stream = StreamBuffer( | ||||||
|  |                         sanic_app.config.REQUEST_BUFFER_QUEUE_SIZE | ||||||
|  |                     ) | ||||||
|  |                     instance.do_stream = True | ||||||
|  |  | ||||||
|  |         return instance | ||||||
|  |  | ||||||
|  |     async def read_body(self) -> bytes: | ||||||
|  |         """ | ||||||
|  |         Read and return the entire body from an incoming ASGI message. | ||||||
|  |         """ | ||||||
|  |         body = b"" | ||||||
|  |         more_body = True | ||||||
|  |         while more_body: | ||||||
|  |             message = await self.transport.receive() | ||||||
|  |             body += message.get("body", b"") | ||||||
|  |             more_body = message.get("more_body", False) | ||||||
|  |  | ||||||
|  |         return body | ||||||
|  |  | ||||||
|  |     async def stream_body(self) -> None: | ||||||
|  |         """ | ||||||
|  |         Read and stream the body in chunks from an incoming ASGI message. | ||||||
|  |         """ | ||||||
|  |         more_body = True | ||||||
|  |  | ||||||
|  |         while more_body: | ||||||
|  |             message = await self.transport.receive() | ||||||
|  |             chunk = message.get("body", b"") | ||||||
|  |             await self.request.stream.put(chunk) | ||||||
|  |  | ||||||
|  |             more_body = message.get("more_body", False) | ||||||
|  |  | ||||||
|  |         await self.request.stream.put(None) | ||||||
|  |  | ||||||
|  |     async def __call__(self) -> None: | ||||||
|  |         """ | ||||||
|  |         Handle the incoming request. | ||||||
|  |         """ | ||||||
|  |         if not self.do_stream: | ||||||
|  |             self.request.body = await self.read_body() | ||||||
|  |         else: | ||||||
|  |             self.sanic_app.loop.create_task(self.stream_body()) | ||||||
|  |  | ||||||
|  |         handler = self.sanic_app.handle_request | ||||||
|  |         callback = None if self.ws else self.stream_callback | ||||||
|  |         await handler(self.request, None, callback) | ||||||
|  |  | ||||||
|  |     async def stream_callback(self, response: HTTPResponse) -> None: | ||||||
|  |         """ | ||||||
|  |         Write the response. | ||||||
|  |         """ | ||||||
|  |  | ||||||
|  |         try: | ||||||
|  |             headers = [ | ||||||
|  |                 (str(name).encode("latin-1"), str(value).encode("latin-1")) | ||||||
|  |                 for name, value in response.headers.items() | ||||||
|  |             ] | ||||||
|  |         except AttributeError: | ||||||
|  |             logger.error( | ||||||
|  |                 "Invalid response object for url %s, " | ||||||
|  |                 "Expected Type: HTTPResponse, Actual Type: %s", | ||||||
|  |                 self.request.url, | ||||||
|  |                 type(response), | ||||||
|  |             ) | ||||||
|  |             exception = ServerError("Invalid response type") | ||||||
|  |             response = self.sanic_app.error_handler.response( | ||||||
|  |                 self.request, exception | ||||||
|  |             ) | ||||||
|  |             headers = [ | ||||||
|  |                 (str(name).encode("latin-1"), str(value).encode("latin-1")) | ||||||
|  |                 for name, value in response.headers.items() | ||||||
|  |                 if name not in (b"Set-Cookie",) | ||||||
|  |             ] | ||||||
|  |  | ||||||
|  |         if "content-length" not in response.headers and not isinstance( | ||||||
|  |             response, StreamingHTTPResponse | ||||||
|  |         ): | ||||||
|  |             headers += [ | ||||||
|  |                 (b"content-length", str(len(response.body)).encode("latin-1")) | ||||||
|  |             ] | ||||||
|  |  | ||||||
|  |         if response.cookies: | ||||||
|  |             cookies = SimpleCookie() | ||||||
|  |             cookies.load(response.cookies) | ||||||
|  |             headers += [ | ||||||
|  |                 (b"set-cookie", cookie.encode("utf-8")) | ||||||
|  |                 for name, cookie in response.cookies.items() | ||||||
|  |             ] | ||||||
|  |  | ||||||
|  |         await self.transport.send( | ||||||
|  |             { | ||||||
|  |                 "type": "http.response.start", | ||||||
|  |                 "status": response.status, | ||||||
|  |                 "headers": headers, | ||||||
|  |             } | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |         if isinstance(response, StreamingHTTPResponse): | ||||||
|  |             response.protocol = self.transport.get_protocol() | ||||||
|  |             await response.stream() | ||||||
|  |             await response.protocol.complete() | ||||||
|  |  | ||||||
|  |         else: | ||||||
|  |             await self.transport.send( | ||||||
|  |                 { | ||||||
|  |                     "type": "http.response.body", | ||||||
|  |                     "body": response.body, | ||||||
|  |                     "more_body": False, | ||||||
|  |                 } | ||||||
|  |             ) | ||||||
| @@ -87,9 +87,9 @@ class StreamingHTTPResponse(BaseHTTPResponse): | |||||||
|             data = self._encode_body(data) |             data = self._encode_body(data) | ||||||
|  |  | ||||||
|         if self.chunked: |         if self.chunked: | ||||||
|             self.protocol.push_data(b"%x\r\n%b\r\n" % (len(data), data)) |             await self.protocol.push_data(b"%x\r\n%b\r\n" % (len(data), data)) | ||||||
|         else: |         else: | ||||||
|             self.protocol.push_data(data) |             await self.protocol.push_data(data) | ||||||
|         await self.protocol.drain() |         await self.protocol.drain() | ||||||
|  |  | ||||||
|     async def stream( |     async def stream( | ||||||
| @@ -105,11 +105,11 @@ class StreamingHTTPResponse(BaseHTTPResponse): | |||||||
|             keep_alive=keep_alive, |             keep_alive=keep_alive, | ||||||
|             keep_alive_timeout=keep_alive_timeout, |             keep_alive_timeout=keep_alive_timeout, | ||||||
|         ) |         ) | ||||||
|         self.protocol.push_data(headers) |         await self.protocol.push_data(headers) | ||||||
|         await self.protocol.drain() |         await self.protocol.drain() | ||||||
|         await self.streaming_fn(self) |         await self.streaming_fn(self) | ||||||
|         if self.chunked: |         if self.chunked: | ||||||
|             self.protocol.push_data(b"0\r\n\r\n") |             await self.protocol.push_data(b"0\r\n\r\n") | ||||||
|         # no need to await drain here after this write, because it is the |         # no need to await drain here after this write, because it is the | ||||||
|         # very last thing we write and nothing needs to wait for it. |         # very last thing we write and nothing needs to wait for it. | ||||||
|  |  | ||||||
|   | |||||||
| @@ -406,6 +406,7 @@ class Router: | |||||||
|         if not self.hosts: |         if not self.hosts: | ||||||
|             return self._get(request.path, request.method, "") |             return self._get(request.path, request.method, "") | ||||||
|         # virtual hosts specified; try to match route to the host header |         # virtual hosts specified; try to match route to the host header | ||||||
|  |  | ||||||
|         try: |         try: | ||||||
|             return self._get( |             return self._get( | ||||||
|                 request.path, request.method, request.headers.get("Host", "") |                 request.path, request.method, request.headers.get("Host", "") | ||||||
|   | |||||||
| @@ -477,7 +477,7 @@ class HttpProtocol(asyncio.Protocol): | |||||||
|     async def drain(self): |     async def drain(self): | ||||||
|         await self._not_paused.wait() |         await self._not_paused.wait() | ||||||
|  |  | ||||||
|     def push_data(self, data): |     async def push_data(self, data): | ||||||
|         self.transport.write(data) |         self.transport.write(data) | ||||||
|  |  | ||||||
|     async def stream_response(self, response): |     async def stream_response(self, response): | ||||||
| @@ -728,6 +728,8 @@ def serve( | |||||||
|     if debug: |     if debug: | ||||||
|         loop.set_debug(debug) |         loop.set_debug(debug) | ||||||
|  |  | ||||||
|  |     app.asgi = False | ||||||
|  |  | ||||||
|     connections = connections if connections is not None else set() |     connections = connections if connections is not None else set() | ||||||
|     server = partial( |     server = partial( | ||||||
|         protocol, |         protocol, | ||||||
|   | |||||||
							
								
								
									
										262
									
								
								sanic/testing.py
									
									
									
									
									
								
							
							
						
						
									
										262
									
								
								sanic/testing.py
									
									
									
									
									
								
							| @@ -1,14 +1,22 @@ | |||||||
|  | import asyncio | ||||||
|  | import types | ||||||
|  | import typing | ||||||
|  |  | ||||||
| from json import JSONDecodeError | from json import JSONDecodeError | ||||||
| from socket import socket | from socket import socket | ||||||
|  | from urllib.parse import unquote, urlsplit | ||||||
|  |  | ||||||
|  | import httpcore | ||||||
| import requests_async as requests | import requests_async as requests | ||||||
| import websockets | import websockets | ||||||
|  |  | ||||||
|  | from sanic.asgi import ASGIApp | ||||||
| from sanic.exceptions import MethodNotSupported | from sanic.exceptions import MethodNotSupported | ||||||
| from sanic.log import logger | from sanic.log import logger | ||||||
| from sanic.response import text | from sanic.response import text | ||||||
|  |  | ||||||
|  |  | ||||||
|  | ASGI_HOST = "mockserver" | ||||||
| HOST = "127.0.0.1" | HOST = "127.0.0.1" | ||||||
| PORT = 42101 | PORT = 42101 | ||||||
|  |  | ||||||
| @@ -64,7 +72,7 @@ class SanicTestClient: | |||||||
|         debug=False, |         debug=False, | ||||||
|         server_kwargs={"auto_reload": False}, |         server_kwargs={"auto_reload": False}, | ||||||
|         *request_args, |         *request_args, | ||||||
|         **request_kwargs |         **request_kwargs, | ||||||
|     ): |     ): | ||||||
|         results = [None, None] |         results = [None, None] | ||||||
|         exceptions = [] |         exceptions = [] | ||||||
| @@ -128,7 +136,7 @@ class SanicTestClient: | |||||||
|             try: |             try: | ||||||
|                 request, response = results |                 request, response = results | ||||||
|                 return request, response |                 return request, response | ||||||
|             except BaseException: |             except BaseException:  # noqa | ||||||
|                 raise ValueError( |                 raise ValueError( | ||||||
|                     "Request and response object expected, got ({})".format( |                     "Request and response object expected, got ({})".format( | ||||||
|                         results |                         results | ||||||
| @@ -137,7 +145,7 @@ class SanicTestClient: | |||||||
|         else: |         else: | ||||||
|             try: |             try: | ||||||
|                 return results[-1] |                 return results[-1] | ||||||
|             except BaseException: |             except BaseException:  # noqa | ||||||
|                 raise ValueError( |                 raise ValueError( | ||||||
|                     "Request object expected, got ({})".format(results) |                     "Request object expected, got ({})".format(results) | ||||||
|                 ) |                 ) | ||||||
| @@ -165,3 +173,251 @@ class SanicTestClient: | |||||||
|  |  | ||||||
|     def websocket(self, *args, **kwargs): |     def websocket(self, *args, **kwargs): | ||||||
|         return self._sanic_endpoint_test("websocket", *args, **kwargs) |         return self._sanic_endpoint_test("websocket", *args, **kwargs) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class SanicASGIAdapter(requests.asgi.ASGIAdapter):  # noqa | ||||||
|  |     async def send(  # type: ignore | ||||||
|  |         self, | ||||||
|  |         request: requests.PreparedRequest, | ||||||
|  |         gather_return: bool = False, | ||||||
|  |         *args: typing.Any, | ||||||
|  |         **kwargs: typing.Any, | ||||||
|  |     ) -> requests.Response: | ||||||
|  |         """This method is taken MOSTLY verbatim from requests-asyn. The | ||||||
|  |         difference is the capturing of a response on the ASGI call and then | ||||||
|  |         returning it on the response object. This is implemented to achieve: | ||||||
|  |  | ||||||
|  |         request, response = await app.asgi_client.get("/") | ||||||
|  |  | ||||||
|  |         You can see the original code here: | ||||||
|  |         https://github.com/encode/requests-async/blob/614f40f77f19e6c6da8a212ae799107b0384dbf9/requests_async/asgi.py#L51"""  # noqa | ||||||
|  |         scheme, netloc, path, query, fragment = urlsplit( | ||||||
|  |             request.url | ||||||
|  |         )  # type: ignore | ||||||
|  |  | ||||||
|  |         default_port = {"http": 80, "ws": 80, "https": 443, "wss": 443}[scheme] | ||||||
|  |  | ||||||
|  |         if ":" in netloc: | ||||||
|  |             host, port_string = netloc.split(":", 1) | ||||||
|  |             port = int(port_string) | ||||||
|  |         else: | ||||||
|  |             host = netloc | ||||||
|  |             port = default_port | ||||||
|  |  | ||||||
|  |         # Include the 'host' header. | ||||||
|  |         if "host" in request.headers: | ||||||
|  |             headers = []  # type: typing.List[typing.Tuple[bytes, bytes]] | ||||||
|  |         elif port == default_port: | ||||||
|  |             headers = [(b"host", host.encode())] | ||||||
|  |         else: | ||||||
|  |             headers = [(b"host", (f"{host}:{port}").encode())] | ||||||
|  |  | ||||||
|  |         # Include other request headers. | ||||||
|  |         headers += [ | ||||||
|  |             (key.lower().encode(), value.encode()) | ||||||
|  |             for key, value in request.headers.items() | ||||||
|  |         ] | ||||||
|  |  | ||||||
|  |         no_response = False | ||||||
|  |         if scheme in {"ws", "wss"}: | ||||||
|  |             subprotocol = request.headers.get("sec-websocket-protocol", None) | ||||||
|  |             if subprotocol is None: | ||||||
|  |                 subprotocols = []  # type: typing.Sequence[str] | ||||||
|  |             else: | ||||||
|  |                 subprotocols = [ | ||||||
|  |                     value.strip() for value in subprotocol.split(",") | ||||||
|  |                 ] | ||||||
|  |  | ||||||
|  |             scope = { | ||||||
|  |                 "type": "websocket", | ||||||
|  |                 "path": unquote(path), | ||||||
|  |                 "root_path": "", | ||||||
|  |                 "scheme": scheme, | ||||||
|  |                 "query_string": query.encode(), | ||||||
|  |                 "headers": headers, | ||||||
|  |                 "client": ["testclient", 50000], | ||||||
|  |                 "server": [host, port], | ||||||
|  |                 "subprotocols": subprotocols, | ||||||
|  |             } | ||||||
|  |             no_response = True | ||||||
|  |  | ||||||
|  |         else: | ||||||
|  |             scope = { | ||||||
|  |                 "type": "http", | ||||||
|  |                 "http_version": "1.1", | ||||||
|  |                 "method": request.method, | ||||||
|  |                 "path": unquote(path), | ||||||
|  |                 "root_path": "", | ||||||
|  |                 "scheme": scheme, | ||||||
|  |                 "query_string": query.encode(), | ||||||
|  |                 "headers": headers, | ||||||
|  |                 "client": ["testclient", 50000], | ||||||
|  |                 "server": [host, port], | ||||||
|  |                 "extensions": {"http.response.template": {}}, | ||||||
|  |             } | ||||||
|  |  | ||||||
|  |         async def receive(): | ||||||
|  |             nonlocal request_complete, response_complete | ||||||
|  |  | ||||||
|  |             if request_complete: | ||||||
|  |                 while not response_complete: | ||||||
|  |                     await asyncio.sleep(0.0001) | ||||||
|  |                 return {"type": "http.disconnect"} | ||||||
|  |  | ||||||
|  |             body = request.body | ||||||
|  |             if isinstance(body, str): | ||||||
|  |                 body_bytes = body.encode("utf-8")  # type: bytes | ||||||
|  |             elif body is None: | ||||||
|  |                 body_bytes = b"" | ||||||
|  |             elif isinstance(body, types.GeneratorType): | ||||||
|  |                 try: | ||||||
|  |                     chunk = body.send(None) | ||||||
|  |                     if isinstance(chunk, str): | ||||||
|  |                         chunk = chunk.encode("utf-8") | ||||||
|  |                     return { | ||||||
|  |                         "type": "http.request", | ||||||
|  |                         "body": chunk, | ||||||
|  |                         "more_body": True, | ||||||
|  |                     } | ||||||
|  |                 except StopIteration: | ||||||
|  |                     request_complete = True | ||||||
|  |                     return {"type": "http.request", "body": b""} | ||||||
|  |             else: | ||||||
|  |                 body_bytes = body | ||||||
|  |  | ||||||
|  |             request_complete = True | ||||||
|  |             return {"type": "http.request", "body": body_bytes} | ||||||
|  |  | ||||||
|  |         async def send(message) -> None: | ||||||
|  |             nonlocal raw_kwargs, response_started, response_complete, template, context  # noqa | ||||||
|  |  | ||||||
|  |             if message["type"] == "http.response.start": | ||||||
|  |                 assert ( | ||||||
|  |                     not response_started | ||||||
|  |                 ), 'Received multiple "http.response.start" messages.' | ||||||
|  |                 raw_kwargs["status_code"] = message["status"] | ||||||
|  |                 raw_kwargs["headers"] = message["headers"] | ||||||
|  |                 response_started = True | ||||||
|  |             elif message["type"] == "http.response.body": | ||||||
|  |                 assert response_started, ( | ||||||
|  |                     'Received "http.response.body" ' | ||||||
|  |                     'without "http.response.start".' | ||||||
|  |                 ) | ||||||
|  |                 assert ( | ||||||
|  |                     not response_complete | ||||||
|  |                 ), 'Received "http.response.body" after response completed.' | ||||||
|  |                 body = message.get("body", b"") | ||||||
|  |                 more_body = message.get("more_body", False) | ||||||
|  |                 if request.method != "HEAD": | ||||||
|  |                     raw_kwargs["content"] += body | ||||||
|  |                 if not more_body: | ||||||
|  |                     response_complete = True | ||||||
|  |             elif message["type"] == "http.response.template": | ||||||
|  |                 template = message["template"] | ||||||
|  |                 context = message["context"] | ||||||
|  |  | ||||||
|  |         request_complete = False | ||||||
|  |         response_started = False | ||||||
|  |         response_complete = False | ||||||
|  |         raw_kwargs = {"content": b""}  # type: typing.Dict[str, typing.Any] | ||||||
|  |         template = None | ||||||
|  |         context = None | ||||||
|  |         return_value = None | ||||||
|  |  | ||||||
|  |         try: | ||||||
|  |             return_value = await self.app(scope, receive, send) | ||||||
|  |         except BaseException as exc: | ||||||
|  |             if not self.suppress_exceptions: | ||||||
|  |                 raise exc from None | ||||||
|  |  | ||||||
|  |         if no_response: | ||||||
|  |             response_started = True | ||||||
|  |             raw_kwargs = {"status_code": 204, "headers": []} | ||||||
|  |  | ||||||
|  |         if not self.suppress_exceptions: | ||||||
|  |             assert response_started, "TestClient did not receive any response." | ||||||
|  |         elif not response_started: | ||||||
|  |             raw_kwargs = {"status_code": 500, "headers": []} | ||||||
|  |  | ||||||
|  |         raw = httpcore.Response(**raw_kwargs) | ||||||
|  |         response = self.build_response(request, raw) | ||||||
|  |         if template is not None: | ||||||
|  |             response.template = template | ||||||
|  |             response.context = context | ||||||
|  |  | ||||||
|  |         if gather_return: | ||||||
|  |             response.return_value = return_value | ||||||
|  |         return response | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class TestASGIApp(ASGIApp): | ||||||
|  |     async def __call__(self): | ||||||
|  |         await super().__call__() | ||||||
|  |         return self.request | ||||||
|  |  | ||||||
|  |  | ||||||
|  | async def app_call_with_return(self, scope, receive, send): | ||||||
|  |     asgi_app = await TestASGIApp.create(self, scope, receive, send) | ||||||
|  |     return await asgi_app() | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class SanicASGITestClient(requests.ASGISession): | ||||||
|  |     def __init__( | ||||||
|  |         self, | ||||||
|  |         app, | ||||||
|  |         base_url: str = "http://{}".format(ASGI_HOST), | ||||||
|  |         suppress_exceptions: bool = False, | ||||||
|  |     ) -> None: | ||||||
|  |         app.__class__.__call__ = app_call_with_return | ||||||
|  |         app.asgi = True | ||||||
|  |         super().__init__(app) | ||||||
|  |  | ||||||
|  |         adapter = SanicASGIAdapter( | ||||||
|  |             app, suppress_exceptions=suppress_exceptions | ||||||
|  |         ) | ||||||
|  |         self.mount("http://", adapter) | ||||||
|  |         self.mount("https://", adapter) | ||||||
|  |         self.mount("ws://", adapter) | ||||||
|  |         self.mount("wss://", adapter) | ||||||
|  |         self.headers.update({"user-agent": "testclient"}) | ||||||
|  |         self.app = app | ||||||
|  |         self.base_url = base_url | ||||||
|  |  | ||||||
|  |     async def request(self, method, url, gather_request=True, *args, **kwargs): | ||||||
|  |  | ||||||
|  |         self.gather_request = gather_request | ||||||
|  |         response = await super().request(method, url, *args, **kwargs) | ||||||
|  |         response.status = response.status_code | ||||||
|  |         response.body = response.content | ||||||
|  |         response.content_type = response.headers.get("content-type") | ||||||
|  |  | ||||||
|  |         if hasattr(response, "return_value"): | ||||||
|  |             request = response.return_value | ||||||
|  |             del response.return_value | ||||||
|  |             return request, response | ||||||
|  |  | ||||||
|  |         return response | ||||||
|  |  | ||||||
|  |     def merge_environment_settings(self, *args, **kwargs): | ||||||
|  |         settings = super().merge_environment_settings(*args, **kwargs) | ||||||
|  |         settings.update({"gather_return": self.gather_request}) | ||||||
|  |         return settings | ||||||
|  |  | ||||||
|  |     async def websocket(self, uri, subprotocols=None, *args, **kwargs): | ||||||
|  |         if uri.startswith(("ws:", "wss:")): | ||||||
|  |             url = uri | ||||||
|  |         else: | ||||||
|  |             uri = uri if uri.startswith("/") else "/{uri}".format(uri=uri) | ||||||
|  |             url = "ws://testserver{uri}".format(uri=uri) | ||||||
|  |  | ||||||
|  |             headers = kwargs.get("headers", {}) | ||||||
|  |             headers.setdefault("connection", "upgrade") | ||||||
|  |             headers.setdefault("sec-websocket-key", "testserver==") | ||||||
|  |             headers.setdefault("sec-websocket-version", "13") | ||||||
|  |             if subprotocols is not None: | ||||||
|  |                 headers.setdefault( | ||||||
|  |                     "sec-websocket-protocol", ", ".join(subprotocols) | ||||||
|  |                 ) | ||||||
|  |             kwargs["headers"] = headers | ||||||
|  |  | ||||||
|  |             return await self.request("websocket", url, **kwargs) | ||||||
|   | |||||||
| @@ -1,3 +1,5 @@ | |||||||
|  | from typing import Any, Awaitable, Callable, MutableMapping, Optional, Union | ||||||
|  |  | ||||||
| from httptools import HttpParserUpgrade | from httptools import HttpParserUpgrade | ||||||
| from websockets import ConnectionClosed  # noqa | from websockets import ConnectionClosed  # noqa | ||||||
| from websockets import InvalidHandshake, WebSocketCommonProtocol, handshake | from websockets import InvalidHandshake, WebSocketCommonProtocol, handshake | ||||||
| @@ -6,6 +8,9 @@ from sanic.exceptions import InvalidUsage | |||||||
| from sanic.server import HttpProtocol | from sanic.server import HttpProtocol | ||||||
|  |  | ||||||
|  |  | ||||||
|  | ASIMessage = MutableMapping[str, Any] | ||||||
|  |  | ||||||
|  |  | ||||||
| class WebSocketProtocol(HttpProtocol): | class WebSocketProtocol(HttpProtocol): | ||||||
|     def __init__( |     def __init__( | ||||||
|         self, |         self, | ||||||
| @@ -19,6 +24,7 @@ class WebSocketProtocol(HttpProtocol): | |||||||
|     ): |     ): | ||||||
|         super().__init__(*args, **kwargs) |         super().__init__(*args, **kwargs) | ||||||
|         self.websocket = None |         self.websocket = None | ||||||
|  |         # self.app = None | ||||||
|         self.websocket_timeout = websocket_timeout |         self.websocket_timeout = websocket_timeout | ||||||
|         self.websocket_max_size = websocket_max_size |         self.websocket_max_size = websocket_max_size | ||||||
|         self.websocket_max_queue = websocket_max_queue |         self.websocket_max_queue = websocket_max_queue | ||||||
| @@ -103,3 +109,45 @@ class WebSocketProtocol(HttpProtocol): | |||||||
|         self.websocket.connection_made(request.transport) |         self.websocket.connection_made(request.transport) | ||||||
|         self.websocket.connection_open() |         self.websocket.connection_open() | ||||||
|         return self.websocket |         return self.websocket | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class WebSocketConnection: | ||||||
|  |  | ||||||
|  |     # TODO | ||||||
|  |     # - Implement ping/pong | ||||||
|  |  | ||||||
|  |     def __init__( | ||||||
|  |         self, | ||||||
|  |         send: Callable[[ASIMessage], Awaitable[None]], | ||||||
|  |         receive: Callable[[], Awaitable[ASIMessage]], | ||||||
|  |     ) -> None: | ||||||
|  |         self._send = send | ||||||
|  |         self._receive = receive | ||||||
|  |  | ||||||
|  |     async def send(self, data: Union[str, bytes], *args, **kwargs) -> None: | ||||||
|  |         message = {"type": "websocket.send"} | ||||||
|  |  | ||||||
|  |         try: | ||||||
|  |             data.decode() | ||||||
|  |         except AttributeError: | ||||||
|  |             message.update({"text": str(data)}) | ||||||
|  |         else: | ||||||
|  |             message.update({"bytes": data}) | ||||||
|  |  | ||||||
|  |         await self._send(message) | ||||||
|  |  | ||||||
|  |     async def recv(self, *args, **kwargs) -> Optional[str]: | ||||||
|  |         message = await self._receive() | ||||||
|  |  | ||||||
|  |         if message["type"] == "websocket.receive": | ||||||
|  |             return message["text"] | ||||||
|  |         elif message["type"] == "websocket.disconnect": | ||||||
|  |             pass | ||||||
|  |  | ||||||
|  |     receive = recv | ||||||
|  |  | ||||||
|  |     async def accept(self) -> None: | ||||||
|  |         await self._send({"type": "websocket.accept", "subprotocol": ""}) | ||||||
|  |  | ||||||
|  |     async def close(self) -> None: | ||||||
|  |         pass | ||||||
|   | |||||||
| @@ -57,7 +57,7 @@ def test_asyncio_server_start_serving(app): | |||||||
|  |  | ||||||
| def test_app_loop_not_running(app): | def test_app_loop_not_running(app): | ||||||
|     with pytest.raises(SanicException) as excinfo: |     with pytest.raises(SanicException) as excinfo: | ||||||
|         _ = app.loop |         app.loop | ||||||
|  |  | ||||||
|     assert str(excinfo.value) == ( |     assert str(excinfo.value) == ( | ||||||
|         "Loop can only be retrieved after the app has started " |         "Loop can only be retrieved after the app has started " | ||||||
|   | |||||||
							
								
								
									
										203
									
								
								tests/test_asgi.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										203
									
								
								tests/test_asgi.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,203 @@ | |||||||
|  | import asyncio | ||||||
|  |  | ||||||
|  | from collections import deque | ||||||
|  |  | ||||||
|  | import pytest | ||||||
|  | import uvicorn | ||||||
|  |  | ||||||
|  | from sanic.asgi import MockTransport | ||||||
|  | from sanic.exceptions import InvalidUsage | ||||||
|  | from sanic.websocket import WebSocketConnection | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @pytest.fixture | ||||||
|  | def message_stack(): | ||||||
|  |     return deque() | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @pytest.fixture | ||||||
|  | def receive(message_stack): | ||||||
|  |     async def _receive(): | ||||||
|  |         return message_stack.popleft() | ||||||
|  |  | ||||||
|  |     return _receive | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @pytest.fixture | ||||||
|  | def send(message_stack): | ||||||
|  |     async def _send(message): | ||||||
|  |         message_stack.append(message) | ||||||
|  |  | ||||||
|  |     return _send | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @pytest.fixture | ||||||
|  | def transport(message_stack, receive, send): | ||||||
|  |     return MockTransport({}, receive, send) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @pytest.fixture | ||||||
|  | # @pytest.mark.asyncio | ||||||
|  | def protocol(transport, loop): | ||||||
|  |     return transport.get_protocol() | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def test_listeners_triggered(app): | ||||||
|  |     before_server_start = False | ||||||
|  |     after_server_start = False | ||||||
|  |     before_server_stop = False | ||||||
|  |     after_server_stop = False | ||||||
|  |  | ||||||
|  |     @app.listener("before_server_start") | ||||||
|  |     def do_before_server_start(*args, **kwargs): | ||||||
|  |         nonlocal before_server_start | ||||||
|  |         before_server_start = True | ||||||
|  |  | ||||||
|  |     @app.listener("after_server_start") | ||||||
|  |     def do_after_server_start(*args, **kwargs): | ||||||
|  |         nonlocal after_server_start | ||||||
|  |         after_server_start = True | ||||||
|  |  | ||||||
|  |     @app.listener("before_server_stop") | ||||||
|  |     def do_before_server_stop(*args, **kwargs): | ||||||
|  |         nonlocal before_server_stop | ||||||
|  |         before_server_stop = True | ||||||
|  |  | ||||||
|  |     @app.listener("after_server_stop") | ||||||
|  |     def do_after_server_stop(*args, **kwargs): | ||||||
|  |         nonlocal after_server_stop | ||||||
|  |         after_server_stop = True | ||||||
|  |  | ||||||
|  |     class CustomServer(uvicorn.Server): | ||||||
|  |         def install_signal_handlers(self): | ||||||
|  |             pass | ||||||
|  |  | ||||||
|  |     config = uvicorn.Config(app=app, loop="asyncio", limit_max_requests=0) | ||||||
|  |     server = CustomServer(config=config) | ||||||
|  |  | ||||||
|  |     with pytest.warns(UserWarning): | ||||||
|  |         server.run() | ||||||
|  |  | ||||||
|  |     for task in asyncio.Task.all_tasks(): | ||||||
|  |         task.cancel() | ||||||
|  |  | ||||||
|  |     assert before_server_start | ||||||
|  |     assert after_server_start | ||||||
|  |     assert before_server_stop | ||||||
|  |     assert after_server_stop | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def test_listeners_triggered_async(app): | ||||||
|  |     before_server_start = False | ||||||
|  |     after_server_start = False | ||||||
|  |     before_server_stop = False | ||||||
|  |     after_server_stop = False | ||||||
|  |  | ||||||
|  |     @app.listener("before_server_start") | ||||||
|  |     async def do_before_server_start(*args, **kwargs): | ||||||
|  |         nonlocal before_server_start | ||||||
|  |         before_server_start = True | ||||||
|  |  | ||||||
|  |     @app.listener("after_server_start") | ||||||
|  |     async def do_after_server_start(*args, **kwargs): | ||||||
|  |         nonlocal after_server_start | ||||||
|  |         after_server_start = True | ||||||
|  |  | ||||||
|  |     @app.listener("before_server_stop") | ||||||
|  |     async def do_before_server_stop(*args, **kwargs): | ||||||
|  |         nonlocal before_server_stop | ||||||
|  |         before_server_stop = True | ||||||
|  |  | ||||||
|  |     @app.listener("after_server_stop") | ||||||
|  |     async def do_after_server_stop(*args, **kwargs): | ||||||
|  |         nonlocal after_server_stop | ||||||
|  |         after_server_stop = True | ||||||
|  |  | ||||||
|  |     class CustomServer(uvicorn.Server): | ||||||
|  |         def install_signal_handlers(self): | ||||||
|  |             pass | ||||||
|  |  | ||||||
|  |     config = uvicorn.Config(app=app, loop="asyncio", limit_max_requests=0) | ||||||
|  |     server = CustomServer(config=config) | ||||||
|  |  | ||||||
|  |     with pytest.warns(UserWarning): | ||||||
|  |         server.run() | ||||||
|  |  | ||||||
|  |     for task in asyncio.Task.all_tasks(): | ||||||
|  |         task.cancel() | ||||||
|  |  | ||||||
|  |     assert before_server_start | ||||||
|  |     assert after_server_start | ||||||
|  |     assert before_server_stop | ||||||
|  |     assert after_server_stop | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @pytest.mark.asyncio | ||||||
|  | async def test_mockprotocol_events(protocol): | ||||||
|  |     assert protocol._not_paused.is_set() | ||||||
|  |     protocol.pause_writing() | ||||||
|  |     assert not protocol._not_paused.is_set() | ||||||
|  |     protocol.resume_writing() | ||||||
|  |     assert protocol._not_paused.is_set() | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @pytest.mark.asyncio | ||||||
|  | async def test_protocol_push_data(protocol, message_stack): | ||||||
|  |     text = b"hello" | ||||||
|  |  | ||||||
|  |     await protocol.push_data(text) | ||||||
|  |     await protocol.complete() | ||||||
|  |  | ||||||
|  |     assert len(message_stack) == 2 | ||||||
|  |  | ||||||
|  |     message = message_stack.popleft() | ||||||
|  |     assert message["type"] == "http.response.body" | ||||||
|  |     assert message["more_body"] | ||||||
|  |     assert message["body"] == text | ||||||
|  |  | ||||||
|  |     message = message_stack.popleft() | ||||||
|  |     assert message["type"] == "http.response.body" | ||||||
|  |     assert not message["more_body"] | ||||||
|  |     assert message["body"] == b"" | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @pytest.mark.asyncio | ||||||
|  | async def test_websocket_send(send, receive, message_stack): | ||||||
|  |     text_string = "hello" | ||||||
|  |     text_bytes = b"hello" | ||||||
|  |  | ||||||
|  |     ws = WebSocketConnection(send, receive) | ||||||
|  |     await ws.send(text_string) | ||||||
|  |     await ws.send(text_bytes) | ||||||
|  |  | ||||||
|  |     assert len(message_stack) == 2 | ||||||
|  |  | ||||||
|  |     message = message_stack.popleft() | ||||||
|  |     assert message["type"] == "websocket.send" | ||||||
|  |     assert message["text"] == text_string | ||||||
|  |     assert "bytes" not in message | ||||||
|  |  | ||||||
|  |     message = message_stack.popleft() | ||||||
|  |     assert message["type"] == "websocket.send" | ||||||
|  |     assert message["bytes"] == text_bytes | ||||||
|  |     assert "text" not in message | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @pytest.mark.asyncio | ||||||
|  | async def test_websocket_receive(send, receive, message_stack): | ||||||
|  |     msg = {"text": "hello", "type": "websocket.receive"} | ||||||
|  |     message_stack.append(msg) | ||||||
|  |  | ||||||
|  |     ws = WebSocketConnection(send, receive) | ||||||
|  |     text = await ws.receive() | ||||||
|  |  | ||||||
|  |     assert text == msg["text"] | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def test_improper_websocket_connection(transport, send, receive): | ||||||
|  |     with pytest.raises(InvalidUsage): | ||||||
|  |         transport.get_websocket_connection() | ||||||
|  |  | ||||||
|  |     transport.create_websocket_connection(send, receive) | ||||||
|  |     connection = transport.get_websocket_connection() | ||||||
|  |     assert isinstance(connection, WebSocketConnection) | ||||||
							
								
								
									
										5
									
								
								tests/test_asgi_client.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										5
									
								
								tests/test_asgi_client.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,5 @@ | |||||||
|  | from sanic.testing import SanicASGITestClient | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def test_asgi_client_instantiation(app): | ||||||
|  |     assert isinstance(app.asgi_client, SanicASGITestClient) | ||||||
| @@ -239,6 +239,7 @@ def test_config_access_log_passing_in_run(app): | |||||||
|     assert app.config.ACCESS_LOG == True |     assert app.config.ACCESS_LOG == True | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @pytest.mark.asyncio | ||||||
| async def test_config_access_log_passing_in_create_server(app): | async def test_config_access_log_passing_in_create_server(app): | ||||||
|     assert app.config.ACCESS_LOG == True |     assert app.config.ACCESS_LOG == True | ||||||
|  |  | ||||||
|   | |||||||
| @@ -27,6 +27,24 @@ def test_cookies(app): | |||||||
|     assert response_cookies["right_back"].value == "at you" |     assert response_cookies["right_back"].value == "at you" | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @pytest.mark.asyncio | ||||||
|  | async def test_cookies_asgi(app): | ||||||
|  |     @app.route("/") | ||||||
|  |     def handler(request): | ||||||
|  |         response = text("Cookies are: {}".format(request.cookies["test"])) | ||||||
|  |         response.cookies["right_back"] = "at you" | ||||||
|  |         return response | ||||||
|  |  | ||||||
|  |     request, response = await app.asgi_client.get( | ||||||
|  |         "/", cookies={"test": "working!"} | ||||||
|  |     ) | ||||||
|  |     response_cookies = SimpleCookie() | ||||||
|  |     response_cookies.load(response.headers.get("set-cookie", {})) | ||||||
|  |  | ||||||
|  |     assert response.text == "Cookies are: working!" | ||||||
|  |     assert response_cookies["right_back"].value == "at you" | ||||||
|  |  | ||||||
|  |  | ||||||
| @pytest.mark.parametrize("httponly,expected", [(False, False), (True, True)]) | @pytest.mark.parametrize("httponly,expected", [(False, False), (True, True)]) | ||||||
| def test_false_cookies_encoded(app, httponly, expected): | def test_false_cookies_encoded(app, httponly, expected): | ||||||
|     @app.route("/") |     @app.route("/") | ||||||
|   | |||||||
| @@ -110,21 +110,19 @@ def test_redirect_with_header_injection(redirect_app): | |||||||
|  |  | ||||||
|  |  | ||||||
| @pytest.mark.parametrize("test_str", ["sanic-test", "sanictest", "sanic test"]) | @pytest.mark.parametrize("test_str", ["sanic-test", "sanictest", "sanic test"]) | ||||||
| async def test_redirect_with_params(app, sanic_client, test_str): | def test_redirect_with_params(app, test_str): | ||||||
|  |     use_in_uri = quote(test_str) | ||||||
|  |  | ||||||
|     @app.route("/api/v1/test/<test>/") |     @app.route("/api/v1/test/<test>/") | ||||||
|     async def init_handler(request, test): |     async def init_handler(request, test): | ||||||
|         assert test == test_str |         return redirect("/api/v2/test/{}/".format(use_in_uri)) | ||||||
|         return redirect("/api/v2/test/{}/".format(quote(test))) |  | ||||||
|  |  | ||||||
|     @app.route("/api/v2/test/<test>/") |     @app.route("/api/v2/test/<test>/") | ||||||
|     async def target_handler(request, test): |     async def target_handler(request, test): | ||||||
|         assert test == test_str |         assert test == test_str | ||||||
|         return text("OK") |         return text("OK") | ||||||
|  |  | ||||||
|     test_cli = await sanic_client(app) |     _, response = app.test_client.get("/api/v1/test/{}/".format(use_in_uri)) | ||||||
|  |  | ||||||
|     response = await test_cli.get("/api/v1/test/{}/".format(quote(test_str))) |  | ||||||
|     assert response.status == 200 |     assert response.status == 200 | ||||||
|  |  | ||||||
|     txt = await response.text() |     assert response.content == b"OK" | ||||||
|     assert txt == "OK" |  | ||||||
|   | |||||||
| @@ -1,10 +1,13 @@ | |||||||
| import asyncio | import asyncio | ||||||
| import contextlib | import contextlib | ||||||
|  |  | ||||||
|  | import pytest | ||||||
|  |  | ||||||
| from sanic.response import stream, text | from sanic.response import stream, text | ||||||
|  |  | ||||||
|  |  | ||||||
| async def test_request_cancel_when_connection_lost(loop, app, sanic_client): | @pytest.mark.asyncio | ||||||
|  | async def test_request_cancel_when_connection_lost(app): | ||||||
|     app.still_serving_cancelled_request = False |     app.still_serving_cancelled_request = False | ||||||
|  |  | ||||||
|     @app.get("/") |     @app.get("/") | ||||||
| @@ -14,10 +17,9 @@ async def test_request_cancel_when_connection_lost(loop, app, sanic_client): | |||||||
|         app.still_serving_cancelled_request = True |         app.still_serving_cancelled_request = True | ||||||
|         return text("OK") |         return text("OK") | ||||||
|  |  | ||||||
|     test_cli = await sanic_client(app) |  | ||||||
|  |  | ||||||
|     # schedule client call |     # schedule client call | ||||||
|     task = loop.create_task(test_cli.get("/")) |     loop = asyncio.get_event_loop() | ||||||
|  |     task = loop.create_task(app.asgi_client.get("/")) | ||||||
|     loop.call_later(0.01, task) |     loop.call_later(0.01, task) | ||||||
|     await asyncio.sleep(0.5) |     await asyncio.sleep(0.5) | ||||||
|  |  | ||||||
| @@ -33,7 +35,8 @@ async def test_request_cancel_when_connection_lost(loop, app, sanic_client): | |||||||
|     assert app.still_serving_cancelled_request is False |     assert app.still_serving_cancelled_request is False | ||||||
|  |  | ||||||
|  |  | ||||||
| async def test_stream_request_cancel_when_conn_lost(loop, app, sanic_client): | @pytest.mark.asyncio | ||||||
|  | async def test_stream_request_cancel_when_conn_lost(app): | ||||||
|     app.still_serving_cancelled_request = False |     app.still_serving_cancelled_request = False | ||||||
|  |  | ||||||
|     @app.post("/post/<id>", stream=True) |     @app.post("/post/<id>", stream=True) | ||||||
| @@ -53,10 +56,9 @@ async def test_stream_request_cancel_when_conn_lost(loop, app, sanic_client): | |||||||
|  |  | ||||||
|         return stream(streaming) |         return stream(streaming) | ||||||
|  |  | ||||||
|     test_cli = await sanic_client(app) |  | ||||||
|  |  | ||||||
|     # schedule client call |     # schedule client call | ||||||
|     task = loop.create_task(test_cli.post("/post/1")) |     loop = asyncio.get_event_loop() | ||||||
|  |     task = loop.create_task(app.asgi_client.post("/post/1")) | ||||||
|     loop.call_later(0.01, task) |     loop.call_later(0.01, task) | ||||||
|     await asyncio.sleep(0.5) |     await asyncio.sleep(0.5) | ||||||
|  |  | ||||||
|   | |||||||
| @@ -1,4 +1,5 @@ | |||||||
| import pytest | import pytest | ||||||
|  |  | ||||||
| from sanic.blueprints import Blueprint | from sanic.blueprints import Blueprint | ||||||
| from sanic.exceptions import HeaderExpectationFailed | from sanic.exceptions import HeaderExpectationFailed | ||||||
| from sanic.request import StreamBuffer | from sanic.request import StreamBuffer | ||||||
| @@ -42,13 +43,15 @@ def test_request_stream_method_view(app): | |||||||
|     assert response.text == data |     assert response.text == data | ||||||
|  |  | ||||||
|  |  | ||||||
| @pytest.mark.parametrize("headers, expect_raise_exception", [ | @pytest.mark.parametrize( | ||||||
| ({"EXPECT": "100-continue"}, False), |     "headers, expect_raise_exception", | ||||||
| ({"EXPECT": "100-continue-extra"}, True), |     [ | ||||||
| ]) |         ({"EXPECT": "100-continue"}, False), | ||||||
|  |         ({"EXPECT": "100-continue-extra"}, True), | ||||||
|  |     ], | ||||||
|  | ) | ||||||
| def test_request_stream_100_continue(app, headers, expect_raise_exception): | def test_request_stream_100_continue(app, headers, expect_raise_exception): | ||||||
|     class SimpleView(HTTPMethodView): |     class SimpleView(HTTPMethodView): | ||||||
|  |  | ||||||
|         @stream_decorator |         @stream_decorator | ||||||
|         async def post(self, request): |         async def post(self, request): | ||||||
|             assert isinstance(request.stream, StreamBuffer) |             assert isinstance(request.stream, StreamBuffer) | ||||||
| @@ -65,12 +68,18 @@ def test_request_stream_100_continue(app, headers, expect_raise_exception): | |||||||
|     assert app.is_request_stream is True |     assert app.is_request_stream is True | ||||||
|  |  | ||||||
|     if not expect_raise_exception: |     if not expect_raise_exception: | ||||||
|         request, response = app.test_client.post("/method_view", data=data, headers={"EXPECT": "100-continue"}) |         request, response = app.test_client.post( | ||||||
|  |             "/method_view", data=data, headers={"EXPECT": "100-continue"} | ||||||
|  |         ) | ||||||
|         assert response.status == 200 |         assert response.status == 200 | ||||||
|         assert response.text == data |         assert response.text == data | ||||||
|     else: |     else: | ||||||
|         with pytest.raises(ValueError) as e: |         with pytest.raises(ValueError) as e: | ||||||
|             app.test_client.post("/method_view", data=data, headers={"EXPECT": "100-continue-extra"}) |             app.test_client.post( | ||||||
|  |                 "/method_view", | ||||||
|  |                 data=data, | ||||||
|  |                 headers={"EXPECT": "100-continue-extra"}, | ||||||
|  |             ) | ||||||
|             assert "Unknown Expect: 100-continue-extra" in str(e) |             assert "Unknown Expect: 100-continue-extra" in str(e) | ||||||
|  |  | ||||||
|  |  | ||||||
| @@ -188,6 +197,121 @@ def test_request_stream_app(app): | |||||||
|     assert response.text == data |     assert response.text == data | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @pytest.mark.asyncio | ||||||
|  | async def test_request_stream_app_asgi(app): | ||||||
|  |     """for self.is_request_stream = True and decorators""" | ||||||
|  |  | ||||||
|  |     @app.get("/get") | ||||||
|  |     async def get(request): | ||||||
|  |         assert request.stream is None | ||||||
|  |         return text("GET") | ||||||
|  |  | ||||||
|  |     @app.head("/head") | ||||||
|  |     async def head(request): | ||||||
|  |         assert request.stream is None | ||||||
|  |         return text("HEAD") | ||||||
|  |  | ||||||
|  |     @app.delete("/delete") | ||||||
|  |     async def delete(request): | ||||||
|  |         assert request.stream is None | ||||||
|  |         return text("DELETE") | ||||||
|  |  | ||||||
|  |     @app.options("/options") | ||||||
|  |     async def options(request): | ||||||
|  |         assert request.stream is None | ||||||
|  |         return text("OPTIONS") | ||||||
|  |  | ||||||
|  |     @app.post("/_post/<id>") | ||||||
|  |     async def _post(request, id): | ||||||
|  |         assert request.stream is None | ||||||
|  |         return text("_POST") | ||||||
|  |  | ||||||
|  |     @app.post("/post/<id>", stream=True) | ||||||
|  |     async def post(request, id): | ||||||
|  |         assert isinstance(request.stream, StreamBuffer) | ||||||
|  |         result = "" | ||||||
|  |         while True: | ||||||
|  |             body = await request.stream.read() | ||||||
|  |             if body is None: | ||||||
|  |                 break | ||||||
|  |             result += body.decode("utf-8") | ||||||
|  |         return text(result) | ||||||
|  |  | ||||||
|  |     @app.put("/_put") | ||||||
|  |     async def _put(request): | ||||||
|  |         assert request.stream is None | ||||||
|  |         return text("_PUT") | ||||||
|  |  | ||||||
|  |     @app.put("/put", stream=True) | ||||||
|  |     async def put(request): | ||||||
|  |         assert isinstance(request.stream, StreamBuffer) | ||||||
|  |         result = "" | ||||||
|  |         while True: | ||||||
|  |             body = await request.stream.read() | ||||||
|  |             if body is None: | ||||||
|  |                 break | ||||||
|  |             result += body.decode("utf-8") | ||||||
|  |         return text(result) | ||||||
|  |  | ||||||
|  |     @app.patch("/_patch") | ||||||
|  |     async def _patch(request): | ||||||
|  |         assert request.stream is None | ||||||
|  |         return text("_PATCH") | ||||||
|  |  | ||||||
|  |     @app.patch("/patch", stream=True) | ||||||
|  |     async def patch(request): | ||||||
|  |         assert isinstance(request.stream, StreamBuffer) | ||||||
|  |         result = "" | ||||||
|  |         while True: | ||||||
|  |             body = await request.stream.read() | ||||||
|  |             if body is None: | ||||||
|  |                 break | ||||||
|  |             result += body.decode("utf-8") | ||||||
|  |         return text(result) | ||||||
|  |  | ||||||
|  |     assert app.is_request_stream is True | ||||||
|  |  | ||||||
|  |     request, response = await app.asgi_client.get("/get") | ||||||
|  |     assert response.status == 200 | ||||||
|  |     assert response.text == "GET" | ||||||
|  |  | ||||||
|  |     request, response = await app.asgi_client.head("/head") | ||||||
|  |     assert response.status == 200 | ||||||
|  |     assert response.text == "" | ||||||
|  |  | ||||||
|  |     request, response = await app.asgi_client.delete("/delete") | ||||||
|  |     assert response.status == 200 | ||||||
|  |     assert response.text == "DELETE" | ||||||
|  |  | ||||||
|  |     request, response = await app.asgi_client.options("/options") | ||||||
|  |     assert response.status == 200 | ||||||
|  |     assert response.text == "OPTIONS" | ||||||
|  |  | ||||||
|  |     request, response = await app.asgi_client.post("/_post/1", data=data) | ||||||
|  |     assert response.status == 200 | ||||||
|  |     assert response.text == "_POST" | ||||||
|  |  | ||||||
|  |     request, response = await app.asgi_client.post("/post/1", data=data) | ||||||
|  |     assert response.status == 200 | ||||||
|  |     assert response.text == data | ||||||
|  |  | ||||||
|  |     request, response = await app.asgi_client.put("/_put", data=data) | ||||||
|  |     assert response.status == 200 | ||||||
|  |     assert response.text == "_PUT" | ||||||
|  |  | ||||||
|  |     request, response = await app.asgi_client.put("/put", data=data) | ||||||
|  |     assert response.status == 200 | ||||||
|  |     assert response.text == data | ||||||
|  |  | ||||||
|  |     request, response = await app.asgi_client.patch("/_patch", data=data) | ||||||
|  |     assert response.status == 200 | ||||||
|  |     assert response.text == "_PATCH" | ||||||
|  |  | ||||||
|  |     request, response = await app.asgi_client.patch("/patch", data=data) | ||||||
|  |     assert response.status == 200 | ||||||
|  |     assert response.text == data | ||||||
|  |  | ||||||
|  |  | ||||||
| def test_request_stream_handle_exception(app): | def test_request_stream_handle_exception(app): | ||||||
|     """for handling exceptions properly""" |     """for handling exceptions properly""" | ||||||
|  |  | ||||||
|   | |||||||
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							| @@ -292,7 +292,7 @@ def test_stream_response_writes_correct_content_to_transport_when_chunked( | |||||||
|     async def mock_drain(): |     async def mock_drain(): | ||||||
|         pass |         pass | ||||||
|  |  | ||||||
|     def mock_push_data(data): |     async def mock_push_data(data): | ||||||
|         response.protocol.transport.write(data) |         response.protocol.transport.write(data) | ||||||
|  |  | ||||||
|     response.protocol.push_data = mock_push_data |     response.protocol.push_data = mock_push_data | ||||||
| @@ -330,7 +330,7 @@ def test_stream_response_writes_correct_content_to_transport_when_not_chunked( | |||||||
|     async def mock_drain(): |     async def mock_drain(): | ||||||
|         pass |         pass | ||||||
|  |  | ||||||
|     def mock_push_data(data): |     async def mock_push_data(data): | ||||||
|         response.protocol.transport.write(data) |         response.protocol.transport.write(data) | ||||||
|  |  | ||||||
|     response.protocol.push_data = mock_push_data |     response.protocol.push_data = mock_push_data | ||||||
|   | |||||||
| @@ -474,6 +474,19 @@ def test_websocket_route(app, url): | |||||||
|     assert ev.is_set() |     assert ev.is_set() | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @pytest.mark.asyncio | ||||||
|  | @pytest.mark.parametrize("url", ["/ws", "ws"]) | ||||||
|  | async def test_websocket_route_asgi(app, url): | ||||||
|  |     ev = asyncio.Event() | ||||||
|  |  | ||||||
|  |     @app.websocket(url) | ||||||
|  |     async def handler(request, ws): | ||||||
|  |         ev.set() | ||||||
|  |  | ||||||
|  |     request, response = await app.asgi_client.websocket(url) | ||||||
|  |     assert ev.is_set() | ||||||
|  |  | ||||||
|  |  | ||||||
| def test_websocket_route_with_subprotocols(app): | def test_websocket_route_with_subprotocols(app): | ||||||
|     results = [] |     results = [] | ||||||
|  |  | ||||||
|   | |||||||
| @@ -76,6 +76,7 @@ def test_all_listeners(app): | |||||||
|         assert app.name + listener_name == output.pop() |         assert app.name + listener_name == output.pop() | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @pytest.mark.asyncio | ||||||
| async def test_trigger_before_events_create_server(app): | async def test_trigger_before_events_create_server(app): | ||||||
|     class MySanicDb: |     class MySanicDb: | ||||||
|         pass |         pass | ||||||
|   | |||||||
							
								
								
									
										2
									
								
								tox.ini
									
									
									
									
									
								
							
							
						
						
									
										2
									
								
								tox.ini
									
									
									
									
									
								
							| @@ -18,6 +18,8 @@ deps = | |||||||
|     beautifulsoup4 |     beautifulsoup4 | ||||||
|     gunicorn |     gunicorn | ||||||
|     pytest-benchmark |     pytest-benchmark | ||||||
|  |     uvicorn | ||||||
|  |     websockets>=6.0,<7.0 | ||||||
| commands = | commands = | ||||||
|     pytest {posargs:tests --cov sanic} |     pytest {posargs:tests --cov sanic} | ||||||
|     - coverage combine --append |     - coverage combine --append | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 7
					7