Streaming Server (#1876)

* Streaming request by async for.

* Make all requests streaming and preload body for non-streaming handlers.

* Cleanup of code and avoid mixing streaming responses.

* Async http protocol loop.

* Change of test: don't require early bad request error but only after CRLF-CRLF.

* Add back streaming requests.

* Rewritten request body parser.

* Misc. cleanup, down to 4 failing tests.

* All tests OK.

* Entirely remove request body queue.

* Let black f*ckup the layout

* Better testing error messages on protocol errors.

* Remove StreamBuffer tests because the type is about to be removed.

* Remove tests using the deprecated get_headers function that can no longer be supported. Chunked mode is now autodetected, so do not put content-length header if chunked mode is preferred.

* Major refactoring of HTTP protocol handling (new module http.py added), all requests made streaming. A few compatibility issues and a lot of cleanup to be done remain, 16 tests failing.

* Terminate check_timeouts once connection_task finishes.

* Code cleanup, 14 tests failing.

* Much cleanup, 12 failing...

* Even more cleanup and error checking, 8 failing tests.

* Remove keep-alive header from responses. First of all, it should say timeout=<value> which wasn't the case with existing implementation, and secondly none of the other web servers I tried include this header.

* Everything but CustomServer OK.

* Linter

* Disable custom protocol test

* Remove unnecessary variables, optimise performance.

* A test was missing that body_init/body_push/body_finish are never called. Rewritten using receive_body and case switching to make it fail if bypassed.

* Minor fixes.

* Remove unused code.

* Py 3.8 check for deprecated loop argument.

* Fix a middleware cancellation handling test with py38.

* Linter 'n fixes

* Typing

* Stricter handling of request header size

* More specific error messages on Payload Too Large.

* Init http.response = None

* Messages further tuned.

* Always try to consume request body, plus minor cleanup.

* Add a missing check in case of close_if_idle on a dead connection.

* Avoid error messages on PayloadTooLarge.

* Add test for new API.

* json takes str, not bytes

* Default to no maximum request size for streaming handlers.

* Fix chunked mode crash.

* Header values should be strictly ASCII but both UTF-8 and Latin-1 exist. Use UTF-8B to
cope with all.

* Refactoring and cleanup.

* Unify response header processing of ASGI and asyncio modes.

* Avoid special handling of StreamingHTTPResponse.

* 35 % speedup in HTTP/1.1 response formatting (not so much overall effect).

* Duplicate set-cookie headers were being produced.

* Cleanup processed_headers some more.

* Linting

* Import ordering

* Response middleware ran by async request.respond().

* Need to check if transport is closing to avoid getting stuck in sending loops after peer has disconnected.

* Middleware and error handling refactoring.

* Linter

* Fix tracking of HTTP stage when writing to transport fails.

* Add clarifying comment

* Add a check for request body functions and a test for NotImplementedError.

* Linter and typing

* These must be tuples + hack mypy warnings away.

* New streaming test and minor fixes.

* Constant receive buffer size.

* 256 KiB send and receive buffers.

* Revert "256 KiB send and receive buffers."

This reverts commit abc1e3edb2.

* app.handle_exception already sends the response.

* Improved handling of errors during request.

* An odd hack to avoid an httpx limitation that causes test failures.

* Limit request header size to 8 KiB at most.

* Remove unnecessary use of format string.

* Cleanup tests

* Remove artifact

* Fix type checking

* Mark test for skipping

* Cleanup some edge cases

* Add ignore_body flag to safe methods

* Add unit tests for timeout logic

* Add unit tests for timeout logic

* Fix Mock usage in timeout test

* Change logging test to only logger in handler

* Windows py3.8 logging issue with current testing client

* Add test_header_size_exceeded

* Resolve merge conflicts

* Add request middleware to hard exception handling

* Add request middleware to hard exception handling

* Request middleware on exception handlers

* Linting

* Cleanup deprecations

Co-authored-by: L. Kärkkäinen <tronic@users.noreply.github.com>
Co-authored-by: Adam Hopkins <admhpkns@gmail.com>
This commit is contained in:
L. Kärkkäinen
2021-01-11 00:45:36 +02:00
committed by GitHub
parent 574a9c27a6
commit 7028eae083
35 changed files with 1372 additions and 1348 deletions

View File

@@ -4,24 +4,27 @@ import os
import re
from asyncio import CancelledError, Protocol, ensure_future, get_event_loop
from asyncio.futures import Future
from collections import defaultdict, deque
from functools import partial
from inspect import isawaitable, signature
from socket import socket
from ssl import Purpose, SSLContext, create_default_context
from traceback import format_exc
from typing import Any, Dict, Optional, Type, Union
from typing import Any, Dict, Iterable, List, Optional, Set, Type, Union
from urllib.parse import urlencode, urlunparse
from sanic import reloader_helpers
from sanic.asgi import ASGIApp
from sanic.blueprint_group import BlueprintGroup
from sanic.blueprints import Blueprint
from sanic.config import BASE_LOGO, Config
from sanic.constants import HTTP_METHODS
from sanic.exceptions import SanicException, ServerError, URLBuildError
from sanic.handlers import ErrorHandler
from sanic.handlers import ErrorHandler, ListenerType, MiddlewareType
from sanic.log import LOGGING_CONFIG_DEFAULTS, error_logger, logger
from sanic.response import HTTPResponse, StreamingHTTPResponse
from sanic.request import Request
from sanic.response import BaseHTTPResponse, HTTPResponse
from sanic.router import Router
from sanic.server import (
AsyncioServer,
@@ -42,16 +45,16 @@ class Sanic:
def __init__(
self,
name=None,
router=None,
error_handler=None,
load_env=True,
request_class=None,
strict_slashes=False,
log_config=None,
configure_logging=True,
register=None,
):
name: str = None,
router: Router = None,
error_handler: ErrorHandler = None,
load_env: bool = True,
request_class: Request = None,
strict_slashes: bool = False,
log_config: Optional[Dict[str, Any]] = None,
configure_logging: bool = True,
register: Optional[bool] = None,
) -> None:
# Get name from previous stack frame
if name is None:
@@ -59,7 +62,6 @@ class Sanic:
"Sanic instance cannot be unnamed. "
"Please use Sanic(name='your_application_name') instead.",
)
# logging
if configure_logging:
logging.config.dictConfig(log_config or LOGGING_CONFIG_DEFAULTS)
@@ -70,22 +72,21 @@ class Sanic:
self.request_class = request_class
self.error_handler = error_handler or ErrorHandler()
self.config = Config(load_env=load_env)
self.request_middleware = deque()
self.response_middleware = deque()
self.blueprints = {}
self._blueprint_order = []
self.request_middleware: Iterable[MiddlewareType] = deque()
self.response_middleware: Iterable[MiddlewareType] = deque()
self.blueprints: Dict[str, Blueprint] = {}
self._blueprint_order: List[Blueprint] = []
self.configure_logging = configure_logging
self.debug = None
self.sock = None
self.strict_slashes = strict_slashes
self.listeners = defaultdict(list)
self.listeners: Dict[str, List[ListenerType]] = defaultdict(list)
self.is_stopping = False
self.is_running = False
self.is_request_stream = False
self.websocket_enabled = False
self.websocket_tasks = set()
self.named_request_middleware = {}
self.named_response_middleware = {}
self.websocket_tasks: Set[Future] = set()
self.named_request_middleware: Dict[str, MiddlewareType] = {}
self.named_response_middleware: Dict[str, MiddlewareType] = {}
# Register alternative method names
self.go_fast = self.run
@@ -162,6 +163,7 @@ class Sanic:
stream=False,
version=None,
name=None,
ignore_body=False,
):
"""Decorate a function to be registered as a route
@@ -180,9 +182,6 @@ class Sanic:
if not uri.startswith("/"):
uri = "/" + uri
if stream:
self.is_request_stream = True
if strict_slashes is None:
strict_slashes = self.strict_slashes
@@ -215,6 +214,7 @@ class Sanic:
strict_slashes=strict_slashes,
version=version,
name=name,
ignore_body=ignore_body,
)
)
return routes, handler
@@ -223,7 +223,13 @@ class Sanic:
# Shorthand method decorators
def get(
self, uri, host=None, strict_slashes=None, version=None, name=None
self,
uri,
host=None,
strict_slashes=None,
version=None,
name=None,
ignore_body=True,
):
"""
Add an API URL under the **GET** *HTTP* method
@@ -243,6 +249,7 @@ class Sanic:
strict_slashes=strict_slashes,
version=version,
name=name,
ignore_body=ignore_body,
)
def post(
@@ -306,7 +313,13 @@ class Sanic:
)
def head(
self, uri, host=None, strict_slashes=None, version=None, name=None
self,
uri,
host=None,
strict_slashes=None,
version=None,
name=None,
ignore_body=True,
):
return self.route(
uri,
@@ -315,10 +328,17 @@ class Sanic:
strict_slashes=strict_slashes,
version=version,
name=name,
ignore_body=ignore_body,
)
def options(
self, uri, host=None, strict_slashes=None, version=None, name=None
self,
uri,
host=None,
strict_slashes=None,
version=None,
name=None,
ignore_body=True,
):
"""
Add an API URL under the **OPTIONS** *HTTP* method
@@ -338,6 +358,7 @@ class Sanic:
strict_slashes=strict_slashes,
version=version,
name=name,
ignore_body=ignore_body,
)
def patch(
@@ -371,7 +392,13 @@ class Sanic:
)
def delete(
self, uri, host=None, strict_slashes=None, version=None, name=None
self,
uri,
host=None,
strict_slashes=None,
version=None,
name=None,
ignore_body=True,
):
"""
Add an API URL under the **DELETE** *HTTP* method
@@ -391,6 +418,7 @@ class Sanic:
strict_slashes=strict_slashes,
version=version,
name=name,
ignore_body=ignore_body,
)
def add_route(
@@ -497,6 +525,7 @@ class Sanic:
websocket_handler.__name__ = (
"websocket_handler_" + handler.__name__
)
websocket_handler.is_websocket = True
routes.extend(
self.router.add(
uri=uri,
@@ -861,7 +890,52 @@ class Sanic:
"""
pass
async def handle_request(self, request, write_callback, stream_callback):
async def handle_exception(self, request, exception):
# -------------------------------------------- #
# Request Middleware
# -------------------------------------------- #
response = await self._run_request_middleware(
request, request_name=None
)
# No middleware results
if not response:
try:
response = self.error_handler.response(request, exception)
if isawaitable(response):
response = await response
except Exception as e:
if isinstance(e, SanicException):
response = self.error_handler.default(request, e)
elif self.debug:
response = HTTPResponse(
(
f"Error while handling error: {e}\n"
f"Stack: {format_exc()}"
),
status=500,
)
else:
response = HTTPResponse(
"An error occurred while handling an error", status=500
)
if response is not None:
try:
response = await request.respond(response)
except BaseException:
# Skip response middleware
request.stream.respond(response)
await response.send(end_stream=True)
raise
else:
response = request.stream.response
if isinstance(response, BaseHTTPResponse):
await response.send(end_stream=True)
else:
raise ServerError(
f"Invalid response type {response!r} (need HTTPResponse)"
)
async def handle_request(self, request):
"""Take a request from the HTTP Server and return a response object
to be sent back The HTTP Server only expects a response object, so
exception handling must be done here
@@ -877,13 +951,27 @@ class Sanic:
# Define `response` var here to remove warnings about
# allocation before assignment below.
response = None
cancelled = False
name = None
try:
# Fetch handler from router
handler, args, kwargs, uri, name, endpoint = self.router.get(
request
)
(
handler,
args,
kwargs,
uri,
name,
endpoint,
ignore_body,
) = self.router.get(request)
request.name = name
if request.stream.request_body and not ignore_body:
if self.router.is_stream_handler(request):
# Streaming handler: lift the size limit
request.stream.request_max_size = float("inf")
else:
# Non-streaming handler: preload body
await request.receive_body()
# -------------------------------------------- #
# Request Middleware
@@ -912,72 +1000,31 @@ class Sanic:
response = handler(request, *args, **kwargs)
if isawaitable(response):
response = await response
if response:
response = await request.respond(response)
else:
response = request.stream.response
# Make sure that response is finished / run StreamingHTTP callback
if isinstance(response, BaseHTTPResponse):
await response.send(end_stream=True)
else:
try:
# Fastest method for checking if the property exists
handler.is_websocket
except AttributeError:
raise ServerError(
f"Invalid response type {response!r} "
"(need HTTPResponse)"
)
except CancelledError:
# If response handler times out, the server handles the error
# and cancels the handle_request job.
# In this case, the transport is already closed and we cannot
# issue a response.
response = None
cancelled = True
raise
except Exception as e:
# -------------------------------------------- #
# Response Generation Failed
# -------------------------------------------- #
try:
response = self.error_handler.response(request, e)
if isawaitable(response):
response = await response
except Exception as e:
if isinstance(e, SanicException):
response = self.error_handler.default(
request=request, exception=e
)
elif self.debug:
response = HTTPResponse(
f"Error while "
f"handling error: {e}\nStack: {format_exc()}",
status=500,
)
else:
response = HTTPResponse(
"An error occurred while handling an error", status=500
)
finally:
# -------------------------------------------- #
# Response Middleware
# -------------------------------------------- #
# Don't run response middleware if response is None
if response is not None:
try:
response = await self._run_response_middleware(
request, response, request_name=name
)
except CancelledError:
# Response middleware can timeout too, as above.
response = None
cancelled = True
except BaseException:
error_logger.exception(
"Exception occurred in one of response "
"middleware handlers"
)
if cancelled:
raise CancelledError()
# pass the response to the correct callback
if write_callback is None or isinstance(
response, StreamingHTTPResponse
):
if stream_callback:
await stream_callback(response)
else:
# Should only end here IF it is an ASGI websocket.
# TODO:
# - Add exception handling
pass
else:
write_callback(response)
await self.handle_exception(request, e)
# -------------------------------------------------------------------- #
# Testing
@@ -1213,7 +1260,12 @@ class Sanic:
request_name, deque()
)
applicable_middleware = self.request_middleware + named_middleware
if applicable_middleware:
# request.request_middleware_started is meant as a stop-gap solution
# until RFC 1630 is adopted
if applicable_middleware and not request.request_middleware_started:
request.request_middleware_started = True
for middleware in applicable_middleware:
response = middleware(request)
if isawaitable(response):
@@ -1236,6 +1288,8 @@ class Sanic:
_response = await _response
if _response:
response = _response
if isinstance(response, BaseHTTPResponse):
response = request.stream.respond(response)
break
return response

View File

@@ -2,27 +2,15 @@ import asyncio
import warnings
from inspect import isawaitable
from typing import (
Any,
Awaitable,
Callable,
Dict,
List,
MutableMapping,
Optional,
Tuple,
Union,
)
from typing import Any, Awaitable, Callable, MutableMapping, Optional, Union
from urllib.parse import quote
import sanic.app # noqa
from sanic.compat import Header
from sanic.exceptions import InvalidUsage, ServerError
from sanic.log import logger
from sanic.request import Request
from sanic.response import HTTPResponse, StreamingHTTPResponse
from sanic.server import ConnInfo, StreamBuffer
from sanic.server import ConnInfo
from sanic.websocket import WebSocketConnection
@@ -84,7 +72,7 @@ class MockTransport:
def get_extra_info(self, info: str) -> Union[str, bool, None]:
if info == "peername":
return self.scope.get("server")
return self.scope.get("client")
elif info == "sslcontext":
return self.scope.get("scheme") in ["https", "wss"]
return None
@@ -151,7 +139,7 @@ class Lifespan:
response = handler(
self.asgi_app.sanic_app, self.asgi_app.sanic_app.loop
)
if isawaitable(response):
if response and isawaitable(response):
await response
async def shutdown(self) -> None:
@@ -171,7 +159,7 @@ class Lifespan:
response = handler(
self.asgi_app.sanic_app, self.asgi_app.sanic_app.loop
)
if isawaitable(response):
if response and isawaitable(response):
await response
async def __call__(
@@ -192,7 +180,6 @@ class ASGIApp:
sanic_app: "sanic.app.Sanic"
request: Request
transport: MockTransport
do_stream: bool
lifespan: Lifespan
ws: Optional[WebSocketConnection]
@@ -215,9 +202,6 @@ class ASGIApp:
for key, value in scope.get("headers", [])
]
)
instance.do_stream = (
True if headers.get("expect") == "100-continue" else False
)
instance.lifespan = Lifespan(instance)
if scope["type"] == "lifespan":
@@ -244,9 +228,7 @@ class ASGIApp:
)
await instance.ws.accept()
else:
pass
# TODO:
# - close connection
raise ServerError("Received unknown ASGI scope")
request_class = sanic_app.request_class or Request
instance.request = request_class(
@@ -257,161 +239,57 @@ class ASGIApp:
instance.transport,
sanic_app,
)
instance.request.stream = instance
instance.request_body = True
instance.request.conn_info = ConnInfo(instance.transport)
if sanic_app.is_request_stream:
is_stream_handler = sanic_app.router.is_stream_handler(
instance.request
)
if is_stream_handler:
instance.request.stream = StreamBuffer(
sanic_app.config.REQUEST_BUFFER_QUEUE_SIZE
)
instance.do_stream = True
return instance
async def read_body(self) -> bytes:
"""
Read and return the entire body from an incoming ASGI message.
"""
body = b""
more_body = True
while more_body:
message = await self.transport.receive()
body += message.get("body", b"")
more_body = message.get("more_body", False)
return body
async def stream_body(self) -> None:
async def read(self) -> Optional[bytes]:
"""
Read and stream the body in chunks from an incoming ASGI message.
"""
more_body = True
message = await self.transport.receive()
if not message.get("more_body", False):
self.request_body = False
return None
return message.get("body", b"")
while more_body:
message = await self.transport.receive()
chunk = message.get("body", b"")
await self.request.stream.put(chunk)
async def __aiter__(self):
while self.request_body:
data = await self.read()
if data:
yield data
more_body = message.get("more_body", False)
def respond(self, response):
response.stream, self.response = self, response
return response
await self.request.stream.put(None)
async def send(self, data, end_stream):
if self.response:
response, self.response = self.response, None
await self.transport.send(
{
"type": "http.response.start",
"status": response.status,
"headers": response.processed_headers,
}
)
response_body = getattr(response, "body", None)
if response_body:
data = response_body + data if data else response_body
await self.transport.send(
{
"type": "http.response.body",
"body": data.encode() if hasattr(data, "encode") else data,
"more_body": not end_stream,
}
)
_asgi_single_callable = True # We conform to ASGI 3.0 single-callable
async def __call__(self) -> None:
"""
Handle the incoming request.
"""
if not self.do_stream:
self.request.body = await self.read_body()
else:
self.sanic_app.loop.create_task(self.stream_body())
handler = self.sanic_app.handle_request
callback = None if self.ws else self.stream_callback
await handler(self.request, None, callback)
_asgi_single_callable = True # We conform to ASGI 3.0 single-callable
async def stream_callback(
self, response: Union[HTTPResponse, StreamingHTTPResponse]
) -> None:
"""
Write the response.
"""
headers: List[Tuple[bytes, bytes]] = []
cookies: Dict[str, str] = {}
content_length: List[str] = []
try:
content_length = response.headers.popall("content-length", [])
cookies = {
v.key: v
for _, v in list(
filter(
lambda item: item[0].lower() == "set-cookie",
response.headers.items(),
)
)
}
headers += [
(str(name).encode("latin-1"), str(value).encode("latin-1"))
for name, value in response.headers.items()
if name.lower() not in ["set-cookie"]
]
except AttributeError:
logger.error(
"Invalid response object for url %s, "
"Expected Type: HTTPResponse, Actual Type: %s",
self.request.url,
type(response),
)
exception = ServerError("Invalid response type")
response = self.sanic_app.error_handler.response(
self.request, exception
)
headers = [
(str(name).encode("latin-1"), str(value).encode("latin-1"))
for name, value in response.headers.items()
if name not in (b"Set-Cookie",)
]
response.asgi = True
is_streaming = isinstance(response, StreamingHTTPResponse)
if is_streaming and getattr(response, "chunked", False):
# disable sanic chunking, this is done at the ASGI-server level
setattr(response, "chunked", False)
# content-length header is removed to signal to the ASGI-server
# to use automatic-chunking if it supports it
elif len(content_length) > 0:
headers += [
(b"content-length", str(content_length[0]).encode("latin-1"))
]
elif not is_streaming:
headers += [
(
b"content-length",
str(len(getattr(response, "body", b""))).encode("latin-1"),
)
]
if "content-type" not in response.headers:
headers += [
(b"content-type", str(response.content_type).encode("latin-1"))
]
if response.cookies:
cookies.update(
{
v.key: v
for _, v in response.cookies.items()
if v.key not in cookies.keys()
}
)
headers += [
(b"set-cookie", cookie.encode("utf-8"))
for k, cookie in cookies.items()
]
await self.transport.send(
{
"type": "http.response.start",
"status": response.status,
"headers": headers,
}
)
if isinstance(response, StreamingHTTPResponse):
response.protocol = self.transport.get_protocol()
await response.stream()
await response.protocol.complete()
else:
await self.transport.send(
{
"type": "http.response.body",
"body": response.body,
"more_body": False,
}
)
await self.sanic_app.handle_request(self.request)

View File

@@ -1,4 +1,5 @@
import asyncio
import os
import signal
from sys import argv
@@ -6,6 +7,9 @@ from sys import argv
from multidict import CIMultiDict # type: ignore
OS_IS_WINDOWS = os.name == "nt"
class Header(CIMultiDict):
def get_all(self, key):
return self.getall(key, default=[])
@@ -14,13 +18,13 @@ class Header(CIMultiDict):
use_trio = argv[0].endswith("hypercorn") and "trio" in argv
if use_trio:
from trio import Path # type: ignore
from trio import open_file as open_async # type: ignore
import trio # type: ignore
def stat_async(path):
return Path(path).stat()
return trio.Path(path).stat()
open_async = trio.open_file
CancelledErrors = tuple([asyncio.CancelledError, trio.Cancelled])
else:
from aiofiles import open as aio_open # type: ignore
from aiofiles.os import stat as stat_async # type: ignore # noqa: F401
@@ -28,6 +32,8 @@ else:
async def open_async(file, mode="r", **kwargs):
return aio_open(file, mode, **kwargs)
CancelledErrors = tuple([asyncio.CancelledError])
def ctrlc_workaround_for_windows(app):
async def stay_active(app):

View File

@@ -23,6 +23,7 @@ BASE_LOGO = """
DEFAULT_CONFIG = {
"REQUEST_MAX_SIZE": 100000000, # 100 megabytes
"REQUEST_BUFFER_QUEUE_SIZE": 100,
"REQUEST_BUFFER_SIZE": 65536, # 64 KiB
"REQUEST_TIMEOUT": 60, # 60 seconds
"RESPONSE_TIMEOUT": 60, # 60 seconds
"KEEP_ALIVE": True,

View File

@@ -206,7 +206,6 @@ class TextRenderer(BaseRenderer):
_, exc_value, __ = sys.exc_info()
exceptions = []
# traceback_html = self.TRACEBACK_BORDER.join(reversed(exceptions))
lines = [
f"{self.exception.__class__.__name__}: {self.exception} while "
f"handling path {self.request.path}",

View File

@@ -1,4 +1,6 @@
from asyncio.events import AbstractEventLoop
from traceback import format_exc
from typing import Any, Callable, Coroutine, Optional, TypeVar, Union
from sanic.errorpages import exception_response
from sanic.exceptions import (
@@ -7,7 +9,23 @@ from sanic.exceptions import (
InvalidRangeType,
)
from sanic.log import logger
from sanic.response import text
from sanic.request import Request
from sanic.response import BaseHTTPResponse, HTTPResponse, text
Sanic = TypeVar("Sanic")
MiddlewareResponse = Union[
Optional[HTTPResponse], Coroutine[Any, Any, Optional[HTTPResponse]]
]
RequestMiddlewareType = Callable[[Request], MiddlewareResponse]
ResponseMiddlewareType = Callable[
[Request, BaseHTTPResponse], MiddlewareResponse
]
MiddlewareType = Union[RequestMiddlewareType, ResponseMiddlewareType]
ListenerType = Callable[
[Sanic, AbstractEventLoop], Optional[Coroutine[Any, Any, None]]
]
class ErrorHandler:

View File

@@ -7,6 +7,7 @@ from sanic.helpers import STATUS_CODES
HeaderIterable = Iterable[Tuple[str, Any]] # Values convertible to str
HeaderBytesIterable = Iterable[Tuple[bytes, bytes]]
Options = Dict[str, Union[int, str]] # key=value fields in various headers
OptionsIterable = Iterable[Tuple[str, str]] # May contain duplicate keys
@@ -175,26 +176,18 @@ def parse_host(host: str) -> Tuple[Optional[str], Optional[int]]:
return host.lower(), int(port) if port is not None else None
def format_http1(headers: HeaderIterable) -> bytes:
"""Convert a headers iterable into HTTP/1 header format.
- Outputs UTF-8 bytes where each header line ends with \\r\\n.
- Values are converted into strings if necessary.
"""
return "".join(f"{name}: {val}\r\n" for name, val in headers).encode()
_HTTP1_STATUSLINES = [
b"HTTP/1.1 %d %b\r\n" % (status, STATUS_CODES.get(status, b"UNKNOWN"))
for status in range(1000)
]
def format_http1_response(
status: int, headers: HeaderIterable, body=b""
) -> bytes:
"""Format a full HTTP/1.1 response.
- If `body` is included, content-length must be specified in headers.
"""
headerbytes = format_http1(headers)
return b"HTTP/1.1 %d %b\r\n%b\r\n%b" % (
status,
STATUS_CODES.get(status, b"UNKNOWN"),
headerbytes,
body,
)
def format_http1_response(status: int, headers: HeaderBytesIterable) -> bytes:
"""Format a HTTP/1.1 response header."""
# Note: benchmarks show that here bytes concat is faster than bytearray,
# b"".join() or %-formatting. %timeit any changes you make.
ret = _HTTP1_STATUSLINES[status]
for h in headers:
ret += b"%b: %b\r\n" % h
ret += b"\r\n"
return ret

475
sanic/http.py Normal file
View File

@@ -0,0 +1,475 @@
from asyncio import CancelledError, sleep
from enum import Enum
from sanic.compat import Header
from sanic.exceptions import (
HeaderExpectationFailed,
InvalidUsage,
PayloadTooLarge,
ServerError,
ServiceUnavailable,
)
from sanic.headers import format_http1_response
from sanic.helpers import has_message_body
from sanic.log import access_logger, logger
class Stage(Enum):
IDLE = 0 # Waiting for request
REQUEST = 1 # Request headers being received
HANDLER = 3 # Headers done, handler running
RESPONSE = 4 # Response headers sent, body in progress
FAILED = 100 # Unrecoverable state (error while sending response)
HTTP_CONTINUE = b"HTTP/1.1 100 Continue\r\n\r\n"
class Http:
__slots__ = [
"_send",
"_receive_more",
"recv_buffer",
"protocol",
"expecting_continue",
"stage",
"keep_alive",
"head_only",
"request",
"exception",
"url",
"request_body",
"request_bytes",
"request_bytes_left",
"request_max_size",
"response",
"response_func",
"response_bytes_left",
"upgrade_websocket",
]
def __init__(self, protocol):
self._send = protocol.send
self._receive_more = protocol.receive_more
self.recv_buffer = protocol.recv_buffer
self.protocol = protocol
self.expecting_continue = False
self.stage = Stage.IDLE
self.request_body = None
self.request_bytes = None
self.request_bytes_left = None
self.request_max_size = protocol.request_max_size
self.keep_alive = True
self.head_only = None
self.request = None
self.response = None
self.exception = None
self.url = None
self.upgrade_websocket = False
def __bool__(self):
"""Test if request handling is in progress"""
return self.stage in (Stage.HANDLER, Stage.RESPONSE)
async def http1(self):
"""HTTP 1.1 connection handler"""
while True: # As long as connection stays keep-alive
try:
# Receive and handle a request
self.stage = Stage.REQUEST
self.response_func = self.http1_response_header
await self.http1_request_header()
self.request.conn_info = self.protocol.conn_info
await self.protocol.request_handler(self.request)
# Handler finished, response should've been sent
if self.stage is Stage.HANDLER and not self.upgrade_websocket:
raise ServerError("Handler produced no response")
if self.stage is Stage.RESPONSE:
await self.response.send(end_stream=True)
except CancelledError:
# Write an appropriate response before exiting
e = self.exception or ServiceUnavailable("Cancelled")
self.exception = None
self.keep_alive = False
await self.error_response(e)
except Exception as e:
# Write an error response
await self.error_response(e)
# Try to consume any remaining request body
if self.request_body:
if self.response and 200 <= self.response.status < 300:
logger.error(f"{self.request} body not consumed.")
try:
async for _ in self:
pass
except PayloadTooLarge:
# We won't read the body and that may cause httpx and
# tests to fail. This little delay allows clients to push
# a small request into network buffers before we close the
# socket, so that they are then able to read the response.
await sleep(0.001)
self.keep_alive = False
# Exit and disconnect if no more requests can be taken
if self.stage is not Stage.IDLE or not self.keep_alive:
break
# Wait for next request
if not self.recv_buffer:
await self._receive_more()
async def http1_request_header(self):
"""Receive and parse request header into self.request."""
HEADER_MAX_SIZE = min(8192, self.request_max_size)
# Receive until full header is in buffer
buf = self.recv_buffer
pos = 0
while True:
pos = buf.find(b"\r\n\r\n", pos)
if pos != -1:
break
pos = max(0, len(buf) - 3)
if pos >= HEADER_MAX_SIZE:
break
await self._receive_more()
if pos >= HEADER_MAX_SIZE:
raise PayloadTooLarge("Request header exceeds the size limit")
# Parse header content
try:
raw_headers = buf[:pos].decode(errors="surrogateescape")
reqline, *raw_headers = raw_headers.split("\r\n")
method, self.url, protocol = reqline.split(" ")
if protocol == "HTTP/1.1":
self.keep_alive = True
elif protocol == "HTTP/1.0":
self.keep_alive = False
else:
raise Exception # Raise a Bad Request on try-except
self.head_only = method.upper() == "HEAD"
request_body = False
headers = []
for name, value in (h.split(":", 1) for h in raw_headers):
name, value = h = name.lower(), value.lstrip()
if name in ("content-length", "transfer-encoding"):
request_body = True
elif name == "connection":
self.keep_alive = value.lower() == "keep-alive"
headers.append(h)
except Exception:
raise InvalidUsage("Bad Request")
headers_instance = Header(headers)
self.upgrade_websocket = headers_instance.get("upgrade") == "websocket"
# Prepare a Request object
request = self.protocol.request_class(
url_bytes=self.url.encode(),
headers=headers_instance,
version=protocol[5:],
method=method,
transport=self.protocol.transport,
app=self.protocol.app,
)
# Prepare for request body
self.request_bytes_left = self.request_bytes = 0
if request_body:
headers = request.headers
expect = headers.get("expect")
if expect is not None:
if expect.lower() == "100-continue":
self.expecting_continue = True
else:
raise HeaderExpectationFailed(f"Unknown Expect: {expect}")
if headers.get("transfer-encoding") == "chunked":
self.request_body = "chunked"
pos -= 2 # One CRLF stays in buffer
else:
self.request_body = True
self.request_bytes_left = self.request_bytes = int(
headers["content-length"]
)
# Remove header and its trailing CRLF
del buf[: pos + 4]
self.stage = Stage.HANDLER
self.request, request.stream = request, self
self.protocol.state["requests_count"] += 1
async def http1_response_header(self, data, end_stream):
res = self.response
# Compatibility with simple response body
if not data and getattr(res, "body", None):
data, end_stream = res.body, True
size = len(data)
headers = res.headers
status = res.status
if not isinstance(status, int) or status < 200:
raise RuntimeError(f"Invalid response status {status!r}")
if not has_message_body(status):
# Header-only response status
self.response_func = None
if (
data
or not end_stream
or "content-length" in headers
or "transfer-encoding" in headers
):
data, size, end_stream = b"", 0, True
headers.pop("content-length", None)
headers.pop("transfer-encoding", None)
logger.warning(
f"Message body set in response on {self.request.path}. "
f"A {status} response may only have headers, no body."
)
elif self.head_only and "content-length" in headers:
self.response_func = None
elif end_stream:
# Non-streaming response (all in one block)
headers["content-length"] = size
self.response_func = None
elif "content-length" in headers:
# Streaming response with size known in advance
self.response_bytes_left = int(headers["content-length"]) - size
self.response_func = self.http1_response_normal
else:
# Length not known, use chunked encoding
headers["transfer-encoding"] = "chunked"
data = b"%x\r\n%b\r\n" % (size, data) if size else None
self.response_func = self.http1_response_chunked
if self.head_only:
# Head request: don't send body
data = b""
self.response_func = self.head_response_ignored
headers["connection"] = "keep-alive" if self.keep_alive else "close"
ret = format_http1_response(status, res.processed_headers)
if data:
ret += data
# Send a 100-continue if expected and not Expectation Failed
if self.expecting_continue:
self.expecting_continue = False
if status != 417:
ret = HTTP_CONTINUE + ret
# Send response
if self.protocol.access_log:
self.log_response()
await self._send(ret)
self.stage = Stage.IDLE if end_stream else Stage.RESPONSE
def head_response_ignored(self, data, end_stream):
"""HEAD response: body data silently ignored."""
if end_stream:
self.response_func = None
self.stage = Stage.IDLE
async def http1_response_chunked(self, data, end_stream):
"""Format a part of response body in chunked encoding."""
# Chunked encoding
size = len(data)
if end_stream:
await self._send(
b"%x\r\n%b\r\n0\r\n\r\n" % (size, data)
if size
else b"0\r\n\r\n"
)
self.response_func = None
self.stage = Stage.IDLE
elif size:
await self._send(b"%x\r\n%b\r\n" % (size, data))
async def http1_response_normal(self, data: bytes, end_stream: bool):
"""Format / keep track of non-chunked response."""
bytes_left = self.response_bytes_left - len(data)
if bytes_left <= 0:
if bytes_left < 0:
raise ServerError("Response was bigger than content-length")
await self._send(data)
self.response_func = None
self.stage = Stage.IDLE
else:
if end_stream:
raise ServerError("Response was smaller than content-length")
await self._send(data)
self.response_bytes_left = bytes_left
async def error_response(self, exception):
# Disconnect after an error if in any other state than handler
if self.stage is not Stage.HANDLER:
self.keep_alive = False
# Request failure? Respond but then disconnect
if self.stage is Stage.REQUEST:
self.stage = Stage.HANDLER
# From request and handler states we can respond, otherwise be silent
if self.stage is Stage.HANDLER:
app = self.protocol.app
if self.request is None:
self.create_empty_request()
await app.handle_exception(self.request, exception)
def create_empty_request(self):
"""Current error handling code needs a request object that won't exist
if an error occurred during before a request was received. Create a
bogus response for error handling use."""
# FIXME: Avoid this by refactoring error handling and response code
self.request = self.protocol.request_class(
url_bytes=self.url.encode() if self.url else b"*",
headers=Header({}),
version="1.1",
method="NONE",
transport=self.protocol.transport,
app=self.protocol.app,
)
self.request.stream = self
def log_response(self):
"""
Helper method provided to enable the logging of responses in case if
the :attr:`HttpProtocol.access_log` is enabled.
:param response: Response generated for the current request
:type response: :class:`sanic.response.HTTPResponse` or
:class:`sanic.response.StreamingHTTPResponse`
:return: None
"""
req, res = self.request, self.response
extra = {
"status": getattr(res, "status", 0),
"byte": getattr(self, "response_bytes_left", -1),
"host": "UNKNOWN",
"request": "nil",
}
if req is not None:
if req.ip:
extra["host"] = f"{req.ip}:{req.port}"
extra["request"] = f"{req.method} {req.url}"
access_logger.info("", extra=extra)
# Request methods
async def __aiter__(self):
"""Async iterate over request body."""
while self.request_body:
data = await self.read()
if data:
yield data
async def read(self):
"""Read some bytes of request body."""
# Send a 100-continue if needed
if self.expecting_continue:
self.expecting_continue = False
await self._send(HTTP_CONTINUE)
# Receive request body chunk
buf = self.recv_buffer
if self.request_bytes_left == 0 and self.request_body == "chunked":
# Process a chunk header: \r\n<size>[;<chunk extensions>]\r\n
while True:
pos = buf.find(b"\r\n", 3)
if pos != -1:
break
if len(buf) > 64:
self.keep_alive = False
raise InvalidUsage("Bad chunked encoding")
await self._receive_more()
try:
size = int(buf[2:pos].split(b";", 1)[0].decode(), 16)
except Exception:
self.keep_alive = False
raise InvalidUsage("Bad chunked encoding")
del buf[: pos + 2]
if size <= 0:
self.request_body = None
if size < 0:
self.keep_alive = False
raise InvalidUsage("Bad chunked encoding")
return None
self.request_bytes_left = size
self.request_bytes += size
# Request size limit
if self.request_bytes > self.request_max_size:
self.keep_alive = False
raise PayloadTooLarge("Request body exceeds the size limit")
# End of request body?
if not self.request_bytes_left:
self.request_body = None
return
# At this point we are good to read/return up to request_bytes_left
if not buf:
await self._receive_more()
data = bytes(buf[: self.request_bytes_left])
size = len(data)
del buf[:size]
self.request_bytes_left -= size
return data
# Response methods
def respond(self, response):
"""Initiate new streaming response.
Nothing is sent until the first send() call on the returned object, and
calling this function multiple times will just alter the response to be
given."""
if self.stage is not Stage.HANDLER:
self.stage = Stage.FAILED
raise RuntimeError("Response already started")
self.response, response.stream = response, self
return response
@property
def send(self):
return self.response_func

View File

@@ -1,4 +1,3 @@
import asyncio
import email.utils
from collections import defaultdict, namedtuple
@@ -8,6 +7,7 @@ from urllib.parse import parse_qs, parse_qsl, unquote, urlunparse
from httptools import parse_url # type: ignore
from sanic.compat import CancelledErrors
from sanic.exceptions import InvalidUsage
from sanic.headers import (
parse_content_header,
@@ -16,6 +16,7 @@ from sanic.headers import (
parse_xforwarded,
)
from sanic.log import error_logger, logger
from sanic.response import BaseHTTPResponse, HTTPResponse
try:
@@ -24,7 +25,6 @@ except ImportError:
from json import loads as json_loads # type: ignore
DEFAULT_HTTP_CONTENT_TYPE = "application/octet-stream"
EXPECT_HEADER = "EXPECT"
# HTTP/1.1: https://www.w3.org/Protocols/rfc2616/rfc2616-sec7.html#sec7.2.1
# > If the media type remains unknown, the recipient SHOULD treat it
@@ -45,35 +45,6 @@ class RequestParameters(dict):
return super().get(name, default)
class StreamBuffer:
def __init__(self, buffer_size=100):
self._queue = asyncio.Queue(buffer_size)
async def read(self):
""" Stop reading when gets None """
payload = await self._queue.get()
self._queue.task_done()
return payload
async def __aiter__(self):
"""Support `async for data in request.stream`"""
while True:
data = await self.read()
if not data:
break
yield data
async def put(self, payload):
await self._queue.put(payload)
def is_full(self):
return self._queue.full()
@property
def buffer_size(self):
return self._queue.maxsize
class Request:
"""Properties of an HTTP request such as URL, headers, etc."""
@@ -92,6 +63,7 @@ class Request:
"endpoint",
"headers",
"method",
"name",
"parsed_args",
"parsed_not_grouped_args",
"parsed_files",
@@ -99,6 +71,7 @@ class Request:
"parsed_json",
"parsed_forwarded",
"raw_url",
"request_middleware_started",
"stream",
"transport",
"uri_template",
@@ -117,9 +90,10 @@ class Request:
self.transport = transport
# Init but do not inhale
self.body_init()
self.body = b""
self.conn_info = None
self.ctx = SimpleNamespace()
self.name = None
self.parsed_forwarded = None
self.parsed_json = None
self.parsed_form = None
@@ -127,6 +101,7 @@ class Request:
self.parsed_args = defaultdict(RequestParameters)
self.parsed_not_grouped_args = defaultdict(list)
self.uri_template = None
self.request_middleware_started = False
self._cookies = None
self.stream = None
self.endpoint = None
@@ -135,36 +110,43 @@ class Request:
class_name = self.__class__.__name__
return f"<{class_name}: {self.method} {self.path}>"
def body_init(self):
""".. deprecated:: 20.3
To be removed in 21.3"""
self.body = []
def body_push(self, data):
""".. deprecated:: 20.3
To be removed in 21.3"""
self.body.append(data)
def body_finish(self):
""".. deprecated:: 20.3
To be removed in 21.3"""
self.body = b"".join(self.body)
async def respond(
self, response=None, *, status=200, headers=None, content_type=None
):
# This logic of determining which response to use is subject to change
if response is None:
response = self.stream.response or HTTPResponse(
status=status,
headers=headers,
content_type=content_type,
)
# Connect the response
if isinstance(response, BaseHTTPResponse):
response = self.stream.respond(response)
# Run response middleware
try:
response = await self.app._run_response_middleware(
self, response, request_name=self.name
)
except CancelledErrors:
raise
except Exception:
error_logger.exception(
"Exception occurred in one of response middleware handlers"
)
return response
async def receive_body(self):
"""Receive request.body, if not already received.
Streaming handlers may call this to receive the full body.
Streaming handlers may call this to receive the full body. Sanic calls
this function before running any handlers of non-streaming routes.
This is added as a compatibility shim in Sanic 20.3 because future
versions of Sanic will make all requests streaming and will use this
function instead of the non-async body_init/push/finish functions.
Please make an issue if your code depends on the old functionality and
cannot be upgraded to the new API.
Custom request classes can override this for custom handling of both
streaming and non-streaming routes.
"""
if not self.stream:
return
self.body = b"".join([data async for data in self.stream])
if not self.body:
self.body = b"".join([data async for data in self.stream])
@property
def json(self):

View File

@@ -2,10 +2,10 @@ from functools import partial
from mimetypes import guess_type
from os import path
from urllib.parse import quote_plus
from warnings import warn
from sanic.compat import Header, open_async
from sanic.cookies import CookieJar
from sanic.headers import format_http1, format_http1_response
from sanic.helpers import has_message_body, remove_entity_headers
@@ -28,50 +28,51 @@ class BaseHTTPResponse:
return b""
return data.encode() if hasattr(data, "encode") else data
def _parse_headers(self):
return format_http1(self.headers.items())
@property
def cookies(self):
if self._cookies is None:
self._cookies = CookieJar(self.headers)
return self._cookies
def get_headers(
self,
version="1.1",
keep_alive=False,
keep_alive_timeout=None,
body=b"",
):
""".. deprecated:: 20.3:
This function is not public API and will be removed in 21.3."""
@property
def processed_headers(self):
"""Obtain a list of header tuples encoded in bytes for sending.
# self.headers get priority over content_type
if self.content_type and "Content-Type" not in self.headers:
self.headers["Content-Type"] = self.content_type
if keep_alive:
self.headers["Connection"] = "keep-alive"
if keep_alive_timeout is not None:
self.headers["Keep-Alive"] = keep_alive_timeout
else:
self.headers["Connection"] = "close"
if self.status in (304, 412):
Add and remove headers based on status and content_type.
"""
# TODO: Make a blacklist set of header names and then filter with that
if self.status in (304, 412): # Not Modified, Precondition Failed
self.headers = remove_entity_headers(self.headers)
if has_message_body(self.status):
self.headers.setdefault("content-type", self.content_type)
# Encode headers into bytes
return (
(name.encode("ascii"), f"{value}".encode(errors="surrogateescape"))
for name, value in self.headers.items()
)
return format_http1_response(self.status, self.headers.items(), body)
async def send(self, data=None, end_stream=None):
"""Send any pending response headers and the given data as body.
:param data: str or bytes to be written
:end_stream: whether to close the stream after this block
"""
if data is None and end_stream is None:
end_stream = True
if end_stream and not data and self.stream.send is None:
return
data = data.encode() if hasattr(data, "encode") else data or b""
await self.stream.send(data, end_stream=end_stream)
class StreamingHTTPResponse(BaseHTTPResponse):
"""Old style streaming response. Use `request.respond()` instead of this in
new code to avoid the callback."""
__slots__ = (
"protocol",
"streaming_fn",
"status",
"content_type",
"headers",
"chunked",
"_cookies",
)
@@ -81,63 +82,34 @@ class StreamingHTTPResponse(BaseHTTPResponse):
status=200,
headers=None,
content_type="text/plain; charset=utf-8",
chunked=True,
chunked="deprecated",
):
if chunked != "deprecated":
warn(
"The chunked argument has been deprecated and will be "
"removed in v21.6"
)
super().__init__()
self.content_type = content_type
self.streaming_fn = streaming_fn
self.status = status
self.headers = Header(headers or {})
self.chunked = chunked
self._cookies = None
self.protocol = None
async def write(self, data):
"""Writes a chunk of data to the streaming response.
:param data: str or bytes-ish data to be written.
"""
data = self._encode_body(data)
await super().send(self._encode_body(data))
# `chunked` will always be False in ASGI-mode, even if the underlying
# ASGI Transport implements Chunked transport. That does it itself.
if self.chunked:
await self.protocol.push_data(b"%x\r\n%b\r\n" % (len(data), data))
else:
await self.protocol.push_data(data)
await self.protocol.drain()
async def stream(
self, version="1.1", keep_alive=False, keep_alive_timeout=None
):
"""Streams headers, runs the `streaming_fn` callback that writes
content to the response body, then finalizes the response body.
"""
if version != "1.1":
self.chunked = False
if not getattr(self, "asgi", False):
headers = self.get_headers(
version,
keep_alive=keep_alive,
keep_alive_timeout=keep_alive_timeout,
)
await self.protocol.push_data(headers)
await self.protocol.drain()
await self.streaming_fn(self)
if self.chunked:
await self.protocol.push_data(b"0\r\n\r\n")
# no need to await drain here after this write, because it is the
# very last thing we write and nothing needs to wait for it.
def get_headers(
self, version="1.1", keep_alive=False, keep_alive_timeout=None
):
if self.chunked and version == "1.1":
self.headers["Transfer-Encoding"] = "chunked"
self.headers.pop("Content-Length", None)
return super().get_headers(version, keep_alive, keep_alive_timeout)
async def send(self, *args, **kwargs):
if self.streaming_fn is not None:
await self.streaming_fn(self)
self.streaming_fn = None
await super().send(*args, **kwargs)
class HTTPResponse(BaseHTTPResponse):
@@ -158,22 +130,6 @@ class HTTPResponse(BaseHTTPResponse):
self.headers = Header(headers or {})
self._cookies = None
def output(self, version="1.1", keep_alive=False, keep_alive_timeout=None):
body = b""
if has_message_body(self.status):
body = self.body
self.headers["Content-Length"] = self.headers.get(
"Content-Length", len(self.body)
)
return self.get_headers(version, keep_alive, keep_alive_timeout, body)
@property
def cookies(self):
if self._cookies is None:
self._cookies = CookieJar(self.headers)
return self._cookies
def empty(status=204, headers=None):
"""
@@ -319,7 +275,7 @@ async def file_stream(
mime_type=None,
headers=None,
filename=None,
chunked=True,
chunked="deprecated",
_range=None,
):
"""Return a streaming response object with file data.
@@ -329,9 +285,15 @@ async def file_stream(
:param mime_type: Specific mime_type.
:param headers: Custom Headers.
:param filename: Override filename.
:param chunked: Enable or disable chunked transfer-encoding
:param chunked: Deprecated
:param _range:
"""
if chunked != "deprecated":
warn(
"The chunked argument has been deprecated and will be "
"removed in v21.6"
)
headers = headers or {}
if filename:
headers.setdefault(
@@ -370,7 +332,6 @@ async def file_stream(
status=status,
headers=headers,
content_type=mime_type,
chunked=chunked,
)
@@ -379,7 +340,7 @@ def stream(
status=200,
headers=None,
content_type="text/plain; charset=utf-8",
chunked=True,
chunked="deprecated",
):
"""Accepts an coroutine `streaming_fn` which can be used to
write chunks to a streaming response. Returns a `StreamingHTTPResponse`.
@@ -398,14 +359,19 @@ def stream(
writes content to that response.
:param mime_type: Specific mime_type.
:param headers: Custom Headers.
:param chunked: Enable or disable chunked transfer-encoding
:param chunked: Deprecated
"""
if chunked != "deprecated":
warn(
"The chunked argument has been deprecated and will be "
"removed in v21.6"
)
return StreamingHTTPResponse(
streaming_fn,
headers=headers,
content_type=content_type,
status=status,
chunked=chunked,
)

View File

@@ -20,6 +20,7 @@ Route = namedtuple(
"name",
"uri",
"endpoint",
"ignore_body",
],
)
Parameter = namedtuple("Parameter", ["name", "cast"])
@@ -135,6 +136,7 @@ class Router:
handler,
host=None,
strict_slashes=False,
ignore_body=False,
version=None,
name=None,
):
@@ -146,6 +148,7 @@ class Router:
:param handler: request handler function.
When executed, it should provide a response object.
:param strict_slashes: strict to trailing slash
:param ignore_body: Handler should not read the body, if any
:param version: current version of the route or blueprint. See
docs for further details.
:return: Nothing
@@ -155,7 +158,9 @@ class Router:
version = re.escape(str(version).strip("/").lstrip("v"))
uri = "/".join([f"/v{version}", uri.lstrip("/")])
# add regular version
routes.append(self._add(uri, methods, handler, host, name))
routes.append(
self._add(uri, methods, handler, host, name, ignore_body)
)
if strict_slashes:
return routes
@@ -187,14 +192,20 @@ class Router:
)
# add version with trailing slash
if slash_is_missing:
routes.append(self._add(uri + "/", methods, handler, host, name))
routes.append(
self._add(uri + "/", methods, handler, host, name, ignore_body)
)
# add version without trailing slash
elif without_slash_is_missing:
routes.append(self._add(uri[:-1], methods, handler, host, name))
routes.append(
self._add(uri[:-1], methods, handler, host, name, ignore_body)
)
return routes
def _add(self, uri, methods, handler, host=None, name=None):
def _add(
self, uri, methods, handler, host=None, name=None, ignore_body=False
):
"""Add a handler to the route list
:param uri: path to match
@@ -326,6 +337,7 @@ class Router:
name=handler_name,
uri=uri,
endpoint=endpoint,
ignore_body=ignore_body,
)
self.routes_all[uri] = route
@@ -465,7 +477,15 @@ class Router:
if hasattr(route_handler, "handlers"):
route_handler = route_handler.handlers[method]
return route_handler, [], kwargs, route.uri, route.name, route.endpoint
return (
route_handler,
[],
kwargs,
route.uri,
route.name,
route.endpoint,
route.ignore_body,
)
def is_stream_handler(self, request):
"""Handler for request is stream or not.

View File

@@ -5,33 +5,22 @@ import secrets
import socket
import stat
import sys
import traceback
from collections import deque
from asyncio import CancelledError
from functools import partial
from inspect import isawaitable
from ipaddress import ip_address
from signal import SIG_IGN, SIGINT, SIGTERM, Signals
from signal import signal as signal_func
from time import time
from time import monotonic as current_time
from typing import Dict, Type, Union
from httptools import HttpRequestParser # type: ignore
from httptools.parser.errors import HttpParserError # type: ignore
from sanic.compat import Header, ctrlc_workaround_for_windows
from sanic.compat import OS_IS_WINDOWS, ctrlc_workaround_for_windows
from sanic.config import Config
from sanic.exceptions import (
HeaderExpectationFailed,
InvalidUsage,
PayloadTooLarge,
RequestTimeout,
ServerError,
ServiceUnavailable,
)
from sanic.log import access_logger, logger
from sanic.request import EXPECT_HEADER, Request, StreamBuffer
from sanic.response import HTTPResponse
from sanic.exceptions import RequestTimeout, ServiceUnavailable
from sanic.http import Http, Stage
from sanic.log import logger
from sanic.request import Request
try:
@@ -42,8 +31,6 @@ try:
except ImportError:
pass
OS_IS_WINDOWS = os.name == "nt"
class Signal:
stopped = False
@@ -99,10 +86,7 @@ class HttpProtocol(asyncio.Protocol):
"signal",
"conn_info",
# request params
"parser",
"request",
"url",
"headers",
# request config
"request_handler",
"request_timeout",
@@ -111,26 +95,21 @@ class HttpProtocol(asyncio.Protocol):
"request_max_size",
"request_buffer_queue_size",
"request_class",
"is_request_stream",
"error_handler",
# enable or disable access log purpose
"access_log",
# connection management
"_total_request_size",
"_request_timeout_handler",
"_response_timeout_handler",
"_keep_alive_timeout_handler",
"_last_request_time",
"_last_response_time",
"_is_stream_handler",
"_not_paused",
"_request_handler_task",
"_request_stream_task",
"_keep_alive",
"_header_fragment",
"state",
"url",
"_handler_task",
"_can_write",
"_data_received",
"_time",
"_task",
"_http",
"_exception",
"recv_buffer",
"_unix",
"_body_chunks",
)
def __init__(
@@ -148,12 +127,10 @@ class HttpProtocol(asyncio.Protocol):
self.loop = loop
deprecated_loop = self.loop if sys.version_info < (3, 7) else None
self.app = app
self.url = None
self.transport = None
self.conn_info = None
self.request = None
self.parser = None
self.url = None
self.headers = None
self.signal = signal
self.access_log = self.app.config.ACCESS_LOG
self.connections = connections if connections is not None else set()
@@ -167,511 +144,99 @@ class HttpProtocol(asyncio.Protocol):
self.keep_alive_timeout = self.app.config.KEEP_ALIVE_TIMEOUT
self.request_max_size = self.app.config.REQUEST_MAX_SIZE
self.request_class = self.app.request_class or Request
self.is_request_stream = self.app.is_request_stream
self._is_stream_handler = False
self._not_paused = asyncio.Event(loop=deprecated_loop)
self._total_request_size = 0
self._request_timeout_handler = None
self._response_timeout_handler = None
self._keep_alive_timeout_handler = None
self._last_request_time = None
self._last_response_time = None
self._request_handler_task = None
self._request_stream_task = None
self._keep_alive = self.app.config.KEEP_ALIVE
self._header_fragment = b""
self.state = state if state else {}
if "requests_count" not in self.state:
self.state["requests_count"] = 0
self._data_received = asyncio.Event(loop=deprecated_loop)
self._can_write = asyncio.Event(loop=deprecated_loop)
self._can_write.set()
self._exception = None
self._unix = unix
self._not_paused.set()
self._body_chunks = deque()
@property
def keep_alive(self):
def _setup_connection(self):
self._http = Http(self)
self._time = current_time()
self.check_timeouts()
async def connection_task(self):
"""Run a HTTP connection.
Timeouts and some additional error handling occur here, while most of
everything else happens in class Http or in code called from there.
"""
Check if the connection needs to be kept alive based on the params
attached to the `_keep_alive` attribute, :attr:`Signal.stopped`
and :func:`HttpProtocol.parser.should_keep_alive`
:return: ``True`` if connection is to be kept alive ``False`` else
"""
return (
self._keep_alive
and not self.signal.stopped
and self.parser.should_keep_alive()
)
# -------------------------------------------- #
# Connection
# -------------------------------------------- #
def connection_made(self, transport):
self.connections.add(self)
self._request_timeout_handler = self.loop.call_later(
self.request_timeout, self.request_timeout_callback
)
self.transport = transport
self.conn_info = ConnInfo(transport, unix=self._unix)
self._last_request_time = time()
def connection_lost(self, exc):
self.connections.discard(self)
if self._request_handler_task:
self._request_handler_task.cancel()
if self._request_stream_task:
self._request_stream_task.cancel()
if self._request_timeout_handler:
self._request_timeout_handler.cancel()
if self._response_timeout_handler:
self._response_timeout_handler.cancel()
if self._keep_alive_timeout_handler:
self._keep_alive_timeout_handler.cancel()
def pause_writing(self):
self._not_paused.clear()
def resume_writing(self):
self._not_paused.set()
def request_timeout_callback(self):
# See the docstring in the RequestTimeout exception, to see
# exactly what this timeout is checking for.
# Check if elapsed time since request initiated exceeds our
# configured maximum request timeout value
time_elapsed = time() - self._last_request_time
if time_elapsed < self.request_timeout:
time_left = self.request_timeout - time_elapsed
self._request_timeout_handler = self.loop.call_later(
time_left, self.request_timeout_callback
)
else:
if self._request_stream_task:
self._request_stream_task.cancel()
if self._request_handler_task:
self._request_handler_task.cancel()
self.write_error(RequestTimeout("Request Timeout"))
def response_timeout_callback(self):
# Check if elapsed time since response was initiated exceeds our
# configured maximum request timeout value
time_elapsed = time() - self._last_request_time
if time_elapsed < self.response_timeout:
time_left = self.response_timeout - time_elapsed
self._response_timeout_handler = self.loop.call_later(
time_left, self.response_timeout_callback
)
else:
if self._request_stream_task:
self._request_stream_task.cancel()
if self._request_handler_task:
self._request_handler_task.cancel()
self.write_error(ServiceUnavailable("Response Timeout"))
def keep_alive_timeout_callback(self):
"""
Check if elapsed time since last response exceeds our configured
maximum keep alive timeout value and if so, close the transport
pipe and let the response writer handle the error.
:return: None
"""
time_elapsed = time() - self._last_response_time
if time_elapsed < self.keep_alive_timeout:
time_left = self.keep_alive_timeout - time_elapsed
self._keep_alive_timeout_handler = self.loop.call_later(
time_left, self.keep_alive_timeout_callback
)
else:
logger.debug("KeepAlive Timeout. Closing connection.")
self.transport.close()
self.transport = None
# -------------------------------------------- #
# Parsing
# -------------------------------------------- #
def data_received(self, data):
# Check for the request itself getting too large and exceeding
# memory limits
self._total_request_size += len(data)
if self._total_request_size > self.request_max_size:
self.write_error(PayloadTooLarge("Payload Too Large"))
# Create parser if this is the first time we're receiving data
if self.parser is None:
assert self.request is None
self.headers = []
self.parser = HttpRequestParser(self)
# requests count
self.state["requests_count"] = self.state["requests_count"] + 1
# Parse request chunk or close connection
try:
self.parser.feed_data(data)
except HttpParserError:
message = "Bad Request"
if self.app.debug:
message += "\n" + traceback.format_exc()
self.write_error(InvalidUsage(message))
def on_url(self, url):
if not self.url:
self.url = url
else:
self.url += url
def on_header(self, name, value):
self._header_fragment += name
if value is not None:
if (
self._header_fragment == b"Content-Length"
and int(value) > self.request_max_size
):
self.write_error(PayloadTooLarge("Payload Too Large"))
try:
value = value.decode()
except UnicodeDecodeError:
value = value.decode("latin_1")
self.headers.append(
(self._header_fragment.decode().casefold(), value)
)
self._header_fragment = b""
def on_headers_complete(self):
self.request = self.request_class(
url_bytes=self.url,
headers=Header(self.headers),
version=self.parser.get_http_version(),
method=self.parser.get_method().decode(),
transport=self.transport,
app=self.app,
)
self.request.conn_info = self.conn_info
# Remove any existing KeepAlive handler here,
# It will be recreated if required on the new request.
if self._keep_alive_timeout_handler:
self._keep_alive_timeout_handler.cancel()
self._keep_alive_timeout_handler = None
if self.request.headers.get(EXPECT_HEADER):
self.expect_handler()
if self.is_request_stream:
self._is_stream_handler = self.app.router.is_stream_handler(
self.request
)
if self._is_stream_handler:
self.request.stream = StreamBuffer(
self.request_buffer_queue_size
)
self.execute_request_handler()
def expect_handler(self):
"""
Handler for Expect Header.
"""
expect = self.request.headers.get(EXPECT_HEADER)
if self.request.version == "1.1":
if expect.lower() == "100-continue":
self.transport.write(b"HTTP/1.1 100 Continue\r\n\r\n")
else:
self.write_error(
HeaderExpectationFailed(f"Unknown Expect: {expect}")
)
def on_body(self, body):
if self.is_request_stream and self._is_stream_handler:
# body chunks can be put into asyncio.Queue out of order if
# multiple tasks put concurrently and the queue is full in python
# 3.7. so we should not create more than one task putting into the
# queue simultaneously.
self._body_chunks.append(body)
if (
not self._request_stream_task
or self._request_stream_task.done()
):
self._request_stream_task = self.loop.create_task(
self.stream_append()
)
else:
self.request.body_push(body)
async def body_append(self, body):
if (
self.request is None
or self._request_stream_task is None
or self._request_stream_task.cancelled()
):
return
if self.request.stream.is_full():
self.transport.pause_reading()
await self.request.stream.put(body)
self.transport.resume_reading()
else:
await self.request.stream.put(body)
async def stream_append(self):
while self._body_chunks:
body = self._body_chunks.popleft()
if self.request:
if self.request.stream.is_full():
self.transport.pause_reading()
await self.request.stream.put(body)
self.transport.resume_reading()
else:
await self.request.stream.put(body)
def on_message_complete(self):
# Entire request (headers and whole body) is received.
# We can cancel and remove the request timeout handler now.
if self._request_timeout_handler:
self._request_timeout_handler.cancel()
self._request_timeout_handler = None
if self.is_request_stream and self._is_stream_handler:
self._body_chunks.append(None)
if (
not self._request_stream_task
or self._request_stream_task.done()
):
self._request_stream_task = self.loop.create_task(
self.stream_append()
)
return
self.request.body_finish()
self.execute_request_handler()
def execute_request_handler(self):
"""
Invoke the request handler defined by the
:func:`sanic.app.Sanic.handle_request` method
:return: None
"""
self._response_timeout_handler = self.loop.call_later(
self.response_timeout, self.response_timeout_callback
)
self._last_request_time = time()
self._request_handler_task = self.loop.create_task(
self.request_handler(
self.request, self.write_response, self.stream_response
)
)
# -------------------------------------------- #
# Responding
# -------------------------------------------- #
def log_response(self, response):
"""
Helper method provided to enable the logging of responses in case if
the :attr:`HttpProtocol.access_log` is enabled.
:param response: Response generated for the current request
:type response: :class:`sanic.response.HTTPResponse` or
:class:`sanic.response.StreamingHTTPResponse`
:return: None
"""
if self.access_log:
extra = {"status": getattr(response, "status", 0)}
if isinstance(response, HTTPResponse):
extra["byte"] = len(response.body)
else:
extra["byte"] = -1
extra["host"] = "UNKNOWN"
if self.request is not None:
if self.request.ip:
extra["host"] = f"{self.request.ip}:{self.request.port}"
extra["request"] = f"{self.request.method} {self.request.url}"
else:
extra["request"] = "nil"
access_logger.info("", extra=extra)
def write_response(self, response):
"""
Writes response content synchronously to the transport.
"""
if self._response_timeout_handler:
self._response_timeout_handler.cancel()
self._response_timeout_handler = None
try:
keep_alive = self.keep_alive
self.transport.write(
response.output(
self.request.version, keep_alive, self.keep_alive_timeout
)
)
self.log_response(response)
except AttributeError:
logger.error(
"Invalid response object for url %s, "
"Expected Type: HTTPResponse, Actual Type: %s",
self.url,
type(response),
)
self.write_error(ServerError("Invalid response type"))
except RuntimeError:
if self.app.debug:
logger.error(
"Connection lost before response written @ %s",
self.request.ip,
)
keep_alive = False
except Exception as e:
self.bail_out(f"Writing response failed, connection closed {e!r}")
self._setup_connection()
await self._http.http1()
except CancelledError:
pass
except Exception:
logger.exception("protocol.connection_task uncaught")
finally:
if not keep_alive:
self.transport.close()
self.transport = None
else:
self._keep_alive_timeout_handler = self.loop.call_later(
self.keep_alive_timeout, self.keep_alive_timeout_callback
if self.app.debug and self._http:
ip = self.transport.get_extra_info("peername")
logger.error(
"Connection lost before response written"
f" @ {ip} {self._http.request}"
)
self._last_response_time = time()
self.cleanup()
self._http = None
self._task = None
try:
self.close()
except BaseException:
logger.exception("Closing failed")
async def drain(self):
await self._not_paused.wait()
async def receive_more(self):
"""Wait until more data is received into self._buffer."""
self.transport.resume_reading()
self._data_received.clear()
await self._data_received.wait()
async def push_data(self, data):
def check_timeouts(self):
"""Runs itself periodically to enforce any expired timeouts."""
try:
if not self._task:
return
duration = current_time() - self._time
stage = self._http.stage
if stage is Stage.IDLE and duration > self.keep_alive_timeout:
logger.debug("KeepAlive Timeout. Closing connection.")
elif stage is Stage.REQUEST and duration > self.request_timeout:
self._http.exception = RequestTimeout("Request Timeout")
elif (
stage in (Stage.HANDLER, Stage.RESPONSE, Stage.FAILED)
and duration > self.response_timeout
):
self._http.exception = ServiceUnavailable("Response Timeout")
else:
interval = (
min(
self.keep_alive_timeout,
self.request_timeout,
self.response_timeout,
)
/ 2
)
self.loop.call_later(max(0.1, interval), self.check_timeouts)
return
self._task.cancel()
except Exception:
logger.exception("protocol.check_timeouts")
async def send(self, data):
"""Writes data with backpressure control."""
await self._can_write.wait()
if self.transport.is_closing():
raise CancelledError
self.transport.write(data)
async def stream_response(self, response):
"""
Streams a response to the client asynchronously. Attaches
the transport to the response so the response consumer can
write to the response as needed.
"""
if self._response_timeout_handler:
self._response_timeout_handler.cancel()
self._response_timeout_handler = None
try:
keep_alive = self.keep_alive
response.protocol = self
await response.stream(
self.request.version, keep_alive, self.keep_alive_timeout
)
self.log_response(response)
except AttributeError:
logger.error(
"Invalid response object for url %s, "
"Expected Type: HTTPResponse, Actual Type: %s",
self.url,
type(response),
)
self.write_error(ServerError("Invalid response type"))
except RuntimeError:
if self.app.debug:
logger.error(
"Connection lost before response written @ %s",
self.request.ip,
)
keep_alive = False
except Exception as e:
self.bail_out(f"Writing response failed, connection closed {e!r}")
finally:
if not keep_alive:
self.transport.close()
self.transport = None
else:
self._keep_alive_timeout_handler = self.loop.call_later(
self.keep_alive_timeout, self.keep_alive_timeout_callback
)
self._last_response_time = time()
self.cleanup()
def write_error(self, exception):
# An error _is_ a response.
# Don't throw a response timeout, when a response _is_ given.
if self._response_timeout_handler:
self._response_timeout_handler.cancel()
self._response_timeout_handler = None
response = None
try:
response = self.error_handler.response(self.request, exception)
version = self.request.version if self.request else "1.1"
self.transport.write(response.output(version))
except RuntimeError:
if self.app.debug:
logger.error(
"Connection lost before error written @ %s",
self.request.ip if self.request else "Unknown",
)
except Exception as e:
self.bail_out(
f"Writing error failed, connection closed {e!r}",
from_error=True,
)
finally:
if self.parser and (
self.keep_alive or getattr(response, "status", 0) == 408
):
self.log_response(response)
try:
self.transport.close()
except AttributeError:
logger.debug("Connection lost before server could close it.")
def bail_out(self, message, from_error=False):
"""
In case if the transport pipes are closed and the sanic app encounters
an error while writing data to the transport pipe, we log the error
with proper details.
:param message: Error message to display
:param from_error: If the bail out was invoked while handling an
exception scenario.
:type message: str
:type from_error: bool
:return: None
"""
if from_error or self.transport is None or self.transport.is_closing():
logger.error(
"Transport closed @ %s and exception "
"experienced during error handling",
(
self.transport.get_extra_info("peername")
if self.transport is not None
else "N/A"
),
)
logger.debug("Exception:", exc_info=True)
else:
self.write_error(ServerError(message))
logger.error(message)
def cleanup(self):
"""This is called when KeepAlive feature is used,
it resets the connection in order for it to be able
to handle receiving another request on the same connection."""
self.parser = None
self.request = None
self.url = None
self.headers = None
self._request_handler_task = None
self._request_stream_task = None
self._total_request_size = 0
self._is_stream_handler = False
self._time = current_time()
def close_if_idle(self):
"""Close the connection if a request is not being sent or received
:return: boolean - True if closed, false if staying open
"""
if not self.parser and self.transport is not None:
self.transport.close()
if self._http is None or self._http.stage is Stage.IDLE:
self.close()
return True
return False
@@ -679,10 +244,57 @@ class HttpProtocol(asyncio.Protocol):
"""
Force close the connection.
"""
if self.transport is not None:
# Cause a call to connection_lost where further cleanup occurs
if self.transport:
self.transport.close()
self.transport = None
# -------------------------------------------- #
# Only asyncio.Protocol callbacks below this
# -------------------------------------------- #
def connection_made(self, transport):
try:
# TODO: Benchmark to find suitable write buffer limits
transport.set_write_buffer_limits(low=16384, high=65536)
self.connections.add(self)
self.transport = transport
self._task = self.loop.create_task(self.connection_task())
self.recv_buffer = bytearray()
self.conn_info = ConnInfo(self.transport, unix=self._unix)
except Exception:
logger.exception("protocol.connect_made")
def connection_lost(self, exc):
try:
self.connections.discard(self)
self.resume_writing()
if self._task:
self._task.cancel()
except Exception:
logger.exception("protocol.connection_lost")
def pause_writing(self):
self._can_write.clear()
def resume_writing(self):
self._can_write.set()
def data_received(self, data):
try:
self._time = current_time()
if not data:
return self.close()
self.recv_buffer += data
if len(self.recv_buffer) > self.app.config.REQUEST_BUFFER_SIZE:
self.transport.pause_reading()
if self._data_received:
self._data_received.set()
except Exception:
logger.exception("protocol.data_received")
def trigger_events(events, loop):
"""Trigger event callbacks (functions or async)

View File

@@ -47,11 +47,21 @@ class SanicTestClient:
async with self.get_new_session() as session:
try:
if method == "request":
args = [url] + list(args)
url = kwargs.pop("http_method", "GET").upper()
response = await getattr(session, method.lower())(
url, *args, **kwargs
)
except NameError:
raise Exception(response.status_code)
except httpx.HTTPError as e:
if hasattr(e, "response"):
response = e.response
else:
logger.error(
f"{method.upper()} {url} received no response!",
exc_info=True,
)
return None
response.body = await response.aread()
response.status = response.status_code
@@ -85,7 +95,6 @@ class SanicTestClient:
):
results = [None, None]
exceptions = []
if gather_request:
def _collect_request(request):
@@ -161,6 +170,9 @@ class SanicTestClient:
except BaseException: # noqa
raise ValueError(f"Request object expected, got ({results})")
def request(self, *args, **kwargs):
return self._sanic_endpoint_test("request", *args, **kwargs)
def get(self, *args, **kwargs):
return self._sanic_endpoint_test("get", *args, **kwargs)