Merge branch 'master' into asgi-refactor-attempt

This commit is contained in:
Adam Hopkins
2019-05-06 12:59:56 +03:00
committed by GitHub
85 changed files with 2466 additions and 808 deletions

View File

@@ -2,6 +2,6 @@ from sanic.app import Sanic
from sanic.blueprints import Blueprint
__version__ = "18.12.0"
__version__ = "19.03.1"
__all__ = ["Sanic", "Blueprint"]

View File

@@ -16,6 +16,7 @@ from typing import Any, Optional, Type, Union
from urllib.parse import urlencode, urlunparse
from sanic import reloader_helpers
from sanic.blueprint_group import BlueprintGroup
from sanic.config import BASE_LOGO, Config
from sanic.constants import HTTP_METHODS
from sanic.exceptions import SanicException, ServerError, URLBuildError
@@ -181,27 +182,28 @@ class Sanic:
strict_slashes = self.strict_slashes
def response(handler):
args = [key for key in signature(handler).parameters.keys()]
if args:
if stream:
handler.is_stream = stream
args = list(signature(handler).parameters.keys())
self.router.add(
uri=uri,
methods=methods,
handler=handler,
host=host,
strict_slashes=strict_slashes,
version=version,
name=name,
)
return handler
else:
if not args:
raise ValueError(
"Required parameter `request` missing "
"in the {0}() route?".format(handler.__name__)
)
if stream:
handler.is_stream = stream
self.router.add(
uri=uri,
methods=methods,
handler=handler,
host=host,
strict_slashes=strict_slashes,
version=version,
name=name,
)
return handler
return response
# Shorthand method decorators
@@ -333,7 +335,7 @@ class Sanic:
name=None,
):
"""
Add an API URL under the **DELETE** *HTTP* method
Add an API URL under the **PATCH** *HTTP* method
:param uri: URL to be tagged to **PATCH** method of *HTTP*
:param host: Host IP or FQDN for the service to use
@@ -599,9 +601,11 @@ class Sanic:
:return: decorated method
"""
if attach_to == "request":
self.request_middleware.append(middleware)
if middleware not in self.request_middleware:
self.request_middleware.append(middleware)
if attach_to == "response":
self.response_middleware.appendleft(middleware)
if middleware not in self.response_middleware:
self.response_middleware.appendleft(middleware)
return middleware
# Decorator
@@ -683,7 +687,7 @@ class Sanic:
:param options: option dictionary with blueprint defaults
:return: Nothing
"""
if isinstance(blueprint, (list, tuple)):
if isinstance(blueprint, (list, tuple, BlueprintGroup)):
for item in blueprint:
self.blueprint(item, **options)
return
@@ -879,8 +883,6 @@ class Sanic:
# -------------------------------------------- #
# Request Middleware
# -------------------------------------------- #
request.app = self
response = await self._run_request_middleware(request)
# No middleware results
if not response:
@@ -1122,6 +1124,8 @@ class Sanic:
backlog: int = 100,
stop_event: Any = None,
access_log: Optional[bool] = None,
return_asyncio_server=False,
asyncio_server_kwargs=None,
) -> None:
"""
Asynchronous version of :func:`run`.
@@ -1155,6 +1159,13 @@ class Sanic:
:type stop_event: None
:param access_log: Enables writing access logs (slows server)
:type access_log: bool
:param return_asyncio_server: flag that defines whether there's a need
to return asyncio.Server or
start it serving right away
:type return_asyncio_server: bool
:param asyncio_server_kwargs: key-value arguments for
asyncio/uvloop create_server method
:type asyncio_server_kwargs: dict
:return: Nothing
"""
@@ -1185,7 +1196,7 @@ class Sanic:
loop=get_event_loop(),
protocol=protocol,
backlog=backlog,
run_async=True,
run_async=return_asyncio_server,
)
# Trigger before_start events
@@ -1194,7 +1205,9 @@ class Sanic:
server_settings.get("loop"),
)
return await serve(**server_settings)
return await serve(
asyncio_server_kwargs=asyncio_server_kwargs, **server_settings
)
async def trigger_events(self, events, loop):
"""Trigger events (functions or async)
@@ -1274,6 +1287,7 @@ class Sanic:
"port": port,
"sock": sock,
"ssl": ssl,
"app": self,
"signal": Signal(),
"debug": debug,
"request_handler": self.handle_request,

120
sanic/blueprint_group.py Normal file
View File

@@ -0,0 +1,120 @@
from collections import MutableSequence
class BlueprintGroup(MutableSequence):
"""
This class provides a mechanism to implement a Blueprint Group
using the `Blueprint.group` method. To avoid having to re-write
some of the existing implementation, this class provides a custom
iterator implementation that will let you use the object of this
class as a list/tuple inside the existing implementation.
"""
__slots__ = ("_blueprints", "_url_prefix")
def __init__(self, url_prefix=None):
"""
Create a new Blueprint Group
:param url_prefix: URL: to be prefixed before all the Blueprint Prefix
"""
self._blueprints = []
self._url_prefix = url_prefix
@property
def url_prefix(self):
"""
Retrieve the URL prefix being used for the Current Blueprint Group
:return: string with url prefix
"""
return self._url_prefix
@property
def blueprints(self):
"""
Retrieve a list of all the available blueprints under this group.
:return: List of Blueprint instance
"""
return self._blueprints
def __iter__(self):
"""Tun the class Blueprint Group into an Iterable item"""
return iter(self._blueprints)
def __getitem__(self, item):
"""
This method returns a blueprint inside the group specified by
an index value. This will enable indexing, splice and slicing
of the blueprint group like we can do with regular list/tuple.
This method is provided to ensure backward compatibility with
any of the pre-existing usage that might break.
:param item: Index of the Blueprint item in the group
:return: Blueprint object
"""
return self._blueprints[item]
def __setitem__(self, index: int, item: object) -> None:
"""
Abstract method implemented to turn the `BlueprintGroup` class
into a list like object to support all the existing behavior.
This method is used to perform the list's indexed setter operation.
:param index: Index to use for inserting a new Blueprint item
:param item: New `Blueprint` object.
:return: None
"""
self._blueprints[index] = item
def __delitem__(self, index: int) -> None:
"""
Abstract method implemented to turn the `BlueprintGroup` class
into a list like object to support all the existing behavior.
This method is used to delete an item from the list of blueprint
groups like it can be done on a regular list with index.
:param index: Index to use for removing a new Blueprint item
:return: None
"""
del self._blueprints[index]
def __len__(self) -> int:
"""
Get the Length of the blueprint group object.
:return: Length of Blueprint group object
"""
return len(self._blueprints)
def insert(self, index: int, item: object) -> None:
"""
The Abstract class `MutableSequence` leverages this insert method to
perform the `BlueprintGroup.append` operation.
:param index: Index to use for removing a new Blueprint item
:param item: New `Blueprint` object.
:return: None
"""
self._blueprints.insert(index, item)
def middleware(self, *args, **kwargs):
"""
A decorator that can be used to implement a Middleware plugin to
all of the Blueprints that belongs to this specific Blueprint Group.
In case of nested Blueprint Groups, the same middleware is applied
across each of the Blueprints recursively.
:param args: Optional positional Parameters to be use middleware
:param kwargs: Optional Keyword arg to use with Middleware
:return: Partial function to apply the middleware
"""
kwargs["bp_group"] = True
def register_middleware_for_blueprints(fn):
for blueprint in self.blueprints:
blueprint.middleware(fn, *args, **kwargs)
return register_middleware_for_blueprints

View File

@@ -1,5 +1,6 @@
from collections import defaultdict, namedtuple
from sanic.blueprint_group import BlueprintGroup
from sanic.constants import HTTP_METHODS
from sanic.views import CompositionView
@@ -78,10 +79,12 @@ class Blueprint:
for i in nested:
if isinstance(i, (list, tuple)):
yield from chain(i)
elif isinstance(i, BlueprintGroup):
yield from i.blueprints
else:
yield i
bps = []
bps = BlueprintGroup(url_prefix=url_prefix)
for bp in chain(blueprints):
if bp.url_prefix is None:
bp.url_prefix = ""
@@ -212,6 +215,7 @@ class Blueprint:
strict_slashes=None,
version=None,
name=None,
stream=False,
):
"""Create a blueprint route from a function.
@@ -224,6 +228,7 @@ class Blueprint:
training */*
:param version: Blueprint Version
:param name: user defined route name for url_for
:param stream: boolean specifying if the handler is a stream handler
:return: function or class instance
"""
# Handle HTTPMethodView differently
@@ -246,6 +251,7 @@ class Blueprint:
methods=methods,
host=host,
strict_slashes=strict_slashes,
stream=stream,
version=version,
name=name,
)(handler)
@@ -324,7 +330,13 @@ class Blueprint:
args = []
return register_middleware(middleware)
else:
return register_middleware
if kwargs.get("bp_group") and callable(args[0]):
middleware = args[0]
args = args[1:]
kwargs.pop("bp_group")
return register_middleware(middleware)
else:
return register_middleware
def exception(self, *args, **kwargs):
"""

View File

@@ -1,8 +1,6 @@
import os
import types
from distutils.util import strtobool
from sanic.exceptions import PyFileError
@@ -27,6 +25,9 @@ DEFAULT_CONFIG = {
"WEBSOCKET_WRITE_LIMIT": 2 ** 16,
"GRACEFUL_SHUTDOWN_TIMEOUT": 15.0, # 15 sec
"ACCESS_LOG": True,
"PROXIES_COUNT": -1,
"FORWARDED_FOR_HEADER": "X-Forwarded-For",
"REAL_IP_HEADER": "X-Real-IP",
}
@@ -127,6 +128,23 @@ class Config(dict):
self[config_key] = float(v)
except ValueError:
try:
self[config_key] = bool(strtobool(v))
self[config_key] = strtobool(v)
except ValueError:
self[config_key] = v
def strtobool(val):
"""
This function was borrowed from distutils.utils. While distutils
is part of stdlib, it feels odd to use distutils in main application code.
The function was modified to walk its talk and actually return bool
and not int.
"""
val = val.lower()
if val in ("y", "yes", "t", "true", "on", "1"):
return True
elif val in ("n", "no", "f", "false", "off", "0"):
return False
else:
raise ValueError("invalid truth value %r" % (val,))

View File

@@ -1,6 +1,10 @@
import re
import string
from datetime import datetime
DEFAULT_MAX_AGE = 0
# ------------------------------------------------------------ #
# SimpleCookie
@@ -103,6 +107,14 @@ class Cookie(dict):
if key not in self._keys:
raise KeyError("Unknown cookie property")
if value is not False:
if key.lower() == "max-age":
if not str(value).isdigit():
value = DEFAULT_MAX_AGE
elif key.lower() == "expires":
if not isinstance(value, datetime):
raise TypeError(
"Cookie 'expires' property must be a datetime"
)
return super().__setitem__(key, value)
def encode(self, encoding):
@@ -126,16 +138,10 @@ class Cookie(dict):
except TypeError:
output.append("%s=%s" % (self._keys[key], value))
elif key == "expires":
try:
output.append(
"%s=%s"
% (
self._keys[key],
value.strftime("%a, %d-%b-%Y %T GMT"),
)
)
except AttributeError:
output.append("%s=%s" % (self._keys[key], value))
output.append(
"%s=%s"
% (self._keys[key], value.strftime("%a, %d-%b-%Y %T GMT"))
)
elif key in self._flags and self[key]:
output.append(self._keys[key])
else:

View File

@@ -36,7 +36,15 @@ def _iter_module_files():
def _get_args_for_reloading():
"""Returns the executable."""
rv = [sys.executable]
rv.extend(sys.argv)
main_module = sys.modules["__main__"]
mod_spec = getattr(main_module, "__spec__", None)
if mod_spec:
# Parent exe was launched as a module rather than a script
rv.extend(["-m", mod_spec.name])
if len(sys.argv) > 1:
rv.extend(sys.argv[1:])
else:
rv.extend(sys.argv)
return rv
@@ -44,6 +52,7 @@ def restart_with_reloader():
"""Create a new process and a subprocess in it with the same arguments as
this one.
"""
cwd = os.getcwd()
args = _get_args_for_reloading()
new_environ = os.environ.copy()
new_environ["SANIC_SERVER_RUNNING"] = "true"
@@ -51,7 +60,7 @@ def restart_with_reloader():
worker_process = Process(
target=subprocess.call,
args=(cmd,),
kwargs=dict(shell=True, env=new_environ),
kwargs={"cwd": cwd, "shell": True, "env": new_environ},
)
worker_process.start()
return worker_process

View File

@@ -1,11 +1,13 @@
import asyncio
import email.utils
import json
import sys
import warnings
from cgi import parse_header
from collections import namedtuple
from collections import defaultdict, namedtuple
from http.cookies import SimpleCookie
from urllib.parse import parse_qs, urlunparse
from urllib.parse import parse_qs, parse_qsl, unquote, urlunparse
from httptools import parse_url
@@ -82,6 +84,7 @@ class Request(dict):
"headers",
"method",
"parsed_args",
"parsed_not_grouped_args",
"parsed_files",
"parsed_form",
"parsed_json",
@@ -92,11 +95,11 @@ class Request(dict):
"version",
)
def __init__(self, url_bytes, headers, version, method, transport):
def __init__(self, url_bytes, headers, version, method, transport, app):
self.raw_url = url_bytes
# TODO: Content-Encoding detection
self._parsed_url = parse_url(url_bytes)
self.app = None
self.app = app
self.headers = headers
self.version = version
@@ -108,15 +111,14 @@ class Request(dict):
self.parsed_json = None
self.parsed_form = None
self.parsed_files = None
self.parsed_args = None
self.parsed_args = defaultdict(RequestParameters)
self.parsed_not_grouped_args = defaultdict(list)
self.uri_template = None
self._cookies = None
self.stream = None
self.endpoint = None
def __repr__(self):
if self.method is None or not self.path:
return "<{0}>".format(self.__class__.__name__)
return "<{0}: {1} {2}>".format(
self.__class__.__name__, self.method, self.path
)
@@ -200,21 +202,117 @@ class Request(dict):
return self.parsed_files
@property
def args(self):
if self.parsed_args is None:
def get_args(
self,
keep_blank_values: bool = False,
strict_parsing: bool = False,
encoding: str = "utf-8",
errors: str = "replace",
) -> RequestParameters:
"""
Method to parse `query_string` using `urllib.parse.parse_qs`.
This methods is used by `args` property.
Can be used directly if you need to change default parameters.
:param keep_blank_values: flag indicating whether blank values in
percent-encoded queries should be treated as blank strings.
A true value indicates that blanks should be retained as blank
strings. The default false value indicates that blank values
are to be ignored and treated as if they were not included.
:type keep_blank_values: bool
:param strict_parsing: flag indicating what to do with parsing errors.
If false (the default), errors are silently ignored. If true,
errors raise a ValueError exception.
:type strict_parsing: bool
:param encoding: specify how to decode percent-encoded sequences
into Unicode characters, as accepted by the bytes.decode() method.
:type encoding: str
:param errors: specify how to decode percent-encoded sequences
into Unicode characters, as accepted by the bytes.decode() method.
:type errors: str
:return: RequestParameters
"""
if not self.parsed_args[
(keep_blank_values, strict_parsing, encoding, errors)
]:
if self.query_string:
self.parsed_args = RequestParameters(
parse_qs(self.query_string)
self.parsed_args[
(keep_blank_values, strict_parsing, encoding, errors)
] = RequestParameters(
parse_qs(
qs=self.query_string,
keep_blank_values=keep_blank_values,
strict_parsing=strict_parsing,
encoding=encoding,
errors=errors,
)
)
else:
self.parsed_args = RequestParameters()
return self.parsed_args
return self.parsed_args[
(keep_blank_values, strict_parsing, encoding, errors)
]
args = property(get_args)
@property
def raw_args(self):
def raw_args(self) -> dict:
if self.app.debug: # pragma: no cover
warnings.simplefilter("default")
warnings.warn(
"Use of raw_args will be deprecated in "
"the future versions. Please use args or query_args "
"properties instead",
DeprecationWarning,
)
return {k: v[0] for k, v in self.args.items()}
def get_query_args(
self,
keep_blank_values: bool = False,
strict_parsing: bool = False,
encoding: str = "utf-8",
errors: str = "replace",
) -> list:
"""
Method to parse `query_string` using `urllib.parse.parse_qsl`.
This methods is used by `query_args` property.
Can be used directly if you need to change default parameters.
:param keep_blank_values: flag indicating whether blank values in
percent-encoded queries should be treated as blank strings.
A true value indicates that blanks should be retained as blank
strings. The default false value indicates that blank values
are to be ignored and treated as if they were not included.
:type keep_blank_values: bool
:param strict_parsing: flag indicating what to do with parsing errors.
If false (the default), errors are silently ignored. If true,
errors raise a ValueError exception.
:type strict_parsing: bool
:param encoding: specify how to decode percent-encoded sequences
into Unicode characters, as accepted by the bytes.decode() method.
:type encoding: str
:param errors: specify how to decode percent-encoded sequences
into Unicode characters, as accepted by the bytes.decode() method.
:type errors: str
:return: list
"""
if not self.parsed_not_grouped_args[
(keep_blank_values, strict_parsing, encoding, errors)
]:
if self.query_string:
self.parsed_not_grouped_args[
(keep_blank_values, strict_parsing, encoding, errors)
] = parse_qsl(
qs=self.query_string,
keep_blank_values=keep_blank_values,
strict_parsing=strict_parsing,
encoding=encoding,
errors=errors,
)
return self.parsed_not_grouped_args[
(keep_blank_values, strict_parsing, encoding, errors)
]
query_args = property(get_query_args)
@property
def cookies(self):
if self._cookies is None:
@@ -257,19 +355,38 @@ class Request(dict):
@property
def remote_addr(self):
"""Attempt to return the original client ip based on X-Forwarded-For.
"""Attempt to return the original client ip based on X-Forwarded-For
or X-Real-IP. If HTTP headers are unavailable or untrusted, returns
an empty string.
:return: original client ip.
"""
if not hasattr(self, "_remote_addr"):
forwarded_for = self.headers.get("X-Forwarded-For", "").split(",")
remote_addrs = [
addr
for addr in [addr.strip() for addr in forwarded_for]
if addr
]
if len(remote_addrs) > 0:
self._remote_addr = remote_addrs[0]
if self.app.config.PROXIES_COUNT == 0:
self._remote_addr = ""
elif self.app.config.REAL_IP_HEADER and self.headers.get(
self.app.config.REAL_IP_HEADER
):
self._remote_addr = self.headers[
self.app.config.REAL_IP_HEADER
]
elif self.app.config.FORWARDED_FOR_HEADER:
forwarded_for = self.headers.get(
self.app.config.FORWARDED_FOR_HEADER, ""
).split(",")
remote_addrs = [
addr
for addr in [addr.strip() for addr in forwarded_for]
if addr
]
if self.app.config.PROXIES_COUNT == -1:
self._remote_addr = remote_addrs[0]
elif len(remote_addrs) >= self.app.config.PROXIES_COUNT:
self._remote_addr = remote_addrs[
-self.app.config.PROXIES_COUNT
]
else:
self._remote_addr = ""
else:
self._remote_addr = ""
return self._remote_addr
@@ -358,15 +475,28 @@ def parse_multipart_form(body, boundary):
)
if form_header_field == "content-disposition":
file_name = form_parameters.get("filename")
field_name = form_parameters.get("name")
file_name = form_parameters.get("filename")
# non-ASCII filenames in RFC2231, "filename*" format
if file_name is None and form_parameters.get("filename*"):
encoding, _, value = email.utils.decode_rfc2231(
form_parameters["filename*"]
)
file_name = unquote(value, encoding=encoding)
elif form_header_field == "content-type":
content_type = form_header_value
content_charset = form_parameters.get("charset", "utf-8")
if field_name:
post_data = form_part[line_index:-4]
if file_name:
if file_name is None:
value = post_data.decode(content_charset)
if field_name in fields:
fields[field_name].append(value)
else:
fields[field_name] = [value]
else:
form_file = File(
type=content_type, name=file_name, body=post_data
)
@@ -374,12 +504,6 @@ def parse_multipart_form(body, boundary):
files[field_name].append(form_file)
else:
files[field_name] = [form_file]
else:
value = post_data.decode(content_charset)
if field_name in fields:
fields[field_name].append(value)
else:
fields[field_name] = [value]
else:
logger.debug(
"Form-data field does not have a 'name' parameter "

View File

@@ -59,16 +59,23 @@ class StreamingHTTPResponse(BaseHTTPResponse):
"status",
"content_type",
"headers",
"chunked",
"_cookies",
)
def __init__(
self, streaming_fn, status=200, headers=None, content_type="text/plain"
self,
streaming_fn,
status=200,
headers=None,
content_type="text/plain",
chunked=True,
):
self.content_type = content_type
self.streaming_fn = streaming_fn
self.status = status
self.headers = CIMultiDict(headers or {})
self.chunked = chunked
self._cookies = None
async def write(self, data):
@@ -79,7 +86,10 @@ class StreamingHTTPResponse(BaseHTTPResponse):
if type(data) != bytes:
data = self._encode_body(data)
self.protocol.push_data(b"%x\r\n%b\r\n" % (len(data), data))
if self.chunked:
self.protocol.push_data(b"%x\r\n%b\r\n" % (len(data), data))
else:
self.protocol.push_data(data)
await self.protocol.drain()
async def stream(
@@ -88,6 +98,8 @@ class StreamingHTTPResponse(BaseHTTPResponse):
"""Streams headers, runs the `streaming_fn` callback that writes
content to the response body, then finalizes the response body.
"""
if version != "1.1":
self.chunked = False
headers = self.get_headers(
version,
keep_alive=keep_alive,
@@ -96,7 +108,8 @@ class StreamingHTTPResponse(BaseHTTPResponse):
self.protocol.push_data(headers)
await self.protocol.drain()
await self.streaming_fn(self)
self.protocol.push_data(b"0\r\n\r\n")
if self.chunked:
self.protocol.push_data(b"0\r\n\r\n")
# no need to await drain here after this write, because it is the
# very last thing we write and nothing needs to wait for it.
@@ -109,15 +122,16 @@ class StreamingHTTPResponse(BaseHTTPResponse):
if keep_alive and keep_alive_timeout is not None:
timeout_header = b"Keep-Alive: %d\r\n" % keep_alive_timeout
self.headers["Transfer-Encoding"] = "chunked"
self.headers.pop("Content-Length", None)
if self.chunked and version == "1.1":
self.headers["Transfer-Encoding"] = "chunked"
self.headers.pop("Content-Length", None)
self.headers["Content-Type"] = self.headers.get(
"Content-Type", self.content_type
)
headers = self._parse_headers()
if self.status is 200:
if self.status == 200:
status = b"OK"
else:
status = STATUS_CODES.get(self.status)
@@ -176,7 +190,7 @@ class HTTPResponse(BaseHTTPResponse):
headers = self._parse_headers()
if self.status is 200:
if self.status == 200:
status = b"OK"
else:
status = STATUS_CODES.get(self.status, b"UNKNOWN RESPONSE")
@@ -327,6 +341,7 @@ async def file_stream(
mime_type=None,
headers=None,
filename=None,
chunked=True,
_range=None,
):
"""Return a streaming response object with file data.
@@ -336,6 +351,7 @@ async def file_stream(
:param mime_type: Specific mime_type.
:param headers: Custom Headers.
:param filename: Override filename.
:param chunked: Enable or disable chunked transfer-encoding
:param _range:
"""
headers = headers or {}
@@ -383,6 +399,7 @@ async def file_stream(
status=status,
headers=headers,
content_type=mime_type,
chunked=chunked,
)
@@ -391,6 +408,7 @@ def stream(
status=200,
headers=None,
content_type="text/plain; charset=utf-8",
chunked=True,
):
"""Accepts an coroutine `streaming_fn` which can be used to
write chunks to a streaming response. Returns a `StreamingHTTPResponse`.
@@ -409,9 +427,14 @@ def stream(
writes content to that response.
:param mime_type: Specific mime_type.
:param headers: Custom Headers.
:param chunked: Enable or disable chunked transfer-encoding
"""
return StreamingHTTPResponse(
streaming_fn, headers=headers, content_type=content_type, status=status
streaming_fn,
headers=headers,
content_type=content_type,
status=status,
chunked=chunked,
)

View File

@@ -17,8 +17,8 @@ Parameter = namedtuple("Parameter", ["name", "cast"])
REGEX_TYPES = {
"string": (str, r"[^/]+"),
"int": (int, r"\d+"),
"number": (float, r"[0-9\\.]+"),
"int": (int, r"-?\d+"),
"number": (float, r"-?(?:\d+(?:\.\d*)?|\.\d+)"),
"alpha": (str, r"[A-Za-z]+"),
"path": (str, r"[^/].*?"),
"uuid": (

View File

@@ -34,9 +34,6 @@ except ImportError:
pass
current_time = None
class Signal:
stopped = False
@@ -47,6 +44,8 @@ class HttpProtocol(asyncio.Protocol):
"""
__slots__ = (
# app
"app",
# event loop, connection
"loop",
"transport",
@@ -91,6 +90,7 @@ class HttpProtocol(asyncio.Protocol):
self,
*,
loop,
app,
request_handler,
error_handler,
signal=Signal(),
@@ -110,6 +110,7 @@ class HttpProtocol(asyncio.Protocol):
**kwargs
):
self.loop = loop
self.app = app
self.transport = None
self.request = None
self.parser = None
@@ -118,7 +119,7 @@ class HttpProtocol(asyncio.Protocol):
self.router = router
self.signal = signal
self.access_log = access_log
self.connections = connections or set()
self.connections = connections if connections is not None else set()
self.request_handler = request_handler
self.error_handler = error_handler
self.request_timeout = request_timeout
@@ -171,7 +172,7 @@ class HttpProtocol(asyncio.Protocol):
self.request_timeout, self.request_timeout_callback
)
self.transport = transport
self._last_request_time = current_time
self._last_request_time = time()
def connection_lost(self, exc):
self.connections.discard(self)
@@ -197,7 +198,7 @@ class HttpProtocol(asyncio.Protocol):
# exactly what this timeout is checking for.
# Check if elapsed time since request initiated exceeds our
# configured maximum request timeout value
time_elapsed = current_time - self._last_request_time
time_elapsed = time() - self._last_request_time
if time_elapsed < self.request_timeout:
time_left = self.request_timeout - time_elapsed
self._request_timeout_handler = self.loop.call_later(
@@ -213,7 +214,7 @@ class HttpProtocol(asyncio.Protocol):
def response_timeout_callback(self):
# Check if elapsed time since response was initiated exceeds our
# configured maximum request timeout value
time_elapsed = current_time - self._last_request_time
time_elapsed = time() - self._last_request_time
if time_elapsed < self.response_timeout:
time_left = self.response_timeout - time_elapsed
self._response_timeout_handler = self.loop.call_later(
@@ -234,7 +235,7 @@ class HttpProtocol(asyncio.Protocol):
:return: None
"""
time_elapsed = current_time - self._last_response_time
time_elapsed = time() - self._last_response_time
if time_elapsed < self.keep_alive_timeout:
time_left = self.keep_alive_timeout - time_elapsed
self._keep_alive_timeout_handler = self.loop.call_later(
@@ -306,6 +307,7 @@ class HttpProtocol(asyncio.Protocol):
version=self.parser.get_http_version(),
method=self.parser.get_method().decode(),
transport=self.transport,
app=self.app,
)
# Remove any existing KeepAlive handler here,
# It will be recreated if required on the new request.
@@ -362,7 +364,7 @@ class HttpProtocol(asyncio.Protocol):
self._response_timeout_handler = self.loop.call_later(
self.response_timeout, self.response_timeout_callback
)
self._last_request_time = current_time
self._last_request_time = time()
self._request_handler_task = self.loop.create_task(
self.request_handler(
self.request, self.write_response, self.stream_response
@@ -449,7 +451,7 @@ class HttpProtocol(asyncio.Protocol):
self._keep_alive_timeout_handler = self.loop.call_later(
self.keep_alive_timeout, self.keep_alive_timeout_callback
)
self._last_response_time = current_time
self._last_response_time = time()
self.cleanup()
async def drain(self):
@@ -502,7 +504,7 @@ class HttpProtocol(asyncio.Protocol):
self._keep_alive_timeout_handler = self.loop.call_later(
self.keep_alive_timeout, self.keep_alive_timeout_callback
)
self._last_response_time = current_time
self._last_response_time = time()
self.cleanup()
def write_error(self, exception):
@@ -552,11 +554,15 @@ class HttpProtocol(asyncio.Protocol):
:return: None
"""
if from_error or self.transport.is_closing():
if from_error or self.transport is None or self.transport.is_closing():
logger.error(
"Transport closed @ %s and exception "
"experienced during error handling",
self.transport.get_extra_info("peername"),
(
self.transport.get_extra_info("peername")
if self.transport is not None
else "N/A"
),
)
logger.debug("Exception:", exc_info=True)
else:
@@ -595,18 +601,6 @@ class HttpProtocol(asyncio.Protocol):
self.transport = None
def update_current_time(loop):
"""Cache the current time, since it is needed at the end of every
keep-alive request to update the request timeout time
:param loop:
:return:
"""
global current_time
current_time = time()
loop.call_later(1, partial(update_current_time, loop))
def trigger_events(events, loop):
"""Trigger event callbacks (functions or async)
@@ -622,6 +616,7 @@ def trigger_events(events, loop):
def serve(
host,
port,
app,
request_handler,
error_handler,
before_start=None,
@@ -656,6 +651,7 @@ def serve(
websocket_write_limit=2 ** 16,
state=None,
graceful_shutdown_timeout=15.0,
asyncio_server_kwargs=None,
):
"""Start asynchronous HTTP Server on an individual process.
@@ -700,6 +696,8 @@ def serve(
:param router: Router object
:param graceful_shutdown_timeout: How long take to Force close non-idle
connection
:param asyncio_server_kwargs: key-value args for asyncio/uvloop
create_server method
:return: Nothing
"""
if not run_async:
@@ -716,6 +714,7 @@ def serve(
loop=loop,
connections=connections,
signal=signal,
app=app,
request_handler=request_handler,
error_handler=error_handler,
request_timeout=request_timeout,
@@ -734,7 +733,9 @@ def serve(
state=state,
debug=debug,
)
asyncio_server_kwargs = (
asyncio_server_kwargs if asyncio_server_kwargs else {}
)
server_coroutine = loop.create_server(
server,
host,
@@ -743,12 +744,9 @@ def serve(
reuse_port=reuse_port,
sock=sock,
backlog=backlog,
**asyncio_server_kwargs
)
# Instead of pulling time at the end of every request,
# pull it once per minute
loop.call_soon(partial(update_current_time, loop))
if run_async:
return server_coroutine

View File

@@ -1,4 +1,10 @@
from json import JSONDecodeError
from socket import socket
import requests_async as requests
import websockets
import asyncio
import http
import io
@@ -18,54 +24,10 @@ HOST = "127.0.0.1"
PORT = 42101
# Annotations for `Session.request()`
Cookies = typing.Union[
typing.MutableMapping[str, str], requests.cookies.RequestsCookieJar
]
Params = typing.Union[bytes, typing.MutableMapping[str, str]]
DataType = typing.Union[bytes, typing.MutableMapping[str, str], typing.IO]
TimeOut = typing.Union[float, typing.Tuple[float, float]]
FileType = typing.MutableMapping[str, typing.IO]
AuthType = typing.Union[
typing.Tuple[str, str],
requests.auth.AuthBase,
typing.Callable[[requests.Request], requests.Request],
]
class _HeaderDict(requests.packages.urllib3._collections.HTTPHeaderDict):
def get_all(self, key: str, default: str) -> str:
return self.getheaders(key)
class _MockOriginalResponse:
"""
We have to jump through some hoops to present the response as if
it was made using urllib3.
"""
def __init__(self, headers: typing.List[typing.Tuple[bytes, bytes]]) -> None:
self.msg = _HeaderDict(headers)
self.closed = False
def isclosed(self) -> bool:
return self.closed
class _Upgrade(Exception):
def __init__(self, session: "WebSocketTestSession") -> None:
self.session = session
def _get_reason_phrase(status_code: int) -> str:
try:
return http.HTTPStatus(status_code).phrase
except ValueError:
return ""
class _ASGIAdapter(requests.adapters.HTTPAdapter):
def __init__(self, app: ASGIApp, raise_server_exceptions: bool = True) -> None:
class SanicTestClient:
def __init__(self, app, port=PORT):
"""Use port=None to bind to a random port"""
self.app = app
self.raise_server_exceptions = raise_server_exceptions
@@ -76,22 +38,55 @@ class _ASGIAdapter(requests.adapters.HTTPAdapter):
request.url
)
default_port = {"http": 80, "ws": 80, "https": 443, "wss": 443}[scheme]
def get_new_session(self):
return requests.Session()
if ":" in netloc:
host, port_string = netloc.split(":", 1)
port = int(port_string)
else:
host = netloc
port = default_port
async def _local_request(self, method, url, *args, **kwargs):
logger.info(url)
raw_cookies = kwargs.pop("raw_cookies", None)
# Include the 'host' header.
if "host" in request.headers:
headers = [] # type: typing.List[typing.Tuple[bytes, bytes]]
elif port == default_port:
headers = [(b"host", host.encode())]
if method == "websocket":
async with websockets.connect(url, *args, **kwargs) as websocket:
websocket.opened = websocket.open
return websocket
else:
headers = [(b"host", ("%s:%d" % (host, port)).encode())]
async with self.get_new_session() as session:
try:
response = await getattr(session, method.lower())(
url, verify=False, *args, **kwargs
)
except NameError:
raise Exception(response.status_code)
try:
response.json = response.json()
except (JSONDecodeError, UnicodeDecodeError):
response.json = None
response.body = await response.read()
response.status = response.status_code
response.content_type = response.headers.get("content-type")
if raw_cookies:
response.raw_cookies = {}
for cookie in response.cookies:
response.raw_cookies[cookie.name] = cookie
return response
def _sanic_endpoint_test(
self,
method="get",
uri="/",
gather_request=True,
debug=False,
server_kwargs={"auto_reload": False},
*request_args,
**request_kwargs
):
results = [None, None]
exceptions = []
# Include other request headers.
headers += [
@@ -158,25 +153,31 @@ class _ASGIAdapter(requests.adapters.HTTPAdapter):
else:
body_bytes = body
request_complete = True
return {"type": "http.request", "body": body_bytes}
if self.port:
server_kwargs = dict(host=HOST, port=self.port, **server_kwargs)
host, port = HOST, self.port
else:
sock = socket()
sock.bind((HOST, 0))
server_kwargs = dict(sock=sock, **server_kwargs)
host, port = sock.getsockname()
async def send(message: Message) -> None:
nonlocal raw_kwargs, response_started, response_complete, template, context
if uri.startswith(
("http:", "https:", "ftp:", "ftps://", "//", "ws:", "wss:")
):
url = uri
else:
uri = uri if uri.startswith("/") else "/{uri}".format(uri=uri)
scheme = "ws" if method == "websocket" else "http"
url = "{scheme}://{host}:{port}{uri}".format(
scheme=scheme, host=host, port=port, uri=uri
)
if message["type"] == "http.response.start":
assert (
not response_started
), 'Received multiple "http.response.start" messages.'
raw_kwargs["version"] = 11
raw_kwargs["status"] = message["status"]
raw_kwargs["reason"] = _get_reason_phrase(message["status"])
raw_kwargs["headers"] = [
(key.decode(), value.decode()) for key, value in message["headers"]
]
raw_kwargs["preload_content"] = False
raw_kwargs["original_response"] = _MockOriginalResponse(
raw_kwargs["headers"]
@self.app.listener("after_server_start")
async def _collect_response(sanic, loop):
try:
response = await self._local_request(
method, url, *request_args, **request_kwargs
)
response_started = True
elif message["type"] == "http.response.body":
@@ -204,11 +205,9 @@ class _ASGIAdapter(requests.adapters.HTTPAdapter):
template = None
context = None
try:
loop = asyncio.get_event_loop()
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
self.app.run(debug=debug, **server_kwargs)
self.app.listeners["after_server_start"].pop()
self.app.is_running = True
try:
@@ -350,6 +349,7 @@ class SanicTestClient(requests.Session):
return self.request("options", *args, **kwargs)
def head(self, *args, **kwargs):
if 'uri' in kwargs:
kwargs['url'] = kwargs.pop('uri')
return self.request("head", *args, **kwargs)
return self._sanic_endpoint_test("head", *args, **kwargs)
def websocket(self, *args, **kwargs):
return self._sanic_endpoint_test("websocket", *args, **kwargs)