Merge branch 'master' into asgi-refactor-attempt
This commit is contained in:
@@ -2,6 +2,6 @@ from sanic.app import Sanic
|
||||
from sanic.blueprints import Blueprint
|
||||
|
||||
|
||||
__version__ = "18.12.0"
|
||||
__version__ = "19.03.1"
|
||||
|
||||
__all__ = ["Sanic", "Blueprint"]
|
||||
|
||||
60
sanic/app.py
60
sanic/app.py
@@ -16,6 +16,7 @@ from typing import Any, Optional, Type, Union
|
||||
from urllib.parse import urlencode, urlunparse
|
||||
|
||||
from sanic import reloader_helpers
|
||||
from sanic.blueprint_group import BlueprintGroup
|
||||
from sanic.config import BASE_LOGO, Config
|
||||
from sanic.constants import HTTP_METHODS
|
||||
from sanic.exceptions import SanicException, ServerError, URLBuildError
|
||||
@@ -181,27 +182,28 @@ class Sanic:
|
||||
strict_slashes = self.strict_slashes
|
||||
|
||||
def response(handler):
|
||||
args = [key for key in signature(handler).parameters.keys()]
|
||||
if args:
|
||||
if stream:
|
||||
handler.is_stream = stream
|
||||
args = list(signature(handler).parameters.keys())
|
||||
|
||||
self.router.add(
|
||||
uri=uri,
|
||||
methods=methods,
|
||||
handler=handler,
|
||||
host=host,
|
||||
strict_slashes=strict_slashes,
|
||||
version=version,
|
||||
name=name,
|
||||
)
|
||||
return handler
|
||||
else:
|
||||
if not args:
|
||||
raise ValueError(
|
||||
"Required parameter `request` missing "
|
||||
"in the {0}() route?".format(handler.__name__)
|
||||
)
|
||||
|
||||
if stream:
|
||||
handler.is_stream = stream
|
||||
|
||||
self.router.add(
|
||||
uri=uri,
|
||||
methods=methods,
|
||||
handler=handler,
|
||||
host=host,
|
||||
strict_slashes=strict_slashes,
|
||||
version=version,
|
||||
name=name,
|
||||
)
|
||||
return handler
|
||||
|
||||
return response
|
||||
|
||||
# Shorthand method decorators
|
||||
@@ -333,7 +335,7 @@ class Sanic:
|
||||
name=None,
|
||||
):
|
||||
"""
|
||||
Add an API URL under the **DELETE** *HTTP* method
|
||||
Add an API URL under the **PATCH** *HTTP* method
|
||||
|
||||
:param uri: URL to be tagged to **PATCH** method of *HTTP*
|
||||
:param host: Host IP or FQDN for the service to use
|
||||
@@ -599,9 +601,11 @@ class Sanic:
|
||||
:return: decorated method
|
||||
"""
|
||||
if attach_to == "request":
|
||||
self.request_middleware.append(middleware)
|
||||
if middleware not in self.request_middleware:
|
||||
self.request_middleware.append(middleware)
|
||||
if attach_to == "response":
|
||||
self.response_middleware.appendleft(middleware)
|
||||
if middleware not in self.response_middleware:
|
||||
self.response_middleware.appendleft(middleware)
|
||||
return middleware
|
||||
|
||||
# Decorator
|
||||
@@ -683,7 +687,7 @@ class Sanic:
|
||||
:param options: option dictionary with blueprint defaults
|
||||
:return: Nothing
|
||||
"""
|
||||
if isinstance(blueprint, (list, tuple)):
|
||||
if isinstance(blueprint, (list, tuple, BlueprintGroup)):
|
||||
for item in blueprint:
|
||||
self.blueprint(item, **options)
|
||||
return
|
||||
@@ -879,8 +883,6 @@ class Sanic:
|
||||
# -------------------------------------------- #
|
||||
# Request Middleware
|
||||
# -------------------------------------------- #
|
||||
|
||||
request.app = self
|
||||
response = await self._run_request_middleware(request)
|
||||
# No middleware results
|
||||
if not response:
|
||||
@@ -1122,6 +1124,8 @@ class Sanic:
|
||||
backlog: int = 100,
|
||||
stop_event: Any = None,
|
||||
access_log: Optional[bool] = None,
|
||||
return_asyncio_server=False,
|
||||
asyncio_server_kwargs=None,
|
||||
) -> None:
|
||||
"""
|
||||
Asynchronous version of :func:`run`.
|
||||
@@ -1155,6 +1159,13 @@ class Sanic:
|
||||
:type stop_event: None
|
||||
:param access_log: Enables writing access logs (slows server)
|
||||
:type access_log: bool
|
||||
:param return_asyncio_server: flag that defines whether there's a need
|
||||
to return asyncio.Server or
|
||||
start it serving right away
|
||||
:type return_asyncio_server: bool
|
||||
:param asyncio_server_kwargs: key-value arguments for
|
||||
asyncio/uvloop create_server method
|
||||
:type asyncio_server_kwargs: dict
|
||||
:return: Nothing
|
||||
"""
|
||||
|
||||
@@ -1185,7 +1196,7 @@ class Sanic:
|
||||
loop=get_event_loop(),
|
||||
protocol=protocol,
|
||||
backlog=backlog,
|
||||
run_async=True,
|
||||
run_async=return_asyncio_server,
|
||||
)
|
||||
|
||||
# Trigger before_start events
|
||||
@@ -1194,7 +1205,9 @@ class Sanic:
|
||||
server_settings.get("loop"),
|
||||
)
|
||||
|
||||
return await serve(**server_settings)
|
||||
return await serve(
|
||||
asyncio_server_kwargs=asyncio_server_kwargs, **server_settings
|
||||
)
|
||||
|
||||
async def trigger_events(self, events, loop):
|
||||
"""Trigger events (functions or async)
|
||||
@@ -1274,6 +1287,7 @@ class Sanic:
|
||||
"port": port,
|
||||
"sock": sock,
|
||||
"ssl": ssl,
|
||||
"app": self,
|
||||
"signal": Signal(),
|
||||
"debug": debug,
|
||||
"request_handler": self.handle_request,
|
||||
|
||||
120
sanic/blueprint_group.py
Normal file
120
sanic/blueprint_group.py
Normal file
@@ -0,0 +1,120 @@
|
||||
from collections import MutableSequence
|
||||
|
||||
|
||||
class BlueprintGroup(MutableSequence):
|
||||
"""
|
||||
This class provides a mechanism to implement a Blueprint Group
|
||||
using the `Blueprint.group` method. To avoid having to re-write
|
||||
some of the existing implementation, this class provides a custom
|
||||
iterator implementation that will let you use the object of this
|
||||
class as a list/tuple inside the existing implementation.
|
||||
"""
|
||||
|
||||
__slots__ = ("_blueprints", "_url_prefix")
|
||||
|
||||
def __init__(self, url_prefix=None):
|
||||
"""
|
||||
Create a new Blueprint Group
|
||||
|
||||
:param url_prefix: URL: to be prefixed before all the Blueprint Prefix
|
||||
"""
|
||||
self._blueprints = []
|
||||
self._url_prefix = url_prefix
|
||||
|
||||
@property
|
||||
def url_prefix(self):
|
||||
"""
|
||||
Retrieve the URL prefix being used for the Current Blueprint Group
|
||||
:return: string with url prefix
|
||||
"""
|
||||
return self._url_prefix
|
||||
|
||||
@property
|
||||
def blueprints(self):
|
||||
"""
|
||||
Retrieve a list of all the available blueprints under this group.
|
||||
:return: List of Blueprint instance
|
||||
"""
|
||||
return self._blueprints
|
||||
|
||||
def __iter__(self):
|
||||
"""Tun the class Blueprint Group into an Iterable item"""
|
||||
return iter(self._blueprints)
|
||||
|
||||
def __getitem__(self, item):
|
||||
"""
|
||||
This method returns a blueprint inside the group specified by
|
||||
an index value. This will enable indexing, splice and slicing
|
||||
of the blueprint group like we can do with regular list/tuple.
|
||||
|
||||
This method is provided to ensure backward compatibility with
|
||||
any of the pre-existing usage that might break.
|
||||
|
||||
:param item: Index of the Blueprint item in the group
|
||||
:return: Blueprint object
|
||||
"""
|
||||
return self._blueprints[item]
|
||||
|
||||
def __setitem__(self, index: int, item: object) -> None:
|
||||
"""
|
||||
Abstract method implemented to turn the `BlueprintGroup` class
|
||||
into a list like object to support all the existing behavior.
|
||||
|
||||
This method is used to perform the list's indexed setter operation.
|
||||
|
||||
:param index: Index to use for inserting a new Blueprint item
|
||||
:param item: New `Blueprint` object.
|
||||
:return: None
|
||||
"""
|
||||
self._blueprints[index] = item
|
||||
|
||||
def __delitem__(self, index: int) -> None:
|
||||
"""
|
||||
Abstract method implemented to turn the `BlueprintGroup` class
|
||||
into a list like object to support all the existing behavior.
|
||||
|
||||
This method is used to delete an item from the list of blueprint
|
||||
groups like it can be done on a regular list with index.
|
||||
|
||||
:param index: Index to use for removing a new Blueprint item
|
||||
:return: None
|
||||
"""
|
||||
del self._blueprints[index]
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""
|
||||
Get the Length of the blueprint group object.
|
||||
:return: Length of Blueprint group object
|
||||
"""
|
||||
return len(self._blueprints)
|
||||
|
||||
def insert(self, index: int, item: object) -> None:
|
||||
"""
|
||||
The Abstract class `MutableSequence` leverages this insert method to
|
||||
perform the `BlueprintGroup.append` operation.
|
||||
|
||||
:param index: Index to use for removing a new Blueprint item
|
||||
:param item: New `Blueprint` object.
|
||||
:return: None
|
||||
"""
|
||||
self._blueprints.insert(index, item)
|
||||
|
||||
def middleware(self, *args, **kwargs):
|
||||
"""
|
||||
A decorator that can be used to implement a Middleware plugin to
|
||||
all of the Blueprints that belongs to this specific Blueprint Group.
|
||||
|
||||
In case of nested Blueprint Groups, the same middleware is applied
|
||||
across each of the Blueprints recursively.
|
||||
|
||||
:param args: Optional positional Parameters to be use middleware
|
||||
:param kwargs: Optional Keyword arg to use with Middleware
|
||||
:return: Partial function to apply the middleware
|
||||
"""
|
||||
kwargs["bp_group"] = True
|
||||
|
||||
def register_middleware_for_blueprints(fn):
|
||||
for blueprint in self.blueprints:
|
||||
blueprint.middleware(fn, *args, **kwargs)
|
||||
|
||||
return register_middleware_for_blueprints
|
||||
@@ -1,5 +1,6 @@
|
||||
from collections import defaultdict, namedtuple
|
||||
|
||||
from sanic.blueprint_group import BlueprintGroup
|
||||
from sanic.constants import HTTP_METHODS
|
||||
from sanic.views import CompositionView
|
||||
|
||||
@@ -78,10 +79,12 @@ class Blueprint:
|
||||
for i in nested:
|
||||
if isinstance(i, (list, tuple)):
|
||||
yield from chain(i)
|
||||
elif isinstance(i, BlueprintGroup):
|
||||
yield from i.blueprints
|
||||
else:
|
||||
yield i
|
||||
|
||||
bps = []
|
||||
bps = BlueprintGroup(url_prefix=url_prefix)
|
||||
for bp in chain(blueprints):
|
||||
if bp.url_prefix is None:
|
||||
bp.url_prefix = ""
|
||||
@@ -212,6 +215,7 @@ class Blueprint:
|
||||
strict_slashes=None,
|
||||
version=None,
|
||||
name=None,
|
||||
stream=False,
|
||||
):
|
||||
"""Create a blueprint route from a function.
|
||||
|
||||
@@ -224,6 +228,7 @@ class Blueprint:
|
||||
training */*
|
||||
:param version: Blueprint Version
|
||||
:param name: user defined route name for url_for
|
||||
:param stream: boolean specifying if the handler is a stream handler
|
||||
:return: function or class instance
|
||||
"""
|
||||
# Handle HTTPMethodView differently
|
||||
@@ -246,6 +251,7 @@ class Blueprint:
|
||||
methods=methods,
|
||||
host=host,
|
||||
strict_slashes=strict_slashes,
|
||||
stream=stream,
|
||||
version=version,
|
||||
name=name,
|
||||
)(handler)
|
||||
@@ -324,7 +330,13 @@ class Blueprint:
|
||||
args = []
|
||||
return register_middleware(middleware)
|
||||
else:
|
||||
return register_middleware
|
||||
if kwargs.get("bp_group") and callable(args[0]):
|
||||
middleware = args[0]
|
||||
args = args[1:]
|
||||
kwargs.pop("bp_group")
|
||||
return register_middleware(middleware)
|
||||
else:
|
||||
return register_middleware
|
||||
|
||||
def exception(self, *args, **kwargs):
|
||||
"""
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
import os
|
||||
import types
|
||||
|
||||
from distutils.util import strtobool
|
||||
|
||||
from sanic.exceptions import PyFileError
|
||||
|
||||
|
||||
@@ -27,6 +25,9 @@ DEFAULT_CONFIG = {
|
||||
"WEBSOCKET_WRITE_LIMIT": 2 ** 16,
|
||||
"GRACEFUL_SHUTDOWN_TIMEOUT": 15.0, # 15 sec
|
||||
"ACCESS_LOG": True,
|
||||
"PROXIES_COUNT": -1,
|
||||
"FORWARDED_FOR_HEADER": "X-Forwarded-For",
|
||||
"REAL_IP_HEADER": "X-Real-IP",
|
||||
}
|
||||
|
||||
|
||||
@@ -127,6 +128,23 @@ class Config(dict):
|
||||
self[config_key] = float(v)
|
||||
except ValueError:
|
||||
try:
|
||||
self[config_key] = bool(strtobool(v))
|
||||
self[config_key] = strtobool(v)
|
||||
except ValueError:
|
||||
self[config_key] = v
|
||||
|
||||
|
||||
def strtobool(val):
|
||||
"""
|
||||
This function was borrowed from distutils.utils. While distutils
|
||||
is part of stdlib, it feels odd to use distutils in main application code.
|
||||
|
||||
The function was modified to walk its talk and actually return bool
|
||||
and not int.
|
||||
"""
|
||||
val = val.lower()
|
||||
if val in ("y", "yes", "t", "true", "on", "1"):
|
||||
return True
|
||||
elif val in ("n", "no", "f", "false", "off", "0"):
|
||||
return False
|
||||
else:
|
||||
raise ValueError("invalid truth value %r" % (val,))
|
||||
|
||||
@@ -1,6 +1,10 @@
|
||||
import re
|
||||
import string
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
DEFAULT_MAX_AGE = 0
|
||||
|
||||
# ------------------------------------------------------------ #
|
||||
# SimpleCookie
|
||||
@@ -103,6 +107,14 @@ class Cookie(dict):
|
||||
if key not in self._keys:
|
||||
raise KeyError("Unknown cookie property")
|
||||
if value is not False:
|
||||
if key.lower() == "max-age":
|
||||
if not str(value).isdigit():
|
||||
value = DEFAULT_MAX_AGE
|
||||
elif key.lower() == "expires":
|
||||
if not isinstance(value, datetime):
|
||||
raise TypeError(
|
||||
"Cookie 'expires' property must be a datetime"
|
||||
)
|
||||
return super().__setitem__(key, value)
|
||||
|
||||
def encode(self, encoding):
|
||||
@@ -126,16 +138,10 @@ class Cookie(dict):
|
||||
except TypeError:
|
||||
output.append("%s=%s" % (self._keys[key], value))
|
||||
elif key == "expires":
|
||||
try:
|
||||
output.append(
|
||||
"%s=%s"
|
||||
% (
|
||||
self._keys[key],
|
||||
value.strftime("%a, %d-%b-%Y %T GMT"),
|
||||
)
|
||||
)
|
||||
except AttributeError:
|
||||
output.append("%s=%s" % (self._keys[key], value))
|
||||
output.append(
|
||||
"%s=%s"
|
||||
% (self._keys[key], value.strftime("%a, %d-%b-%Y %T GMT"))
|
||||
)
|
||||
elif key in self._flags and self[key]:
|
||||
output.append(self._keys[key])
|
||||
else:
|
||||
|
||||
@@ -36,7 +36,15 @@ def _iter_module_files():
|
||||
def _get_args_for_reloading():
|
||||
"""Returns the executable."""
|
||||
rv = [sys.executable]
|
||||
rv.extend(sys.argv)
|
||||
main_module = sys.modules["__main__"]
|
||||
mod_spec = getattr(main_module, "__spec__", None)
|
||||
if mod_spec:
|
||||
# Parent exe was launched as a module rather than a script
|
||||
rv.extend(["-m", mod_spec.name])
|
||||
if len(sys.argv) > 1:
|
||||
rv.extend(sys.argv[1:])
|
||||
else:
|
||||
rv.extend(sys.argv)
|
||||
return rv
|
||||
|
||||
|
||||
@@ -44,6 +52,7 @@ def restart_with_reloader():
|
||||
"""Create a new process and a subprocess in it with the same arguments as
|
||||
this one.
|
||||
"""
|
||||
cwd = os.getcwd()
|
||||
args = _get_args_for_reloading()
|
||||
new_environ = os.environ.copy()
|
||||
new_environ["SANIC_SERVER_RUNNING"] = "true"
|
||||
@@ -51,7 +60,7 @@ def restart_with_reloader():
|
||||
worker_process = Process(
|
||||
target=subprocess.call,
|
||||
args=(cmd,),
|
||||
kwargs=dict(shell=True, env=new_environ),
|
||||
kwargs={"cwd": cwd, "shell": True, "env": new_environ},
|
||||
)
|
||||
worker_process.start()
|
||||
return worker_process
|
||||
|
||||
190
sanic/request.py
190
sanic/request.py
@@ -1,11 +1,13 @@
|
||||
import asyncio
|
||||
import email.utils
|
||||
import json
|
||||
import sys
|
||||
import warnings
|
||||
|
||||
from cgi import parse_header
|
||||
from collections import namedtuple
|
||||
from collections import defaultdict, namedtuple
|
||||
from http.cookies import SimpleCookie
|
||||
from urllib.parse import parse_qs, urlunparse
|
||||
from urllib.parse import parse_qs, parse_qsl, unquote, urlunparse
|
||||
|
||||
from httptools import parse_url
|
||||
|
||||
@@ -82,6 +84,7 @@ class Request(dict):
|
||||
"headers",
|
||||
"method",
|
||||
"parsed_args",
|
||||
"parsed_not_grouped_args",
|
||||
"parsed_files",
|
||||
"parsed_form",
|
||||
"parsed_json",
|
||||
@@ -92,11 +95,11 @@ class Request(dict):
|
||||
"version",
|
||||
)
|
||||
|
||||
def __init__(self, url_bytes, headers, version, method, transport):
|
||||
def __init__(self, url_bytes, headers, version, method, transport, app):
|
||||
self.raw_url = url_bytes
|
||||
# TODO: Content-Encoding detection
|
||||
self._parsed_url = parse_url(url_bytes)
|
||||
self.app = None
|
||||
self.app = app
|
||||
|
||||
self.headers = headers
|
||||
self.version = version
|
||||
@@ -108,15 +111,14 @@ class Request(dict):
|
||||
self.parsed_json = None
|
||||
self.parsed_form = None
|
||||
self.parsed_files = None
|
||||
self.parsed_args = None
|
||||
self.parsed_args = defaultdict(RequestParameters)
|
||||
self.parsed_not_grouped_args = defaultdict(list)
|
||||
self.uri_template = None
|
||||
self._cookies = None
|
||||
self.stream = None
|
||||
self.endpoint = None
|
||||
|
||||
def __repr__(self):
|
||||
if self.method is None or not self.path:
|
||||
return "<{0}>".format(self.__class__.__name__)
|
||||
return "<{0}: {1} {2}>".format(
|
||||
self.__class__.__name__, self.method, self.path
|
||||
)
|
||||
@@ -200,21 +202,117 @@ class Request(dict):
|
||||
|
||||
return self.parsed_files
|
||||
|
||||
@property
|
||||
def args(self):
|
||||
if self.parsed_args is None:
|
||||
def get_args(
|
||||
self,
|
||||
keep_blank_values: bool = False,
|
||||
strict_parsing: bool = False,
|
||||
encoding: str = "utf-8",
|
||||
errors: str = "replace",
|
||||
) -> RequestParameters:
|
||||
"""
|
||||
Method to parse `query_string` using `urllib.parse.parse_qs`.
|
||||
This methods is used by `args` property.
|
||||
Can be used directly if you need to change default parameters.
|
||||
:param keep_blank_values: flag indicating whether blank values in
|
||||
percent-encoded queries should be treated as blank strings.
|
||||
A true value indicates that blanks should be retained as blank
|
||||
strings. The default false value indicates that blank values
|
||||
are to be ignored and treated as if they were not included.
|
||||
:type keep_blank_values: bool
|
||||
:param strict_parsing: flag indicating what to do with parsing errors.
|
||||
If false (the default), errors are silently ignored. If true,
|
||||
errors raise a ValueError exception.
|
||||
:type strict_parsing: bool
|
||||
:param encoding: specify how to decode percent-encoded sequences
|
||||
into Unicode characters, as accepted by the bytes.decode() method.
|
||||
:type encoding: str
|
||||
:param errors: specify how to decode percent-encoded sequences
|
||||
into Unicode characters, as accepted by the bytes.decode() method.
|
||||
:type errors: str
|
||||
:return: RequestParameters
|
||||
"""
|
||||
if not self.parsed_args[
|
||||
(keep_blank_values, strict_parsing, encoding, errors)
|
||||
]:
|
||||
if self.query_string:
|
||||
self.parsed_args = RequestParameters(
|
||||
parse_qs(self.query_string)
|
||||
self.parsed_args[
|
||||
(keep_blank_values, strict_parsing, encoding, errors)
|
||||
] = RequestParameters(
|
||||
parse_qs(
|
||||
qs=self.query_string,
|
||||
keep_blank_values=keep_blank_values,
|
||||
strict_parsing=strict_parsing,
|
||||
encoding=encoding,
|
||||
errors=errors,
|
||||
)
|
||||
)
|
||||
else:
|
||||
self.parsed_args = RequestParameters()
|
||||
return self.parsed_args
|
||||
|
||||
return self.parsed_args[
|
||||
(keep_blank_values, strict_parsing, encoding, errors)
|
||||
]
|
||||
|
||||
args = property(get_args)
|
||||
|
||||
@property
|
||||
def raw_args(self):
|
||||
def raw_args(self) -> dict:
|
||||
if self.app.debug: # pragma: no cover
|
||||
warnings.simplefilter("default")
|
||||
warnings.warn(
|
||||
"Use of raw_args will be deprecated in "
|
||||
"the future versions. Please use args or query_args "
|
||||
"properties instead",
|
||||
DeprecationWarning,
|
||||
)
|
||||
return {k: v[0] for k, v in self.args.items()}
|
||||
|
||||
def get_query_args(
|
||||
self,
|
||||
keep_blank_values: bool = False,
|
||||
strict_parsing: bool = False,
|
||||
encoding: str = "utf-8",
|
||||
errors: str = "replace",
|
||||
) -> list:
|
||||
"""
|
||||
Method to parse `query_string` using `urllib.parse.parse_qsl`.
|
||||
This methods is used by `query_args` property.
|
||||
Can be used directly if you need to change default parameters.
|
||||
:param keep_blank_values: flag indicating whether blank values in
|
||||
percent-encoded queries should be treated as blank strings.
|
||||
A true value indicates that blanks should be retained as blank
|
||||
strings. The default false value indicates that blank values
|
||||
are to be ignored and treated as if they were not included.
|
||||
:type keep_blank_values: bool
|
||||
:param strict_parsing: flag indicating what to do with parsing errors.
|
||||
If false (the default), errors are silently ignored. If true,
|
||||
errors raise a ValueError exception.
|
||||
:type strict_parsing: bool
|
||||
:param encoding: specify how to decode percent-encoded sequences
|
||||
into Unicode characters, as accepted by the bytes.decode() method.
|
||||
:type encoding: str
|
||||
:param errors: specify how to decode percent-encoded sequences
|
||||
into Unicode characters, as accepted by the bytes.decode() method.
|
||||
:type errors: str
|
||||
:return: list
|
||||
"""
|
||||
if not self.parsed_not_grouped_args[
|
||||
(keep_blank_values, strict_parsing, encoding, errors)
|
||||
]:
|
||||
if self.query_string:
|
||||
self.parsed_not_grouped_args[
|
||||
(keep_blank_values, strict_parsing, encoding, errors)
|
||||
] = parse_qsl(
|
||||
qs=self.query_string,
|
||||
keep_blank_values=keep_blank_values,
|
||||
strict_parsing=strict_parsing,
|
||||
encoding=encoding,
|
||||
errors=errors,
|
||||
)
|
||||
return self.parsed_not_grouped_args[
|
||||
(keep_blank_values, strict_parsing, encoding, errors)
|
||||
]
|
||||
|
||||
query_args = property(get_query_args)
|
||||
|
||||
@property
|
||||
def cookies(self):
|
||||
if self._cookies is None:
|
||||
@@ -257,19 +355,38 @@ class Request(dict):
|
||||
|
||||
@property
|
||||
def remote_addr(self):
|
||||
"""Attempt to return the original client ip based on X-Forwarded-For.
|
||||
"""Attempt to return the original client ip based on X-Forwarded-For
|
||||
or X-Real-IP. If HTTP headers are unavailable or untrusted, returns
|
||||
an empty string.
|
||||
|
||||
:return: original client ip.
|
||||
"""
|
||||
if not hasattr(self, "_remote_addr"):
|
||||
forwarded_for = self.headers.get("X-Forwarded-For", "").split(",")
|
||||
remote_addrs = [
|
||||
addr
|
||||
for addr in [addr.strip() for addr in forwarded_for]
|
||||
if addr
|
||||
]
|
||||
if len(remote_addrs) > 0:
|
||||
self._remote_addr = remote_addrs[0]
|
||||
if self.app.config.PROXIES_COUNT == 0:
|
||||
self._remote_addr = ""
|
||||
elif self.app.config.REAL_IP_HEADER and self.headers.get(
|
||||
self.app.config.REAL_IP_HEADER
|
||||
):
|
||||
self._remote_addr = self.headers[
|
||||
self.app.config.REAL_IP_HEADER
|
||||
]
|
||||
elif self.app.config.FORWARDED_FOR_HEADER:
|
||||
forwarded_for = self.headers.get(
|
||||
self.app.config.FORWARDED_FOR_HEADER, ""
|
||||
).split(",")
|
||||
remote_addrs = [
|
||||
addr
|
||||
for addr in [addr.strip() for addr in forwarded_for]
|
||||
if addr
|
||||
]
|
||||
if self.app.config.PROXIES_COUNT == -1:
|
||||
self._remote_addr = remote_addrs[0]
|
||||
elif len(remote_addrs) >= self.app.config.PROXIES_COUNT:
|
||||
self._remote_addr = remote_addrs[
|
||||
-self.app.config.PROXIES_COUNT
|
||||
]
|
||||
else:
|
||||
self._remote_addr = ""
|
||||
else:
|
||||
self._remote_addr = ""
|
||||
return self._remote_addr
|
||||
@@ -358,15 +475,28 @@ def parse_multipart_form(body, boundary):
|
||||
)
|
||||
|
||||
if form_header_field == "content-disposition":
|
||||
file_name = form_parameters.get("filename")
|
||||
field_name = form_parameters.get("name")
|
||||
file_name = form_parameters.get("filename")
|
||||
|
||||
# non-ASCII filenames in RFC2231, "filename*" format
|
||||
if file_name is None and form_parameters.get("filename*"):
|
||||
encoding, _, value = email.utils.decode_rfc2231(
|
||||
form_parameters["filename*"]
|
||||
)
|
||||
file_name = unquote(value, encoding=encoding)
|
||||
elif form_header_field == "content-type":
|
||||
content_type = form_header_value
|
||||
content_charset = form_parameters.get("charset", "utf-8")
|
||||
|
||||
if field_name:
|
||||
post_data = form_part[line_index:-4]
|
||||
if file_name:
|
||||
if file_name is None:
|
||||
value = post_data.decode(content_charset)
|
||||
if field_name in fields:
|
||||
fields[field_name].append(value)
|
||||
else:
|
||||
fields[field_name] = [value]
|
||||
else:
|
||||
form_file = File(
|
||||
type=content_type, name=file_name, body=post_data
|
||||
)
|
||||
@@ -374,12 +504,6 @@ def parse_multipart_form(body, boundary):
|
||||
files[field_name].append(form_file)
|
||||
else:
|
||||
files[field_name] = [form_file]
|
||||
else:
|
||||
value = post_data.decode(content_charset)
|
||||
if field_name in fields:
|
||||
fields[field_name].append(value)
|
||||
else:
|
||||
fields[field_name] = [value]
|
||||
else:
|
||||
logger.debug(
|
||||
"Form-data field does not have a 'name' parameter "
|
||||
|
||||
@@ -59,16 +59,23 @@ class StreamingHTTPResponse(BaseHTTPResponse):
|
||||
"status",
|
||||
"content_type",
|
||||
"headers",
|
||||
"chunked",
|
||||
"_cookies",
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self, streaming_fn, status=200, headers=None, content_type="text/plain"
|
||||
self,
|
||||
streaming_fn,
|
||||
status=200,
|
||||
headers=None,
|
||||
content_type="text/plain",
|
||||
chunked=True,
|
||||
):
|
||||
self.content_type = content_type
|
||||
self.streaming_fn = streaming_fn
|
||||
self.status = status
|
||||
self.headers = CIMultiDict(headers or {})
|
||||
self.chunked = chunked
|
||||
self._cookies = None
|
||||
|
||||
async def write(self, data):
|
||||
@@ -79,7 +86,10 @@ class StreamingHTTPResponse(BaseHTTPResponse):
|
||||
if type(data) != bytes:
|
||||
data = self._encode_body(data)
|
||||
|
||||
self.protocol.push_data(b"%x\r\n%b\r\n" % (len(data), data))
|
||||
if self.chunked:
|
||||
self.protocol.push_data(b"%x\r\n%b\r\n" % (len(data), data))
|
||||
else:
|
||||
self.protocol.push_data(data)
|
||||
await self.protocol.drain()
|
||||
|
||||
async def stream(
|
||||
@@ -88,6 +98,8 @@ class StreamingHTTPResponse(BaseHTTPResponse):
|
||||
"""Streams headers, runs the `streaming_fn` callback that writes
|
||||
content to the response body, then finalizes the response body.
|
||||
"""
|
||||
if version != "1.1":
|
||||
self.chunked = False
|
||||
headers = self.get_headers(
|
||||
version,
|
||||
keep_alive=keep_alive,
|
||||
@@ -96,7 +108,8 @@ class StreamingHTTPResponse(BaseHTTPResponse):
|
||||
self.protocol.push_data(headers)
|
||||
await self.protocol.drain()
|
||||
await self.streaming_fn(self)
|
||||
self.protocol.push_data(b"0\r\n\r\n")
|
||||
if self.chunked:
|
||||
self.protocol.push_data(b"0\r\n\r\n")
|
||||
# no need to await drain here after this write, because it is the
|
||||
# very last thing we write and nothing needs to wait for it.
|
||||
|
||||
@@ -109,15 +122,16 @@ class StreamingHTTPResponse(BaseHTTPResponse):
|
||||
if keep_alive and keep_alive_timeout is not None:
|
||||
timeout_header = b"Keep-Alive: %d\r\n" % keep_alive_timeout
|
||||
|
||||
self.headers["Transfer-Encoding"] = "chunked"
|
||||
self.headers.pop("Content-Length", None)
|
||||
if self.chunked and version == "1.1":
|
||||
self.headers["Transfer-Encoding"] = "chunked"
|
||||
self.headers.pop("Content-Length", None)
|
||||
self.headers["Content-Type"] = self.headers.get(
|
||||
"Content-Type", self.content_type
|
||||
)
|
||||
|
||||
headers = self._parse_headers()
|
||||
|
||||
if self.status is 200:
|
||||
if self.status == 200:
|
||||
status = b"OK"
|
||||
else:
|
||||
status = STATUS_CODES.get(self.status)
|
||||
@@ -176,7 +190,7 @@ class HTTPResponse(BaseHTTPResponse):
|
||||
|
||||
headers = self._parse_headers()
|
||||
|
||||
if self.status is 200:
|
||||
if self.status == 200:
|
||||
status = b"OK"
|
||||
else:
|
||||
status = STATUS_CODES.get(self.status, b"UNKNOWN RESPONSE")
|
||||
@@ -327,6 +341,7 @@ async def file_stream(
|
||||
mime_type=None,
|
||||
headers=None,
|
||||
filename=None,
|
||||
chunked=True,
|
||||
_range=None,
|
||||
):
|
||||
"""Return a streaming response object with file data.
|
||||
@@ -336,6 +351,7 @@ async def file_stream(
|
||||
:param mime_type: Specific mime_type.
|
||||
:param headers: Custom Headers.
|
||||
:param filename: Override filename.
|
||||
:param chunked: Enable or disable chunked transfer-encoding
|
||||
:param _range:
|
||||
"""
|
||||
headers = headers or {}
|
||||
@@ -383,6 +399,7 @@ async def file_stream(
|
||||
status=status,
|
||||
headers=headers,
|
||||
content_type=mime_type,
|
||||
chunked=chunked,
|
||||
)
|
||||
|
||||
|
||||
@@ -391,6 +408,7 @@ def stream(
|
||||
status=200,
|
||||
headers=None,
|
||||
content_type="text/plain; charset=utf-8",
|
||||
chunked=True,
|
||||
):
|
||||
"""Accepts an coroutine `streaming_fn` which can be used to
|
||||
write chunks to a streaming response. Returns a `StreamingHTTPResponse`.
|
||||
@@ -409,9 +427,14 @@ def stream(
|
||||
writes content to that response.
|
||||
:param mime_type: Specific mime_type.
|
||||
:param headers: Custom Headers.
|
||||
:param chunked: Enable or disable chunked transfer-encoding
|
||||
"""
|
||||
return StreamingHTTPResponse(
|
||||
streaming_fn, headers=headers, content_type=content_type, status=status
|
||||
streaming_fn,
|
||||
headers=headers,
|
||||
content_type=content_type,
|
||||
status=status,
|
||||
chunked=chunked,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -17,8 +17,8 @@ Parameter = namedtuple("Parameter", ["name", "cast"])
|
||||
|
||||
REGEX_TYPES = {
|
||||
"string": (str, r"[^/]+"),
|
||||
"int": (int, r"\d+"),
|
||||
"number": (float, r"[0-9\\.]+"),
|
||||
"int": (int, r"-?\d+"),
|
||||
"number": (float, r"-?(?:\d+(?:\.\d*)?|\.\d+)"),
|
||||
"alpha": (str, r"[A-Za-z]+"),
|
||||
"path": (str, r"[^/].*?"),
|
||||
"uuid": (
|
||||
|
||||
@@ -34,9 +34,6 @@ except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
current_time = None
|
||||
|
||||
|
||||
class Signal:
|
||||
stopped = False
|
||||
|
||||
@@ -47,6 +44,8 @@ class HttpProtocol(asyncio.Protocol):
|
||||
"""
|
||||
|
||||
__slots__ = (
|
||||
# app
|
||||
"app",
|
||||
# event loop, connection
|
||||
"loop",
|
||||
"transport",
|
||||
@@ -91,6 +90,7 @@ class HttpProtocol(asyncio.Protocol):
|
||||
self,
|
||||
*,
|
||||
loop,
|
||||
app,
|
||||
request_handler,
|
||||
error_handler,
|
||||
signal=Signal(),
|
||||
@@ -110,6 +110,7 @@ class HttpProtocol(asyncio.Protocol):
|
||||
**kwargs
|
||||
):
|
||||
self.loop = loop
|
||||
self.app = app
|
||||
self.transport = None
|
||||
self.request = None
|
||||
self.parser = None
|
||||
@@ -118,7 +119,7 @@ class HttpProtocol(asyncio.Protocol):
|
||||
self.router = router
|
||||
self.signal = signal
|
||||
self.access_log = access_log
|
||||
self.connections = connections or set()
|
||||
self.connections = connections if connections is not None else set()
|
||||
self.request_handler = request_handler
|
||||
self.error_handler = error_handler
|
||||
self.request_timeout = request_timeout
|
||||
@@ -171,7 +172,7 @@ class HttpProtocol(asyncio.Protocol):
|
||||
self.request_timeout, self.request_timeout_callback
|
||||
)
|
||||
self.transport = transport
|
||||
self._last_request_time = current_time
|
||||
self._last_request_time = time()
|
||||
|
||||
def connection_lost(self, exc):
|
||||
self.connections.discard(self)
|
||||
@@ -197,7 +198,7 @@ class HttpProtocol(asyncio.Protocol):
|
||||
# exactly what this timeout is checking for.
|
||||
# Check if elapsed time since request initiated exceeds our
|
||||
# configured maximum request timeout value
|
||||
time_elapsed = current_time - self._last_request_time
|
||||
time_elapsed = time() - self._last_request_time
|
||||
if time_elapsed < self.request_timeout:
|
||||
time_left = self.request_timeout - time_elapsed
|
||||
self._request_timeout_handler = self.loop.call_later(
|
||||
@@ -213,7 +214,7 @@ class HttpProtocol(asyncio.Protocol):
|
||||
def response_timeout_callback(self):
|
||||
# Check if elapsed time since response was initiated exceeds our
|
||||
# configured maximum request timeout value
|
||||
time_elapsed = current_time - self._last_request_time
|
||||
time_elapsed = time() - self._last_request_time
|
||||
if time_elapsed < self.response_timeout:
|
||||
time_left = self.response_timeout - time_elapsed
|
||||
self._response_timeout_handler = self.loop.call_later(
|
||||
@@ -234,7 +235,7 @@ class HttpProtocol(asyncio.Protocol):
|
||||
|
||||
:return: None
|
||||
"""
|
||||
time_elapsed = current_time - self._last_response_time
|
||||
time_elapsed = time() - self._last_response_time
|
||||
if time_elapsed < self.keep_alive_timeout:
|
||||
time_left = self.keep_alive_timeout - time_elapsed
|
||||
self._keep_alive_timeout_handler = self.loop.call_later(
|
||||
@@ -306,6 +307,7 @@ class HttpProtocol(asyncio.Protocol):
|
||||
version=self.parser.get_http_version(),
|
||||
method=self.parser.get_method().decode(),
|
||||
transport=self.transport,
|
||||
app=self.app,
|
||||
)
|
||||
# Remove any existing KeepAlive handler here,
|
||||
# It will be recreated if required on the new request.
|
||||
@@ -362,7 +364,7 @@ class HttpProtocol(asyncio.Protocol):
|
||||
self._response_timeout_handler = self.loop.call_later(
|
||||
self.response_timeout, self.response_timeout_callback
|
||||
)
|
||||
self._last_request_time = current_time
|
||||
self._last_request_time = time()
|
||||
self._request_handler_task = self.loop.create_task(
|
||||
self.request_handler(
|
||||
self.request, self.write_response, self.stream_response
|
||||
@@ -449,7 +451,7 @@ class HttpProtocol(asyncio.Protocol):
|
||||
self._keep_alive_timeout_handler = self.loop.call_later(
|
||||
self.keep_alive_timeout, self.keep_alive_timeout_callback
|
||||
)
|
||||
self._last_response_time = current_time
|
||||
self._last_response_time = time()
|
||||
self.cleanup()
|
||||
|
||||
async def drain(self):
|
||||
@@ -502,7 +504,7 @@ class HttpProtocol(asyncio.Protocol):
|
||||
self._keep_alive_timeout_handler = self.loop.call_later(
|
||||
self.keep_alive_timeout, self.keep_alive_timeout_callback
|
||||
)
|
||||
self._last_response_time = current_time
|
||||
self._last_response_time = time()
|
||||
self.cleanup()
|
||||
|
||||
def write_error(self, exception):
|
||||
@@ -552,11 +554,15 @@ class HttpProtocol(asyncio.Protocol):
|
||||
|
||||
:return: None
|
||||
"""
|
||||
if from_error or self.transport.is_closing():
|
||||
if from_error or self.transport is None or self.transport.is_closing():
|
||||
logger.error(
|
||||
"Transport closed @ %s and exception "
|
||||
"experienced during error handling",
|
||||
self.transport.get_extra_info("peername"),
|
||||
(
|
||||
self.transport.get_extra_info("peername")
|
||||
if self.transport is not None
|
||||
else "N/A"
|
||||
),
|
||||
)
|
||||
logger.debug("Exception:", exc_info=True)
|
||||
else:
|
||||
@@ -595,18 +601,6 @@ class HttpProtocol(asyncio.Protocol):
|
||||
self.transport = None
|
||||
|
||||
|
||||
def update_current_time(loop):
|
||||
"""Cache the current time, since it is needed at the end of every
|
||||
keep-alive request to update the request timeout time
|
||||
|
||||
:param loop:
|
||||
:return:
|
||||
"""
|
||||
global current_time
|
||||
current_time = time()
|
||||
loop.call_later(1, partial(update_current_time, loop))
|
||||
|
||||
|
||||
def trigger_events(events, loop):
|
||||
"""Trigger event callbacks (functions or async)
|
||||
|
||||
@@ -622,6 +616,7 @@ def trigger_events(events, loop):
|
||||
def serve(
|
||||
host,
|
||||
port,
|
||||
app,
|
||||
request_handler,
|
||||
error_handler,
|
||||
before_start=None,
|
||||
@@ -656,6 +651,7 @@ def serve(
|
||||
websocket_write_limit=2 ** 16,
|
||||
state=None,
|
||||
graceful_shutdown_timeout=15.0,
|
||||
asyncio_server_kwargs=None,
|
||||
):
|
||||
"""Start asynchronous HTTP Server on an individual process.
|
||||
|
||||
@@ -700,6 +696,8 @@ def serve(
|
||||
:param router: Router object
|
||||
:param graceful_shutdown_timeout: How long take to Force close non-idle
|
||||
connection
|
||||
:param asyncio_server_kwargs: key-value args for asyncio/uvloop
|
||||
create_server method
|
||||
:return: Nothing
|
||||
"""
|
||||
if not run_async:
|
||||
@@ -716,6 +714,7 @@ def serve(
|
||||
loop=loop,
|
||||
connections=connections,
|
||||
signal=signal,
|
||||
app=app,
|
||||
request_handler=request_handler,
|
||||
error_handler=error_handler,
|
||||
request_timeout=request_timeout,
|
||||
@@ -734,7 +733,9 @@ def serve(
|
||||
state=state,
|
||||
debug=debug,
|
||||
)
|
||||
|
||||
asyncio_server_kwargs = (
|
||||
asyncio_server_kwargs if asyncio_server_kwargs else {}
|
||||
)
|
||||
server_coroutine = loop.create_server(
|
||||
server,
|
||||
host,
|
||||
@@ -743,12 +744,9 @@ def serve(
|
||||
reuse_port=reuse_port,
|
||||
sock=sock,
|
||||
backlog=backlog,
|
||||
**asyncio_server_kwargs
|
||||
)
|
||||
|
||||
# Instead of pulling time at the end of every request,
|
||||
# pull it once per minute
|
||||
loop.call_soon(partial(update_current_time, loop))
|
||||
|
||||
if run_async:
|
||||
return server_coroutine
|
||||
|
||||
|
||||
170
sanic/testing.py
170
sanic/testing.py
@@ -1,4 +1,10 @@
|
||||
|
||||
from json import JSONDecodeError
|
||||
from socket import socket
|
||||
|
||||
import requests_async as requests
|
||||
import websockets
|
||||
|
||||
import asyncio
|
||||
import http
|
||||
import io
|
||||
@@ -18,54 +24,10 @@ HOST = "127.0.0.1"
|
||||
PORT = 42101
|
||||
|
||||
|
||||
# Annotations for `Session.request()`
|
||||
Cookies = typing.Union[
|
||||
typing.MutableMapping[str, str], requests.cookies.RequestsCookieJar
|
||||
]
|
||||
Params = typing.Union[bytes, typing.MutableMapping[str, str]]
|
||||
DataType = typing.Union[bytes, typing.MutableMapping[str, str], typing.IO]
|
||||
TimeOut = typing.Union[float, typing.Tuple[float, float]]
|
||||
FileType = typing.MutableMapping[str, typing.IO]
|
||||
AuthType = typing.Union[
|
||||
typing.Tuple[str, str],
|
||||
requests.auth.AuthBase,
|
||||
typing.Callable[[requests.Request], requests.Request],
|
||||
]
|
||||
|
||||
|
||||
class _HeaderDict(requests.packages.urllib3._collections.HTTPHeaderDict):
|
||||
def get_all(self, key: str, default: str) -> str:
|
||||
return self.getheaders(key)
|
||||
|
||||
|
||||
class _MockOriginalResponse:
|
||||
"""
|
||||
We have to jump through some hoops to present the response as if
|
||||
it was made using urllib3.
|
||||
"""
|
||||
|
||||
def __init__(self, headers: typing.List[typing.Tuple[bytes, bytes]]) -> None:
|
||||
self.msg = _HeaderDict(headers)
|
||||
self.closed = False
|
||||
|
||||
def isclosed(self) -> bool:
|
||||
return self.closed
|
||||
|
||||
|
||||
class _Upgrade(Exception):
|
||||
def __init__(self, session: "WebSocketTestSession") -> None:
|
||||
self.session = session
|
||||
|
||||
|
||||
def _get_reason_phrase(status_code: int) -> str:
|
||||
try:
|
||||
return http.HTTPStatus(status_code).phrase
|
||||
except ValueError:
|
||||
return ""
|
||||
|
||||
|
||||
class _ASGIAdapter(requests.adapters.HTTPAdapter):
|
||||
def __init__(self, app: ASGIApp, raise_server_exceptions: bool = True) -> None:
|
||||
class SanicTestClient:
|
||||
def __init__(self, app, port=PORT):
|
||||
"""Use port=None to bind to a random port"""
|
||||
self.app = app
|
||||
self.raise_server_exceptions = raise_server_exceptions
|
||||
|
||||
@@ -76,22 +38,55 @@ class _ASGIAdapter(requests.adapters.HTTPAdapter):
|
||||
request.url
|
||||
)
|
||||
|
||||
default_port = {"http": 80, "ws": 80, "https": 443, "wss": 443}[scheme]
|
||||
def get_new_session(self):
|
||||
return requests.Session()
|
||||
|
||||
if ":" in netloc:
|
||||
host, port_string = netloc.split(":", 1)
|
||||
port = int(port_string)
|
||||
else:
|
||||
host = netloc
|
||||
port = default_port
|
||||
async def _local_request(self, method, url, *args, **kwargs):
|
||||
logger.info(url)
|
||||
raw_cookies = kwargs.pop("raw_cookies", None)
|
||||
|
||||
# Include the 'host' header.
|
||||
if "host" in request.headers:
|
||||
headers = [] # type: typing.List[typing.Tuple[bytes, bytes]]
|
||||
elif port == default_port:
|
||||
headers = [(b"host", host.encode())]
|
||||
if method == "websocket":
|
||||
async with websockets.connect(url, *args, **kwargs) as websocket:
|
||||
websocket.opened = websocket.open
|
||||
return websocket
|
||||
else:
|
||||
headers = [(b"host", ("%s:%d" % (host, port)).encode())]
|
||||
async with self.get_new_session() as session:
|
||||
|
||||
try:
|
||||
response = await getattr(session, method.lower())(
|
||||
url, verify=False, *args, **kwargs
|
||||
)
|
||||
except NameError:
|
||||
raise Exception(response.status_code)
|
||||
|
||||
try:
|
||||
response.json = response.json()
|
||||
except (JSONDecodeError, UnicodeDecodeError):
|
||||
response.json = None
|
||||
|
||||
response.body = await response.read()
|
||||
response.status = response.status_code
|
||||
response.content_type = response.headers.get("content-type")
|
||||
|
||||
if raw_cookies:
|
||||
response.raw_cookies = {}
|
||||
for cookie in response.cookies:
|
||||
response.raw_cookies[cookie.name] = cookie
|
||||
|
||||
return response
|
||||
|
||||
def _sanic_endpoint_test(
|
||||
self,
|
||||
method="get",
|
||||
uri="/",
|
||||
gather_request=True,
|
||||
debug=False,
|
||||
server_kwargs={"auto_reload": False},
|
||||
*request_args,
|
||||
**request_kwargs
|
||||
):
|
||||
results = [None, None]
|
||||
exceptions = []
|
||||
|
||||
# Include other request headers.
|
||||
headers += [
|
||||
@@ -158,25 +153,31 @@ class _ASGIAdapter(requests.adapters.HTTPAdapter):
|
||||
else:
|
||||
body_bytes = body
|
||||
|
||||
request_complete = True
|
||||
return {"type": "http.request", "body": body_bytes}
|
||||
if self.port:
|
||||
server_kwargs = dict(host=HOST, port=self.port, **server_kwargs)
|
||||
host, port = HOST, self.port
|
||||
else:
|
||||
sock = socket()
|
||||
sock.bind((HOST, 0))
|
||||
server_kwargs = dict(sock=sock, **server_kwargs)
|
||||
host, port = sock.getsockname()
|
||||
|
||||
async def send(message: Message) -> None:
|
||||
nonlocal raw_kwargs, response_started, response_complete, template, context
|
||||
if uri.startswith(
|
||||
("http:", "https:", "ftp:", "ftps://", "//", "ws:", "wss:")
|
||||
):
|
||||
url = uri
|
||||
else:
|
||||
uri = uri if uri.startswith("/") else "/{uri}".format(uri=uri)
|
||||
scheme = "ws" if method == "websocket" else "http"
|
||||
url = "{scheme}://{host}:{port}{uri}".format(
|
||||
scheme=scheme, host=host, port=port, uri=uri
|
||||
)
|
||||
|
||||
if message["type"] == "http.response.start":
|
||||
assert (
|
||||
not response_started
|
||||
), 'Received multiple "http.response.start" messages.'
|
||||
raw_kwargs["version"] = 11
|
||||
raw_kwargs["status"] = message["status"]
|
||||
raw_kwargs["reason"] = _get_reason_phrase(message["status"])
|
||||
raw_kwargs["headers"] = [
|
||||
(key.decode(), value.decode()) for key, value in message["headers"]
|
||||
]
|
||||
raw_kwargs["preload_content"] = False
|
||||
raw_kwargs["original_response"] = _MockOriginalResponse(
|
||||
raw_kwargs["headers"]
|
||||
@self.app.listener("after_server_start")
|
||||
async def _collect_response(sanic, loop):
|
||||
try:
|
||||
response = await self._local_request(
|
||||
method, url, *request_args, **request_kwargs
|
||||
)
|
||||
response_started = True
|
||||
elif message["type"] == "http.response.body":
|
||||
@@ -204,11 +205,9 @@ class _ASGIAdapter(requests.adapters.HTTPAdapter):
|
||||
template = None
|
||||
context = None
|
||||
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
except RuntimeError:
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
self.app.run(debug=debug, **server_kwargs)
|
||||
self.app.listeners["after_server_start"].pop()
|
||||
|
||||
self.app.is_running = True
|
||||
try:
|
||||
@@ -350,6 +349,7 @@ class SanicTestClient(requests.Session):
|
||||
return self.request("options", *args, **kwargs)
|
||||
|
||||
def head(self, *args, **kwargs):
|
||||
if 'uri' in kwargs:
|
||||
kwargs['url'] = kwargs.pop('uri')
|
||||
return self.request("head", *args, **kwargs)
|
||||
return self._sanic_endpoint_test("head", *args, **kwargs)
|
||||
|
||||
def websocket(self, *args, **kwargs):
|
||||
return self._sanic_endpoint_test("websocket", *args, **kwargs)
|
||||
|
||||
Reference in New Issue
Block a user