Streaming Server (#1876)
* Streaming request by async for.
* Make all requests streaming and preload body for non-streaming handlers.
* Cleanup of code and avoid mixing streaming responses.
* Async http protocol loop.
* Change of test: don't require early bad request error but only after CRLF-CRLF.
* Add back streaming requests.
* Rewritten request body parser.
* Misc. cleanup, down to 4 failing tests.
* All tests OK.
* Entirely remove request body queue.
* Let black f*ckup the layout
* Better testing error messages on protocol errors.
* Remove StreamBuffer tests because the type is about to be removed.
* Remove tests using the deprecated get_headers function that can no longer be supported. Chunked mode is now autodetected, so do not put content-length header if chunked mode is preferred.
* Major refactoring of HTTP protocol handling (new module http.py added), all requests made streaming. A few compatibility issues and a lot of cleanup to be done remain, 16 tests failing.
* Terminate check_timeouts once connection_task finishes.
* Code cleanup, 14 tests failing.
* Much cleanup, 12 failing...
* Even more cleanup and error checking, 8 failing tests.
* Remove keep-alive header from responses. First of all, it should say timeout=<value> which wasn't the case with existing implementation, and secondly none of the other web servers I tried include this header.
* Everything but CustomServer OK.
* Linter
* Disable custom protocol test
* Remove unnecessary variables, optimise performance.
* A test was missing that body_init/body_push/body_finish are never called. Rewritten using receive_body and case switching to make it fail if bypassed.
* Minor fixes.
* Remove unused code.
* Py 3.8 check for deprecated loop argument.
* Fix a middleware cancellation handling test with py38.
* Linter 'n fixes
* Typing
* Stricter handling of request header size
* More specific error messages on Payload Too Large.
* Init http.response = None
* Messages further tuned.
* Always try to consume request body, plus minor cleanup.
* Add a missing check in case of close_if_idle on a dead connection.
* Avoid error messages on PayloadTooLarge.
* Add test for new API.
* json takes str, not bytes
* Default to no maximum request size for streaming handlers.
* Fix chunked mode crash.
* Header values should be strictly ASCII but both UTF-8 and Latin-1 exist. Use UTF-8B to
cope with all.
* Refactoring and cleanup.
* Unify response header processing of ASGI and asyncio modes.
* Avoid special handling of StreamingHTTPResponse.
* 35 % speedup in HTTP/1.1 response formatting (not so much overall effect).
* Duplicate set-cookie headers were being produced.
* Cleanup processed_headers some more.
* Linting
* Import ordering
* Response middleware ran by async request.respond().
* Need to check if transport is closing to avoid getting stuck in sending loops after peer has disconnected.
* Middleware and error handling refactoring.
* Linter
* Fix tracking of HTTP stage when writing to transport fails.
* Add clarifying comment
* Add a check for request body functions and a test for NotImplementedError.
* Linter and typing
* These must be tuples + hack mypy warnings away.
* New streaming test and minor fixes.
* Constant receive buffer size.
* 256 KiB send and receive buffers.
* Revert "256 KiB send and receive buffers."
This reverts commit abc1e3edb2
.
* app.handle_exception already sends the response.
* Improved handling of errors during request.
* An odd hack to avoid an httpx limitation that causes test failures.
* Limit request header size to 8 KiB at most.
* Remove unnecessary use of format string.
* Cleanup tests
* Remove artifact
* Fix type checking
* Mark test for skipping
* Cleanup some edge cases
* Add ignore_body flag to safe methods
* Add unit tests for timeout logic
* Add unit tests for timeout logic
* Fix Mock usage in timeout test
* Change logging test to only logger in handler
* Windows py3.8 logging issue with current testing client
* Add test_header_size_exceeded
* Resolve merge conflicts
* Add request middleware to hard exception handling
* Add request middleware to hard exception handling
* Request middleware on exception handlers
* Linting
* Cleanup deprecations
Co-authored-by: L. Kärkkäinen <tronic@users.noreply.github.com>
Co-authored-by: Adam Hopkins <admhpkns@gmail.com>
This commit is contained in:
parent
574a9c27a6
commit
7028eae083
|
@ -2,37 +2,38 @@ from sanic import Blueprint, Sanic
|
|||
from sanic.response import file, json
|
||||
|
||||
app = Sanic(__name__)
|
||||
blueprint = Blueprint('name', url_prefix='/my_blueprint')
|
||||
blueprint2 = Blueprint('name2', url_prefix='/my_blueprint2')
|
||||
blueprint3 = Blueprint('name3', url_prefix='/my_blueprint3')
|
||||
blueprint = Blueprint("name", url_prefix="/my_blueprint")
|
||||
blueprint2 = Blueprint("name2", url_prefix="/my_blueprint2")
|
||||
blueprint3 = Blueprint("name3", url_prefix="/my_blueprint3")
|
||||
|
||||
|
||||
@blueprint.route('/foo')
|
||||
@blueprint.route("/foo")
|
||||
async def foo(request):
|
||||
return json({'msg': 'hi from blueprint'})
|
||||
return json({"msg": "hi from blueprint"})
|
||||
|
||||
|
||||
@blueprint2.route('/foo')
|
||||
@blueprint2.route("/foo")
|
||||
async def foo2(request):
|
||||
return json({'msg': 'hi from blueprint2'})
|
||||
return json({"msg": "hi from blueprint2"})
|
||||
|
||||
|
||||
@blueprint3.route('/foo')
|
||||
@blueprint3.route("/foo")
|
||||
async def index(request):
|
||||
return await file('websocket.html')
|
||||
return await file("websocket.html")
|
||||
|
||||
|
||||
@app.websocket('/feed')
|
||||
@app.websocket("/feed")
|
||||
async def foo3(request, ws):
|
||||
while True:
|
||||
data = 'hello!'
|
||||
print('Sending: ' + data)
|
||||
data = "hello!"
|
||||
print("Sending: " + data)
|
||||
await ws.send(data)
|
||||
data = await ws.recv()
|
||||
print('Received: ' + data)
|
||||
print("Received: " + data)
|
||||
|
||||
|
||||
app.blueprint(blueprint)
|
||||
app.blueprint(blueprint2)
|
||||
app.blueprint(blueprint3)
|
||||
|
||||
app.run(host="0.0.0.0", port=8000, debug=True)
|
||||
app.run(host="0.0.0.0", port=9999, debug=True)
|
||||
|
|
|
@ -48,14 +48,14 @@ async def handler_file_stream(request):
|
|||
)
|
||||
|
||||
|
||||
@app.route("/stream", stream=True)
|
||||
@app.post("/stream", stream=True)
|
||||
async def handler_stream(request):
|
||||
while True:
|
||||
body = await request.stream.read()
|
||||
if body is None:
|
||||
break
|
||||
body = body.decode("utf-8").replace("1", "A")
|
||||
# await response.write(body)
|
||||
await response.write(body)
|
||||
return response.stream(body)
|
||||
|
||||
|
||||
|
|
248
sanic/app.py
248
sanic/app.py
|
@ -4,24 +4,27 @@ import os
|
|||
import re
|
||||
|
||||
from asyncio import CancelledError, Protocol, ensure_future, get_event_loop
|
||||
from asyncio.futures import Future
|
||||
from collections import defaultdict, deque
|
||||
from functools import partial
|
||||
from inspect import isawaitable, signature
|
||||
from socket import socket
|
||||
from ssl import Purpose, SSLContext, create_default_context
|
||||
from traceback import format_exc
|
||||
from typing import Any, Dict, Optional, Type, Union
|
||||
from typing import Any, Dict, Iterable, List, Optional, Set, Type, Union
|
||||
from urllib.parse import urlencode, urlunparse
|
||||
|
||||
from sanic import reloader_helpers
|
||||
from sanic.asgi import ASGIApp
|
||||
from sanic.blueprint_group import BlueprintGroup
|
||||
from sanic.blueprints import Blueprint
|
||||
from sanic.config import BASE_LOGO, Config
|
||||
from sanic.constants import HTTP_METHODS
|
||||
from sanic.exceptions import SanicException, ServerError, URLBuildError
|
||||
from sanic.handlers import ErrorHandler
|
||||
from sanic.handlers import ErrorHandler, ListenerType, MiddlewareType
|
||||
from sanic.log import LOGGING_CONFIG_DEFAULTS, error_logger, logger
|
||||
from sanic.response import HTTPResponse, StreamingHTTPResponse
|
||||
from sanic.request import Request
|
||||
from sanic.response import BaseHTTPResponse, HTTPResponse
|
||||
from sanic.router import Router
|
||||
from sanic.server import (
|
||||
AsyncioServer,
|
||||
|
@ -42,16 +45,16 @@ class Sanic:
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
name=None,
|
||||
router=None,
|
||||
error_handler=None,
|
||||
load_env=True,
|
||||
request_class=None,
|
||||
strict_slashes=False,
|
||||
log_config=None,
|
||||
configure_logging=True,
|
||||
register=None,
|
||||
):
|
||||
name: str = None,
|
||||
router: Router = None,
|
||||
error_handler: ErrorHandler = None,
|
||||
load_env: bool = True,
|
||||
request_class: Request = None,
|
||||
strict_slashes: bool = False,
|
||||
log_config: Optional[Dict[str, Any]] = None,
|
||||
configure_logging: bool = True,
|
||||
register: Optional[bool] = None,
|
||||
) -> None:
|
||||
|
||||
# Get name from previous stack frame
|
||||
if name is None:
|
||||
|
@ -59,7 +62,6 @@ class Sanic:
|
|||
"Sanic instance cannot be unnamed. "
|
||||
"Please use Sanic(name='your_application_name') instead.",
|
||||
)
|
||||
|
||||
# logging
|
||||
if configure_logging:
|
||||
logging.config.dictConfig(log_config or LOGGING_CONFIG_DEFAULTS)
|
||||
|
@ -70,22 +72,21 @@ class Sanic:
|
|||
self.request_class = request_class
|
||||
self.error_handler = error_handler or ErrorHandler()
|
||||
self.config = Config(load_env=load_env)
|
||||
self.request_middleware = deque()
|
||||
self.response_middleware = deque()
|
||||
self.blueprints = {}
|
||||
self._blueprint_order = []
|
||||
self.request_middleware: Iterable[MiddlewareType] = deque()
|
||||
self.response_middleware: Iterable[MiddlewareType] = deque()
|
||||
self.blueprints: Dict[str, Blueprint] = {}
|
||||
self._blueprint_order: List[Blueprint] = []
|
||||
self.configure_logging = configure_logging
|
||||
self.debug = None
|
||||
self.sock = None
|
||||
self.strict_slashes = strict_slashes
|
||||
self.listeners = defaultdict(list)
|
||||
self.listeners: Dict[str, List[ListenerType]] = defaultdict(list)
|
||||
self.is_stopping = False
|
||||
self.is_running = False
|
||||
self.is_request_stream = False
|
||||
self.websocket_enabled = False
|
||||
self.websocket_tasks = set()
|
||||
self.named_request_middleware = {}
|
||||
self.named_response_middleware = {}
|
||||
self.websocket_tasks: Set[Future] = set()
|
||||
self.named_request_middleware: Dict[str, MiddlewareType] = {}
|
||||
self.named_response_middleware: Dict[str, MiddlewareType] = {}
|
||||
# Register alternative method names
|
||||
self.go_fast = self.run
|
||||
|
||||
|
@ -162,6 +163,7 @@ class Sanic:
|
|||
stream=False,
|
||||
version=None,
|
||||
name=None,
|
||||
ignore_body=False,
|
||||
):
|
||||
"""Decorate a function to be registered as a route
|
||||
|
||||
|
@ -180,9 +182,6 @@ class Sanic:
|
|||
if not uri.startswith("/"):
|
||||
uri = "/" + uri
|
||||
|
||||
if stream:
|
||||
self.is_request_stream = True
|
||||
|
||||
if strict_slashes is None:
|
||||
strict_slashes = self.strict_slashes
|
||||
|
||||
|
@ -215,6 +214,7 @@ class Sanic:
|
|||
strict_slashes=strict_slashes,
|
||||
version=version,
|
||||
name=name,
|
||||
ignore_body=ignore_body,
|
||||
)
|
||||
)
|
||||
return routes, handler
|
||||
|
@ -223,7 +223,13 @@ class Sanic:
|
|||
|
||||
# Shorthand method decorators
|
||||
def get(
|
||||
self, uri, host=None, strict_slashes=None, version=None, name=None
|
||||
self,
|
||||
uri,
|
||||
host=None,
|
||||
strict_slashes=None,
|
||||
version=None,
|
||||
name=None,
|
||||
ignore_body=True,
|
||||
):
|
||||
"""
|
||||
Add an API URL under the **GET** *HTTP* method
|
||||
|
@ -243,6 +249,7 @@ class Sanic:
|
|||
strict_slashes=strict_slashes,
|
||||
version=version,
|
||||
name=name,
|
||||
ignore_body=ignore_body,
|
||||
)
|
||||
|
||||
def post(
|
||||
|
@ -306,7 +313,13 @@ class Sanic:
|
|||
)
|
||||
|
||||
def head(
|
||||
self, uri, host=None, strict_slashes=None, version=None, name=None
|
||||
self,
|
||||
uri,
|
||||
host=None,
|
||||
strict_slashes=None,
|
||||
version=None,
|
||||
name=None,
|
||||
ignore_body=True,
|
||||
):
|
||||
return self.route(
|
||||
uri,
|
||||
|
@ -315,10 +328,17 @@ class Sanic:
|
|||
strict_slashes=strict_slashes,
|
||||
version=version,
|
||||
name=name,
|
||||
ignore_body=ignore_body,
|
||||
)
|
||||
|
||||
def options(
|
||||
self, uri, host=None, strict_slashes=None, version=None, name=None
|
||||
self,
|
||||
uri,
|
||||
host=None,
|
||||
strict_slashes=None,
|
||||
version=None,
|
||||
name=None,
|
||||
ignore_body=True,
|
||||
):
|
||||
"""
|
||||
Add an API URL under the **OPTIONS** *HTTP* method
|
||||
|
@ -338,6 +358,7 @@ class Sanic:
|
|||
strict_slashes=strict_slashes,
|
||||
version=version,
|
||||
name=name,
|
||||
ignore_body=ignore_body,
|
||||
)
|
||||
|
||||
def patch(
|
||||
|
@ -371,7 +392,13 @@ class Sanic:
|
|||
)
|
||||
|
||||
def delete(
|
||||
self, uri, host=None, strict_slashes=None, version=None, name=None
|
||||
self,
|
||||
uri,
|
||||
host=None,
|
||||
strict_slashes=None,
|
||||
version=None,
|
||||
name=None,
|
||||
ignore_body=True,
|
||||
):
|
||||
"""
|
||||
Add an API URL under the **DELETE** *HTTP* method
|
||||
|
@ -391,6 +418,7 @@ class Sanic:
|
|||
strict_slashes=strict_slashes,
|
||||
version=version,
|
||||
name=name,
|
||||
ignore_body=ignore_body,
|
||||
)
|
||||
|
||||
def add_route(
|
||||
|
@ -497,6 +525,7 @@ class Sanic:
|
|||
websocket_handler.__name__ = (
|
||||
"websocket_handler_" + handler.__name__
|
||||
)
|
||||
websocket_handler.is_websocket = True
|
||||
routes.extend(
|
||||
self.router.add(
|
||||
uri=uri,
|
||||
|
@ -861,7 +890,52 @@ class Sanic:
|
|||
"""
|
||||
pass
|
||||
|
||||
async def handle_request(self, request, write_callback, stream_callback):
|
||||
async def handle_exception(self, request, exception):
|
||||
# -------------------------------------------- #
|
||||
# Request Middleware
|
||||
# -------------------------------------------- #
|
||||
response = await self._run_request_middleware(
|
||||
request, request_name=None
|
||||
)
|
||||
# No middleware results
|
||||
if not response:
|
||||
try:
|
||||
response = self.error_handler.response(request, exception)
|
||||
if isawaitable(response):
|
||||
response = await response
|
||||
except Exception as e:
|
||||
if isinstance(e, SanicException):
|
||||
response = self.error_handler.default(request, e)
|
||||
elif self.debug:
|
||||
response = HTTPResponse(
|
||||
(
|
||||
f"Error while handling error: {e}\n"
|
||||
f"Stack: {format_exc()}"
|
||||
),
|
||||
status=500,
|
||||
)
|
||||
else:
|
||||
response = HTTPResponse(
|
||||
"An error occurred while handling an error", status=500
|
||||
)
|
||||
if response is not None:
|
||||
try:
|
||||
response = await request.respond(response)
|
||||
except BaseException:
|
||||
# Skip response middleware
|
||||
request.stream.respond(response)
|
||||
await response.send(end_stream=True)
|
||||
raise
|
||||
else:
|
||||
response = request.stream.response
|
||||
if isinstance(response, BaseHTTPResponse):
|
||||
await response.send(end_stream=True)
|
||||
else:
|
||||
raise ServerError(
|
||||
f"Invalid response type {response!r} (need HTTPResponse)"
|
||||
)
|
||||
|
||||
async def handle_request(self, request):
|
||||
"""Take a request from the HTTP Server and return a response object
|
||||
to be sent back The HTTP Server only expects a response object, so
|
||||
exception handling must be done here
|
||||
|
@ -877,13 +951,27 @@ class Sanic:
|
|||
# Define `response` var here to remove warnings about
|
||||
# allocation before assignment below.
|
||||
response = None
|
||||
cancelled = False
|
||||
name = None
|
||||
try:
|
||||
# Fetch handler from router
|
||||
handler, args, kwargs, uri, name, endpoint = self.router.get(
|
||||
request
|
||||
)
|
||||
(
|
||||
handler,
|
||||
args,
|
||||
kwargs,
|
||||
uri,
|
||||
name,
|
||||
endpoint,
|
||||
ignore_body,
|
||||
) = self.router.get(request)
|
||||
request.name = name
|
||||
|
||||
if request.stream.request_body and not ignore_body:
|
||||
if self.router.is_stream_handler(request):
|
||||
# Streaming handler: lift the size limit
|
||||
request.stream.request_max_size = float("inf")
|
||||
else:
|
||||
# Non-streaming handler: preload body
|
||||
await request.receive_body()
|
||||
|
||||
# -------------------------------------------- #
|
||||
# Request Middleware
|
||||
|
@ -912,72 +1000,31 @@ class Sanic:
|
|||
response = handler(request, *args, **kwargs)
|
||||
if isawaitable(response):
|
||||
response = await response
|
||||
if response:
|
||||
response = await request.respond(response)
|
||||
else:
|
||||
response = request.stream.response
|
||||
# Make sure that response is finished / run StreamingHTTP callback
|
||||
|
||||
if isinstance(response, BaseHTTPResponse):
|
||||
await response.send(end_stream=True)
|
||||
else:
|
||||
try:
|
||||
# Fastest method for checking if the property exists
|
||||
handler.is_websocket
|
||||
except AttributeError:
|
||||
raise ServerError(
|
||||
f"Invalid response type {response!r} "
|
||||
"(need HTTPResponse)"
|
||||
)
|
||||
|
||||
except CancelledError:
|
||||
# If response handler times out, the server handles the error
|
||||
# and cancels the handle_request job.
|
||||
# In this case, the transport is already closed and we cannot
|
||||
# issue a response.
|
||||
response = None
|
||||
cancelled = True
|
||||
raise
|
||||
except Exception as e:
|
||||
# -------------------------------------------- #
|
||||
# Response Generation Failed
|
||||
# -------------------------------------------- #
|
||||
|
||||
try:
|
||||
response = self.error_handler.response(request, e)
|
||||
if isawaitable(response):
|
||||
response = await response
|
||||
except Exception as e:
|
||||
if isinstance(e, SanicException):
|
||||
response = self.error_handler.default(
|
||||
request=request, exception=e
|
||||
)
|
||||
elif self.debug:
|
||||
response = HTTPResponse(
|
||||
f"Error while "
|
||||
f"handling error: {e}\nStack: {format_exc()}",
|
||||
status=500,
|
||||
)
|
||||
else:
|
||||
response = HTTPResponse(
|
||||
"An error occurred while handling an error", status=500
|
||||
)
|
||||
finally:
|
||||
# -------------------------------------------- #
|
||||
# Response Middleware
|
||||
# -------------------------------------------- #
|
||||
# Don't run response middleware if response is None
|
||||
if response is not None:
|
||||
try:
|
||||
response = await self._run_response_middleware(
|
||||
request, response, request_name=name
|
||||
)
|
||||
except CancelledError:
|
||||
# Response middleware can timeout too, as above.
|
||||
response = None
|
||||
cancelled = True
|
||||
except BaseException:
|
||||
error_logger.exception(
|
||||
"Exception occurred in one of response "
|
||||
"middleware handlers"
|
||||
)
|
||||
if cancelled:
|
||||
raise CancelledError()
|
||||
|
||||
# pass the response to the correct callback
|
||||
if write_callback is None or isinstance(
|
||||
response, StreamingHTTPResponse
|
||||
):
|
||||
if stream_callback:
|
||||
await stream_callback(response)
|
||||
else:
|
||||
# Should only end here IF it is an ASGI websocket.
|
||||
# TODO:
|
||||
# - Add exception handling
|
||||
pass
|
||||
else:
|
||||
write_callback(response)
|
||||
await self.handle_exception(request, e)
|
||||
|
||||
# -------------------------------------------------------------------- #
|
||||
# Testing
|
||||
|
@ -1213,7 +1260,12 @@ class Sanic:
|
|||
request_name, deque()
|
||||
)
|
||||
applicable_middleware = self.request_middleware + named_middleware
|
||||
if applicable_middleware:
|
||||
|
||||
# request.request_middleware_started is meant as a stop-gap solution
|
||||
# until RFC 1630 is adopted
|
||||
if applicable_middleware and not request.request_middleware_started:
|
||||
request.request_middleware_started = True
|
||||
|
||||
for middleware in applicable_middleware:
|
||||
response = middleware(request)
|
||||
if isawaitable(response):
|
||||
|
@ -1236,6 +1288,8 @@ class Sanic:
|
|||
_response = await _response
|
||||
if _response:
|
||||
response = _response
|
||||
if isinstance(response, BaseHTTPResponse):
|
||||
response = request.stream.respond(response)
|
||||
break
|
||||
return response
|
||||
|
||||
|
|
212
sanic/asgi.py
212
sanic/asgi.py
|
@ -2,27 +2,15 @@ import asyncio
|
|||
import warnings
|
||||
|
||||
from inspect import isawaitable
|
||||
from typing import (
|
||||
Any,
|
||||
Awaitable,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
MutableMapping,
|
||||
Optional,
|
||||
Tuple,
|
||||
Union,
|
||||
)
|
||||
from typing import Any, Awaitable, Callable, MutableMapping, Optional, Union
|
||||
from urllib.parse import quote
|
||||
|
||||
import sanic.app # noqa
|
||||
|
||||
from sanic.compat import Header
|
||||
from sanic.exceptions import InvalidUsage, ServerError
|
||||
from sanic.log import logger
|
||||
from sanic.request import Request
|
||||
from sanic.response import HTTPResponse, StreamingHTTPResponse
|
||||
from sanic.server import ConnInfo, StreamBuffer
|
||||
from sanic.server import ConnInfo
|
||||
from sanic.websocket import WebSocketConnection
|
||||
|
||||
|
||||
|
@ -84,7 +72,7 @@ class MockTransport:
|
|||
|
||||
def get_extra_info(self, info: str) -> Union[str, bool, None]:
|
||||
if info == "peername":
|
||||
return self.scope.get("server")
|
||||
return self.scope.get("client")
|
||||
elif info == "sslcontext":
|
||||
return self.scope.get("scheme") in ["https", "wss"]
|
||||
return None
|
||||
|
@ -151,7 +139,7 @@ class Lifespan:
|
|||
response = handler(
|
||||
self.asgi_app.sanic_app, self.asgi_app.sanic_app.loop
|
||||
)
|
||||
if isawaitable(response):
|
||||
if response and isawaitable(response):
|
||||
await response
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
|
@ -171,7 +159,7 @@ class Lifespan:
|
|||
response = handler(
|
||||
self.asgi_app.sanic_app, self.asgi_app.sanic_app.loop
|
||||
)
|
||||
if isawaitable(response):
|
||||
if response and isawaitable(response):
|
||||
await response
|
||||
|
||||
async def __call__(
|
||||
|
@ -192,7 +180,6 @@ class ASGIApp:
|
|||
sanic_app: "sanic.app.Sanic"
|
||||
request: Request
|
||||
transport: MockTransport
|
||||
do_stream: bool
|
||||
lifespan: Lifespan
|
||||
ws: Optional[WebSocketConnection]
|
||||
|
||||
|
@ -215,9 +202,6 @@ class ASGIApp:
|
|||
for key, value in scope.get("headers", [])
|
||||
]
|
||||
)
|
||||
instance.do_stream = (
|
||||
True if headers.get("expect") == "100-continue" else False
|
||||
)
|
||||
instance.lifespan = Lifespan(instance)
|
||||
|
||||
if scope["type"] == "lifespan":
|
||||
|
@ -244,9 +228,7 @@ class ASGIApp:
|
|||
)
|
||||
await instance.ws.accept()
|
||||
else:
|
||||
pass
|
||||
# TODO:
|
||||
# - close connection
|
||||
raise ServerError("Received unknown ASGI scope")
|
||||
|
||||
request_class = sanic_app.request_class or Request
|
||||
instance.request = request_class(
|
||||
|
@ -257,161 +239,57 @@ class ASGIApp:
|
|||
instance.transport,
|
||||
sanic_app,
|
||||
)
|
||||
instance.request.stream = instance
|
||||
instance.request_body = True
|
||||
instance.request.conn_info = ConnInfo(instance.transport)
|
||||
|
||||
if sanic_app.is_request_stream:
|
||||
is_stream_handler = sanic_app.router.is_stream_handler(
|
||||
instance.request
|
||||
)
|
||||
if is_stream_handler:
|
||||
instance.request.stream = StreamBuffer(
|
||||
sanic_app.config.REQUEST_BUFFER_QUEUE_SIZE
|
||||
)
|
||||
instance.do_stream = True
|
||||
|
||||
return instance
|
||||
|
||||
async def read_body(self) -> bytes:
|
||||
"""
|
||||
Read and return the entire body from an incoming ASGI message.
|
||||
"""
|
||||
body = b""
|
||||
more_body = True
|
||||
while more_body:
|
||||
message = await self.transport.receive()
|
||||
body += message.get("body", b"")
|
||||
more_body = message.get("more_body", False)
|
||||
|
||||
return body
|
||||
|
||||
async def stream_body(self) -> None:
|
||||
async def read(self) -> Optional[bytes]:
|
||||
"""
|
||||
Read and stream the body in chunks from an incoming ASGI message.
|
||||
"""
|
||||
more_body = True
|
||||
message = await self.transport.receive()
|
||||
if not message.get("more_body", False):
|
||||
self.request_body = False
|
||||
return None
|
||||
return message.get("body", b"")
|
||||
|
||||
while more_body:
|
||||
message = await self.transport.receive()
|
||||
chunk = message.get("body", b"")
|
||||
await self.request.stream.put(chunk)
|
||||
async def __aiter__(self):
|
||||
while self.request_body:
|
||||
data = await self.read()
|
||||
if data:
|
||||
yield data
|
||||
|
||||
more_body = message.get("more_body", False)
|
||||
def respond(self, response):
|
||||
response.stream, self.response = self, response
|
||||
return response
|
||||
|
||||
await self.request.stream.put(None)
|
||||
async def send(self, data, end_stream):
|
||||
if self.response:
|
||||
response, self.response = self.response, None
|
||||
await self.transport.send(
|
||||
{
|
||||
"type": "http.response.start",
|
||||
"status": response.status,
|
||||
"headers": response.processed_headers,
|
||||
}
|
||||
)
|
||||
response_body = getattr(response, "body", None)
|
||||
if response_body:
|
||||
data = response_body + data if data else response_body
|
||||
await self.transport.send(
|
||||
{
|
||||
"type": "http.response.body",
|
||||
"body": data.encode() if hasattr(data, "encode") else data,
|
||||
"more_body": not end_stream,
|
||||
}
|
||||
)
|
||||
|
||||
_asgi_single_callable = True # We conform to ASGI 3.0 single-callable
|
||||
|
||||
async def __call__(self) -> None:
|
||||
"""
|
||||
Handle the incoming request.
|
||||
"""
|
||||
if not self.do_stream:
|
||||
self.request.body = await self.read_body()
|
||||
else:
|
||||
self.sanic_app.loop.create_task(self.stream_body())
|
||||
|
||||
handler = self.sanic_app.handle_request
|
||||
callback = None if self.ws else self.stream_callback
|
||||
await handler(self.request, None, callback)
|
||||
|
||||
_asgi_single_callable = True # We conform to ASGI 3.0 single-callable
|
||||
|
||||
async def stream_callback(
|
||||
self, response: Union[HTTPResponse, StreamingHTTPResponse]
|
||||
) -> None:
|
||||
"""
|
||||
Write the response.
|
||||
"""
|
||||
headers: List[Tuple[bytes, bytes]] = []
|
||||
cookies: Dict[str, str] = {}
|
||||
content_length: List[str] = []
|
||||
try:
|
||||
content_length = response.headers.popall("content-length", [])
|
||||
cookies = {
|
||||
v.key: v
|
||||
for _, v in list(
|
||||
filter(
|
||||
lambda item: item[0].lower() == "set-cookie",
|
||||
response.headers.items(),
|
||||
)
|
||||
)
|
||||
}
|
||||
headers += [
|
||||
(str(name).encode("latin-1"), str(value).encode("latin-1"))
|
||||
for name, value in response.headers.items()
|
||||
if name.lower() not in ["set-cookie"]
|
||||
]
|
||||
except AttributeError:
|
||||
logger.error(
|
||||
"Invalid response object for url %s, "
|
||||
"Expected Type: HTTPResponse, Actual Type: %s",
|
||||
self.request.url,
|
||||
type(response),
|
||||
)
|
||||
exception = ServerError("Invalid response type")
|
||||
response = self.sanic_app.error_handler.response(
|
||||
self.request, exception
|
||||
)
|
||||
headers = [
|
||||
(str(name).encode("latin-1"), str(value).encode("latin-1"))
|
||||
for name, value in response.headers.items()
|
||||
if name not in (b"Set-Cookie",)
|
||||
]
|
||||
|
||||
response.asgi = True
|
||||
is_streaming = isinstance(response, StreamingHTTPResponse)
|
||||
if is_streaming and getattr(response, "chunked", False):
|
||||
# disable sanic chunking, this is done at the ASGI-server level
|
||||
setattr(response, "chunked", False)
|
||||
# content-length header is removed to signal to the ASGI-server
|
||||
# to use automatic-chunking if it supports it
|
||||
elif len(content_length) > 0:
|
||||
headers += [
|
||||
(b"content-length", str(content_length[0]).encode("latin-1"))
|
||||
]
|
||||
elif not is_streaming:
|
||||
headers += [
|
||||
(
|
||||
b"content-length",
|
||||
str(len(getattr(response, "body", b""))).encode("latin-1"),
|
||||
)
|
||||
]
|
||||
|
||||
if "content-type" not in response.headers:
|
||||
headers += [
|
||||
(b"content-type", str(response.content_type).encode("latin-1"))
|
||||
]
|
||||
|
||||
if response.cookies:
|
||||
cookies.update(
|
||||
{
|
||||
v.key: v
|
||||
for _, v in response.cookies.items()
|
||||
if v.key not in cookies.keys()
|
||||
}
|
||||
)
|
||||
|
||||
headers += [
|
||||
(b"set-cookie", cookie.encode("utf-8"))
|
||||
for k, cookie in cookies.items()
|
||||
]
|
||||
|
||||
await self.transport.send(
|
||||
{
|
||||
"type": "http.response.start",
|
||||
"status": response.status,
|
||||
"headers": headers,
|
||||
}
|
||||
)
|
||||
|
||||
if isinstance(response, StreamingHTTPResponse):
|
||||
response.protocol = self.transport.get_protocol()
|
||||
await response.stream()
|
||||
await response.protocol.complete()
|
||||
|
||||
else:
|
||||
await self.transport.send(
|
||||
{
|
||||
"type": "http.response.body",
|
||||
"body": response.body,
|
||||
"more_body": False,
|
||||
}
|
||||
)
|
||||
await self.sanic_app.handle_request(self.request)
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import asyncio
|
||||
import os
|
||||
import signal
|
||||
|
||||
from sys import argv
|
||||
|
@ -6,6 +7,9 @@ from sys import argv
|
|||
from multidict import CIMultiDict # type: ignore
|
||||
|
||||
|
||||
OS_IS_WINDOWS = os.name == "nt"
|
||||
|
||||
|
||||
class Header(CIMultiDict):
|
||||
def get_all(self, key):
|
||||
return self.getall(key, default=[])
|
||||
|
@ -14,13 +18,13 @@ class Header(CIMultiDict):
|
|||
use_trio = argv[0].endswith("hypercorn") and "trio" in argv
|
||||
|
||||
if use_trio:
|
||||
from trio import Path # type: ignore
|
||||
from trio import open_file as open_async # type: ignore
|
||||
import trio # type: ignore
|
||||
|
||||
def stat_async(path):
|
||||
return Path(path).stat()
|
||||
|
||||
return trio.Path(path).stat()
|
||||
|
||||
open_async = trio.open_file
|
||||
CancelledErrors = tuple([asyncio.CancelledError, trio.Cancelled])
|
||||
else:
|
||||
from aiofiles import open as aio_open # type: ignore
|
||||
from aiofiles.os import stat as stat_async # type: ignore # noqa: F401
|
||||
|
@ -28,6 +32,8 @@ else:
|
|||
async def open_async(file, mode="r", **kwargs):
|
||||
return aio_open(file, mode, **kwargs)
|
||||
|
||||
CancelledErrors = tuple([asyncio.CancelledError])
|
||||
|
||||
|
||||
def ctrlc_workaround_for_windows(app):
|
||||
async def stay_active(app):
|
||||
|
|
|
@ -23,6 +23,7 @@ BASE_LOGO = """
|
|||
DEFAULT_CONFIG = {
|
||||
"REQUEST_MAX_SIZE": 100000000, # 100 megabytes
|
||||
"REQUEST_BUFFER_QUEUE_SIZE": 100,
|
||||
"REQUEST_BUFFER_SIZE": 65536, # 64 KiB
|
||||
"REQUEST_TIMEOUT": 60, # 60 seconds
|
||||
"RESPONSE_TIMEOUT": 60, # 60 seconds
|
||||
"KEEP_ALIVE": True,
|
||||
|
|
|
@ -206,7 +206,6 @@ class TextRenderer(BaseRenderer):
|
|||
_, exc_value, __ = sys.exc_info()
|
||||
exceptions = []
|
||||
|
||||
# traceback_html = self.TRACEBACK_BORDER.join(reversed(exceptions))
|
||||
lines = [
|
||||
f"{self.exception.__class__.__name__}: {self.exception} while "
|
||||
f"handling path {self.request.path}",
|
||||
|
|
|
@ -1,4 +1,6 @@
|
|||
from asyncio.events import AbstractEventLoop
|
||||
from traceback import format_exc
|
||||
from typing import Any, Callable, Coroutine, Optional, TypeVar, Union
|
||||
|
||||
from sanic.errorpages import exception_response
|
||||
from sanic.exceptions import (
|
||||
|
@ -7,7 +9,23 @@ from sanic.exceptions import (
|
|||
InvalidRangeType,
|
||||
)
|
||||
from sanic.log import logger
|
||||
from sanic.response import text
|
||||
from sanic.request import Request
|
||||
from sanic.response import BaseHTTPResponse, HTTPResponse, text
|
||||
|
||||
|
||||
Sanic = TypeVar("Sanic")
|
||||
|
||||
MiddlewareResponse = Union[
|
||||
Optional[HTTPResponse], Coroutine[Any, Any, Optional[HTTPResponse]]
|
||||
]
|
||||
RequestMiddlewareType = Callable[[Request], MiddlewareResponse]
|
||||
ResponseMiddlewareType = Callable[
|
||||
[Request, BaseHTTPResponse], MiddlewareResponse
|
||||
]
|
||||
MiddlewareType = Union[RequestMiddlewareType, ResponseMiddlewareType]
|
||||
ListenerType = Callable[
|
||||
[Sanic, AbstractEventLoop], Optional[Coroutine[Any, Any, None]]
|
||||
]
|
||||
|
||||
|
||||
class ErrorHandler:
|
||||
|
|
|
@ -7,6 +7,7 @@ from sanic.helpers import STATUS_CODES
|
|||
|
||||
|
||||
HeaderIterable = Iterable[Tuple[str, Any]] # Values convertible to str
|
||||
HeaderBytesIterable = Iterable[Tuple[bytes, bytes]]
|
||||
Options = Dict[str, Union[int, str]] # key=value fields in various headers
|
||||
OptionsIterable = Iterable[Tuple[str, str]] # May contain duplicate keys
|
||||
|
||||
|
@ -175,26 +176,18 @@ def parse_host(host: str) -> Tuple[Optional[str], Optional[int]]:
|
|||
return host.lower(), int(port) if port is not None else None
|
||||
|
||||
|
||||
def format_http1(headers: HeaderIterable) -> bytes:
|
||||
"""Convert a headers iterable into HTTP/1 header format.
|
||||
|
||||
- Outputs UTF-8 bytes where each header line ends with \\r\\n.
|
||||
- Values are converted into strings if necessary.
|
||||
"""
|
||||
return "".join(f"{name}: {val}\r\n" for name, val in headers).encode()
|
||||
_HTTP1_STATUSLINES = [
|
||||
b"HTTP/1.1 %d %b\r\n" % (status, STATUS_CODES.get(status, b"UNKNOWN"))
|
||||
for status in range(1000)
|
||||
]
|
||||
|
||||
|
||||
def format_http1_response(
|
||||
status: int, headers: HeaderIterable, body=b""
|
||||
) -> bytes:
|
||||
"""Format a full HTTP/1.1 response.
|
||||
|
||||
- If `body` is included, content-length must be specified in headers.
|
||||
"""
|
||||
headerbytes = format_http1(headers)
|
||||
return b"HTTP/1.1 %d %b\r\n%b\r\n%b" % (
|
||||
status,
|
||||
STATUS_CODES.get(status, b"UNKNOWN"),
|
||||
headerbytes,
|
||||
body,
|
||||
)
|
||||
def format_http1_response(status: int, headers: HeaderBytesIterable) -> bytes:
|
||||
"""Format a HTTP/1.1 response header."""
|
||||
# Note: benchmarks show that here bytes concat is faster than bytearray,
|
||||
# b"".join() or %-formatting. %timeit any changes you make.
|
||||
ret = _HTTP1_STATUSLINES[status]
|
||||
for h in headers:
|
||||
ret += b"%b: %b\r\n" % h
|
||||
ret += b"\r\n"
|
||||
return ret
|
||||
|
|
475
sanic/http.py
Normal file
475
sanic/http.py
Normal file
|
@ -0,0 +1,475 @@
|
|||
from asyncio import CancelledError, sleep
|
||||
from enum import Enum
|
||||
|
||||
from sanic.compat import Header
|
||||
from sanic.exceptions import (
|
||||
HeaderExpectationFailed,
|
||||
InvalidUsage,
|
||||
PayloadTooLarge,
|
||||
ServerError,
|
||||
ServiceUnavailable,
|
||||
)
|
||||
from sanic.headers import format_http1_response
|
||||
from sanic.helpers import has_message_body
|
||||
from sanic.log import access_logger, logger
|
||||
|
||||
|
||||
class Stage(Enum):
|
||||
IDLE = 0 # Waiting for request
|
||||
REQUEST = 1 # Request headers being received
|
||||
HANDLER = 3 # Headers done, handler running
|
||||
RESPONSE = 4 # Response headers sent, body in progress
|
||||
FAILED = 100 # Unrecoverable state (error while sending response)
|
||||
|
||||
|
||||
HTTP_CONTINUE = b"HTTP/1.1 100 Continue\r\n\r\n"
|
||||
|
||||
|
||||
class Http:
|
||||
__slots__ = [
|
||||
"_send",
|
||||
"_receive_more",
|
||||
"recv_buffer",
|
||||
"protocol",
|
||||
"expecting_continue",
|
||||
"stage",
|
||||
"keep_alive",
|
||||
"head_only",
|
||||
"request",
|
||||
"exception",
|
||||
"url",
|
||||
"request_body",
|
||||
"request_bytes",
|
||||
"request_bytes_left",
|
||||
"request_max_size",
|
||||
"response",
|
||||
"response_func",
|
||||
"response_bytes_left",
|
||||
"upgrade_websocket",
|
||||
]
|
||||
|
||||
def __init__(self, protocol):
|
||||
self._send = protocol.send
|
||||
self._receive_more = protocol.receive_more
|
||||
self.recv_buffer = protocol.recv_buffer
|
||||
self.protocol = protocol
|
||||
self.expecting_continue = False
|
||||
self.stage = Stage.IDLE
|
||||
self.request_body = None
|
||||
self.request_bytes = None
|
||||
self.request_bytes_left = None
|
||||
self.request_max_size = protocol.request_max_size
|
||||
self.keep_alive = True
|
||||
self.head_only = None
|
||||
self.request = None
|
||||
self.response = None
|
||||
self.exception = None
|
||||
self.url = None
|
||||
self.upgrade_websocket = False
|
||||
|
||||
def __bool__(self):
|
||||
"""Test if request handling is in progress"""
|
||||
return self.stage in (Stage.HANDLER, Stage.RESPONSE)
|
||||
|
||||
async def http1(self):
|
||||
"""HTTP 1.1 connection handler"""
|
||||
while True: # As long as connection stays keep-alive
|
||||
try:
|
||||
# Receive and handle a request
|
||||
self.stage = Stage.REQUEST
|
||||
self.response_func = self.http1_response_header
|
||||
|
||||
await self.http1_request_header()
|
||||
|
||||
self.request.conn_info = self.protocol.conn_info
|
||||
await self.protocol.request_handler(self.request)
|
||||
|
||||
# Handler finished, response should've been sent
|
||||
if self.stage is Stage.HANDLER and not self.upgrade_websocket:
|
||||
raise ServerError("Handler produced no response")
|
||||
|
||||
if self.stage is Stage.RESPONSE:
|
||||
await self.response.send(end_stream=True)
|
||||
except CancelledError:
|
||||
# Write an appropriate response before exiting
|
||||
e = self.exception or ServiceUnavailable("Cancelled")
|
||||
self.exception = None
|
||||
self.keep_alive = False
|
||||
await self.error_response(e)
|
||||
except Exception as e:
|
||||
# Write an error response
|
||||
await self.error_response(e)
|
||||
|
||||
# Try to consume any remaining request body
|
||||
if self.request_body:
|
||||
if self.response and 200 <= self.response.status < 300:
|
||||
logger.error(f"{self.request} body not consumed.")
|
||||
|
||||
try:
|
||||
async for _ in self:
|
||||
pass
|
||||
except PayloadTooLarge:
|
||||
# We won't read the body and that may cause httpx and
|
||||
# tests to fail. This little delay allows clients to push
|
||||
# a small request into network buffers before we close the
|
||||
# socket, so that they are then able to read the response.
|
||||
await sleep(0.001)
|
||||
self.keep_alive = False
|
||||
|
||||
# Exit and disconnect if no more requests can be taken
|
||||
if self.stage is not Stage.IDLE or not self.keep_alive:
|
||||
break
|
||||
|
||||
# Wait for next request
|
||||
if not self.recv_buffer:
|
||||
await self._receive_more()
|
||||
|
||||
async def http1_request_header(self):
|
||||
"""Receive and parse request header into self.request."""
|
||||
HEADER_MAX_SIZE = min(8192, self.request_max_size)
|
||||
# Receive until full header is in buffer
|
||||
buf = self.recv_buffer
|
||||
pos = 0
|
||||
|
||||
while True:
|
||||
pos = buf.find(b"\r\n\r\n", pos)
|
||||
if pos != -1:
|
||||
break
|
||||
|
||||
pos = max(0, len(buf) - 3)
|
||||
if pos >= HEADER_MAX_SIZE:
|
||||
break
|
||||
|
||||
await self._receive_more()
|
||||
|
||||
if pos >= HEADER_MAX_SIZE:
|
||||
raise PayloadTooLarge("Request header exceeds the size limit")
|
||||
|
||||
# Parse header content
|
||||
try:
|
||||
raw_headers = buf[:pos].decode(errors="surrogateescape")
|
||||
reqline, *raw_headers = raw_headers.split("\r\n")
|
||||
method, self.url, protocol = reqline.split(" ")
|
||||
|
||||
if protocol == "HTTP/1.1":
|
||||
self.keep_alive = True
|
||||
elif protocol == "HTTP/1.0":
|
||||
self.keep_alive = False
|
||||
else:
|
||||
raise Exception # Raise a Bad Request on try-except
|
||||
|
||||
self.head_only = method.upper() == "HEAD"
|
||||
request_body = False
|
||||
headers = []
|
||||
|
||||
for name, value in (h.split(":", 1) for h in raw_headers):
|
||||
name, value = h = name.lower(), value.lstrip()
|
||||
|
||||
if name in ("content-length", "transfer-encoding"):
|
||||
request_body = True
|
||||
elif name == "connection":
|
||||
self.keep_alive = value.lower() == "keep-alive"
|
||||
|
||||
headers.append(h)
|
||||
except Exception:
|
||||
raise InvalidUsage("Bad Request")
|
||||
|
||||
headers_instance = Header(headers)
|
||||
self.upgrade_websocket = headers_instance.get("upgrade") == "websocket"
|
||||
|
||||
# Prepare a Request object
|
||||
request = self.protocol.request_class(
|
||||
url_bytes=self.url.encode(),
|
||||
headers=headers_instance,
|
||||
version=protocol[5:],
|
||||
method=method,
|
||||
transport=self.protocol.transport,
|
||||
app=self.protocol.app,
|
||||
)
|
||||
|
||||
# Prepare for request body
|
||||
self.request_bytes_left = self.request_bytes = 0
|
||||
if request_body:
|
||||
headers = request.headers
|
||||
expect = headers.get("expect")
|
||||
|
||||
if expect is not None:
|
||||
if expect.lower() == "100-continue":
|
||||
self.expecting_continue = True
|
||||
else:
|
||||
raise HeaderExpectationFailed(f"Unknown Expect: {expect}")
|
||||
|
||||
if headers.get("transfer-encoding") == "chunked":
|
||||
self.request_body = "chunked"
|
||||
pos -= 2 # One CRLF stays in buffer
|
||||
else:
|
||||
self.request_body = True
|
||||
self.request_bytes_left = self.request_bytes = int(
|
||||
headers["content-length"]
|
||||
)
|
||||
|
||||
# Remove header and its trailing CRLF
|
||||
del buf[: pos + 4]
|
||||
self.stage = Stage.HANDLER
|
||||
self.request, request.stream = request, self
|
||||
self.protocol.state["requests_count"] += 1
|
||||
|
||||
async def http1_response_header(self, data, end_stream):
|
||||
res = self.response
|
||||
|
||||
# Compatibility with simple response body
|
||||
if not data and getattr(res, "body", None):
|
||||
data, end_stream = res.body, True
|
||||
|
||||
size = len(data)
|
||||
headers = res.headers
|
||||
status = res.status
|
||||
|
||||
if not isinstance(status, int) or status < 200:
|
||||
raise RuntimeError(f"Invalid response status {status!r}")
|
||||
|
||||
if not has_message_body(status):
|
||||
# Header-only response status
|
||||
self.response_func = None
|
||||
if (
|
||||
data
|
||||
or not end_stream
|
||||
or "content-length" in headers
|
||||
or "transfer-encoding" in headers
|
||||
):
|
||||
data, size, end_stream = b"", 0, True
|
||||
headers.pop("content-length", None)
|
||||
headers.pop("transfer-encoding", None)
|
||||
logger.warning(
|
||||
f"Message body set in response on {self.request.path}. "
|
||||
f"A {status} response may only have headers, no body."
|
||||
)
|
||||
elif self.head_only and "content-length" in headers:
|
||||
self.response_func = None
|
||||
elif end_stream:
|
||||
# Non-streaming response (all in one block)
|
||||
headers["content-length"] = size
|
||||
self.response_func = None
|
||||
elif "content-length" in headers:
|
||||
# Streaming response with size known in advance
|
||||
self.response_bytes_left = int(headers["content-length"]) - size
|
||||
self.response_func = self.http1_response_normal
|
||||
else:
|
||||
# Length not known, use chunked encoding
|
||||
headers["transfer-encoding"] = "chunked"
|
||||
data = b"%x\r\n%b\r\n" % (size, data) if size else None
|
||||
self.response_func = self.http1_response_chunked
|
||||
|
||||
if self.head_only:
|
||||
# Head request: don't send body
|
||||
data = b""
|
||||
self.response_func = self.head_response_ignored
|
||||
|
||||
headers["connection"] = "keep-alive" if self.keep_alive else "close"
|
||||
ret = format_http1_response(status, res.processed_headers)
|
||||
if data:
|
||||
ret += data
|
||||
|
||||
# Send a 100-continue if expected and not Expectation Failed
|
||||
if self.expecting_continue:
|
||||
self.expecting_continue = False
|
||||
if status != 417:
|
||||
ret = HTTP_CONTINUE + ret
|
||||
|
||||
# Send response
|
||||
if self.protocol.access_log:
|
||||
self.log_response()
|
||||
|
||||
await self._send(ret)
|
||||
self.stage = Stage.IDLE if end_stream else Stage.RESPONSE
|
||||
|
||||
def head_response_ignored(self, data, end_stream):
|
||||
"""HEAD response: body data silently ignored."""
|
||||
if end_stream:
|
||||
self.response_func = None
|
||||
self.stage = Stage.IDLE
|
||||
|
||||
async def http1_response_chunked(self, data, end_stream):
|
||||
"""Format a part of response body in chunked encoding."""
|
||||
# Chunked encoding
|
||||
size = len(data)
|
||||
if end_stream:
|
||||
await self._send(
|
||||
b"%x\r\n%b\r\n0\r\n\r\n" % (size, data)
|
||||
if size
|
||||
else b"0\r\n\r\n"
|
||||
)
|
||||
self.response_func = None
|
||||
self.stage = Stage.IDLE
|
||||
elif size:
|
||||
await self._send(b"%x\r\n%b\r\n" % (size, data))
|
||||
|
||||
async def http1_response_normal(self, data: bytes, end_stream: bool):
|
||||
"""Format / keep track of non-chunked response."""
|
||||
bytes_left = self.response_bytes_left - len(data)
|
||||
if bytes_left <= 0:
|
||||
if bytes_left < 0:
|
||||
raise ServerError("Response was bigger than content-length")
|
||||
|
||||
await self._send(data)
|
||||
self.response_func = None
|
||||
self.stage = Stage.IDLE
|
||||
else:
|
||||
if end_stream:
|
||||
raise ServerError("Response was smaller than content-length")
|
||||
|
||||
await self._send(data)
|
||||
self.response_bytes_left = bytes_left
|
||||
|
||||
async def error_response(self, exception):
|
||||
# Disconnect after an error if in any other state than handler
|
||||
if self.stage is not Stage.HANDLER:
|
||||
self.keep_alive = False
|
||||
|
||||
# Request failure? Respond but then disconnect
|
||||
if self.stage is Stage.REQUEST:
|
||||
self.stage = Stage.HANDLER
|
||||
|
||||
# From request and handler states we can respond, otherwise be silent
|
||||
if self.stage is Stage.HANDLER:
|
||||
app = self.protocol.app
|
||||
|
||||
if self.request is None:
|
||||
self.create_empty_request()
|
||||
|
||||
await app.handle_exception(self.request, exception)
|
||||
|
||||
def create_empty_request(self):
|
||||
"""Current error handling code needs a request object that won't exist
|
||||
if an error occurred during before a request was received. Create a
|
||||
bogus response for error handling use."""
|
||||
# FIXME: Avoid this by refactoring error handling and response code
|
||||
self.request = self.protocol.request_class(
|
||||
url_bytes=self.url.encode() if self.url else b"*",
|
||||
headers=Header({}),
|
||||
version="1.1",
|
||||
method="NONE",
|
||||
transport=self.protocol.transport,
|
||||
app=self.protocol.app,
|
||||
)
|
||||
self.request.stream = self
|
||||
|
||||
def log_response(self):
|
||||
"""
|
||||
Helper method provided to enable the logging of responses in case if
|
||||
the :attr:`HttpProtocol.access_log` is enabled.
|
||||
|
||||
:param response: Response generated for the current request
|
||||
|
||||
:type response: :class:`sanic.response.HTTPResponse` or
|
||||
:class:`sanic.response.StreamingHTTPResponse`
|
||||
|
||||
:return: None
|
||||
"""
|
||||
req, res = self.request, self.response
|
||||
extra = {
|
||||
"status": getattr(res, "status", 0),
|
||||
"byte": getattr(self, "response_bytes_left", -1),
|
||||
"host": "UNKNOWN",
|
||||
"request": "nil",
|
||||
}
|
||||
if req is not None:
|
||||
if req.ip:
|
||||
extra["host"] = f"{req.ip}:{req.port}"
|
||||
extra["request"] = f"{req.method} {req.url}"
|
||||
access_logger.info("", extra=extra)
|
||||
|
||||
# Request methods
|
||||
|
||||
async def __aiter__(self):
|
||||
"""Async iterate over request body."""
|
||||
while self.request_body:
|
||||
data = await self.read()
|
||||
|
||||
if data:
|
||||
yield data
|
||||
|
||||
async def read(self):
|
||||
"""Read some bytes of request body."""
|
||||
# Send a 100-continue if needed
|
||||
if self.expecting_continue:
|
||||
self.expecting_continue = False
|
||||
await self._send(HTTP_CONTINUE)
|
||||
|
||||
# Receive request body chunk
|
||||
buf = self.recv_buffer
|
||||
if self.request_bytes_left == 0 and self.request_body == "chunked":
|
||||
# Process a chunk header: \r\n<size>[;<chunk extensions>]\r\n
|
||||
while True:
|
||||
pos = buf.find(b"\r\n", 3)
|
||||
|
||||
if pos != -1:
|
||||
break
|
||||
|
||||
if len(buf) > 64:
|
||||
self.keep_alive = False
|
||||
raise InvalidUsage("Bad chunked encoding")
|
||||
|
||||
await self._receive_more()
|
||||
|
||||
try:
|
||||
size = int(buf[2:pos].split(b";", 1)[0].decode(), 16)
|
||||
except Exception:
|
||||
self.keep_alive = False
|
||||
raise InvalidUsage("Bad chunked encoding")
|
||||
|
||||
del buf[: pos + 2]
|
||||
|
||||
if size <= 0:
|
||||
self.request_body = None
|
||||
|
||||
if size < 0:
|
||||
self.keep_alive = False
|
||||
raise InvalidUsage("Bad chunked encoding")
|
||||
|
||||
return None
|
||||
|
||||
self.request_bytes_left = size
|
||||
self.request_bytes += size
|
||||
|
||||
# Request size limit
|
||||
if self.request_bytes > self.request_max_size:
|
||||
self.keep_alive = False
|
||||
raise PayloadTooLarge("Request body exceeds the size limit")
|
||||
|
||||
# End of request body?
|
||||
if not self.request_bytes_left:
|
||||
self.request_body = None
|
||||
return
|
||||
|
||||
# At this point we are good to read/return up to request_bytes_left
|
||||
if not buf:
|
||||
await self._receive_more()
|
||||
|
||||
data = bytes(buf[: self.request_bytes_left])
|
||||
size = len(data)
|
||||
|
||||
del buf[:size]
|
||||
|
||||
self.request_bytes_left -= size
|
||||
|
||||
return data
|
||||
|
||||
# Response methods
|
||||
|
||||
def respond(self, response):
|
||||
"""Initiate new streaming response.
|
||||
|
||||
Nothing is sent until the first send() call on the returned object, and
|
||||
calling this function multiple times will just alter the response to be
|
||||
given."""
|
||||
if self.stage is not Stage.HANDLER:
|
||||
self.stage = Stage.FAILED
|
||||
raise RuntimeError("Response already started")
|
||||
|
||||
self.response, response.stream = response, self
|
||||
return response
|
||||
|
||||
@property
|
||||
def send(self):
|
||||
return self.response_func
|
|
@ -1,4 +1,3 @@
|
|||
import asyncio
|
||||
import email.utils
|
||||
|
||||
from collections import defaultdict, namedtuple
|
||||
|
@ -8,6 +7,7 @@ from urllib.parse import parse_qs, parse_qsl, unquote, urlunparse
|
|||
|
||||
from httptools import parse_url # type: ignore
|
||||
|
||||
from sanic.compat import CancelledErrors
|
||||
from sanic.exceptions import InvalidUsage
|
||||
from sanic.headers import (
|
||||
parse_content_header,
|
||||
|
@ -16,6 +16,7 @@ from sanic.headers import (
|
|||
parse_xforwarded,
|
||||
)
|
||||
from sanic.log import error_logger, logger
|
||||
from sanic.response import BaseHTTPResponse, HTTPResponse
|
||||
|
||||
|
||||
try:
|
||||
|
@ -24,7 +25,6 @@ except ImportError:
|
|||
from json import loads as json_loads # type: ignore
|
||||
|
||||
DEFAULT_HTTP_CONTENT_TYPE = "application/octet-stream"
|
||||
EXPECT_HEADER = "EXPECT"
|
||||
|
||||
# HTTP/1.1: https://www.w3.org/Protocols/rfc2616/rfc2616-sec7.html#sec7.2.1
|
||||
# > If the media type remains unknown, the recipient SHOULD treat it
|
||||
|
@ -45,35 +45,6 @@ class RequestParameters(dict):
|
|||
return super().get(name, default)
|
||||
|
||||
|
||||
class StreamBuffer:
|
||||
def __init__(self, buffer_size=100):
|
||||
self._queue = asyncio.Queue(buffer_size)
|
||||
|
||||
async def read(self):
|
||||
""" Stop reading when gets None """
|
||||
payload = await self._queue.get()
|
||||
self._queue.task_done()
|
||||
return payload
|
||||
|
||||
async def __aiter__(self):
|
||||
"""Support `async for data in request.stream`"""
|
||||
while True:
|
||||
data = await self.read()
|
||||
if not data:
|
||||
break
|
||||
yield data
|
||||
|
||||
async def put(self, payload):
|
||||
await self._queue.put(payload)
|
||||
|
||||
def is_full(self):
|
||||
return self._queue.full()
|
||||
|
||||
@property
|
||||
def buffer_size(self):
|
||||
return self._queue.maxsize
|
||||
|
||||
|
||||
class Request:
|
||||
"""Properties of an HTTP request such as URL, headers, etc."""
|
||||
|
||||
|
@ -92,6 +63,7 @@ class Request:
|
|||
"endpoint",
|
||||
"headers",
|
||||
"method",
|
||||
"name",
|
||||
"parsed_args",
|
||||
"parsed_not_grouped_args",
|
||||
"parsed_files",
|
||||
|
@ -99,6 +71,7 @@ class Request:
|
|||
"parsed_json",
|
||||
"parsed_forwarded",
|
||||
"raw_url",
|
||||
"request_middleware_started",
|
||||
"stream",
|
||||
"transport",
|
||||
"uri_template",
|
||||
|
@ -117,9 +90,10 @@ class Request:
|
|||
self.transport = transport
|
||||
|
||||
# Init but do not inhale
|
||||
self.body_init()
|
||||
self.body = b""
|
||||
self.conn_info = None
|
||||
self.ctx = SimpleNamespace()
|
||||
self.name = None
|
||||
self.parsed_forwarded = None
|
||||
self.parsed_json = None
|
||||
self.parsed_form = None
|
||||
|
@ -127,6 +101,7 @@ class Request:
|
|||
self.parsed_args = defaultdict(RequestParameters)
|
||||
self.parsed_not_grouped_args = defaultdict(list)
|
||||
self.uri_template = None
|
||||
self.request_middleware_started = False
|
||||
self._cookies = None
|
||||
self.stream = None
|
||||
self.endpoint = None
|
||||
|
@ -135,36 +110,43 @@ class Request:
|
|||
class_name = self.__class__.__name__
|
||||
return f"<{class_name}: {self.method} {self.path}>"
|
||||
|
||||
def body_init(self):
|
||||
""".. deprecated:: 20.3
|
||||
To be removed in 21.3"""
|
||||
self.body = []
|
||||
|
||||
def body_push(self, data):
|
||||
""".. deprecated:: 20.3
|
||||
To be removed in 21.3"""
|
||||
self.body.append(data)
|
||||
|
||||
def body_finish(self):
|
||||
""".. deprecated:: 20.3
|
||||
To be removed in 21.3"""
|
||||
self.body = b"".join(self.body)
|
||||
async def respond(
|
||||
self, response=None, *, status=200, headers=None, content_type=None
|
||||
):
|
||||
# This logic of determining which response to use is subject to change
|
||||
if response is None:
|
||||
response = self.stream.response or HTTPResponse(
|
||||
status=status,
|
||||
headers=headers,
|
||||
content_type=content_type,
|
||||
)
|
||||
# Connect the response
|
||||
if isinstance(response, BaseHTTPResponse):
|
||||
response = self.stream.respond(response)
|
||||
# Run response middleware
|
||||
try:
|
||||
response = await self.app._run_response_middleware(
|
||||
self, response, request_name=self.name
|
||||
)
|
||||
except CancelledErrors:
|
||||
raise
|
||||
except Exception:
|
||||
error_logger.exception(
|
||||
"Exception occurred in one of response middleware handlers"
|
||||
)
|
||||
return response
|
||||
|
||||
async def receive_body(self):
|
||||
"""Receive request.body, if not already received.
|
||||
|
||||
Streaming handlers may call this to receive the full body.
|
||||
Streaming handlers may call this to receive the full body. Sanic calls
|
||||
this function before running any handlers of non-streaming routes.
|
||||
|
||||
This is added as a compatibility shim in Sanic 20.3 because future
|
||||
versions of Sanic will make all requests streaming and will use this
|
||||
function instead of the non-async body_init/push/finish functions.
|
||||
|
||||
Please make an issue if your code depends on the old functionality and
|
||||
cannot be upgraded to the new API.
|
||||
Custom request classes can override this for custom handling of both
|
||||
streaming and non-streaming routes.
|
||||
"""
|
||||
if not self.stream:
|
||||
return
|
||||
self.body = b"".join([data async for data in self.stream])
|
||||
if not self.body:
|
||||
self.body = b"".join([data async for data in self.stream])
|
||||
|
||||
@property
|
||||
def json(self):
|
||||
|
|
|
@ -2,10 +2,10 @@ from functools import partial
|
|||
from mimetypes import guess_type
|
||||
from os import path
|
||||
from urllib.parse import quote_plus
|
||||
from warnings import warn
|
||||
|
||||
from sanic.compat import Header, open_async
|
||||
from sanic.cookies import CookieJar
|
||||
from sanic.headers import format_http1, format_http1_response
|
||||
from sanic.helpers import has_message_body, remove_entity_headers
|
||||
|
||||
|
||||
|
@ -28,50 +28,51 @@ class BaseHTTPResponse:
|
|||
return b""
|
||||
return data.encode() if hasattr(data, "encode") else data
|
||||
|
||||
def _parse_headers(self):
|
||||
return format_http1(self.headers.items())
|
||||
|
||||
@property
|
||||
def cookies(self):
|
||||
if self._cookies is None:
|
||||
self._cookies = CookieJar(self.headers)
|
||||
return self._cookies
|
||||
|
||||
def get_headers(
|
||||
self,
|
||||
version="1.1",
|
||||
keep_alive=False,
|
||||
keep_alive_timeout=None,
|
||||
body=b"",
|
||||
):
|
||||
""".. deprecated:: 20.3:
|
||||
This function is not public API and will be removed in 21.3."""
|
||||
@property
|
||||
def processed_headers(self):
|
||||
"""Obtain a list of header tuples encoded in bytes for sending.
|
||||
|
||||
# self.headers get priority over content_type
|
||||
if self.content_type and "Content-Type" not in self.headers:
|
||||
self.headers["Content-Type"] = self.content_type
|
||||
|
||||
if keep_alive:
|
||||
self.headers["Connection"] = "keep-alive"
|
||||
if keep_alive_timeout is not None:
|
||||
self.headers["Keep-Alive"] = keep_alive_timeout
|
||||
else:
|
||||
self.headers["Connection"] = "close"
|
||||
|
||||
if self.status in (304, 412):
|
||||
Add and remove headers based on status and content_type.
|
||||
"""
|
||||
# TODO: Make a blacklist set of header names and then filter with that
|
||||
if self.status in (304, 412): # Not Modified, Precondition Failed
|
||||
self.headers = remove_entity_headers(self.headers)
|
||||
if has_message_body(self.status):
|
||||
self.headers.setdefault("content-type", self.content_type)
|
||||
# Encode headers into bytes
|
||||
return (
|
||||
(name.encode("ascii"), f"{value}".encode(errors="surrogateescape"))
|
||||
for name, value in self.headers.items()
|
||||
)
|
||||
|
||||
return format_http1_response(self.status, self.headers.items(), body)
|
||||
async def send(self, data=None, end_stream=None):
|
||||
"""Send any pending response headers and the given data as body.
|
||||
:param data: str or bytes to be written
|
||||
:end_stream: whether to close the stream after this block
|
||||
"""
|
||||
if data is None and end_stream is None:
|
||||
end_stream = True
|
||||
if end_stream and not data and self.stream.send is None:
|
||||
return
|
||||
data = data.encode() if hasattr(data, "encode") else data or b""
|
||||
await self.stream.send(data, end_stream=end_stream)
|
||||
|
||||
|
||||
class StreamingHTTPResponse(BaseHTTPResponse):
|
||||
"""Old style streaming response. Use `request.respond()` instead of this in
|
||||
new code to avoid the callback."""
|
||||
|
||||
__slots__ = (
|
||||
"protocol",
|
||||
"streaming_fn",
|
||||
"status",
|
||||
"content_type",
|
||||
"headers",
|
||||
"chunked",
|
||||
"_cookies",
|
||||
)
|
||||
|
||||
|
@ -81,63 +82,34 @@ class StreamingHTTPResponse(BaseHTTPResponse):
|
|||
status=200,
|
||||
headers=None,
|
||||
content_type="text/plain; charset=utf-8",
|
||||
chunked=True,
|
||||
chunked="deprecated",
|
||||
):
|
||||
if chunked != "deprecated":
|
||||
warn(
|
||||
"The chunked argument has been deprecated and will be "
|
||||
"removed in v21.6"
|
||||
)
|
||||
|
||||
super().__init__()
|
||||
|
||||
self.content_type = content_type
|
||||
self.streaming_fn = streaming_fn
|
||||
self.status = status
|
||||
self.headers = Header(headers or {})
|
||||
self.chunked = chunked
|
||||
self._cookies = None
|
||||
self.protocol = None
|
||||
|
||||
async def write(self, data):
|
||||
"""Writes a chunk of data to the streaming response.
|
||||
|
||||
:param data: str or bytes-ish data to be written.
|
||||
"""
|
||||
data = self._encode_body(data)
|
||||
await super().send(self._encode_body(data))
|
||||
|
||||
# `chunked` will always be False in ASGI-mode, even if the underlying
|
||||
# ASGI Transport implements Chunked transport. That does it itself.
|
||||
if self.chunked:
|
||||
await self.protocol.push_data(b"%x\r\n%b\r\n" % (len(data), data))
|
||||
else:
|
||||
await self.protocol.push_data(data)
|
||||
await self.protocol.drain()
|
||||
|
||||
async def stream(
|
||||
self, version="1.1", keep_alive=False, keep_alive_timeout=None
|
||||
):
|
||||
"""Streams headers, runs the `streaming_fn` callback that writes
|
||||
content to the response body, then finalizes the response body.
|
||||
"""
|
||||
if version != "1.1":
|
||||
self.chunked = False
|
||||
if not getattr(self, "asgi", False):
|
||||
headers = self.get_headers(
|
||||
version,
|
||||
keep_alive=keep_alive,
|
||||
keep_alive_timeout=keep_alive_timeout,
|
||||
)
|
||||
await self.protocol.push_data(headers)
|
||||
await self.protocol.drain()
|
||||
await self.streaming_fn(self)
|
||||
if self.chunked:
|
||||
await self.protocol.push_data(b"0\r\n\r\n")
|
||||
# no need to await drain here after this write, because it is the
|
||||
# very last thing we write and nothing needs to wait for it.
|
||||
|
||||
def get_headers(
|
||||
self, version="1.1", keep_alive=False, keep_alive_timeout=None
|
||||
):
|
||||
if self.chunked and version == "1.1":
|
||||
self.headers["Transfer-Encoding"] = "chunked"
|
||||
self.headers.pop("Content-Length", None)
|
||||
|
||||
return super().get_headers(version, keep_alive, keep_alive_timeout)
|
||||
async def send(self, *args, **kwargs):
|
||||
if self.streaming_fn is not None:
|
||||
await self.streaming_fn(self)
|
||||
self.streaming_fn = None
|
||||
await super().send(*args, **kwargs)
|
||||
|
||||
|
||||
class HTTPResponse(BaseHTTPResponse):
|
||||
|
@ -158,22 +130,6 @@ class HTTPResponse(BaseHTTPResponse):
|
|||
self.headers = Header(headers or {})
|
||||
self._cookies = None
|
||||
|
||||
def output(self, version="1.1", keep_alive=False, keep_alive_timeout=None):
|
||||
body = b""
|
||||
if has_message_body(self.status):
|
||||
body = self.body
|
||||
self.headers["Content-Length"] = self.headers.get(
|
||||
"Content-Length", len(self.body)
|
||||
)
|
||||
|
||||
return self.get_headers(version, keep_alive, keep_alive_timeout, body)
|
||||
|
||||
@property
|
||||
def cookies(self):
|
||||
if self._cookies is None:
|
||||
self._cookies = CookieJar(self.headers)
|
||||
return self._cookies
|
||||
|
||||
|
||||
def empty(status=204, headers=None):
|
||||
"""
|
||||
|
@ -319,7 +275,7 @@ async def file_stream(
|
|||
mime_type=None,
|
||||
headers=None,
|
||||
filename=None,
|
||||
chunked=True,
|
||||
chunked="deprecated",
|
||||
_range=None,
|
||||
):
|
||||
"""Return a streaming response object with file data.
|
||||
|
@ -329,9 +285,15 @@ async def file_stream(
|
|||
:param mime_type: Specific mime_type.
|
||||
:param headers: Custom Headers.
|
||||
:param filename: Override filename.
|
||||
:param chunked: Enable or disable chunked transfer-encoding
|
||||
:param chunked: Deprecated
|
||||
:param _range:
|
||||
"""
|
||||
if chunked != "deprecated":
|
||||
warn(
|
||||
"The chunked argument has been deprecated and will be "
|
||||
"removed in v21.6"
|
||||
)
|
||||
|
||||
headers = headers or {}
|
||||
if filename:
|
||||
headers.setdefault(
|
||||
|
@ -370,7 +332,6 @@ async def file_stream(
|
|||
status=status,
|
||||
headers=headers,
|
||||
content_type=mime_type,
|
||||
chunked=chunked,
|
||||
)
|
||||
|
||||
|
||||
|
@ -379,7 +340,7 @@ def stream(
|
|||
status=200,
|
||||
headers=None,
|
||||
content_type="text/plain; charset=utf-8",
|
||||
chunked=True,
|
||||
chunked="deprecated",
|
||||
):
|
||||
"""Accepts an coroutine `streaming_fn` which can be used to
|
||||
write chunks to a streaming response. Returns a `StreamingHTTPResponse`.
|
||||
|
@ -398,14 +359,19 @@ def stream(
|
|||
writes content to that response.
|
||||
:param mime_type: Specific mime_type.
|
||||
:param headers: Custom Headers.
|
||||
:param chunked: Enable or disable chunked transfer-encoding
|
||||
:param chunked: Deprecated
|
||||
"""
|
||||
if chunked != "deprecated":
|
||||
warn(
|
||||
"The chunked argument has been deprecated and will be "
|
||||
"removed in v21.6"
|
||||
)
|
||||
|
||||
return StreamingHTTPResponse(
|
||||
streaming_fn,
|
||||
headers=headers,
|
||||
content_type=content_type,
|
||||
status=status,
|
||||
chunked=chunked,
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -20,6 +20,7 @@ Route = namedtuple(
|
|||
"name",
|
||||
"uri",
|
||||
"endpoint",
|
||||
"ignore_body",
|
||||
],
|
||||
)
|
||||
Parameter = namedtuple("Parameter", ["name", "cast"])
|
||||
|
@ -135,6 +136,7 @@ class Router:
|
|||
handler,
|
||||
host=None,
|
||||
strict_slashes=False,
|
||||
ignore_body=False,
|
||||
version=None,
|
||||
name=None,
|
||||
):
|
||||
|
@ -146,6 +148,7 @@ class Router:
|
|||
:param handler: request handler function.
|
||||
When executed, it should provide a response object.
|
||||
:param strict_slashes: strict to trailing slash
|
||||
:param ignore_body: Handler should not read the body, if any
|
||||
:param version: current version of the route or blueprint. See
|
||||
docs for further details.
|
||||
:return: Nothing
|
||||
|
@ -155,7 +158,9 @@ class Router:
|
|||
version = re.escape(str(version).strip("/").lstrip("v"))
|
||||
uri = "/".join([f"/v{version}", uri.lstrip("/")])
|
||||
# add regular version
|
||||
routes.append(self._add(uri, methods, handler, host, name))
|
||||
routes.append(
|
||||
self._add(uri, methods, handler, host, name, ignore_body)
|
||||
)
|
||||
|
||||
if strict_slashes:
|
||||
return routes
|
||||
|
@ -187,14 +192,20 @@ class Router:
|
|||
)
|
||||
# add version with trailing slash
|
||||
if slash_is_missing:
|
||||
routes.append(self._add(uri + "/", methods, handler, host, name))
|
||||
routes.append(
|
||||
self._add(uri + "/", methods, handler, host, name, ignore_body)
|
||||
)
|
||||
# add version without trailing slash
|
||||
elif without_slash_is_missing:
|
||||
routes.append(self._add(uri[:-1], methods, handler, host, name))
|
||||
routes.append(
|
||||
self._add(uri[:-1], methods, handler, host, name, ignore_body)
|
||||
)
|
||||
|
||||
return routes
|
||||
|
||||
def _add(self, uri, methods, handler, host=None, name=None):
|
||||
def _add(
|
||||
self, uri, methods, handler, host=None, name=None, ignore_body=False
|
||||
):
|
||||
"""Add a handler to the route list
|
||||
|
||||
:param uri: path to match
|
||||
|
@ -326,6 +337,7 @@ class Router:
|
|||
name=handler_name,
|
||||
uri=uri,
|
||||
endpoint=endpoint,
|
||||
ignore_body=ignore_body,
|
||||
)
|
||||
|
||||
self.routes_all[uri] = route
|
||||
|
@ -465,7 +477,15 @@ class Router:
|
|||
if hasattr(route_handler, "handlers"):
|
||||
route_handler = route_handler.handlers[method]
|
||||
|
||||
return route_handler, [], kwargs, route.uri, route.name, route.endpoint
|
||||
return (
|
||||
route_handler,
|
||||
[],
|
||||
kwargs,
|
||||
route.uri,
|
||||
route.name,
|
||||
route.endpoint,
|
||||
route.ignore_body,
|
||||
)
|
||||
|
||||
def is_stream_handler(self, request):
|
||||
"""Handler for request is stream or not.
|
||||
|
|
668
sanic/server.py
668
sanic/server.py
|
@ -5,33 +5,22 @@ import secrets
|
|||
import socket
|
||||
import stat
|
||||
import sys
|
||||
import traceback
|
||||
|
||||
from collections import deque
|
||||
from asyncio import CancelledError
|
||||
from functools import partial
|
||||
from inspect import isawaitable
|
||||
from ipaddress import ip_address
|
||||
from signal import SIG_IGN, SIGINT, SIGTERM, Signals
|
||||
from signal import signal as signal_func
|
||||
from time import time
|
||||
from time import monotonic as current_time
|
||||
from typing import Dict, Type, Union
|
||||
|
||||
from httptools import HttpRequestParser # type: ignore
|
||||
from httptools.parser.errors import HttpParserError # type: ignore
|
||||
|
||||
from sanic.compat import Header, ctrlc_workaround_for_windows
|
||||
from sanic.compat import OS_IS_WINDOWS, ctrlc_workaround_for_windows
|
||||
from sanic.config import Config
|
||||
from sanic.exceptions import (
|
||||
HeaderExpectationFailed,
|
||||
InvalidUsage,
|
||||
PayloadTooLarge,
|
||||
RequestTimeout,
|
||||
ServerError,
|
||||
ServiceUnavailable,
|
||||
)
|
||||
from sanic.log import access_logger, logger
|
||||
from sanic.request import EXPECT_HEADER, Request, StreamBuffer
|
||||
from sanic.response import HTTPResponse
|
||||
from sanic.exceptions import RequestTimeout, ServiceUnavailable
|
||||
from sanic.http import Http, Stage
|
||||
from sanic.log import logger
|
||||
from sanic.request import Request
|
||||
|
||||
|
||||
try:
|
||||
|
@ -42,8 +31,6 @@ try:
|
|||
except ImportError:
|
||||
pass
|
||||
|
||||
OS_IS_WINDOWS = os.name == "nt"
|
||||
|
||||
|
||||
class Signal:
|
||||
stopped = False
|
||||
|
@ -99,10 +86,7 @@ class HttpProtocol(asyncio.Protocol):
|
|||
"signal",
|
||||
"conn_info",
|
||||
# request params
|
||||
"parser",
|
||||
"request",
|
||||
"url",
|
||||
"headers",
|
||||
# request config
|
||||
"request_handler",
|
||||
"request_timeout",
|
||||
|
@ -111,26 +95,21 @@ class HttpProtocol(asyncio.Protocol):
|
|||
"request_max_size",
|
||||
"request_buffer_queue_size",
|
||||
"request_class",
|
||||
"is_request_stream",
|
||||
"error_handler",
|
||||
# enable or disable access log purpose
|
||||
"access_log",
|
||||
# connection management
|
||||
"_total_request_size",
|
||||
"_request_timeout_handler",
|
||||
"_response_timeout_handler",
|
||||
"_keep_alive_timeout_handler",
|
||||
"_last_request_time",
|
||||
"_last_response_time",
|
||||
"_is_stream_handler",
|
||||
"_not_paused",
|
||||
"_request_handler_task",
|
||||
"_request_stream_task",
|
||||
"_keep_alive",
|
||||
"_header_fragment",
|
||||
"state",
|
||||
"url",
|
||||
"_handler_task",
|
||||
"_can_write",
|
||||
"_data_received",
|
||||
"_time",
|
||||
"_task",
|
||||
"_http",
|
||||
"_exception",
|
||||
"recv_buffer",
|
||||
"_unix",
|
||||
"_body_chunks",
|
||||
)
|
||||
|
||||
def __init__(
|
||||
|
@ -148,12 +127,10 @@ class HttpProtocol(asyncio.Protocol):
|
|||
self.loop = loop
|
||||
deprecated_loop = self.loop if sys.version_info < (3, 7) else None
|
||||
self.app = app
|
||||
self.url = None
|
||||
self.transport = None
|
||||
self.conn_info = None
|
||||
self.request = None
|
||||
self.parser = None
|
||||
self.url = None
|
||||
self.headers = None
|
||||
self.signal = signal
|
||||
self.access_log = self.app.config.ACCESS_LOG
|
||||
self.connections = connections if connections is not None else set()
|
||||
|
@ -167,511 +144,99 @@ class HttpProtocol(asyncio.Protocol):
|
|||
self.keep_alive_timeout = self.app.config.KEEP_ALIVE_TIMEOUT
|
||||
self.request_max_size = self.app.config.REQUEST_MAX_SIZE
|
||||
self.request_class = self.app.request_class or Request
|
||||
self.is_request_stream = self.app.is_request_stream
|
||||
self._is_stream_handler = False
|
||||
self._not_paused = asyncio.Event(loop=deprecated_loop)
|
||||
self._total_request_size = 0
|
||||
self._request_timeout_handler = None
|
||||
self._response_timeout_handler = None
|
||||
self._keep_alive_timeout_handler = None
|
||||
self._last_request_time = None
|
||||
self._last_response_time = None
|
||||
self._request_handler_task = None
|
||||
self._request_stream_task = None
|
||||
self._keep_alive = self.app.config.KEEP_ALIVE
|
||||
self._header_fragment = b""
|
||||
self.state = state if state else {}
|
||||
if "requests_count" not in self.state:
|
||||
self.state["requests_count"] = 0
|
||||
self._data_received = asyncio.Event(loop=deprecated_loop)
|
||||
self._can_write = asyncio.Event(loop=deprecated_loop)
|
||||
self._can_write.set()
|
||||
self._exception = None
|
||||
self._unix = unix
|
||||
self._not_paused.set()
|
||||
self._body_chunks = deque()
|
||||
|
||||
@property
|
||||
def keep_alive(self):
|
||||
def _setup_connection(self):
|
||||
self._http = Http(self)
|
||||
self._time = current_time()
|
||||
self.check_timeouts()
|
||||
|
||||
async def connection_task(self):
|
||||
"""Run a HTTP connection.
|
||||
|
||||
Timeouts and some additional error handling occur here, while most of
|
||||
everything else happens in class Http or in code called from there.
|
||||
"""
|
||||
Check if the connection needs to be kept alive based on the params
|
||||
attached to the `_keep_alive` attribute, :attr:`Signal.stopped`
|
||||
and :func:`HttpProtocol.parser.should_keep_alive`
|
||||
|
||||
:return: ``True`` if connection is to be kept alive ``False`` else
|
||||
"""
|
||||
return (
|
||||
self._keep_alive
|
||||
and not self.signal.stopped
|
||||
and self.parser.should_keep_alive()
|
||||
)
|
||||
|
||||
# -------------------------------------------- #
|
||||
# Connection
|
||||
# -------------------------------------------- #
|
||||
|
||||
def connection_made(self, transport):
|
||||
self.connections.add(self)
|
||||
self._request_timeout_handler = self.loop.call_later(
|
||||
self.request_timeout, self.request_timeout_callback
|
||||
)
|
||||
self.transport = transport
|
||||
self.conn_info = ConnInfo(transport, unix=self._unix)
|
||||
self._last_request_time = time()
|
||||
|
||||
def connection_lost(self, exc):
|
||||
self.connections.discard(self)
|
||||
if self._request_handler_task:
|
||||
self._request_handler_task.cancel()
|
||||
if self._request_stream_task:
|
||||
self._request_stream_task.cancel()
|
||||
if self._request_timeout_handler:
|
||||
self._request_timeout_handler.cancel()
|
||||
if self._response_timeout_handler:
|
||||
self._response_timeout_handler.cancel()
|
||||
if self._keep_alive_timeout_handler:
|
||||
self._keep_alive_timeout_handler.cancel()
|
||||
|
||||
def pause_writing(self):
|
||||
self._not_paused.clear()
|
||||
|
||||
def resume_writing(self):
|
||||
self._not_paused.set()
|
||||
|
||||
def request_timeout_callback(self):
|
||||
# See the docstring in the RequestTimeout exception, to see
|
||||
# exactly what this timeout is checking for.
|
||||
# Check if elapsed time since request initiated exceeds our
|
||||
# configured maximum request timeout value
|
||||
time_elapsed = time() - self._last_request_time
|
||||
if time_elapsed < self.request_timeout:
|
||||
time_left = self.request_timeout - time_elapsed
|
||||
self._request_timeout_handler = self.loop.call_later(
|
||||
time_left, self.request_timeout_callback
|
||||
)
|
||||
else:
|
||||
if self._request_stream_task:
|
||||
self._request_stream_task.cancel()
|
||||
if self._request_handler_task:
|
||||
self._request_handler_task.cancel()
|
||||
self.write_error(RequestTimeout("Request Timeout"))
|
||||
|
||||
def response_timeout_callback(self):
|
||||
# Check if elapsed time since response was initiated exceeds our
|
||||
# configured maximum request timeout value
|
||||
time_elapsed = time() - self._last_request_time
|
||||
if time_elapsed < self.response_timeout:
|
||||
time_left = self.response_timeout - time_elapsed
|
||||
self._response_timeout_handler = self.loop.call_later(
|
||||
time_left, self.response_timeout_callback
|
||||
)
|
||||
else:
|
||||
if self._request_stream_task:
|
||||
self._request_stream_task.cancel()
|
||||
if self._request_handler_task:
|
||||
self._request_handler_task.cancel()
|
||||
self.write_error(ServiceUnavailable("Response Timeout"))
|
||||
|
||||
def keep_alive_timeout_callback(self):
|
||||
"""
|
||||
Check if elapsed time since last response exceeds our configured
|
||||
maximum keep alive timeout value and if so, close the transport
|
||||
pipe and let the response writer handle the error.
|
||||
|
||||
:return: None
|
||||
"""
|
||||
time_elapsed = time() - self._last_response_time
|
||||
if time_elapsed < self.keep_alive_timeout:
|
||||
time_left = self.keep_alive_timeout - time_elapsed
|
||||
self._keep_alive_timeout_handler = self.loop.call_later(
|
||||
time_left, self.keep_alive_timeout_callback
|
||||
)
|
||||
else:
|
||||
logger.debug("KeepAlive Timeout. Closing connection.")
|
||||
self.transport.close()
|
||||
self.transport = None
|
||||
|
||||
# -------------------------------------------- #
|
||||
# Parsing
|
||||
# -------------------------------------------- #
|
||||
|
||||
def data_received(self, data):
|
||||
# Check for the request itself getting too large and exceeding
|
||||
# memory limits
|
||||
self._total_request_size += len(data)
|
||||
if self._total_request_size > self.request_max_size:
|
||||
self.write_error(PayloadTooLarge("Payload Too Large"))
|
||||
|
||||
# Create parser if this is the first time we're receiving data
|
||||
if self.parser is None:
|
||||
assert self.request is None
|
||||
self.headers = []
|
||||
self.parser = HttpRequestParser(self)
|
||||
|
||||
# requests count
|
||||
self.state["requests_count"] = self.state["requests_count"] + 1
|
||||
|
||||
# Parse request chunk or close connection
|
||||
try:
|
||||
self.parser.feed_data(data)
|
||||
except HttpParserError:
|
||||
message = "Bad Request"
|
||||
if self.app.debug:
|
||||
message += "\n" + traceback.format_exc()
|
||||
self.write_error(InvalidUsage(message))
|
||||
|
||||
def on_url(self, url):
|
||||
if not self.url:
|
||||
self.url = url
|
||||
else:
|
||||
self.url += url
|
||||
|
||||
def on_header(self, name, value):
|
||||
self._header_fragment += name
|
||||
|
||||
if value is not None:
|
||||
if (
|
||||
self._header_fragment == b"Content-Length"
|
||||
and int(value) > self.request_max_size
|
||||
):
|
||||
self.write_error(PayloadTooLarge("Payload Too Large"))
|
||||
try:
|
||||
value = value.decode()
|
||||
except UnicodeDecodeError:
|
||||
value = value.decode("latin_1")
|
||||
self.headers.append(
|
||||
(self._header_fragment.decode().casefold(), value)
|
||||
)
|
||||
|
||||
self._header_fragment = b""
|
||||
|
||||
def on_headers_complete(self):
|
||||
self.request = self.request_class(
|
||||
url_bytes=self.url,
|
||||
headers=Header(self.headers),
|
||||
version=self.parser.get_http_version(),
|
||||
method=self.parser.get_method().decode(),
|
||||
transport=self.transport,
|
||||
app=self.app,
|
||||
)
|
||||
self.request.conn_info = self.conn_info
|
||||
# Remove any existing KeepAlive handler here,
|
||||
# It will be recreated if required on the new request.
|
||||
if self._keep_alive_timeout_handler:
|
||||
self._keep_alive_timeout_handler.cancel()
|
||||
self._keep_alive_timeout_handler = None
|
||||
|
||||
if self.request.headers.get(EXPECT_HEADER):
|
||||
self.expect_handler()
|
||||
|
||||
if self.is_request_stream:
|
||||
self._is_stream_handler = self.app.router.is_stream_handler(
|
||||
self.request
|
||||
)
|
||||
if self._is_stream_handler:
|
||||
self.request.stream = StreamBuffer(
|
||||
self.request_buffer_queue_size
|
||||
)
|
||||
self.execute_request_handler()
|
||||
|
||||
def expect_handler(self):
|
||||
"""
|
||||
Handler for Expect Header.
|
||||
"""
|
||||
expect = self.request.headers.get(EXPECT_HEADER)
|
||||
if self.request.version == "1.1":
|
||||
if expect.lower() == "100-continue":
|
||||
self.transport.write(b"HTTP/1.1 100 Continue\r\n\r\n")
|
||||
else:
|
||||
self.write_error(
|
||||
HeaderExpectationFailed(f"Unknown Expect: {expect}")
|
||||
)
|
||||
|
||||
def on_body(self, body):
|
||||
if self.is_request_stream and self._is_stream_handler:
|
||||
# body chunks can be put into asyncio.Queue out of order if
|
||||
# multiple tasks put concurrently and the queue is full in python
|
||||
# 3.7. so we should not create more than one task putting into the
|
||||
# queue simultaneously.
|
||||
self._body_chunks.append(body)
|
||||
if (
|
||||
not self._request_stream_task
|
||||
or self._request_stream_task.done()
|
||||
):
|
||||
self._request_stream_task = self.loop.create_task(
|
||||
self.stream_append()
|
||||
)
|
||||
else:
|
||||
self.request.body_push(body)
|
||||
|
||||
async def body_append(self, body):
|
||||
if (
|
||||
self.request is None
|
||||
or self._request_stream_task is None
|
||||
or self._request_stream_task.cancelled()
|
||||
):
|
||||
return
|
||||
|
||||
if self.request.stream.is_full():
|
||||
self.transport.pause_reading()
|
||||
await self.request.stream.put(body)
|
||||
self.transport.resume_reading()
|
||||
else:
|
||||
await self.request.stream.put(body)
|
||||
|
||||
async def stream_append(self):
|
||||
while self._body_chunks:
|
||||
body = self._body_chunks.popleft()
|
||||
if self.request:
|
||||
if self.request.stream.is_full():
|
||||
self.transport.pause_reading()
|
||||
await self.request.stream.put(body)
|
||||
self.transport.resume_reading()
|
||||
else:
|
||||
await self.request.stream.put(body)
|
||||
|
||||
def on_message_complete(self):
|
||||
# Entire request (headers and whole body) is received.
|
||||
# We can cancel and remove the request timeout handler now.
|
||||
if self._request_timeout_handler:
|
||||
self._request_timeout_handler.cancel()
|
||||
self._request_timeout_handler = None
|
||||
if self.is_request_stream and self._is_stream_handler:
|
||||
self._body_chunks.append(None)
|
||||
if (
|
||||
not self._request_stream_task
|
||||
or self._request_stream_task.done()
|
||||
):
|
||||
self._request_stream_task = self.loop.create_task(
|
||||
self.stream_append()
|
||||
)
|
||||
return
|
||||
self.request.body_finish()
|
||||
self.execute_request_handler()
|
||||
|
||||
def execute_request_handler(self):
|
||||
"""
|
||||
Invoke the request handler defined by the
|
||||
:func:`sanic.app.Sanic.handle_request` method
|
||||
|
||||
:return: None
|
||||
"""
|
||||
self._response_timeout_handler = self.loop.call_later(
|
||||
self.response_timeout, self.response_timeout_callback
|
||||
)
|
||||
self._last_request_time = time()
|
||||
self._request_handler_task = self.loop.create_task(
|
||||
self.request_handler(
|
||||
self.request, self.write_response, self.stream_response
|
||||
)
|
||||
)
|
||||
|
||||
# -------------------------------------------- #
|
||||
# Responding
|
||||
# -------------------------------------------- #
|
||||
def log_response(self, response):
|
||||
"""
|
||||
Helper method provided to enable the logging of responses in case if
|
||||
the :attr:`HttpProtocol.access_log` is enabled.
|
||||
|
||||
:param response: Response generated for the current request
|
||||
|
||||
:type response: :class:`sanic.response.HTTPResponse` or
|
||||
:class:`sanic.response.StreamingHTTPResponse`
|
||||
|
||||
:return: None
|
||||
"""
|
||||
if self.access_log:
|
||||
extra = {"status": getattr(response, "status", 0)}
|
||||
|
||||
if isinstance(response, HTTPResponse):
|
||||
extra["byte"] = len(response.body)
|
||||
else:
|
||||
extra["byte"] = -1
|
||||
|
||||
extra["host"] = "UNKNOWN"
|
||||
if self.request is not None:
|
||||
if self.request.ip:
|
||||
extra["host"] = f"{self.request.ip}:{self.request.port}"
|
||||
|
||||
extra["request"] = f"{self.request.method} {self.request.url}"
|
||||
else:
|
||||
extra["request"] = "nil"
|
||||
|
||||
access_logger.info("", extra=extra)
|
||||
|
||||
def write_response(self, response):
|
||||
"""
|
||||
Writes response content synchronously to the transport.
|
||||
"""
|
||||
if self._response_timeout_handler:
|
||||
self._response_timeout_handler.cancel()
|
||||
self._response_timeout_handler = None
|
||||
try:
|
||||
keep_alive = self.keep_alive
|
||||
self.transport.write(
|
||||
response.output(
|
||||
self.request.version, keep_alive, self.keep_alive_timeout
|
||||
)
|
||||
)
|
||||
self.log_response(response)
|
||||
except AttributeError:
|
||||
logger.error(
|
||||
"Invalid response object for url %s, "
|
||||
"Expected Type: HTTPResponse, Actual Type: %s",
|
||||
self.url,
|
||||
type(response),
|
||||
)
|
||||
self.write_error(ServerError("Invalid response type"))
|
||||
except RuntimeError:
|
||||
if self.app.debug:
|
||||
logger.error(
|
||||
"Connection lost before response written @ %s",
|
||||
self.request.ip,
|
||||
)
|
||||
keep_alive = False
|
||||
except Exception as e:
|
||||
self.bail_out(f"Writing response failed, connection closed {e!r}")
|
||||
self._setup_connection()
|
||||
await self._http.http1()
|
||||
except CancelledError:
|
||||
pass
|
||||
except Exception:
|
||||
logger.exception("protocol.connection_task uncaught")
|
||||
finally:
|
||||
if not keep_alive:
|
||||
self.transport.close()
|
||||
self.transport = None
|
||||
else:
|
||||
self._keep_alive_timeout_handler = self.loop.call_later(
|
||||
self.keep_alive_timeout, self.keep_alive_timeout_callback
|
||||
if self.app.debug and self._http:
|
||||
ip = self.transport.get_extra_info("peername")
|
||||
logger.error(
|
||||
"Connection lost before response written"
|
||||
f" @ {ip} {self._http.request}"
|
||||
)
|
||||
self._last_response_time = time()
|
||||
self.cleanup()
|
||||
self._http = None
|
||||
self._task = None
|
||||
try:
|
||||
self.close()
|
||||
except BaseException:
|
||||
logger.exception("Closing failed")
|
||||
|
||||
async def drain(self):
|
||||
await self._not_paused.wait()
|
||||
async def receive_more(self):
|
||||
"""Wait until more data is received into self._buffer."""
|
||||
self.transport.resume_reading()
|
||||
self._data_received.clear()
|
||||
await self._data_received.wait()
|
||||
|
||||
async def push_data(self, data):
|
||||
def check_timeouts(self):
|
||||
"""Runs itself periodically to enforce any expired timeouts."""
|
||||
try:
|
||||
if not self._task:
|
||||
return
|
||||
duration = current_time() - self._time
|
||||
stage = self._http.stage
|
||||
if stage is Stage.IDLE and duration > self.keep_alive_timeout:
|
||||
logger.debug("KeepAlive Timeout. Closing connection.")
|
||||
elif stage is Stage.REQUEST and duration > self.request_timeout:
|
||||
self._http.exception = RequestTimeout("Request Timeout")
|
||||
elif (
|
||||
stage in (Stage.HANDLER, Stage.RESPONSE, Stage.FAILED)
|
||||
and duration > self.response_timeout
|
||||
):
|
||||
self._http.exception = ServiceUnavailable("Response Timeout")
|
||||
else:
|
||||
interval = (
|
||||
min(
|
||||
self.keep_alive_timeout,
|
||||
self.request_timeout,
|
||||
self.response_timeout,
|
||||
)
|
||||
/ 2
|
||||
)
|
||||
self.loop.call_later(max(0.1, interval), self.check_timeouts)
|
||||
return
|
||||
self._task.cancel()
|
||||
except Exception:
|
||||
logger.exception("protocol.check_timeouts")
|
||||
|
||||
async def send(self, data):
|
||||
"""Writes data with backpressure control."""
|
||||
await self._can_write.wait()
|
||||
if self.transport.is_closing():
|
||||
raise CancelledError
|
||||
self.transport.write(data)
|
||||
|
||||
async def stream_response(self, response):
|
||||
"""
|
||||
Streams a response to the client asynchronously. Attaches
|
||||
the transport to the response so the response consumer can
|
||||
write to the response as needed.
|
||||
"""
|
||||
if self._response_timeout_handler:
|
||||
self._response_timeout_handler.cancel()
|
||||
self._response_timeout_handler = None
|
||||
|
||||
try:
|
||||
keep_alive = self.keep_alive
|
||||
response.protocol = self
|
||||
await response.stream(
|
||||
self.request.version, keep_alive, self.keep_alive_timeout
|
||||
)
|
||||
self.log_response(response)
|
||||
except AttributeError:
|
||||
logger.error(
|
||||
"Invalid response object for url %s, "
|
||||
"Expected Type: HTTPResponse, Actual Type: %s",
|
||||
self.url,
|
||||
type(response),
|
||||
)
|
||||
self.write_error(ServerError("Invalid response type"))
|
||||
except RuntimeError:
|
||||
if self.app.debug:
|
||||
logger.error(
|
||||
"Connection lost before response written @ %s",
|
||||
self.request.ip,
|
||||
)
|
||||
keep_alive = False
|
||||
except Exception as e:
|
||||
self.bail_out(f"Writing response failed, connection closed {e!r}")
|
||||
finally:
|
||||
if not keep_alive:
|
||||
self.transport.close()
|
||||
self.transport = None
|
||||
else:
|
||||
self._keep_alive_timeout_handler = self.loop.call_later(
|
||||
self.keep_alive_timeout, self.keep_alive_timeout_callback
|
||||
)
|
||||
self._last_response_time = time()
|
||||
self.cleanup()
|
||||
|
||||
def write_error(self, exception):
|
||||
# An error _is_ a response.
|
||||
# Don't throw a response timeout, when a response _is_ given.
|
||||
if self._response_timeout_handler:
|
||||
self._response_timeout_handler.cancel()
|
||||
self._response_timeout_handler = None
|
||||
response = None
|
||||
try:
|
||||
response = self.error_handler.response(self.request, exception)
|
||||
version = self.request.version if self.request else "1.1"
|
||||
self.transport.write(response.output(version))
|
||||
except RuntimeError:
|
||||
if self.app.debug:
|
||||
logger.error(
|
||||
"Connection lost before error written @ %s",
|
||||
self.request.ip if self.request else "Unknown",
|
||||
)
|
||||
except Exception as e:
|
||||
self.bail_out(
|
||||
f"Writing error failed, connection closed {e!r}",
|
||||
from_error=True,
|
||||
)
|
||||
finally:
|
||||
if self.parser and (
|
||||
self.keep_alive or getattr(response, "status", 0) == 408
|
||||
):
|
||||
self.log_response(response)
|
||||
try:
|
||||
self.transport.close()
|
||||
except AttributeError:
|
||||
logger.debug("Connection lost before server could close it.")
|
||||
|
||||
def bail_out(self, message, from_error=False):
|
||||
"""
|
||||
In case if the transport pipes are closed and the sanic app encounters
|
||||
an error while writing data to the transport pipe, we log the error
|
||||
with proper details.
|
||||
|
||||
:param message: Error message to display
|
||||
:param from_error: If the bail out was invoked while handling an
|
||||
exception scenario.
|
||||
|
||||
:type message: str
|
||||
:type from_error: bool
|
||||
|
||||
:return: None
|
||||
"""
|
||||
if from_error or self.transport is None or self.transport.is_closing():
|
||||
logger.error(
|
||||
"Transport closed @ %s and exception "
|
||||
"experienced during error handling",
|
||||
(
|
||||
self.transport.get_extra_info("peername")
|
||||
if self.transport is not None
|
||||
else "N/A"
|
||||
),
|
||||
)
|
||||
logger.debug("Exception:", exc_info=True)
|
||||
else:
|
||||
self.write_error(ServerError(message))
|
||||
logger.error(message)
|
||||
|
||||
def cleanup(self):
|
||||
"""This is called when KeepAlive feature is used,
|
||||
it resets the connection in order for it to be able
|
||||
to handle receiving another request on the same connection."""
|
||||
self.parser = None
|
||||
self.request = None
|
||||
self.url = None
|
||||
self.headers = None
|
||||
self._request_handler_task = None
|
||||
self._request_stream_task = None
|
||||
self._total_request_size = 0
|
||||
self._is_stream_handler = False
|
||||
self._time = current_time()
|
||||
|
||||
def close_if_idle(self):
|
||||
"""Close the connection if a request is not being sent or received
|
||||
|
||||
:return: boolean - True if closed, false if staying open
|
||||
"""
|
||||
if not self.parser and self.transport is not None:
|
||||
self.transport.close()
|
||||
if self._http is None or self._http.stage is Stage.IDLE:
|
||||
self.close()
|
||||
return True
|
||||
return False
|
||||
|
||||
|
@ -679,10 +244,57 @@ class HttpProtocol(asyncio.Protocol):
|
|||
"""
|
||||
Force close the connection.
|
||||
"""
|
||||
if self.transport is not None:
|
||||
# Cause a call to connection_lost where further cleanup occurs
|
||||
if self.transport:
|
||||
self.transport.close()
|
||||
self.transport = None
|
||||
|
||||
# -------------------------------------------- #
|
||||
# Only asyncio.Protocol callbacks below this
|
||||
# -------------------------------------------- #
|
||||
|
||||
def connection_made(self, transport):
|
||||
try:
|
||||
# TODO: Benchmark to find suitable write buffer limits
|
||||
transport.set_write_buffer_limits(low=16384, high=65536)
|
||||
self.connections.add(self)
|
||||
self.transport = transport
|
||||
self._task = self.loop.create_task(self.connection_task())
|
||||
self.recv_buffer = bytearray()
|
||||
self.conn_info = ConnInfo(self.transport, unix=self._unix)
|
||||
except Exception:
|
||||
logger.exception("protocol.connect_made")
|
||||
|
||||
def connection_lost(self, exc):
|
||||
try:
|
||||
self.connections.discard(self)
|
||||
self.resume_writing()
|
||||
if self._task:
|
||||
self._task.cancel()
|
||||
except Exception:
|
||||
logger.exception("protocol.connection_lost")
|
||||
|
||||
def pause_writing(self):
|
||||
self._can_write.clear()
|
||||
|
||||
def resume_writing(self):
|
||||
self._can_write.set()
|
||||
|
||||
def data_received(self, data):
|
||||
try:
|
||||
self._time = current_time()
|
||||
if not data:
|
||||
return self.close()
|
||||
self.recv_buffer += data
|
||||
|
||||
if len(self.recv_buffer) > self.app.config.REQUEST_BUFFER_SIZE:
|
||||
self.transport.pause_reading()
|
||||
|
||||
if self._data_received:
|
||||
self._data_received.set()
|
||||
except Exception:
|
||||
logger.exception("protocol.data_received")
|
||||
|
||||
|
||||
def trigger_events(events, loop):
|
||||
"""Trigger event callbacks (functions or async)
|
||||
|
|
|
@ -47,11 +47,21 @@ class SanicTestClient:
|
|||
async with self.get_new_session() as session:
|
||||
|
||||
try:
|
||||
if method == "request":
|
||||
args = [url] + list(args)
|
||||
url = kwargs.pop("http_method", "GET").upper()
|
||||
response = await getattr(session, method.lower())(
|
||||
url, *args, **kwargs
|
||||
)
|
||||
except NameError:
|
||||
raise Exception(response.status_code)
|
||||
except httpx.HTTPError as e:
|
||||
if hasattr(e, "response"):
|
||||
response = e.response
|
||||
else:
|
||||
logger.error(
|
||||
f"{method.upper()} {url} received no response!",
|
||||
exc_info=True,
|
||||
)
|
||||
return None
|
||||
|
||||
response.body = await response.aread()
|
||||
response.status = response.status_code
|
||||
|
@ -85,7 +95,6 @@ class SanicTestClient:
|
|||
):
|
||||
results = [None, None]
|
||||
exceptions = []
|
||||
|
||||
if gather_request:
|
||||
|
||||
def _collect_request(request):
|
||||
|
@ -161,6 +170,9 @@ class SanicTestClient:
|
|||
except BaseException: # noqa
|
||||
raise ValueError(f"Request object expected, got ({results})")
|
||||
|
||||
def request(self, *args, **kwargs):
|
||||
return self._sanic_endpoint_test("request", *args, **kwargs)
|
||||
|
||||
def get(self, *args, **kwargs):
|
||||
return self._sanic_endpoint_test("get", *args, **kwargs)
|
||||
|
||||
|
|
|
@ -118,7 +118,7 @@ def test_app_route_raise_value_error(app):
|
|||
|
||||
def test_app_handle_request_handler_is_none(app, monkeypatch):
|
||||
def mockreturn(*args, **kwargs):
|
||||
return None, [], {}, "", "", None
|
||||
return None, [], {}, "", "", None, False
|
||||
|
||||
# Not sure how to make app.router.get() return None, so use mock here.
|
||||
monkeypatch.setattr(app.router, "get", mockreturn)
|
||||
|
|
|
@ -8,7 +8,7 @@ def test_bad_request_response(app):
|
|||
async def _request(sanic, loop):
|
||||
connect = asyncio.open_connection("127.0.0.1", 42101)
|
||||
reader, writer = await connect
|
||||
writer.write(b"not http")
|
||||
writer.write(b"not http\r\n\r\n")
|
||||
while True:
|
||||
line = await reader.readline()
|
||||
if not line:
|
||||
|
|
|
@ -71,7 +71,6 @@ def test_bp(app):
|
|||
|
||||
app.blueprint(bp)
|
||||
request, response = app.test_client.get("/")
|
||||
assert app.is_request_stream is False
|
||||
|
||||
assert response.text == "Hello"
|
||||
|
||||
|
@ -505,48 +504,38 @@ def test_bp_shorthand(app):
|
|||
|
||||
@blueprint.get("/get")
|
||||
def handler(request):
|
||||
assert request.stream is None
|
||||
return text("OK")
|
||||
|
||||
@blueprint.put("/put")
|
||||
def put_handler(request):
|
||||
assert request.stream is None
|
||||
return text("OK")
|
||||
|
||||
@blueprint.post("/post")
|
||||
def post_handler(request):
|
||||
assert request.stream is None
|
||||
return text("OK")
|
||||
|
||||
@blueprint.head("/head")
|
||||
def head_handler(request):
|
||||
assert request.stream is None
|
||||
return text("OK")
|
||||
|
||||
@blueprint.options("/options")
|
||||
def options_handler(request):
|
||||
assert request.stream is None
|
||||
return text("OK")
|
||||
|
||||
@blueprint.patch("/patch")
|
||||
def patch_handler(request):
|
||||
assert request.stream is None
|
||||
return text("OK")
|
||||
|
||||
@blueprint.delete("/delete")
|
||||
def delete_handler(request):
|
||||
assert request.stream is None
|
||||
return text("OK")
|
||||
|
||||
@blueprint.websocket("/ws/", strict_slashes=True)
|
||||
async def websocket_handler(request, ws):
|
||||
assert request.stream is None
|
||||
ev.set()
|
||||
|
||||
app.blueprint(blueprint)
|
||||
|
||||
assert app.is_request_stream is False
|
||||
|
||||
request, response = app.test_client.get("/get")
|
||||
assert response.text == "OK"
|
||||
|
||||
|
|
|
@ -6,17 +6,14 @@ from sanic.response import json_dumps, text
|
|||
|
||||
|
||||
class CustomRequest(Request):
|
||||
__slots__ = ("body_buffer",)
|
||||
"""Alternative implementation for loading body (non-streaming handlers)"""
|
||||
|
||||
def body_init(self):
|
||||
self.body_buffer = BytesIO()
|
||||
|
||||
def body_push(self, data):
|
||||
self.body_buffer.write(data)
|
||||
|
||||
def body_finish(self):
|
||||
self.body = self.body_buffer.getvalue()
|
||||
self.body_buffer.close()
|
||||
async def receive_body(self):
|
||||
buffer = BytesIO()
|
||||
async for data in self.stream:
|
||||
buffer.write(data)
|
||||
self.body = buffer.getvalue().upper()
|
||||
buffer.close()
|
||||
|
||||
|
||||
def test_custom_request():
|
||||
|
@ -37,17 +34,13 @@ def test_custom_request():
|
|||
"/post", data=json_dumps(payload), headers=headers
|
||||
)
|
||||
|
||||
assert isinstance(request.body_buffer, BytesIO)
|
||||
assert request.body_buffer.closed
|
||||
assert request.body == b'{"test":"OK"}'
|
||||
assert request.json.get("test") == "OK"
|
||||
assert request.body == b'{"TEST":"OK"}'
|
||||
assert request.json.get("TEST") == "OK"
|
||||
assert response.text == "OK"
|
||||
assert response.status == 200
|
||||
|
||||
request, response = app.test_client.get("/get")
|
||||
|
||||
assert isinstance(request.body_buffer, BytesIO)
|
||||
assert request.body_buffer.closed
|
||||
assert request.body == b""
|
||||
assert response.text == "OK"
|
||||
assert response.status == 200
|
||||
|
|
|
@ -1,14 +1,26 @@
|
|||
import asyncio
|
||||
|
||||
from bs4 import BeautifulSoup
|
||||
|
||||
from sanic import Sanic
|
||||
from sanic.exceptions import InvalidUsage, NotFound, ServerError
|
||||
from sanic.exceptions import Forbidden, InvalidUsage, NotFound, ServerError
|
||||
from sanic.handlers import ErrorHandler
|
||||
from sanic.response import text
|
||||
from sanic.response import stream, text
|
||||
|
||||
|
||||
exception_handler_app = Sanic("test_exception_handler")
|
||||
|
||||
|
||||
async def sample_streaming_fn(response):
|
||||
await response.write("foo,")
|
||||
await asyncio.sleep(0.001)
|
||||
await response.write("bar")
|
||||
|
||||
|
||||
class ErrorWithRequestCtx(ServerError):
|
||||
pass
|
||||
|
||||
|
||||
@exception_handler_app.route("/1")
|
||||
def handler_1(request):
|
||||
raise InvalidUsage("OK")
|
||||
|
@ -47,11 +59,40 @@ def handler_6(request, arg):
|
|||
return text(foo)
|
||||
|
||||
|
||||
@exception_handler_app.exception(NotFound, ServerError)
|
||||
@exception_handler_app.route("/7")
|
||||
def handler_7(request):
|
||||
raise Forbidden("go away!")
|
||||
|
||||
|
||||
@exception_handler_app.route("/8")
|
||||
def handler_8(request):
|
||||
|
||||
raise ErrorWithRequestCtx("OK")
|
||||
|
||||
|
||||
@exception_handler_app.exception(ErrorWithRequestCtx, NotFound)
|
||||
def handler_exception_with_ctx(request, exception):
|
||||
return text(request.ctx.middleware_ran)
|
||||
|
||||
|
||||
@exception_handler_app.exception(ServerError)
|
||||
def handler_exception(request, exception):
|
||||
return text("OK")
|
||||
|
||||
|
||||
@exception_handler_app.exception(Forbidden)
|
||||
async def async_handler_exception(request, exception):
|
||||
return stream(
|
||||
sample_streaming_fn,
|
||||
content_type="text/csv",
|
||||
)
|
||||
|
||||
|
||||
@exception_handler_app.middleware
|
||||
async def some_request_middleware(request):
|
||||
request.ctx.middleware_ran = "Done."
|
||||
|
||||
|
||||
def test_invalid_usage_exception_handler():
|
||||
request, response = exception_handler_app.test_client.get("/1")
|
||||
assert response.status == 400
|
||||
|
@ -71,7 +112,13 @@ def test_not_found_exception_handler():
|
|||
def test_text_exception__handler():
|
||||
request, response = exception_handler_app.test_client.get("/random")
|
||||
assert response.status == 200
|
||||
assert response.text == "OK"
|
||||
assert response.text == "Done."
|
||||
|
||||
|
||||
def test_async_exception_handler():
|
||||
request, response = exception_handler_app.test_client.get("/7")
|
||||
assert response.status == 200
|
||||
assert response.text == "foo,bar"
|
||||
|
||||
|
||||
def test_html_traceback_output_in_debug_mode():
|
||||
|
@ -156,3 +203,9 @@ def test_exception_handler_lookup():
|
|||
assert handler.lookup(CustomError()) == custom_error_handler
|
||||
assert handler.lookup(ServerError("Error")) == server_error_handler
|
||||
assert handler.lookup(CustomServerError("Error")) == server_error_handler
|
||||
|
||||
|
||||
def test_exception_handler_processed_request_middleware():
|
||||
request, response = exception_handler_app.test_client.get("/8")
|
||||
assert response.status == 200
|
||||
assert response.text == "Done."
|
||||
|
|
|
@ -1,6 +1,12 @@
|
|||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from sanic import headers
|
||||
from sanic import Sanic, headers
|
||||
from sanic.compat import Header
|
||||
from sanic.exceptions import PayloadTooLarge
|
||||
from sanic.http import Http
|
||||
from sanic.request import Request
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
@ -61,3 +67,21 @@ from sanic import headers
|
|||
)
|
||||
def test_parse_headers(input, expected):
|
||||
assert headers.parse_content_header(input) == expected
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_header_size_exceeded():
|
||||
recv_buffer = bytearray()
|
||||
|
||||
async def _receive_more():
|
||||
nonlocal recv_buffer
|
||||
recv_buffer += b"123"
|
||||
|
||||
protocol = Mock()
|
||||
http = Http(protocol)
|
||||
http._receive_more = _receive_more
|
||||
http.request_max_size = 1
|
||||
http.recv_buffer = recv_buffer
|
||||
|
||||
with pytest.raises(PayloadTooLarge):
|
||||
await http.http1_request_header()
|
||||
|
|
|
@ -2,11 +2,14 @@ import asyncio
|
|||
|
||||
from asyncio import sleep as aio_sleep
|
||||
from json import JSONDecodeError
|
||||
from os import environ
|
||||
|
||||
import httpcore
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
from sanic import Sanic, server
|
||||
from sanic.compat import OS_IS_WINDOWS
|
||||
from sanic.response import text
|
||||
from sanic.testing import HOST, SanicTestClient
|
||||
|
||||
|
@ -202,6 +205,10 @@ async def handler3(request):
|
|||
return text("OK")
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
bool(environ.get("SANIC_NO_UVLOOP")) or OS_IS_WINDOWS,
|
||||
reason="Not testable with current client",
|
||||
)
|
||||
def test_keep_alive_timeout_reuse():
|
||||
"""If the server keep-alive timeout and client keep-alive timeout are
|
||||
both longer than the delay, the client _and_ server will successfully
|
||||
|
@ -223,6 +230,10 @@ def test_keep_alive_timeout_reuse():
|
|||
client.kill_server()
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
bool(environ.get("SANIC_NO_UVLOOP")) or OS_IS_WINDOWS,
|
||||
reason="Not testable with current client",
|
||||
)
|
||||
def test_keep_alive_client_timeout():
|
||||
"""If the server keep-alive timeout is longer than the client
|
||||
keep-alive timeout, client will try to create a new connection here."""
|
||||
|
@ -244,6 +255,10 @@ def test_keep_alive_client_timeout():
|
|||
client.kill_server()
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
bool(environ.get("SANIC_NO_UVLOOP")) or OS_IS_WINDOWS,
|
||||
reason="Not testable with current client",
|
||||
)
|
||||
def test_keep_alive_server_timeout():
|
||||
"""If the client keep-alive timeout is longer than the server
|
||||
keep-alive timeout, the client will either a 'Connection reset' error
|
||||
|
|
|
@ -1,16 +1,17 @@
|
|||
import logging
|
||||
import os
|
||||
import sys
|
||||
import uuid
|
||||
|
||||
from importlib import reload
|
||||
from io import StringIO
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
|
||||
import sanic
|
||||
|
||||
from sanic import Sanic
|
||||
from sanic.compat import OS_IS_WINDOWS
|
||||
from sanic.log import LOGGING_CONFIG_DEFAULTS, logger
|
||||
from sanic.response import text
|
||||
from sanic.testing import SanicTestClient
|
||||
|
@ -111,12 +112,11 @@ def test_log_connection_lost(app, debug, monkeypatch):
|
|||
@app.route("/conn_lost")
|
||||
async def conn_lost(request):
|
||||
response = text("Ok")
|
||||
response.output = Mock(side_effect=RuntimeError)
|
||||
request.transport.close()
|
||||
return response
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
# catch ValueError: Exception during request
|
||||
app.test_client.get("/conn_lost", debug=debug)
|
||||
req, res = app.test_client.get("/conn_lost", debug=debug)
|
||||
assert res is None
|
||||
|
||||
log = stream.getvalue()
|
||||
|
||||
|
@ -126,7 +126,8 @@ def test_log_connection_lost(app, debug, monkeypatch):
|
|||
assert "Connection lost before response written @" not in log
|
||||
|
||||
|
||||
def test_logger(caplog):
|
||||
@pytest.mark.asyncio
|
||||
async def test_logger(caplog):
|
||||
rand_string = str(uuid.uuid4())
|
||||
|
||||
app = Sanic(name=__name__)
|
||||
|
@ -137,32 +138,16 @@ def test_logger(caplog):
|
|||
return text("hello")
|
||||
|
||||
with caplog.at_level(logging.INFO):
|
||||
request, response = app.test_client.get("/")
|
||||
_ = await app.asgi_client.get("/")
|
||||
|
||||
port = request.server_port
|
||||
|
||||
# Note: testing with random port doesn't show the banner because it doesn't
|
||||
# define host and port. This test supports both modes.
|
||||
if caplog.record_tuples[0] == (
|
||||
"sanic.root",
|
||||
logging.INFO,
|
||||
f"Goin' Fast @ http://127.0.0.1:{port}",
|
||||
):
|
||||
caplog.record_tuples.pop(0)
|
||||
|
||||
assert caplog.record_tuples[0] == (
|
||||
"sanic.root",
|
||||
logging.INFO,
|
||||
f"http://127.0.0.1:{port}/",
|
||||
)
|
||||
assert caplog.record_tuples[1] == ("sanic.root", logging.INFO, rand_string)
|
||||
assert caplog.record_tuples[-1] == (
|
||||
"sanic.root",
|
||||
logging.INFO,
|
||||
"Server Stopped",
|
||||
)
|
||||
record = ("sanic.root", logging.INFO, rand_string)
|
||||
assert record in caplog.record_tuples
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
OS_IS_WINDOWS and sys.version_info >= (3, 8),
|
||||
reason="Not testable with current client",
|
||||
)
|
||||
def test_logger_static_and_secure(caplog):
|
||||
# Same as test_logger, except for more coverage:
|
||||
# - test_client initialised separately for static port
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
import logging
|
||||
|
||||
from asyncio import CancelledError
|
||||
from itertools import count
|
||||
|
||||
from sanic.exceptions import NotFound, SanicException
|
||||
from sanic.request import Request
|
||||
|
@ -54,7 +55,7 @@ def test_middleware_response(app):
|
|||
|
||||
|
||||
def test_middleware_response_exception(app):
|
||||
result = {"status_code": None}
|
||||
result = {"status_code": "middleware not run"}
|
||||
|
||||
@app.middleware("response")
|
||||
async def process_response(request, response):
|
||||
|
@ -184,3 +185,23 @@ def test_middleware_order(app):
|
|||
|
||||
assert response.status == 200
|
||||
assert order == [1, 2, 3, 4, 5, 6]
|
||||
|
||||
|
||||
def test_request_middleware_executes_once(app):
|
||||
i = count()
|
||||
|
||||
@app.middleware("request")
|
||||
async def inc(request):
|
||||
nonlocal i
|
||||
next(i)
|
||||
|
||||
@app.route("/")
|
||||
async def handler(request):
|
||||
await request.app._run_request_middleware(request)
|
||||
return text("OK")
|
||||
|
||||
request, response = app.test_client.get("/")
|
||||
assert next(i) == 1
|
||||
|
||||
request, response = app.test_client.get("/")
|
||||
assert next(i) == 3
|
||||
|
|
|
@ -85,7 +85,6 @@ def test_pickle_app_with_bp(app, protocol):
|
|||
up_p_app = pickle.loads(p_app)
|
||||
assert up_p_app
|
||||
request, response = up_p_app.test_client.get("/")
|
||||
assert up_p_app.is_request_stream is False
|
||||
assert response.text == "Hello"
|
||||
|
||||
|
||||
|
|
|
@ -107,10 +107,8 @@ def test_shorthand_named_routes_post(app):
|
|||
def test_shorthand_named_routes_put(app):
|
||||
@app.put("/put", name="route_put")
|
||||
def handler(request):
|
||||
assert request.stream is None
|
||||
return text("OK")
|
||||
|
||||
assert app.is_request_stream is False
|
||||
assert app.router.routes_all["/put"].name == "route_put"
|
||||
assert app.url_for("route_put") == "/put"
|
||||
with pytest.raises(URLBuildError):
|
||||
|
@ -120,10 +118,8 @@ def test_shorthand_named_routes_put(app):
|
|||
def test_shorthand_named_routes_delete(app):
|
||||
@app.delete("/delete", name="route_delete")
|
||||
def handler(request):
|
||||
assert request.stream is None
|
||||
return text("OK")
|
||||
|
||||
assert app.is_request_stream is False
|
||||
assert app.router.routes_all["/delete"].name == "route_delete"
|
||||
assert app.url_for("route_delete") == "/delete"
|
||||
with pytest.raises(URLBuildError):
|
||||
|
@ -133,10 +129,8 @@ def test_shorthand_named_routes_delete(app):
|
|||
def test_shorthand_named_routes_patch(app):
|
||||
@app.patch("/patch", name="route_patch")
|
||||
def handler(request):
|
||||
assert request.stream is None
|
||||
return text("OK")
|
||||
|
||||
assert app.is_request_stream is False
|
||||
assert app.router.routes_all["/patch"].name == "route_patch"
|
||||
assert app.url_for("route_patch") == "/patch"
|
||||
with pytest.raises(URLBuildError):
|
||||
|
@ -146,10 +140,8 @@ def test_shorthand_named_routes_patch(app):
|
|||
def test_shorthand_named_routes_head(app):
|
||||
@app.head("/head", name="route_head")
|
||||
def handler(request):
|
||||
assert request.stream is None
|
||||
return text("OK")
|
||||
|
||||
assert app.is_request_stream is False
|
||||
assert app.router.routes_all["/head"].name == "route_head"
|
||||
assert app.url_for("route_head") == "/head"
|
||||
with pytest.raises(URLBuildError):
|
||||
|
@ -159,10 +151,8 @@ def test_shorthand_named_routes_head(app):
|
|||
def test_shorthand_named_routes_options(app):
|
||||
@app.options("/options", name="route_options")
|
||||
def handler(request):
|
||||
assert request.stream is None
|
||||
return text("OK")
|
||||
|
||||
assert app.is_request_stream is False
|
||||
assert app.router.routes_all["/options"].name == "route_options"
|
||||
assert app.url_for("route_options") == "/options"
|
||||
with pytest.raises(URLBuildError):
|
||||
|
|
|
@ -27,7 +27,7 @@ def test_payload_too_large_at_data_received_default(app):
|
|||
|
||||
response = app.test_client.get("/1", gather_request=False)
|
||||
assert response.status == 413
|
||||
assert "Payload Too Large" in response.text
|
||||
assert "Request header" in response.text
|
||||
|
||||
|
||||
def test_payload_too_large_at_on_header_default(app):
|
||||
|
@ -40,4 +40,4 @@ def test_payload_too_large_at_on_header_default(app):
|
|||
data = "a" * 1000
|
||||
response = app.test_client.post("/1", gather_request=False, data=data)
|
||||
assert response.status == 413
|
||||
assert "Payload Too Large" in response.text
|
||||
assert "Request body" in response.text
|
||||
|
|
|
@ -1,37 +0,0 @@
|
|||
import io
|
||||
|
||||
from sanic.response import text
|
||||
|
||||
|
||||
data = "abc" * 10_000_000
|
||||
|
||||
|
||||
def test_request_buffer_queue_size(app):
|
||||
default_buf_qsz = app.config.get("REQUEST_BUFFER_QUEUE_SIZE")
|
||||
qsz = 1
|
||||
while qsz == default_buf_qsz:
|
||||
qsz += 1
|
||||
app.config.REQUEST_BUFFER_QUEUE_SIZE = qsz
|
||||
|
||||
@app.post("/post", stream=True)
|
||||
async def post(request):
|
||||
assert request.stream.buffer_size == qsz
|
||||
print("request.stream.buffer_size =", request.stream.buffer_size)
|
||||
|
||||
bio = io.BytesIO()
|
||||
while True:
|
||||
bdata = await request.stream.read()
|
||||
if not bdata:
|
||||
break
|
||||
bio.write(bdata)
|
||||
|
||||
head = bdata[:3].decode("utf-8")
|
||||
tail = bdata[3:][-3:].decode("utf-8")
|
||||
print(head, "...", tail)
|
||||
|
||||
bio.seek(0)
|
||||
return text(bio.read().decode("utf-8"))
|
||||
|
||||
request, response = app.test_client.post("/post", data=data)
|
||||
assert response.status == 200
|
||||
assert response.text == data
|
|
@ -1,11 +1,13 @@
|
|||
import asyncio
|
||||
|
||||
from contextlib import closing
|
||||
from socket import socket
|
||||
|
||||
import pytest
|
||||
|
||||
from sanic import Sanic
|
||||
from sanic.blueprints import Blueprint
|
||||
from sanic.exceptions import HeaderExpectationFailed
|
||||
from sanic.request import StreamBuffer
|
||||
from sanic.response import json, stream, text
|
||||
from sanic.response import json, text
|
||||
from sanic.server import HttpProtocol
|
||||
from sanic.views import CompositionView, HTTPMethodView
|
||||
from sanic.views import stream as stream_decorator
|
||||
|
@ -15,28 +17,22 @@ data = "abc" * 1_000_000
|
|||
|
||||
|
||||
def test_request_stream_method_view(app):
|
||||
"""for self.is_request_stream = True"""
|
||||
|
||||
class SimpleView(HTTPMethodView):
|
||||
def get(self, request):
|
||||
assert request.stream is None
|
||||
return text("OK")
|
||||
|
||||
@stream_decorator
|
||||
async def post(self, request):
|
||||
assert isinstance(request.stream, StreamBuffer)
|
||||
result = ""
|
||||
result = b""
|
||||
while True:
|
||||
body = await request.stream.read()
|
||||
if body is None:
|
||||
break
|
||||
result += body.decode("utf-8")
|
||||
return text(result)
|
||||
result += body
|
||||
return text(result.decode())
|
||||
|
||||
app.add_route(SimpleView.as_view(), "/method_view")
|
||||
|
||||
assert app.is_request_stream is True
|
||||
|
||||
request, response = app.test_client.get("/method_view")
|
||||
assert response.status == 200
|
||||
assert response.text == "OK"
|
||||
|
@ -50,14 +46,16 @@ def test_request_stream_method_view(app):
|
|||
"headers, expect_raise_exception",
|
||||
[
|
||||
({"EXPECT": "100-continue"}, False),
|
||||
({"EXPECT": "100-continue-extra"}, True),
|
||||
# The below test SHOULD work, and it does produce a 417
|
||||
# However, httpx now intercepts this and raises an exception,
|
||||
# so we will need a new method for testing this
|
||||
# ({"EXPECT": "100-continue-extra"}, True),
|
||||
],
|
||||
)
|
||||
def test_request_stream_100_continue(app, headers, expect_raise_exception):
|
||||
class SimpleView(HTTPMethodView):
|
||||
@stream_decorator
|
||||
async def post(self, request):
|
||||
assert isinstance(request.stream, StreamBuffer)
|
||||
result = ""
|
||||
while True:
|
||||
body = await request.stream.read()
|
||||
|
@ -68,55 +66,42 @@ def test_request_stream_100_continue(app, headers, expect_raise_exception):
|
|||
|
||||
app.add_route(SimpleView.as_view(), "/method_view")
|
||||
|
||||
assert app.is_request_stream is True
|
||||
|
||||
if not expect_raise_exception:
|
||||
request, response = app.test_client.post(
|
||||
"/method_view", data=data, headers={"EXPECT": "100-continue"}
|
||||
"/method_view", data=data, headers=headers
|
||||
)
|
||||
assert response.status == 200
|
||||
assert response.text == data
|
||||
else:
|
||||
with pytest.raises(ValueError) as e:
|
||||
app.test_client.post(
|
||||
"/method_view",
|
||||
data=data,
|
||||
headers={"EXPECT": "100-continue-extra"},
|
||||
)
|
||||
assert "Unknown Expect: 100-continue-extra" in str(e)
|
||||
request, response = app.test_client.post(
|
||||
"/method_view", data=data, headers=headers
|
||||
)
|
||||
assert response.status == 417
|
||||
|
||||
|
||||
def test_request_stream_app(app):
|
||||
"""for self.is_request_stream = True and decorators"""
|
||||
|
||||
@app.get("/get")
|
||||
async def get(request):
|
||||
assert request.stream is None
|
||||
return text("GET")
|
||||
|
||||
@app.head("/head")
|
||||
async def head(request):
|
||||
assert request.stream is None
|
||||
return text("HEAD")
|
||||
|
||||
@app.delete("/delete")
|
||||
async def delete(request):
|
||||
assert request.stream is None
|
||||
return text("DELETE")
|
||||
|
||||
@app.options("/options")
|
||||
async def options(request):
|
||||
assert request.stream is None
|
||||
return text("OPTIONS")
|
||||
|
||||
@app.post("/_post/<id>")
|
||||
async def _post(request, id):
|
||||
assert request.stream is None
|
||||
return text("_POST")
|
||||
|
||||
@app.post("/post/<id>", stream=True)
|
||||
async def post(request, id):
|
||||
assert isinstance(request.stream, StreamBuffer)
|
||||
result = ""
|
||||
while True:
|
||||
body = await request.stream.read()
|
||||
|
@ -127,12 +112,10 @@ def test_request_stream_app(app):
|
|||
|
||||
@app.put("/_put")
|
||||
async def _put(request):
|
||||
assert request.stream is None
|
||||
return text("_PUT")
|
||||
|
||||
@app.put("/put", stream=True)
|
||||
async def put(request):
|
||||
assert isinstance(request.stream, StreamBuffer)
|
||||
result = ""
|
||||
while True:
|
||||
body = await request.stream.read()
|
||||
|
@ -143,12 +126,10 @@ def test_request_stream_app(app):
|
|||
|
||||
@app.patch("/_patch")
|
||||
async def _patch(request):
|
||||
assert request.stream is None
|
||||
return text("_PATCH")
|
||||
|
||||
@app.patch("/patch", stream=True)
|
||||
async def patch(request):
|
||||
assert isinstance(request.stream, StreamBuffer)
|
||||
result = ""
|
||||
while True:
|
||||
body = await request.stream.read()
|
||||
|
@ -157,8 +138,6 @@ def test_request_stream_app(app):
|
|||
result += body.decode("utf-8")
|
||||
return text(result)
|
||||
|
||||
assert app.is_request_stream is True
|
||||
|
||||
request, response = app.test_client.get("/get")
|
||||
assert response.status == 200
|
||||
assert response.text == "GET"
|
||||
|
@ -202,36 +181,28 @@ def test_request_stream_app(app):
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_request_stream_app_asgi(app):
|
||||
"""for self.is_request_stream = True and decorators"""
|
||||
|
||||
@app.get("/get")
|
||||
async def get(request):
|
||||
assert request.stream is None
|
||||
return text("GET")
|
||||
|
||||
@app.head("/head")
|
||||
async def head(request):
|
||||
assert request.stream is None
|
||||
return text("HEAD")
|
||||
|
||||
@app.delete("/delete")
|
||||
async def delete(request):
|
||||
assert request.stream is None
|
||||
return text("DELETE")
|
||||
|
||||
@app.options("/options")
|
||||
async def options(request):
|
||||
assert request.stream is None
|
||||
return text("OPTIONS")
|
||||
|
||||
@app.post("/_post/<id>")
|
||||
async def _post(request, id):
|
||||
assert request.stream is None
|
||||
return text("_POST")
|
||||
|
||||
@app.post("/post/<id>", stream=True)
|
||||
async def post(request, id):
|
||||
assert isinstance(request.stream, StreamBuffer)
|
||||
result = ""
|
||||
while True:
|
||||
body = await request.stream.read()
|
||||
|
@ -242,12 +213,10 @@ async def test_request_stream_app_asgi(app):
|
|||
|
||||
@app.put("/_put")
|
||||
async def _put(request):
|
||||
assert request.stream is None
|
||||
return text("_PUT")
|
||||
|
||||
@app.put("/put", stream=True)
|
||||
async def put(request):
|
||||
assert isinstance(request.stream, StreamBuffer)
|
||||
result = ""
|
||||
while True:
|
||||
body = await request.stream.read()
|
||||
|
@ -258,12 +227,10 @@ async def test_request_stream_app_asgi(app):
|
|||
|
||||
@app.patch("/_patch")
|
||||
async def _patch(request):
|
||||
assert request.stream is None
|
||||
return text("_PATCH")
|
||||
|
||||
@app.patch("/patch", stream=True)
|
||||
async def patch(request):
|
||||
assert isinstance(request.stream, StreamBuffer)
|
||||
result = ""
|
||||
while True:
|
||||
body = await request.stream.read()
|
||||
|
@ -272,8 +239,6 @@ async def test_request_stream_app_asgi(app):
|
|||
result += body.decode("utf-8")
|
||||
return text(result)
|
||||
|
||||
assert app.is_request_stream is True
|
||||
|
||||
request, response = await app.asgi_client.get("/get")
|
||||
assert response.status == 200
|
||||
assert response.text == "GET"
|
||||
|
@ -320,14 +285,13 @@ def test_request_stream_handle_exception(app):
|
|||
|
||||
@app.post("/post/<id>", stream=True)
|
||||
async def post(request, id):
|
||||
assert isinstance(request.stream, StreamBuffer)
|
||||
result = ""
|
||||
result = b""
|
||||
while True:
|
||||
body = await request.stream.read()
|
||||
if body is None:
|
||||
break
|
||||
result += body.decode("utf-8")
|
||||
return text(result)
|
||||
result += body
|
||||
return text(result.decode())
|
||||
|
||||
# 404
|
||||
request, response = app.test_client.post("/in_valid_post", data=data)
|
||||
|
@ -340,54 +304,31 @@ def test_request_stream_handle_exception(app):
|
|||
assert "Method GET not allowed for URL /post/random_id" in response.text
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_request_stream_unread(app):
|
||||
"""ensure no error is raised when leaving unread bytes in byte-buffer"""
|
||||
|
||||
err = None
|
||||
protocol = HttpProtocol(loop=asyncio.get_event_loop(), app=app)
|
||||
try:
|
||||
protocol.request = None
|
||||
protocol._body_chunks.append("this is a test")
|
||||
await protocol.stream_append()
|
||||
except AttributeError as e:
|
||||
err = e
|
||||
|
||||
assert err is None and not protocol._body_chunks
|
||||
|
||||
|
||||
def test_request_stream_blueprint(app):
|
||||
"""for self.is_request_stream = True"""
|
||||
bp = Blueprint("test_blueprint_request_stream_blueprint")
|
||||
|
||||
@app.get("/get")
|
||||
async def get(request):
|
||||
assert request.stream is None
|
||||
return text("GET")
|
||||
|
||||
@bp.head("/head")
|
||||
async def head(request):
|
||||
assert request.stream is None
|
||||
return text("HEAD")
|
||||
|
||||
@bp.delete("/delete")
|
||||
async def delete(request):
|
||||
assert request.stream is None
|
||||
return text("DELETE")
|
||||
|
||||
@bp.options("/options")
|
||||
async def options(request):
|
||||
assert request.stream is None
|
||||
return text("OPTIONS")
|
||||
|
||||
@bp.post("/_post/<id>")
|
||||
async def _post(request, id):
|
||||
assert request.stream is None
|
||||
return text("_POST")
|
||||
|
||||
@bp.post("/post/<id>", stream=True)
|
||||
async def post(request, id):
|
||||
assert isinstance(request.stream, StreamBuffer)
|
||||
result = ""
|
||||
while True:
|
||||
body = await request.stream.read()
|
||||
|
@ -398,12 +339,10 @@ def test_request_stream_blueprint(app):
|
|||
|
||||
@bp.put("/_put")
|
||||
async def _put(request):
|
||||
assert request.stream is None
|
||||
return text("_PUT")
|
||||
|
||||
@bp.put("/put", stream=True)
|
||||
async def put(request):
|
||||
assert isinstance(request.stream, StreamBuffer)
|
||||
result = ""
|
||||
while True:
|
||||
body = await request.stream.read()
|
||||
|
@ -414,12 +353,10 @@ def test_request_stream_blueprint(app):
|
|||
|
||||
@bp.patch("/_patch")
|
||||
async def _patch(request):
|
||||
assert request.stream is None
|
||||
return text("_PATCH")
|
||||
|
||||
@bp.patch("/patch", stream=True)
|
||||
async def patch(request):
|
||||
assert isinstance(request.stream, StreamBuffer)
|
||||
result = ""
|
||||
while True:
|
||||
body = await request.stream.read()
|
||||
|
@ -429,7 +366,6 @@ def test_request_stream_blueprint(app):
|
|||
return text(result)
|
||||
|
||||
async def post_add_route(request):
|
||||
assert isinstance(request.stream, StreamBuffer)
|
||||
result = ""
|
||||
while True:
|
||||
body = await request.stream.read()
|
||||
|
@ -443,8 +379,6 @@ def test_request_stream_blueprint(app):
|
|||
)
|
||||
app.blueprint(bp)
|
||||
|
||||
assert app.is_request_stream is True
|
||||
|
||||
request, response = app.test_client.get("/get")
|
||||
assert response.status == 200
|
||||
assert response.text == "GET"
|
||||
|
@ -491,14 +425,10 @@ def test_request_stream_blueprint(app):
|
|||
|
||||
|
||||
def test_request_stream_composition_view(app):
|
||||
"""for self.is_request_stream = True"""
|
||||
|
||||
def get_handler(request):
|
||||
assert request.stream is None
|
||||
return text("OK")
|
||||
|
||||
async def post_handler(request):
|
||||
assert isinstance(request.stream, StreamBuffer)
|
||||
result = ""
|
||||
while True:
|
||||
body = await request.stream.read()
|
||||
|
@ -512,8 +442,6 @@ def test_request_stream_composition_view(app):
|
|||
view.add(["POST"], post_handler, stream=True)
|
||||
app.add_route(view, "/composition_view")
|
||||
|
||||
assert app.is_request_stream is True
|
||||
|
||||
request, response = app.test_client.get("/composition_view")
|
||||
assert response.status == 200
|
||||
assert response.text == "OK"
|
||||
|
@ -529,12 +457,10 @@ def test_request_stream(app):
|
|||
|
||||
class SimpleView(HTTPMethodView):
|
||||
def get(self, request):
|
||||
assert request.stream is None
|
||||
return text("OK")
|
||||
|
||||
@stream_decorator
|
||||
async def post(self, request):
|
||||
assert isinstance(request.stream, StreamBuffer)
|
||||
result = ""
|
||||
while True:
|
||||
body = await request.stream.read()
|
||||
|
@ -545,7 +471,6 @@ def test_request_stream(app):
|
|||
|
||||
@app.post("/stream", stream=True)
|
||||
async def handler(request):
|
||||
assert isinstance(request.stream, StreamBuffer)
|
||||
result = ""
|
||||
while True:
|
||||
body = await request.stream.read()
|
||||
|
@ -556,12 +481,10 @@ def test_request_stream(app):
|
|||
|
||||
@app.get("/get")
|
||||
async def get(request):
|
||||
assert request.stream is None
|
||||
return text("OK")
|
||||
|
||||
@bp.post("/bp_stream", stream=True)
|
||||
async def bp_stream(request):
|
||||
assert isinstance(request.stream, StreamBuffer)
|
||||
result = ""
|
||||
while True:
|
||||
body = await request.stream.read()
|
||||
|
@ -572,15 +495,12 @@ def test_request_stream(app):
|
|||
|
||||
@bp.get("/bp_get")
|
||||
async def bp_get(request):
|
||||
assert request.stream is None
|
||||
return text("OK")
|
||||
|
||||
def get_handler(request):
|
||||
assert request.stream is None
|
||||
return text("OK")
|
||||
|
||||
async def post_handler(request):
|
||||
assert isinstance(request.stream, StreamBuffer)
|
||||
result = ""
|
||||
while True:
|
||||
body = await request.stream.read()
|
||||
|
@ -599,8 +519,6 @@ def test_request_stream(app):
|
|||
|
||||
app.add_route(view, "/composition_view")
|
||||
|
||||
assert app.is_request_stream is True
|
||||
|
||||
request, response = app.test_client.get("/method_view")
|
||||
assert response.status == 200
|
||||
assert response.text == "OK"
|
||||
|
@ -636,14 +554,14 @@ def test_request_stream(app):
|
|||
|
||||
def test_streaming_new_api(app):
|
||||
@app.post("/non-stream")
|
||||
async def handler(request):
|
||||
async def handler1(request):
|
||||
assert request.body == b"x"
|
||||
await request.receive_body() # This should do nothing
|
||||
assert request.body == b"x"
|
||||
return text("OK")
|
||||
|
||||
@app.post("/1", stream=True)
|
||||
async def handler(request):
|
||||
async def handler2(request):
|
||||
assert request.stream
|
||||
assert not request.body
|
||||
await request.receive_body()
|
||||
|
@ -671,5 +589,85 @@ def test_streaming_new_api(app):
|
|||
assert response.status == 200
|
||||
res = response.json
|
||||
assert isinstance(res, list)
|
||||
assert len(res) > 1
|
||||
assert "".join(res) == data
|
||||
|
||||
|
||||
def test_streaming_echo():
|
||||
"""2-way streaming chat between server and client."""
|
||||
app = Sanic(name=__name__)
|
||||
|
||||
@app.post("/echo", stream=True)
|
||||
async def handler(request):
|
||||
res = await request.respond(content_type="text/plain; charset=utf-8")
|
||||
# Send headers
|
||||
await res.send(end_stream=False)
|
||||
# Echo back data (case swapped)
|
||||
async for data in request.stream:
|
||||
await res.send(data.swapcase())
|
||||
# Add EOF marker after successful operation
|
||||
await res.send(b"-", end_stream=True)
|
||||
|
||||
@app.listener("after_server_start")
|
||||
async def client_task(app, loop):
|
||||
try:
|
||||
reader, writer = await asyncio.open_connection(*addr)
|
||||
await client(app, reader, writer)
|
||||
finally:
|
||||
writer.close()
|
||||
app.stop()
|
||||
|
||||
async def client(app, reader, writer):
|
||||
# Unfortunately httpx does not support 2-way streaming, so do it by hand.
|
||||
host = f"host: {addr[0]}:{addr[1]}\r\n".encode()
|
||||
writer.write(
|
||||
b"POST /echo HTTP/1.1\r\n" + host + b"content-length: 2\r\n"
|
||||
b"content-type: text/plain; charset=utf-8\r\n"
|
||||
b"\r\n"
|
||||
)
|
||||
# Read response
|
||||
res = b""
|
||||
while not b"\r\n\r\n" in res:
|
||||
res += await reader.read(4096)
|
||||
assert res.startswith(b"HTTP/1.1 200 OK\r\n")
|
||||
assert res.endswith(b"\r\n\r\n")
|
||||
buffer = b""
|
||||
|
||||
async def read_chunk():
|
||||
nonlocal buffer
|
||||
while not b"\r\n" in buffer:
|
||||
data = await reader.read(4096)
|
||||
assert data
|
||||
buffer += data
|
||||
size, buffer = buffer.split(b"\r\n", 1)
|
||||
size = int(size, 16)
|
||||
if size == 0:
|
||||
return None
|
||||
while len(buffer) < size + 2:
|
||||
data = await reader.read(4096)
|
||||
assert data
|
||||
buffer += data
|
||||
print(res)
|
||||
assert buffer[size : size + 2] == b"\r\n"
|
||||
ret, buffer = buffer[:size], buffer[size + 2 :]
|
||||
return ret
|
||||
|
||||
# Chat with server
|
||||
writer.write(b"a")
|
||||
res = await read_chunk()
|
||||
assert res == b"A"
|
||||
|
||||
writer.write(b"b")
|
||||
res = await read_chunk()
|
||||
assert res == b"B"
|
||||
|
||||
res = await read_chunk()
|
||||
assert res == b"-"
|
||||
|
||||
res = await read_chunk()
|
||||
assert res == None
|
||||
|
||||
# Use random port for tests
|
||||
with closing(socket()) as sock:
|
||||
sock.bind(("127.0.0.1", 0))
|
||||
addr = sock.getsockname()
|
||||
app.run(sock=sock, access_log=False)
|
||||
|
|
|
@ -59,7 +59,7 @@ def test_ip(app):
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ip_asgi(app):
|
||||
async def test_url_asgi(app):
|
||||
@app.route("/")
|
||||
def handler(request):
|
||||
return text(f"{request.url}")
|
||||
|
@ -2119,3 +2119,37 @@ def test_url_for_without_server_name(app):
|
|||
response.json["url"]
|
||||
== f"http://127.0.0.1:{request.server_port}/url-for"
|
||||
)
|
||||
|
||||
|
||||
def test_safe_method_with_body_ignored(app):
|
||||
@app.get("/")
|
||||
async def handler(request):
|
||||
return text("OK")
|
||||
|
||||
payload = {"test": "OK"}
|
||||
headers = {"content-type": "application/json"}
|
||||
|
||||
request, response = app.test_client.request(
|
||||
"/", http_method="get", data=json_dumps(payload), headers=headers
|
||||
)
|
||||
|
||||
assert request.body == b""
|
||||
assert request.json == None
|
||||
assert response.text == "OK"
|
||||
|
||||
|
||||
def test_safe_method_with_body(app):
|
||||
@app.get("/", ignore_body=False)
|
||||
async def handler(request):
|
||||
return text("OK")
|
||||
|
||||
payload = {"test": "OK"}
|
||||
headers = {"content-type": "application/json"}
|
||||
data = json_dumps(payload)
|
||||
request, response = app.test_client.request(
|
||||
"/", http_method="get", data=data, headers=headers
|
||||
)
|
||||
|
||||
assert request.body == data.encode("utf-8")
|
||||
assert request.json.get("test") == "OK"
|
||||
assert response.text == "OK"
|
||||
|
|
|
@ -85,7 +85,6 @@ def test_response_header(app):
|
|||
request, response = app.test_client.get("/")
|
||||
assert dict(response.headers) == {
|
||||
"connection": "keep-alive",
|
||||
"keep-alive": str(app.config.KEEP_ALIVE_TIMEOUT),
|
||||
"content-length": "11",
|
||||
"content-type": "application/json",
|
||||
}
|
||||
|
@ -201,7 +200,6 @@ def streaming_app(app):
|
|||
async def test(request):
|
||||
return stream(
|
||||
sample_streaming_fn,
|
||||
headers={"Content-Length": "7"},
|
||||
content_type="text/csv",
|
||||
)
|
||||
|
||||
|
@ -243,7 +241,12 @@ async def test_chunked_streaming_returns_correct_content_asgi(streaming_app):
|
|||
|
||||
|
||||
def test_non_chunked_streaming_adds_correct_headers(non_chunked_streaming_app):
|
||||
request, response = non_chunked_streaming_app.test_client.get("/")
|
||||
with pytest.warns(UserWarning) as record:
|
||||
request, response = non_chunked_streaming_app.test_client.get("/")
|
||||
|
||||
assert len(record) == 1
|
||||
assert "removed in v21.6" in record[0].message.args[0]
|
||||
|
||||
assert "Transfer-Encoding" not in response.headers
|
||||
assert response.headers["Content-Type"] == "text/csv"
|
||||
assert response.headers["Content-Length"] == "7"
|
||||
|
@ -266,115 +269,6 @@ def test_non_chunked_streaming_returns_correct_content(
|
|||
assert response.text == "foo,bar"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("status", [200, 201, 400, 401])
|
||||
def test_stream_response_status_returns_correct_headers(status):
|
||||
response = StreamingHTTPResponse(sample_streaming_fn, status=status)
|
||||
headers = response.get_headers()
|
||||
assert b"HTTP/1.1 %s" % str(status).encode() in headers
|
||||
|
||||
|
||||
@pytest.mark.parametrize("keep_alive_timeout", [10, 20, 30])
|
||||
def test_stream_response_keep_alive_returns_correct_headers(
|
||||
keep_alive_timeout,
|
||||
):
|
||||
response = StreamingHTTPResponse(sample_streaming_fn)
|
||||
headers = response.get_headers(
|
||||
keep_alive=True, keep_alive_timeout=keep_alive_timeout
|
||||
)
|
||||
|
||||
assert b"Keep-Alive: %s\r\n" % str(keep_alive_timeout).encode() in headers
|
||||
|
||||
|
||||
def test_stream_response_includes_chunked_header_http11():
|
||||
response = StreamingHTTPResponse(sample_streaming_fn)
|
||||
headers = response.get_headers(version="1.1")
|
||||
assert b"Transfer-Encoding: chunked\r\n" in headers
|
||||
|
||||
|
||||
def test_stream_response_does_not_include_chunked_header_http10():
|
||||
response = StreamingHTTPResponse(sample_streaming_fn)
|
||||
headers = response.get_headers(version="1.0")
|
||||
assert b"Transfer-Encoding: chunked\r\n" not in headers
|
||||
|
||||
|
||||
def test_stream_response_does_not_include_chunked_header_if_disabled():
|
||||
response = StreamingHTTPResponse(sample_streaming_fn, chunked=False)
|
||||
headers = response.get_headers(version="1.1")
|
||||
assert b"Transfer-Encoding: chunked\r\n" not in headers
|
||||
|
||||
|
||||
def test_stream_response_writes_correct_content_to_transport_when_chunked(
|
||||
streaming_app,
|
||||
):
|
||||
response = StreamingHTTPResponse(sample_streaming_fn)
|
||||
response.protocol = MagicMock(HttpProtocol)
|
||||
response.protocol.transport = MagicMock(asyncio.Transport)
|
||||
|
||||
async def mock_drain():
|
||||
pass
|
||||
|
||||
async def mock_push_data(data):
|
||||
response.protocol.transport.write(data)
|
||||
|
||||
response.protocol.push_data = mock_push_data
|
||||
response.protocol.drain = mock_drain
|
||||
|
||||
@streaming_app.listener("after_server_start")
|
||||
async def run_stream(app, loop):
|
||||
await response.stream()
|
||||
assert response.protocol.transport.write.call_args_list[1][0][0] == (
|
||||
b"4\r\nfoo,\r\n"
|
||||
)
|
||||
|
||||
assert response.protocol.transport.write.call_args_list[2][0][0] == (
|
||||
b"3\r\nbar\r\n"
|
||||
)
|
||||
|
||||
assert response.protocol.transport.write.call_args_list[3][0][0] == (
|
||||
b"0\r\n\r\n"
|
||||
)
|
||||
|
||||
assert len(response.protocol.transport.write.call_args_list) == 4
|
||||
|
||||
app.stop()
|
||||
|
||||
streaming_app.run(host=HOST, port=PORT)
|
||||
|
||||
|
||||
def test_stream_response_writes_correct_content_to_transport_when_not_chunked(
|
||||
streaming_app,
|
||||
):
|
||||
response = StreamingHTTPResponse(sample_streaming_fn)
|
||||
response.protocol = MagicMock(HttpProtocol)
|
||||
response.protocol.transport = MagicMock(asyncio.Transport)
|
||||
|
||||
async def mock_drain():
|
||||
pass
|
||||
|
||||
async def mock_push_data(data):
|
||||
response.protocol.transport.write(data)
|
||||
|
||||
response.protocol.push_data = mock_push_data
|
||||
response.protocol.drain = mock_drain
|
||||
|
||||
@streaming_app.listener("after_server_start")
|
||||
async def run_stream(app, loop):
|
||||
await response.stream(version="1.0")
|
||||
assert response.protocol.transport.write.call_args_list[1][0][0] == (
|
||||
b"foo,"
|
||||
)
|
||||
|
||||
assert response.protocol.transport.write.call_args_list[2][0][0] == (
|
||||
b"bar"
|
||||
)
|
||||
|
||||
assert len(response.protocol.transport.write.call_args_list) == 3
|
||||
|
||||
app.stop()
|
||||
|
||||
streaming_app.run(host=HOST, port=PORT)
|
||||
|
||||
|
||||
def test_stream_response_with_cookies(app):
|
||||
@app.route("/")
|
||||
async def test(request):
|
||||
|
|
|
@ -67,16 +67,12 @@ def test_shorthand_routes_multiple(app):
|
|||
def test_route_strict_slash(app):
|
||||
@app.get("/get", strict_slashes=True)
|
||||
def handler1(request):
|
||||
assert request.stream is None
|
||||
return text("OK")
|
||||
|
||||
@app.post("/post/", strict_slashes=True)
|
||||
def handler2(request):
|
||||
assert request.stream is None
|
||||
return text("OK")
|
||||
|
||||
assert app.is_request_stream is False
|
||||
|
||||
request, response = app.test_client.get("/get")
|
||||
assert response.text == "OK"
|
||||
|
||||
|
@ -216,11 +212,8 @@ def test_shorthand_routes_post(app):
|
|||
def test_shorthand_routes_put(app):
|
||||
@app.put("/put")
|
||||
def handler(request):
|
||||
assert request.stream is None
|
||||
return text("OK")
|
||||
|
||||
assert app.is_request_stream is False
|
||||
|
||||
request, response = app.test_client.put("/put")
|
||||
assert response.text == "OK"
|
||||
|
||||
|
@ -231,11 +224,8 @@ def test_shorthand_routes_put(app):
|
|||
def test_shorthand_routes_delete(app):
|
||||
@app.delete("/delete")
|
||||
def handler(request):
|
||||
assert request.stream is None
|
||||
return text("OK")
|
||||
|
||||
assert app.is_request_stream is False
|
||||
|
||||
request, response = app.test_client.delete("/delete")
|
||||
assert response.text == "OK"
|
||||
|
||||
|
@ -246,11 +236,8 @@ def test_shorthand_routes_delete(app):
|
|||
def test_shorthand_routes_patch(app):
|
||||
@app.patch("/patch")
|
||||
def handler(request):
|
||||
assert request.stream is None
|
||||
return text("OK")
|
||||
|
||||
assert app.is_request_stream is False
|
||||
|
||||
request, response = app.test_client.patch("/patch")
|
||||
assert response.text == "OK"
|
||||
|
||||
|
@ -261,11 +248,8 @@ def test_shorthand_routes_patch(app):
|
|||
def test_shorthand_routes_head(app):
|
||||
@app.head("/head")
|
||||
def handler(request):
|
||||
assert request.stream is None
|
||||
return text("OK")
|
||||
|
||||
assert app.is_request_stream is False
|
||||
|
||||
request, response = app.test_client.head("/head")
|
||||
assert response.status == 200
|
||||
|
||||
|
@ -276,11 +260,8 @@ def test_shorthand_routes_head(app):
|
|||
def test_shorthand_routes_options(app):
|
||||
@app.options("/options")
|
||||
def handler(request):
|
||||
assert request.stream is None
|
||||
return text("OK")
|
||||
|
||||
assert app.is_request_stream is False
|
||||
|
||||
request, response = app.test_client.options("/options")
|
||||
assert response.status == 200
|
||||
|
||||
|
|
74
tests/test_timeout_logic.py
Normal file
74
tests/test_timeout_logic.py
Normal file
|
@ -0,0 +1,74 @@
|
|||
import asyncio
|
||||
|
||||
from time import monotonic as current_time
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from sanic import Sanic
|
||||
from sanic.exceptions import RequestTimeout, ServiceUnavailable
|
||||
from sanic.http import Stage
|
||||
from sanic.server import HttpProtocol
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app():
|
||||
return Sanic("test")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_transport():
|
||||
return Mock()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def protocol(app, mock_transport):
|
||||
loop = asyncio.new_event_loop()
|
||||
protocol = HttpProtocol(loop=loop, app=app)
|
||||
protocol.connection_made(mock_transport)
|
||||
protocol._setup_connection()
|
||||
protocol._task = Mock(spec=asyncio.Task)
|
||||
protocol._task.cancel = Mock()
|
||||
return protocol
|
||||
|
||||
|
||||
def test_setup(protocol: HttpProtocol):
|
||||
assert protocol._task is not None
|
||||
assert protocol._http is not None
|
||||
assert protocol._time is not None
|
||||
|
||||
|
||||
def test_check_timeouts_no_timeout(protocol: HttpProtocol):
|
||||
protocol.keep_alive_timeout = 1
|
||||
protocol.loop.call_later = Mock()
|
||||
protocol.check_timeouts()
|
||||
protocol._task.cancel.assert_not_called()
|
||||
assert protocol._http.stage is Stage.IDLE
|
||||
assert protocol._http.exception is None
|
||||
protocol.loop.call_later.assert_called_with(
|
||||
protocol.keep_alive_timeout / 2, protocol.check_timeouts
|
||||
)
|
||||
|
||||
|
||||
def test_check_timeouts_keep_alive_timeout(protocol: HttpProtocol):
|
||||
protocol._http.stage = Stage.IDLE
|
||||
protocol._time = 0
|
||||
protocol.check_timeouts()
|
||||
protocol._task.cancel.assert_called_once()
|
||||
assert protocol._http.exception is None
|
||||
|
||||
|
||||
def test_check_timeouts_request_timeout(protocol: HttpProtocol):
|
||||
protocol._http.stage = Stage.REQUEST
|
||||
protocol._time = 0
|
||||
protocol.check_timeouts()
|
||||
protocol._task.cancel.assert_called_once()
|
||||
assert isinstance(protocol._http.exception, RequestTimeout)
|
||||
|
||||
|
||||
def test_check_timeouts_response_timeout(protocol: HttpProtocol):
|
||||
protocol._http.stage = Stage.RESPONSE
|
||||
protocol._time = 0
|
||||
protocol.check_timeouts()
|
||||
protocol._task.cancel.assert_called_once()
|
||||
assert isinstance(protocol._http.exception, ServiceUnavailable)
|
|
@ -12,7 +12,6 @@ from sanic.views import CompositionView, HTTPMethodView
|
|||
def test_methods(app, method):
|
||||
class DummyView(HTTPMethodView):
|
||||
async def get(self, request):
|
||||
assert request.stream is None
|
||||
return text("", headers={"method": "GET"})
|
||||
|
||||
def post(self, request):
|
||||
|
@ -34,7 +33,6 @@ def test_methods(app, method):
|
|||
return text("", headers={"method": "DELETE"})
|
||||
|
||||
app.add_route(DummyView.as_view(), "/")
|
||||
assert app.is_request_stream is False
|
||||
|
||||
request, response = getattr(app.test_client, method.lower())("/")
|
||||
assert response.headers["method"] == method
|
||||
|
@ -69,7 +67,6 @@ def test_with_bp(app):
|
|||
|
||||
class DummyView(HTTPMethodView):
|
||||
def get(self, request):
|
||||
assert request.stream is None
|
||||
return text("I am get method")
|
||||
|
||||
bp.add_route(DummyView.as_view(), "/")
|
||||
|
@ -77,7 +74,6 @@ def test_with_bp(app):
|
|||
app.blueprint(bp)
|
||||
request, response = app.test_client.get("/")
|
||||
|
||||
assert app.is_request_stream is False
|
||||
assert response.text == "I am get method"
|
||||
|
||||
|
||||
|
@ -210,14 +206,12 @@ def test_composition_view_runs_methods_as_expected(app, method):
|
|||
view = CompositionView()
|
||||
|
||||
def first(request):
|
||||
assert request.stream is None
|
||||
return text("first method")
|
||||
|
||||
view.add(["GET", "POST", "PUT"], first)
|
||||
view.add(["DELETE", "PATCH"], lambda x: text("second method"))
|
||||
|
||||
app.add_route(view, "/")
|
||||
assert app.is_request_stream is False
|
||||
|
||||
if method in ["GET", "POST", "PUT"]:
|
||||
request, response = getattr(app.test_client, method.lower())("/")
|
||||
|
|
Loading…
Reference in New Issue
Block a user