diff --git a/sanic/app.py b/sanic/app.py index c1a39413..5ddae617 100644 --- a/sanic/app.py +++ b/sanic/app.py @@ -8,6 +8,7 @@ from asyncio import CancelledError, Protocol, ensure_future, get_event_loop from collections import defaultdict, deque from functools import partial from inspect import getmodulename, isawaitable, signature, stack +from multidict import CIMultiDict from socket import socket from ssl import Purpose, SSLContext, create_default_context from traceback import format_exc @@ -21,6 +22,7 @@ from sanic.exceptions import SanicException, ServerError, URLBuildError from sanic.handlers import ErrorHandler from sanic.log import LOGGING_CONFIG_DEFAULTS, error_logger, logger from sanic.response import HTTPResponse, StreamingHTTPResponse +from sanic.request import Request from sanic.router import Router from sanic.server import HttpProtocol, Signal, serve, serve_multiple from sanic.static import register as static_register @@ -967,7 +969,7 @@ class Sanic: raise CancelledError() # pass the response to the correct callback - if isinstance(response, StreamingHTTPResponse): + if write_callback is None or isinstance(response, StreamingHTTPResponse): await stream_callback(response) else: write_callback(response) @@ -1106,9 +1108,8 @@ class Sanic: """This kills the Sanic""" get_event_loop().stop() - def __call__(self): - """gunicorn compatibility""" - return self + def __call__(self, scope): + return ASGIApp(self, scope) async def create_server( self, @@ -1339,3 +1340,80 @@ class Sanic: def _build_endpoint_name(self, *parts): parts = [self.name, *parts] return ".".join(parts) + + +class MockTransport: + def __init__(self, scope): + self.scope = scope + + def get_extra_info(self, info): + if info == 'peername': + return self.scope.get('server') + elif info == 'sslcontext': + return self.scope.get('scheme') in ["https", "wss"] + +class ASGIApp: + def __init__(self, sanic_app, scope): + self.sanic_app = sanic_app + url_bytes = scope.get('root_path', '') + scope['path'] + url_bytes = url_bytes.encode('latin-1') + url_bytes += scope['query_string'] + headers = CIMultiDict([ + (key.decode('latin-1'), value.decode('latin-1')) + for key, value in scope.get('headers', []) + ]) + version = scope['http_version'] + method = scope['method'] + self.request = Request(url_bytes, headers, version, method, MockTransport(scope)) + self.request.app = sanic_app + + async def read_body(self, receive): + """ + Read and return the entire body from an incoming ASGI message. + """ + body = b'' + more_body = True + + while more_body: + message = await receive() + body += message.get('body', b'') + more_body = message.get('more_body', False) + + return body + + async def __call__(self, receive, send): + """ + Handle the incoming request. + """ + self.send = send + self.request.body = await self.read_body(receive) + handler = self.sanic_app.handle_request + await handler(self.request, None, self.stream_callback) + + async def stream_callback(self, response): + """ + Write the response. + """ + if isinstance(response, StreamingHTTPResponse): + raise NotImplementedError('Not supported') + + headers = [ + (str(name).encode('latin-1'), str(value).encode('latin-1')) + for name, value in response.headers.items() + ] + if 'content-length' not in response.headers: + headers += [( + b'content-length', + str(len(response.body)).encode('latin-1') + )] + + await self.send({ + 'type': 'http.response.start', + 'status': response.status, + 'headers': headers + }) + await self.send({ + 'type': 'http.response.body', + 'body': response.body, + 'more_body': False + }) diff --git a/sanic/testing.py b/sanic/testing.py index 19f87095..dd31aec4 100644 --- a/sanic/testing.py +++ b/sanic/testing.py @@ -1,137 +1,355 @@ -from json import JSONDecodeError -from sanic.exceptions import MethodNotSupported -from sanic.log import logger -from sanic.response import text +import asyncio +import http +import io +import json +import queue +import threading +import types +import typing +from urllib.parse import unquote, urljoin, urlparse, parse_qs +import requests + +from starlette.types import ASGIApp, Message, Scope +from starlette.websockets import WebSocketDisconnect HOST = "127.0.0.1" PORT = 42101 -class SanicTestClient: - def __init__(self, app, port=PORT): +# Annotations for `Session.request()` +Cookies = typing.Union[ + typing.MutableMapping[str, str], requests.cookies.RequestsCookieJar +] +Params = typing.Union[bytes, typing.MutableMapping[str, str]] +DataType = typing.Union[bytes, typing.MutableMapping[str, str], typing.IO] +TimeOut = typing.Union[float, typing.Tuple[float, float]] +FileType = typing.MutableMapping[str, typing.IO] +AuthType = typing.Union[ + typing.Tuple[str, str], + requests.auth.AuthBase, + typing.Callable[[requests.Request], requests.Request], +] + + +class _HeaderDict(requests.packages.urllib3._collections.HTTPHeaderDict): + def get_all(self, key: str, default: str) -> str: + return self.getheaders(key) + + +class _MockOriginalResponse: + """ + We have to jump through some hoops to present the response as if + it was made using urllib3. + """ + + def __init__(self, headers: typing.List[typing.Tuple[bytes, bytes]]) -> None: + self.msg = _HeaderDict(headers) + self.closed = False + + def isclosed(self) -> bool: + return self.closed + + +class _Upgrade(Exception): + def __init__(self, session: "WebSocketTestSession") -> None: + self.session = session + + +def _get_reason_phrase(status_code: int) -> str: + try: + return http.HTTPStatus(status_code).phrase + except ValueError: + return "" + + +class _ASGIAdapter(requests.adapters.HTTPAdapter): + def __init__(self, app: ASGIApp, raise_server_exceptions: bool = True) -> None: self.app = app - self.port = port + self.raise_server_exceptions = raise_server_exceptions - async def _local_request(self, method, uri, cookies=None, *args, **kwargs): - import aiohttp + def send( # type: ignore + self, request: requests.PreparedRequest, *args: typing.Any, **kwargs: typing.Any + ) -> requests.Response: + scheme, netloc, path, params, query, fragement = urlparse( # type: ignore + request.url + ) - if uri.startswith(("http:", "https:", "ftp:", "ftps://" "//")): - url = uri + default_port = {"http": 80, "ws": 80, "https": 443, "wss": 443}[scheme] + + if ":" in netloc: + host, port_string = netloc.split(":", 1) + port = int(port_string) else: - url = "http://{host}:{port}{uri}".format( - host=HOST, port=self.port, uri=uri - ) + host = netloc + port = default_port - logger.info(url) - conn = aiohttp.TCPConnector(ssl=False) - async with aiohttp.ClientSession( - cookies=cookies, connector=conn - ) as session: - async with getattr(session, method.lower())( - url, *args, **kwargs - ) as response: - try: - response.text = await response.text() - except UnicodeDecodeError: - response.text = None + # Include the 'host' header. + if "host" in request.headers: + headers = [] # type: typing.List[typing.Tuple[bytes, bytes]] + elif port == default_port: + headers = [(b"host", host.encode())] + else: + headers = [(b"host", ("%s:%d" % (host, port)).encode())] - try: - response.json = await response.json() - except ( - JSONDecodeError, - UnicodeDecodeError, - aiohttp.ClientResponseError, - ): - response.json = None + # Include other request headers. + headers += [ + (key.lower().encode(), value.encode()) + for key, value in request.headers.items() + ] - response.body = await response.read() - return response - - def _sanic_endpoint_test( - self, - method="get", - uri="/", - gather_request=True, - debug=False, - server_kwargs={"auto_reload": False}, - *request_args, - **request_kwargs - ): - results = [None, None] - exceptions = [] - - if gather_request: - - def _collect_request(request): - if results[0] is None: - results[0] = request - - self.app.request_middleware.appendleft(_collect_request) - - @self.app.exception(MethodNotSupported) - async def error_handler(request, exception): - if request.method in ["HEAD", "PATCH", "PUT", "DELETE"]: - return text( - "", exception.status_code, headers=exception.headers - ) + if scheme in {"ws", "wss"}: + subprotocol = request.headers.get("sec-websocket-protocol", None) + if subprotocol is None: + subprotocols = [] # type: typing.Sequence[str] else: - return self.app.error_handler.default(request, exception) + subprotocols = [value.strip() for value in subprotocol.split(",")] + scope = { + "type": "websocket", + "path": unquote(path), + "root_path": "", + "scheme": scheme, + "query_string": query.encode(), + "headers": headers, + "client": ["testclient", 50000], + "server": [host, port], + "subprotocols": subprotocols, + } + session = WebSocketTestSession(self.app, scope) + raise _Upgrade(session) - @self.app.listener("after_server_start") - async def _collect_response(sanic, loop): - try: - response = await self._local_request( - method, uri, *request_args, **request_kwargs + scope = { + "type": "http", + "http_version": "1.1", + "method": request.method, + "path": unquote(path), + "root_path": "", + "scheme": scheme, + "query_string": query.encode(), + "headers": headers, + "client": ["testclient", 50000], + "server": [host, port], + "extensions": {"http.response.template": {}}, + } + + async def receive() -> Message: + nonlocal request_complete, response_complete + + if request_complete: + while not response_complete: + await asyncio.sleep(0.0001) + return {"type": "http.disconnect"} + + body = request.body + if isinstance(body, str): + body_bytes = body.encode("utf-8") # type: bytes + elif body is None: + body_bytes = b"" + elif isinstance(body, types.GeneratorType): + try: + chunk = body.send(None) + if isinstance(chunk, str): + chunk = chunk.encode("utf-8") + return {"type": "http.request", "body": chunk, "more_body": True} + except StopIteration: + request_complete = True + return {"type": "http.request", "body": b""} + else: + body_bytes = body + + request_complete = True + return {"type": "http.request", "body": body_bytes} + + async def send(message: Message) -> None: + nonlocal raw_kwargs, response_started, response_complete, template, context + + if message["type"] == "http.response.start": + assert ( + not response_started + ), 'Received multiple "http.response.start" messages.' + raw_kwargs["version"] = 11 + raw_kwargs["status"] = message["status"] + raw_kwargs["reason"] = _get_reason_phrase(message["status"]) + raw_kwargs["headers"] = [ + (key.decode(), value.decode()) for key, value in message["headers"] + ] + raw_kwargs["preload_content"] = False + raw_kwargs["original_response"] = _MockOriginalResponse( + raw_kwargs["headers"] ) - results[-1] = response - except Exception as e: - logger.exception("Exception") - exceptions.append(e) - self.app.stop() + response_started = True + elif message["type"] == "http.response.body": + assert ( + response_started + ), 'Received "http.response.body" without "http.response.start".' + assert ( + not response_complete + ), 'Received "http.response.body" after response completed.' + body = message.get("body", b"") + more_body = message.get("more_body", False) + if request.method != "HEAD": + raw_kwargs["body"].write(body) + if not more_body: + raw_kwargs["body"].seek(0) + response_complete = True + elif message["type"] == "http.response.template": + template = message["template"] + context = message["context"] - self.app.run(host=HOST, debug=debug, port=self.port, **server_kwargs) - self.app.listeners["after_server_start"].pop() + request_complete = False + response_started = False + response_complete = False + raw_kwargs = {"body": io.BytesIO()} # type: typing.Dict[str, typing.Any] + template = None + context = None - if exceptions: - raise ValueError("Exception during request: {}".format(exceptions)) + try: + loop = asyncio.get_event_loop() + except RuntimeError: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + self.app.is_running = True + try: + connection = self.app(scope) + loop.run_until_complete(connection(receive, send)) + except BaseException as exc: + if self.raise_server_exceptions: + raise exc from None + + if self.raise_server_exceptions: + assert response_started, "TestClient did not receive any response." + elif not response_started: + raw_kwargs = { + "version": 11, + "status": 500, + "reason": "Internal Server Error", + "headers": [], + "preload_content": False, + "original_response": _MockOriginalResponse([]), + "body": io.BytesIO(), + } + + raw = requests.packages.urllib3.HTTPResponse(**raw_kwargs) + response = self.build_response(request, raw) + if template is not None: + response.template = template + response.context = context + return response + + +class SanicTestClient(requests.Session): + __test__ = False # For pytest to not discover this up. + + def __init__( + self, + app: ASGIApp, + base_url: str = "http://%s:%d" % (HOST, PORT), + raise_server_exceptions: bool = True, + ) -> None: + super(SanicTestClient, self).__init__() + adapter = _ASGIAdapter(app, raise_server_exceptions=raise_server_exceptions) + self.mount("http://", adapter) + self.mount("https://", adapter) + self.mount("ws://", adapter) + self.mount("wss://", adapter) + self.headers.update({"user-agent": "testclient"}) + self.app = app + self.base_url = base_url + + def request( + self, + method: str, + url: str = '/', + params: Params = None, + data: DataType = None, + headers: typing.MutableMapping[str, str] = None, + cookies: Cookies = None, + files: FileType = None, + auth: AuthType = None, + timeout: TimeOut = None, + allow_redirects: bool = None, + proxies: typing.MutableMapping[str, str] = None, + hooks: typing.Any = None, + stream: bool = None, + verify: typing.Union[bool, str] = None, + cert: typing.Union[str, typing.Tuple[str, str]] = None, + json: typing.Any = None, + debug = None, + gather_request = True + ) -> requests.Response: + if debug is not None: + self.app.debug = debug + + url = urljoin(self.base_url, url) + response = super().request( + method, + url, + params=params, + data=data, + headers=headers, + cookies=cookies, + files=files, + auth=auth, + timeout=timeout, + allow_redirects=allow_redirects, + proxies=proxies, + hooks=hooks, + stream=stream, + verify=verify, + cert=cert, + json=json, + ) + + response.status = response.status_code + response.body = response.content + try: + response.json = response.json() + except: + response.json = None if gather_request: - try: - request, response = results - return request, response - except BaseException: - raise ValueError( - "Request and response object expected, got ({})".format( - results - ) - ) - else: - try: - return results[-1] - except BaseException: - raise ValueError( - "Request object expected, got ({})".format(results) - ) + request = response.request + parsed = urlparse(request.url) + request.scheme = parsed.scheme + request.path = parsed.path + request.args = parse_qs(parsed.query) + return request, response + + return response def get(self, *args, **kwargs): - return self._sanic_endpoint_test("get", *args, **kwargs) + if 'uri' in kwargs: + kwargs['url'] = kwargs.pop('uri') + return self.request("get", *args, **kwargs) def post(self, *args, **kwargs): - return self._sanic_endpoint_test("post", *args, **kwargs) + if 'uri' in kwargs: + kwargs['url'] = kwargs.pop('uri') + return self.request("post", *args, **kwargs) def put(self, *args, **kwargs): - return self._sanic_endpoint_test("put", *args, **kwargs) + if 'uri' in kwargs: + kwargs['url'] = kwargs.pop('uri') + return self.request("put", *args, **kwargs) def delete(self, *args, **kwargs): - return self._sanic_endpoint_test("delete", *args, **kwargs) + if 'uri' in kwargs: + kwargs['url'] = kwargs.pop('uri') + return self.request("delete", *args, **kwargs) def patch(self, *args, **kwargs): - return self._sanic_endpoint_test("patch", *args, **kwargs) + if 'uri' in kwargs: + kwargs['url'] = kwargs.pop('uri') + return self.request("patch", *args, **kwargs) def options(self, *args, **kwargs): - return self._sanic_endpoint_test("options", *args, **kwargs) + if 'uri' in kwargs: + kwargs['url'] = kwargs.pop('uri') + return self.request("options", *args, **kwargs) def head(self, *args, **kwargs): - return self._sanic_endpoint_test("head", *args, **kwargs) + if 'uri' in kwargs: + kwargs['url'] = kwargs.pop('uri') + return self.request("head", *args, **kwargs)