Streaming Server (#1876)

* Streaming request by async for.

* Make all requests streaming and preload body for non-streaming handlers.

* Cleanup of code and avoid mixing streaming responses.

* Async http protocol loop.

* Change of test: don't require early bad request error but only after CRLF-CRLF.

* Add back streaming requests.

* Rewritten request body parser.

* Misc. cleanup, down to 4 failing tests.

* All tests OK.

* Entirely remove request body queue.

* Let black f*ckup the layout

* Better testing error messages on protocol errors.

* Remove StreamBuffer tests because the type is about to be removed.

* Remove tests using the deprecated get_headers function that can no longer be supported. Chunked mode is now autodetected, so do not put content-length header if chunked mode is preferred.

* Major refactoring of HTTP protocol handling (new module http.py added), all requests made streaming. A few compatibility issues and a lot of cleanup to be done remain, 16 tests failing.

* Terminate check_timeouts once connection_task finishes.

* Code cleanup, 14 tests failing.

* Much cleanup, 12 failing...

* Even more cleanup and error checking, 8 failing tests.

* Remove keep-alive header from responses. First of all, it should say timeout=<value> which wasn't the case with existing implementation, and secondly none of the other web servers I tried include this header.

* Everything but CustomServer OK.

* Linter

* Disable custom protocol test

* Remove unnecessary variables, optimise performance.

* A test was missing that body_init/body_push/body_finish are never called. Rewritten using receive_body and case switching to make it fail if bypassed.

* Minor fixes.

* Remove unused code.

* Py 3.8 check for deprecated loop argument.

* Fix a middleware cancellation handling test with py38.

* Linter 'n fixes

* Typing

* Stricter handling of request header size

* More specific error messages on Payload Too Large.

* Init http.response = None

* Messages further tuned.

* Always try to consume request body, plus minor cleanup.

* Add a missing check in case of close_if_idle on a dead connection.

* Avoid error messages on PayloadTooLarge.

* Add test for new API.

* json takes str, not bytes

* Default to no maximum request size for streaming handlers.

* Fix chunked mode crash.

* Header values should be strictly ASCII but both UTF-8 and Latin-1 exist. Use UTF-8B to
cope with all.

* Refactoring and cleanup.

* Unify response header processing of ASGI and asyncio modes.

* Avoid special handling of StreamingHTTPResponse.

* 35 % speedup in HTTP/1.1 response formatting (not so much overall effect).

* Duplicate set-cookie headers were being produced.

* Cleanup processed_headers some more.

* Linting

* Import ordering

* Response middleware ran by async request.respond().

* Need to check if transport is closing to avoid getting stuck in sending loops after peer has disconnected.

* Middleware and error handling refactoring.

* Linter

* Fix tracking of HTTP stage when writing to transport fails.

* Add clarifying comment

* Add a check for request body functions and a test for NotImplementedError.

* Linter and typing

* These must be tuples + hack mypy warnings away.

* New streaming test and minor fixes.

* Constant receive buffer size.

* 256 KiB send and receive buffers.

* Revert "256 KiB send and receive buffers."

This reverts commit abc1e3edb2.

* app.handle_exception already sends the response.

* Improved handling of errors during request.

* An odd hack to avoid an httpx limitation that causes test failures.

* Limit request header size to 8 KiB at most.

* Remove unnecessary use of format string.

* Cleanup tests

* Remove artifact

* Fix type checking

* Mark test for skipping

* Cleanup some edge cases

* Add ignore_body flag to safe methods

* Add unit tests for timeout logic

* Add unit tests for timeout logic

* Fix Mock usage in timeout test

* Change logging test to only logger in handler

* Windows py3.8 logging issue with current testing client

* Add test_header_size_exceeded

* Resolve merge conflicts

* Add request middleware to hard exception handling

* Add request middleware to hard exception handling

* Request middleware on exception handlers

* Linting

* Cleanup deprecations

Co-authored-by: L. Kärkkäinen <tronic@users.noreply.github.com>
Co-authored-by: Adam Hopkins <admhpkns@gmail.com>
This commit is contained in:
L. Kärkkäinen 2021-01-11 00:45:36 +02:00 committed by GitHub
parent 574a9c27a6
commit 7028eae083
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
35 changed files with 1372 additions and 1348 deletions

View File

@ -2,37 +2,38 @@ from sanic import Blueprint, Sanic
from sanic.response import file, json from sanic.response import file, json
app = Sanic(__name__) app = Sanic(__name__)
blueprint = Blueprint('name', url_prefix='/my_blueprint') blueprint = Blueprint("name", url_prefix="/my_blueprint")
blueprint2 = Blueprint('name2', url_prefix='/my_blueprint2') blueprint2 = Blueprint("name2", url_prefix="/my_blueprint2")
blueprint3 = Blueprint('name3', url_prefix='/my_blueprint3') blueprint3 = Blueprint("name3", url_prefix="/my_blueprint3")
@blueprint.route('/foo') @blueprint.route("/foo")
async def foo(request): async def foo(request):
return json({'msg': 'hi from blueprint'}) return json({"msg": "hi from blueprint"})
@blueprint2.route('/foo') @blueprint2.route("/foo")
async def foo2(request): async def foo2(request):
return json({'msg': 'hi from blueprint2'}) return json({"msg": "hi from blueprint2"})
@blueprint3.route('/foo') @blueprint3.route("/foo")
async def index(request): async def index(request):
return await file('websocket.html') return await file("websocket.html")
@app.websocket('/feed') @app.websocket("/feed")
async def foo3(request, ws): async def foo3(request, ws):
while True: while True:
data = 'hello!' data = "hello!"
print('Sending: ' + data) print("Sending: " + data)
await ws.send(data) await ws.send(data)
data = await ws.recv() data = await ws.recv()
print('Received: ' + data) print("Received: " + data)
app.blueprint(blueprint) app.blueprint(blueprint)
app.blueprint(blueprint2) app.blueprint(blueprint2)
app.blueprint(blueprint3) app.blueprint(blueprint3)
app.run(host="0.0.0.0", port=8000, debug=True) app.run(host="0.0.0.0", port=9999, debug=True)

View File

@ -48,14 +48,14 @@ async def handler_file_stream(request):
) )
@app.route("/stream", stream=True) @app.post("/stream", stream=True)
async def handler_stream(request): async def handler_stream(request):
while True: while True:
body = await request.stream.read() body = await request.stream.read()
if body is None: if body is None:
break break
body = body.decode("utf-8").replace("1", "A") body = body.decode("utf-8").replace("1", "A")
# await response.write(body) await response.write(body)
return response.stream(body) return response.stream(body)

View File

@ -4,24 +4,27 @@ import os
import re import re
from asyncio import CancelledError, Protocol, ensure_future, get_event_loop from asyncio import CancelledError, Protocol, ensure_future, get_event_loop
from asyncio.futures import Future
from collections import defaultdict, deque from collections import defaultdict, deque
from functools import partial from functools import partial
from inspect import isawaitable, signature from inspect import isawaitable, signature
from socket import socket from socket import socket
from ssl import Purpose, SSLContext, create_default_context from ssl import Purpose, SSLContext, create_default_context
from traceback import format_exc from traceback import format_exc
from typing import Any, Dict, Optional, Type, Union from typing import Any, Dict, Iterable, List, Optional, Set, Type, Union
from urllib.parse import urlencode, urlunparse from urllib.parse import urlencode, urlunparse
from sanic import reloader_helpers from sanic import reloader_helpers
from sanic.asgi import ASGIApp from sanic.asgi import ASGIApp
from sanic.blueprint_group import BlueprintGroup from sanic.blueprint_group import BlueprintGroup
from sanic.blueprints import Blueprint
from sanic.config import BASE_LOGO, Config from sanic.config import BASE_LOGO, Config
from sanic.constants import HTTP_METHODS from sanic.constants import HTTP_METHODS
from sanic.exceptions import SanicException, ServerError, URLBuildError from sanic.exceptions import SanicException, ServerError, URLBuildError
from sanic.handlers import ErrorHandler from sanic.handlers import ErrorHandler, ListenerType, MiddlewareType
from sanic.log import LOGGING_CONFIG_DEFAULTS, error_logger, logger from sanic.log import LOGGING_CONFIG_DEFAULTS, error_logger, logger
from sanic.response import HTTPResponse, StreamingHTTPResponse from sanic.request import Request
from sanic.response import BaseHTTPResponse, HTTPResponse
from sanic.router import Router from sanic.router import Router
from sanic.server import ( from sanic.server import (
AsyncioServer, AsyncioServer,
@ -42,16 +45,16 @@ class Sanic:
def __init__( def __init__(
self, self,
name=None, name: str = None,
router=None, router: Router = None,
error_handler=None, error_handler: ErrorHandler = None,
load_env=True, load_env: bool = True,
request_class=None, request_class: Request = None,
strict_slashes=False, strict_slashes: bool = False,
log_config=None, log_config: Optional[Dict[str, Any]] = None,
configure_logging=True, configure_logging: bool = True,
register=None, register: Optional[bool] = None,
): ) -> None:
# Get name from previous stack frame # Get name from previous stack frame
if name is None: if name is None:
@ -59,7 +62,6 @@ class Sanic:
"Sanic instance cannot be unnamed. " "Sanic instance cannot be unnamed. "
"Please use Sanic(name='your_application_name') instead.", "Please use Sanic(name='your_application_name') instead.",
) )
# logging # logging
if configure_logging: if configure_logging:
logging.config.dictConfig(log_config or LOGGING_CONFIG_DEFAULTS) logging.config.dictConfig(log_config or LOGGING_CONFIG_DEFAULTS)
@ -70,22 +72,21 @@ class Sanic:
self.request_class = request_class self.request_class = request_class
self.error_handler = error_handler or ErrorHandler() self.error_handler = error_handler or ErrorHandler()
self.config = Config(load_env=load_env) self.config = Config(load_env=load_env)
self.request_middleware = deque() self.request_middleware: Iterable[MiddlewareType] = deque()
self.response_middleware = deque() self.response_middleware: Iterable[MiddlewareType] = deque()
self.blueprints = {} self.blueprints: Dict[str, Blueprint] = {}
self._blueprint_order = [] self._blueprint_order: List[Blueprint] = []
self.configure_logging = configure_logging self.configure_logging = configure_logging
self.debug = None self.debug = None
self.sock = None self.sock = None
self.strict_slashes = strict_slashes self.strict_slashes = strict_slashes
self.listeners = defaultdict(list) self.listeners: Dict[str, List[ListenerType]] = defaultdict(list)
self.is_stopping = False self.is_stopping = False
self.is_running = False self.is_running = False
self.is_request_stream = False
self.websocket_enabled = False self.websocket_enabled = False
self.websocket_tasks = set() self.websocket_tasks: Set[Future] = set()
self.named_request_middleware = {} self.named_request_middleware: Dict[str, MiddlewareType] = {}
self.named_response_middleware = {} self.named_response_middleware: Dict[str, MiddlewareType] = {}
# Register alternative method names # Register alternative method names
self.go_fast = self.run self.go_fast = self.run
@ -162,6 +163,7 @@ class Sanic:
stream=False, stream=False,
version=None, version=None,
name=None, name=None,
ignore_body=False,
): ):
"""Decorate a function to be registered as a route """Decorate a function to be registered as a route
@ -180,9 +182,6 @@ class Sanic:
if not uri.startswith("/"): if not uri.startswith("/"):
uri = "/" + uri uri = "/" + uri
if stream:
self.is_request_stream = True
if strict_slashes is None: if strict_slashes is None:
strict_slashes = self.strict_slashes strict_slashes = self.strict_slashes
@ -215,6 +214,7 @@ class Sanic:
strict_slashes=strict_slashes, strict_slashes=strict_slashes,
version=version, version=version,
name=name, name=name,
ignore_body=ignore_body,
) )
) )
return routes, handler return routes, handler
@ -223,7 +223,13 @@ class Sanic:
# Shorthand method decorators # Shorthand method decorators
def get( def get(
self, uri, host=None, strict_slashes=None, version=None, name=None self,
uri,
host=None,
strict_slashes=None,
version=None,
name=None,
ignore_body=True,
): ):
""" """
Add an API URL under the **GET** *HTTP* method Add an API URL under the **GET** *HTTP* method
@ -243,6 +249,7 @@ class Sanic:
strict_slashes=strict_slashes, strict_slashes=strict_slashes,
version=version, version=version,
name=name, name=name,
ignore_body=ignore_body,
) )
def post( def post(
@ -306,7 +313,13 @@ class Sanic:
) )
def head( def head(
self, uri, host=None, strict_slashes=None, version=None, name=None self,
uri,
host=None,
strict_slashes=None,
version=None,
name=None,
ignore_body=True,
): ):
return self.route( return self.route(
uri, uri,
@ -315,10 +328,17 @@ class Sanic:
strict_slashes=strict_slashes, strict_slashes=strict_slashes,
version=version, version=version,
name=name, name=name,
ignore_body=ignore_body,
) )
def options( def options(
self, uri, host=None, strict_slashes=None, version=None, name=None self,
uri,
host=None,
strict_slashes=None,
version=None,
name=None,
ignore_body=True,
): ):
""" """
Add an API URL under the **OPTIONS** *HTTP* method Add an API URL under the **OPTIONS** *HTTP* method
@ -338,6 +358,7 @@ class Sanic:
strict_slashes=strict_slashes, strict_slashes=strict_slashes,
version=version, version=version,
name=name, name=name,
ignore_body=ignore_body,
) )
def patch( def patch(
@ -371,7 +392,13 @@ class Sanic:
) )
def delete( def delete(
self, uri, host=None, strict_slashes=None, version=None, name=None self,
uri,
host=None,
strict_slashes=None,
version=None,
name=None,
ignore_body=True,
): ):
""" """
Add an API URL under the **DELETE** *HTTP* method Add an API URL under the **DELETE** *HTTP* method
@ -391,6 +418,7 @@ class Sanic:
strict_slashes=strict_slashes, strict_slashes=strict_slashes,
version=version, version=version,
name=name, name=name,
ignore_body=ignore_body,
) )
def add_route( def add_route(
@ -497,6 +525,7 @@ class Sanic:
websocket_handler.__name__ = ( websocket_handler.__name__ = (
"websocket_handler_" + handler.__name__ "websocket_handler_" + handler.__name__
) )
websocket_handler.is_websocket = True
routes.extend( routes.extend(
self.router.add( self.router.add(
uri=uri, uri=uri,
@ -861,7 +890,52 @@ class Sanic:
""" """
pass pass
async def handle_request(self, request, write_callback, stream_callback): async def handle_exception(self, request, exception):
# -------------------------------------------- #
# Request Middleware
# -------------------------------------------- #
response = await self._run_request_middleware(
request, request_name=None
)
# No middleware results
if not response:
try:
response = self.error_handler.response(request, exception)
if isawaitable(response):
response = await response
except Exception as e:
if isinstance(e, SanicException):
response = self.error_handler.default(request, e)
elif self.debug:
response = HTTPResponse(
(
f"Error while handling error: {e}\n"
f"Stack: {format_exc()}"
),
status=500,
)
else:
response = HTTPResponse(
"An error occurred while handling an error", status=500
)
if response is not None:
try:
response = await request.respond(response)
except BaseException:
# Skip response middleware
request.stream.respond(response)
await response.send(end_stream=True)
raise
else:
response = request.stream.response
if isinstance(response, BaseHTTPResponse):
await response.send(end_stream=True)
else:
raise ServerError(
f"Invalid response type {response!r} (need HTTPResponse)"
)
async def handle_request(self, request):
"""Take a request from the HTTP Server and return a response object """Take a request from the HTTP Server and return a response object
to be sent back The HTTP Server only expects a response object, so to be sent back The HTTP Server only expects a response object, so
exception handling must be done here exception handling must be done here
@ -877,13 +951,27 @@ class Sanic:
# Define `response` var here to remove warnings about # Define `response` var here to remove warnings about
# allocation before assignment below. # allocation before assignment below.
response = None response = None
cancelled = False
name = None name = None
try: try:
# Fetch handler from router # Fetch handler from router
handler, args, kwargs, uri, name, endpoint = self.router.get( (
request handler,
) args,
kwargs,
uri,
name,
endpoint,
ignore_body,
) = self.router.get(request)
request.name = name
if request.stream.request_body and not ignore_body:
if self.router.is_stream_handler(request):
# Streaming handler: lift the size limit
request.stream.request_max_size = float("inf")
else:
# Non-streaming handler: preload body
await request.receive_body()
# -------------------------------------------- # # -------------------------------------------- #
# Request Middleware # Request Middleware
@ -912,72 +1000,31 @@ class Sanic:
response = handler(request, *args, **kwargs) response = handler(request, *args, **kwargs)
if isawaitable(response): if isawaitable(response):
response = await response response = await response
if response:
response = await request.respond(response)
else:
response = request.stream.response
# Make sure that response is finished / run StreamingHTTP callback
if isinstance(response, BaseHTTPResponse):
await response.send(end_stream=True)
else:
try:
# Fastest method for checking if the property exists
handler.is_websocket
except AttributeError:
raise ServerError(
f"Invalid response type {response!r} "
"(need HTTPResponse)"
)
except CancelledError: except CancelledError:
# If response handler times out, the server handles the error raise
# and cancels the handle_request job.
# In this case, the transport is already closed and we cannot
# issue a response.
response = None
cancelled = True
except Exception as e: except Exception as e:
# -------------------------------------------- # # -------------------------------------------- #
# Response Generation Failed # Response Generation Failed
# -------------------------------------------- # # -------------------------------------------- #
await self.handle_exception(request, e)
try:
response = self.error_handler.response(request, e)
if isawaitable(response):
response = await response
except Exception as e:
if isinstance(e, SanicException):
response = self.error_handler.default(
request=request, exception=e
)
elif self.debug:
response = HTTPResponse(
f"Error while "
f"handling error: {e}\nStack: {format_exc()}",
status=500,
)
else:
response = HTTPResponse(
"An error occurred while handling an error", status=500
)
finally:
# -------------------------------------------- #
# Response Middleware
# -------------------------------------------- #
# Don't run response middleware if response is None
if response is not None:
try:
response = await self._run_response_middleware(
request, response, request_name=name
)
except CancelledError:
# Response middleware can timeout too, as above.
response = None
cancelled = True
except BaseException:
error_logger.exception(
"Exception occurred in one of response "
"middleware handlers"
)
if cancelled:
raise CancelledError()
# pass the response to the correct callback
if write_callback is None or isinstance(
response, StreamingHTTPResponse
):
if stream_callback:
await stream_callback(response)
else:
# Should only end here IF it is an ASGI websocket.
# TODO:
# - Add exception handling
pass
else:
write_callback(response)
# -------------------------------------------------------------------- # # -------------------------------------------------------------------- #
# Testing # Testing
@ -1213,7 +1260,12 @@ class Sanic:
request_name, deque() request_name, deque()
) )
applicable_middleware = self.request_middleware + named_middleware applicable_middleware = self.request_middleware + named_middleware
if applicable_middleware:
# request.request_middleware_started is meant as a stop-gap solution
# until RFC 1630 is adopted
if applicable_middleware and not request.request_middleware_started:
request.request_middleware_started = True
for middleware in applicable_middleware: for middleware in applicable_middleware:
response = middleware(request) response = middleware(request)
if isawaitable(response): if isawaitable(response):
@ -1236,6 +1288,8 @@ class Sanic:
_response = await _response _response = await _response
if _response: if _response:
response = _response response = _response
if isinstance(response, BaseHTTPResponse):
response = request.stream.respond(response)
break break
return response return response

View File

@ -2,27 +2,15 @@ import asyncio
import warnings import warnings
from inspect import isawaitable from inspect import isawaitable
from typing import ( from typing import Any, Awaitable, Callable, MutableMapping, Optional, Union
Any,
Awaitable,
Callable,
Dict,
List,
MutableMapping,
Optional,
Tuple,
Union,
)
from urllib.parse import quote from urllib.parse import quote
import sanic.app # noqa import sanic.app # noqa
from sanic.compat import Header from sanic.compat import Header
from sanic.exceptions import InvalidUsage, ServerError from sanic.exceptions import InvalidUsage, ServerError
from sanic.log import logger
from sanic.request import Request from sanic.request import Request
from sanic.response import HTTPResponse, StreamingHTTPResponse from sanic.server import ConnInfo
from sanic.server import ConnInfo, StreamBuffer
from sanic.websocket import WebSocketConnection from sanic.websocket import WebSocketConnection
@ -84,7 +72,7 @@ class MockTransport:
def get_extra_info(self, info: str) -> Union[str, bool, None]: def get_extra_info(self, info: str) -> Union[str, bool, None]:
if info == "peername": if info == "peername":
return self.scope.get("server") return self.scope.get("client")
elif info == "sslcontext": elif info == "sslcontext":
return self.scope.get("scheme") in ["https", "wss"] return self.scope.get("scheme") in ["https", "wss"]
return None return None
@ -151,7 +139,7 @@ class Lifespan:
response = handler( response = handler(
self.asgi_app.sanic_app, self.asgi_app.sanic_app.loop self.asgi_app.sanic_app, self.asgi_app.sanic_app.loop
) )
if isawaitable(response): if response and isawaitable(response):
await response await response
async def shutdown(self) -> None: async def shutdown(self) -> None:
@ -171,7 +159,7 @@ class Lifespan:
response = handler( response = handler(
self.asgi_app.sanic_app, self.asgi_app.sanic_app.loop self.asgi_app.sanic_app, self.asgi_app.sanic_app.loop
) )
if isawaitable(response): if response and isawaitable(response):
await response await response
async def __call__( async def __call__(
@ -192,7 +180,6 @@ class ASGIApp:
sanic_app: "sanic.app.Sanic" sanic_app: "sanic.app.Sanic"
request: Request request: Request
transport: MockTransport transport: MockTransport
do_stream: bool
lifespan: Lifespan lifespan: Lifespan
ws: Optional[WebSocketConnection] ws: Optional[WebSocketConnection]
@ -215,9 +202,6 @@ class ASGIApp:
for key, value in scope.get("headers", []) for key, value in scope.get("headers", [])
] ]
) )
instance.do_stream = (
True if headers.get("expect") == "100-continue" else False
)
instance.lifespan = Lifespan(instance) instance.lifespan = Lifespan(instance)
if scope["type"] == "lifespan": if scope["type"] == "lifespan":
@ -244,9 +228,7 @@ class ASGIApp:
) )
await instance.ws.accept() await instance.ws.accept()
else: else:
pass raise ServerError("Received unknown ASGI scope")
# TODO:
# - close connection
request_class = sanic_app.request_class or Request request_class = sanic_app.request_class or Request
instance.request = request_class( instance.request = request_class(
@ -257,161 +239,57 @@ class ASGIApp:
instance.transport, instance.transport,
sanic_app, sanic_app,
) )
instance.request.stream = instance
instance.request_body = True
instance.request.conn_info = ConnInfo(instance.transport) instance.request.conn_info = ConnInfo(instance.transport)
if sanic_app.is_request_stream:
is_stream_handler = sanic_app.router.is_stream_handler(
instance.request
)
if is_stream_handler:
instance.request.stream = StreamBuffer(
sanic_app.config.REQUEST_BUFFER_QUEUE_SIZE
)
instance.do_stream = True
return instance return instance
async def read_body(self) -> bytes: async def read(self) -> Optional[bytes]:
"""
Read and return the entire body from an incoming ASGI message.
"""
body = b""
more_body = True
while more_body:
message = await self.transport.receive()
body += message.get("body", b"")
more_body = message.get("more_body", False)
return body
async def stream_body(self) -> None:
""" """
Read and stream the body in chunks from an incoming ASGI message. Read and stream the body in chunks from an incoming ASGI message.
""" """
more_body = True
while more_body:
message = await self.transport.receive() message = await self.transport.receive()
chunk = message.get("body", b"") if not message.get("more_body", False):
await self.request.stream.put(chunk) self.request_body = False
return None
return message.get("body", b"")
more_body = message.get("more_body", False) async def __aiter__(self):
while self.request_body:
data = await self.read()
if data:
yield data
await self.request.stream.put(None) def respond(self, response):
response.stream, self.response = self, response
return response
async def send(self, data, end_stream):
if self.response:
response, self.response = self.response, None
await self.transport.send(
{
"type": "http.response.start",
"status": response.status,
"headers": response.processed_headers,
}
)
response_body = getattr(response, "body", None)
if response_body:
data = response_body + data if data else response_body
await self.transport.send(
{
"type": "http.response.body",
"body": data.encode() if hasattr(data, "encode") else data,
"more_body": not end_stream,
}
)
_asgi_single_callable = True # We conform to ASGI 3.0 single-callable
async def __call__(self) -> None: async def __call__(self) -> None:
""" """
Handle the incoming request. Handle the incoming request.
""" """
if not self.do_stream: await self.sanic_app.handle_request(self.request)
self.request.body = await self.read_body()
else:
self.sanic_app.loop.create_task(self.stream_body())
handler = self.sanic_app.handle_request
callback = None if self.ws else self.stream_callback
await handler(self.request, None, callback)
_asgi_single_callable = True # We conform to ASGI 3.0 single-callable
async def stream_callback(
self, response: Union[HTTPResponse, StreamingHTTPResponse]
) -> None:
"""
Write the response.
"""
headers: List[Tuple[bytes, bytes]] = []
cookies: Dict[str, str] = {}
content_length: List[str] = []
try:
content_length = response.headers.popall("content-length", [])
cookies = {
v.key: v
for _, v in list(
filter(
lambda item: item[0].lower() == "set-cookie",
response.headers.items(),
)
)
}
headers += [
(str(name).encode("latin-1"), str(value).encode("latin-1"))
for name, value in response.headers.items()
if name.lower() not in ["set-cookie"]
]
except AttributeError:
logger.error(
"Invalid response object for url %s, "
"Expected Type: HTTPResponse, Actual Type: %s",
self.request.url,
type(response),
)
exception = ServerError("Invalid response type")
response = self.sanic_app.error_handler.response(
self.request, exception
)
headers = [
(str(name).encode("latin-1"), str(value).encode("latin-1"))
for name, value in response.headers.items()
if name not in (b"Set-Cookie",)
]
response.asgi = True
is_streaming = isinstance(response, StreamingHTTPResponse)
if is_streaming and getattr(response, "chunked", False):
# disable sanic chunking, this is done at the ASGI-server level
setattr(response, "chunked", False)
# content-length header is removed to signal to the ASGI-server
# to use automatic-chunking if it supports it
elif len(content_length) > 0:
headers += [
(b"content-length", str(content_length[0]).encode("latin-1"))
]
elif not is_streaming:
headers += [
(
b"content-length",
str(len(getattr(response, "body", b""))).encode("latin-1"),
)
]
if "content-type" not in response.headers:
headers += [
(b"content-type", str(response.content_type).encode("latin-1"))
]
if response.cookies:
cookies.update(
{
v.key: v
for _, v in response.cookies.items()
if v.key not in cookies.keys()
}
)
headers += [
(b"set-cookie", cookie.encode("utf-8"))
for k, cookie in cookies.items()
]
await self.transport.send(
{
"type": "http.response.start",
"status": response.status,
"headers": headers,
}
)
if isinstance(response, StreamingHTTPResponse):
response.protocol = self.transport.get_protocol()
await response.stream()
await response.protocol.complete()
else:
await self.transport.send(
{
"type": "http.response.body",
"body": response.body,
"more_body": False,
}
)

View File

@ -1,4 +1,5 @@
import asyncio import asyncio
import os
import signal import signal
from sys import argv from sys import argv
@ -6,6 +7,9 @@ from sys import argv
from multidict import CIMultiDict # type: ignore from multidict import CIMultiDict # type: ignore
OS_IS_WINDOWS = os.name == "nt"
class Header(CIMultiDict): class Header(CIMultiDict):
def get_all(self, key): def get_all(self, key):
return self.getall(key, default=[]) return self.getall(key, default=[])
@ -14,13 +18,13 @@ class Header(CIMultiDict):
use_trio = argv[0].endswith("hypercorn") and "trio" in argv use_trio = argv[0].endswith("hypercorn") and "trio" in argv
if use_trio: if use_trio:
from trio import Path # type: ignore import trio # type: ignore
from trio import open_file as open_async # type: ignore
def stat_async(path): def stat_async(path):
return Path(path).stat() return trio.Path(path).stat()
open_async = trio.open_file
CancelledErrors = tuple([asyncio.CancelledError, trio.Cancelled])
else: else:
from aiofiles import open as aio_open # type: ignore from aiofiles import open as aio_open # type: ignore
from aiofiles.os import stat as stat_async # type: ignore # noqa: F401 from aiofiles.os import stat as stat_async # type: ignore # noqa: F401
@ -28,6 +32,8 @@ else:
async def open_async(file, mode="r", **kwargs): async def open_async(file, mode="r", **kwargs):
return aio_open(file, mode, **kwargs) return aio_open(file, mode, **kwargs)
CancelledErrors = tuple([asyncio.CancelledError])
def ctrlc_workaround_for_windows(app): def ctrlc_workaround_for_windows(app):
async def stay_active(app): async def stay_active(app):

View File

@ -23,6 +23,7 @@ BASE_LOGO = """
DEFAULT_CONFIG = { DEFAULT_CONFIG = {
"REQUEST_MAX_SIZE": 100000000, # 100 megabytes "REQUEST_MAX_SIZE": 100000000, # 100 megabytes
"REQUEST_BUFFER_QUEUE_SIZE": 100, "REQUEST_BUFFER_QUEUE_SIZE": 100,
"REQUEST_BUFFER_SIZE": 65536, # 64 KiB
"REQUEST_TIMEOUT": 60, # 60 seconds "REQUEST_TIMEOUT": 60, # 60 seconds
"RESPONSE_TIMEOUT": 60, # 60 seconds "RESPONSE_TIMEOUT": 60, # 60 seconds
"KEEP_ALIVE": True, "KEEP_ALIVE": True,

View File

@ -206,7 +206,6 @@ class TextRenderer(BaseRenderer):
_, exc_value, __ = sys.exc_info() _, exc_value, __ = sys.exc_info()
exceptions = [] exceptions = []
# traceback_html = self.TRACEBACK_BORDER.join(reversed(exceptions))
lines = [ lines = [
f"{self.exception.__class__.__name__}: {self.exception} while " f"{self.exception.__class__.__name__}: {self.exception} while "
f"handling path {self.request.path}", f"handling path {self.request.path}",

View File

@ -1,4 +1,6 @@
from asyncio.events import AbstractEventLoop
from traceback import format_exc from traceback import format_exc
from typing import Any, Callable, Coroutine, Optional, TypeVar, Union
from sanic.errorpages import exception_response from sanic.errorpages import exception_response
from sanic.exceptions import ( from sanic.exceptions import (
@ -7,7 +9,23 @@ from sanic.exceptions import (
InvalidRangeType, InvalidRangeType,
) )
from sanic.log import logger from sanic.log import logger
from sanic.response import text from sanic.request import Request
from sanic.response import BaseHTTPResponse, HTTPResponse, text
Sanic = TypeVar("Sanic")
MiddlewareResponse = Union[
Optional[HTTPResponse], Coroutine[Any, Any, Optional[HTTPResponse]]
]
RequestMiddlewareType = Callable[[Request], MiddlewareResponse]
ResponseMiddlewareType = Callable[
[Request, BaseHTTPResponse], MiddlewareResponse
]
MiddlewareType = Union[RequestMiddlewareType, ResponseMiddlewareType]
ListenerType = Callable[
[Sanic, AbstractEventLoop], Optional[Coroutine[Any, Any, None]]
]
class ErrorHandler: class ErrorHandler:

View File

@ -7,6 +7,7 @@ from sanic.helpers import STATUS_CODES
HeaderIterable = Iterable[Tuple[str, Any]] # Values convertible to str HeaderIterable = Iterable[Tuple[str, Any]] # Values convertible to str
HeaderBytesIterable = Iterable[Tuple[bytes, bytes]]
Options = Dict[str, Union[int, str]] # key=value fields in various headers Options = Dict[str, Union[int, str]] # key=value fields in various headers
OptionsIterable = Iterable[Tuple[str, str]] # May contain duplicate keys OptionsIterable = Iterable[Tuple[str, str]] # May contain duplicate keys
@ -175,26 +176,18 @@ def parse_host(host: str) -> Tuple[Optional[str], Optional[int]]:
return host.lower(), int(port) if port is not None else None return host.lower(), int(port) if port is not None else None
def format_http1(headers: HeaderIterable) -> bytes: _HTTP1_STATUSLINES = [
"""Convert a headers iterable into HTTP/1 header format. b"HTTP/1.1 %d %b\r\n" % (status, STATUS_CODES.get(status, b"UNKNOWN"))
for status in range(1000)
- Outputs UTF-8 bytes where each header line ends with \\r\\n. ]
- Values are converted into strings if necessary.
"""
return "".join(f"{name}: {val}\r\n" for name, val in headers).encode()
def format_http1_response( def format_http1_response(status: int, headers: HeaderBytesIterable) -> bytes:
status: int, headers: HeaderIterable, body=b"" """Format a HTTP/1.1 response header."""
) -> bytes: # Note: benchmarks show that here bytes concat is faster than bytearray,
"""Format a full HTTP/1.1 response. # b"".join() or %-formatting. %timeit any changes you make.
ret = _HTTP1_STATUSLINES[status]
- If `body` is included, content-length must be specified in headers. for h in headers:
""" ret += b"%b: %b\r\n" % h
headerbytes = format_http1(headers) ret += b"\r\n"
return b"HTTP/1.1 %d %b\r\n%b\r\n%b" % ( return ret
status,
STATUS_CODES.get(status, b"UNKNOWN"),
headerbytes,
body,
)

475
sanic/http.py Normal file
View File

@ -0,0 +1,475 @@
from asyncio import CancelledError, sleep
from enum import Enum
from sanic.compat import Header
from sanic.exceptions import (
HeaderExpectationFailed,
InvalidUsage,
PayloadTooLarge,
ServerError,
ServiceUnavailable,
)
from sanic.headers import format_http1_response
from sanic.helpers import has_message_body
from sanic.log import access_logger, logger
class Stage(Enum):
IDLE = 0 # Waiting for request
REQUEST = 1 # Request headers being received
HANDLER = 3 # Headers done, handler running
RESPONSE = 4 # Response headers sent, body in progress
FAILED = 100 # Unrecoverable state (error while sending response)
HTTP_CONTINUE = b"HTTP/1.1 100 Continue\r\n\r\n"
class Http:
__slots__ = [
"_send",
"_receive_more",
"recv_buffer",
"protocol",
"expecting_continue",
"stage",
"keep_alive",
"head_only",
"request",
"exception",
"url",
"request_body",
"request_bytes",
"request_bytes_left",
"request_max_size",
"response",
"response_func",
"response_bytes_left",
"upgrade_websocket",
]
def __init__(self, protocol):
self._send = protocol.send
self._receive_more = protocol.receive_more
self.recv_buffer = protocol.recv_buffer
self.protocol = protocol
self.expecting_continue = False
self.stage = Stage.IDLE
self.request_body = None
self.request_bytes = None
self.request_bytes_left = None
self.request_max_size = protocol.request_max_size
self.keep_alive = True
self.head_only = None
self.request = None
self.response = None
self.exception = None
self.url = None
self.upgrade_websocket = False
def __bool__(self):
"""Test if request handling is in progress"""
return self.stage in (Stage.HANDLER, Stage.RESPONSE)
async def http1(self):
"""HTTP 1.1 connection handler"""
while True: # As long as connection stays keep-alive
try:
# Receive and handle a request
self.stage = Stage.REQUEST
self.response_func = self.http1_response_header
await self.http1_request_header()
self.request.conn_info = self.protocol.conn_info
await self.protocol.request_handler(self.request)
# Handler finished, response should've been sent
if self.stage is Stage.HANDLER and not self.upgrade_websocket:
raise ServerError("Handler produced no response")
if self.stage is Stage.RESPONSE:
await self.response.send(end_stream=True)
except CancelledError:
# Write an appropriate response before exiting
e = self.exception or ServiceUnavailable("Cancelled")
self.exception = None
self.keep_alive = False
await self.error_response(e)
except Exception as e:
# Write an error response
await self.error_response(e)
# Try to consume any remaining request body
if self.request_body:
if self.response and 200 <= self.response.status < 300:
logger.error(f"{self.request} body not consumed.")
try:
async for _ in self:
pass
except PayloadTooLarge:
# We won't read the body and that may cause httpx and
# tests to fail. This little delay allows clients to push
# a small request into network buffers before we close the
# socket, so that they are then able to read the response.
await sleep(0.001)
self.keep_alive = False
# Exit and disconnect if no more requests can be taken
if self.stage is not Stage.IDLE or not self.keep_alive:
break
# Wait for next request
if not self.recv_buffer:
await self._receive_more()
async def http1_request_header(self):
"""Receive and parse request header into self.request."""
HEADER_MAX_SIZE = min(8192, self.request_max_size)
# Receive until full header is in buffer
buf = self.recv_buffer
pos = 0
while True:
pos = buf.find(b"\r\n\r\n", pos)
if pos != -1:
break
pos = max(0, len(buf) - 3)
if pos >= HEADER_MAX_SIZE:
break
await self._receive_more()
if pos >= HEADER_MAX_SIZE:
raise PayloadTooLarge("Request header exceeds the size limit")
# Parse header content
try:
raw_headers = buf[:pos].decode(errors="surrogateescape")
reqline, *raw_headers = raw_headers.split("\r\n")
method, self.url, protocol = reqline.split(" ")
if protocol == "HTTP/1.1":
self.keep_alive = True
elif protocol == "HTTP/1.0":
self.keep_alive = False
else:
raise Exception # Raise a Bad Request on try-except
self.head_only = method.upper() == "HEAD"
request_body = False
headers = []
for name, value in (h.split(":", 1) for h in raw_headers):
name, value = h = name.lower(), value.lstrip()
if name in ("content-length", "transfer-encoding"):
request_body = True
elif name == "connection":
self.keep_alive = value.lower() == "keep-alive"
headers.append(h)
except Exception:
raise InvalidUsage("Bad Request")
headers_instance = Header(headers)
self.upgrade_websocket = headers_instance.get("upgrade") == "websocket"
# Prepare a Request object
request = self.protocol.request_class(
url_bytes=self.url.encode(),
headers=headers_instance,
version=protocol[5:],
method=method,
transport=self.protocol.transport,
app=self.protocol.app,
)
# Prepare for request body
self.request_bytes_left = self.request_bytes = 0
if request_body:
headers = request.headers
expect = headers.get("expect")
if expect is not None:
if expect.lower() == "100-continue":
self.expecting_continue = True
else:
raise HeaderExpectationFailed(f"Unknown Expect: {expect}")
if headers.get("transfer-encoding") == "chunked":
self.request_body = "chunked"
pos -= 2 # One CRLF stays in buffer
else:
self.request_body = True
self.request_bytes_left = self.request_bytes = int(
headers["content-length"]
)
# Remove header and its trailing CRLF
del buf[: pos + 4]
self.stage = Stage.HANDLER
self.request, request.stream = request, self
self.protocol.state["requests_count"] += 1
async def http1_response_header(self, data, end_stream):
res = self.response
# Compatibility with simple response body
if not data and getattr(res, "body", None):
data, end_stream = res.body, True
size = len(data)
headers = res.headers
status = res.status
if not isinstance(status, int) or status < 200:
raise RuntimeError(f"Invalid response status {status!r}")
if not has_message_body(status):
# Header-only response status
self.response_func = None
if (
data
or not end_stream
or "content-length" in headers
or "transfer-encoding" in headers
):
data, size, end_stream = b"", 0, True
headers.pop("content-length", None)
headers.pop("transfer-encoding", None)
logger.warning(
f"Message body set in response on {self.request.path}. "
f"A {status} response may only have headers, no body."
)
elif self.head_only and "content-length" in headers:
self.response_func = None
elif end_stream:
# Non-streaming response (all in one block)
headers["content-length"] = size
self.response_func = None
elif "content-length" in headers:
# Streaming response with size known in advance
self.response_bytes_left = int(headers["content-length"]) - size
self.response_func = self.http1_response_normal
else:
# Length not known, use chunked encoding
headers["transfer-encoding"] = "chunked"
data = b"%x\r\n%b\r\n" % (size, data) if size else None
self.response_func = self.http1_response_chunked
if self.head_only:
# Head request: don't send body
data = b""
self.response_func = self.head_response_ignored
headers["connection"] = "keep-alive" if self.keep_alive else "close"
ret = format_http1_response(status, res.processed_headers)
if data:
ret += data
# Send a 100-continue if expected and not Expectation Failed
if self.expecting_continue:
self.expecting_continue = False
if status != 417:
ret = HTTP_CONTINUE + ret
# Send response
if self.protocol.access_log:
self.log_response()
await self._send(ret)
self.stage = Stage.IDLE if end_stream else Stage.RESPONSE
def head_response_ignored(self, data, end_stream):
"""HEAD response: body data silently ignored."""
if end_stream:
self.response_func = None
self.stage = Stage.IDLE
async def http1_response_chunked(self, data, end_stream):
"""Format a part of response body in chunked encoding."""
# Chunked encoding
size = len(data)
if end_stream:
await self._send(
b"%x\r\n%b\r\n0\r\n\r\n" % (size, data)
if size
else b"0\r\n\r\n"
)
self.response_func = None
self.stage = Stage.IDLE
elif size:
await self._send(b"%x\r\n%b\r\n" % (size, data))
async def http1_response_normal(self, data: bytes, end_stream: bool):
"""Format / keep track of non-chunked response."""
bytes_left = self.response_bytes_left - len(data)
if bytes_left <= 0:
if bytes_left < 0:
raise ServerError("Response was bigger than content-length")
await self._send(data)
self.response_func = None
self.stage = Stage.IDLE
else:
if end_stream:
raise ServerError("Response was smaller than content-length")
await self._send(data)
self.response_bytes_left = bytes_left
async def error_response(self, exception):
# Disconnect after an error if in any other state than handler
if self.stage is not Stage.HANDLER:
self.keep_alive = False
# Request failure? Respond but then disconnect
if self.stage is Stage.REQUEST:
self.stage = Stage.HANDLER
# From request and handler states we can respond, otherwise be silent
if self.stage is Stage.HANDLER:
app = self.protocol.app
if self.request is None:
self.create_empty_request()
await app.handle_exception(self.request, exception)
def create_empty_request(self):
"""Current error handling code needs a request object that won't exist
if an error occurred during before a request was received. Create a
bogus response for error handling use."""
# FIXME: Avoid this by refactoring error handling and response code
self.request = self.protocol.request_class(
url_bytes=self.url.encode() if self.url else b"*",
headers=Header({}),
version="1.1",
method="NONE",
transport=self.protocol.transport,
app=self.protocol.app,
)
self.request.stream = self
def log_response(self):
"""
Helper method provided to enable the logging of responses in case if
the :attr:`HttpProtocol.access_log` is enabled.
:param response: Response generated for the current request
:type response: :class:`sanic.response.HTTPResponse` or
:class:`sanic.response.StreamingHTTPResponse`
:return: None
"""
req, res = self.request, self.response
extra = {
"status": getattr(res, "status", 0),
"byte": getattr(self, "response_bytes_left", -1),
"host": "UNKNOWN",
"request": "nil",
}
if req is not None:
if req.ip:
extra["host"] = f"{req.ip}:{req.port}"
extra["request"] = f"{req.method} {req.url}"
access_logger.info("", extra=extra)
# Request methods
async def __aiter__(self):
"""Async iterate over request body."""
while self.request_body:
data = await self.read()
if data:
yield data
async def read(self):
"""Read some bytes of request body."""
# Send a 100-continue if needed
if self.expecting_continue:
self.expecting_continue = False
await self._send(HTTP_CONTINUE)
# Receive request body chunk
buf = self.recv_buffer
if self.request_bytes_left == 0 and self.request_body == "chunked":
# Process a chunk header: \r\n<size>[;<chunk extensions>]\r\n
while True:
pos = buf.find(b"\r\n", 3)
if pos != -1:
break
if len(buf) > 64:
self.keep_alive = False
raise InvalidUsage("Bad chunked encoding")
await self._receive_more()
try:
size = int(buf[2:pos].split(b";", 1)[0].decode(), 16)
except Exception:
self.keep_alive = False
raise InvalidUsage("Bad chunked encoding")
del buf[: pos + 2]
if size <= 0:
self.request_body = None
if size < 0:
self.keep_alive = False
raise InvalidUsage("Bad chunked encoding")
return None
self.request_bytes_left = size
self.request_bytes += size
# Request size limit
if self.request_bytes > self.request_max_size:
self.keep_alive = False
raise PayloadTooLarge("Request body exceeds the size limit")
# End of request body?
if not self.request_bytes_left:
self.request_body = None
return
# At this point we are good to read/return up to request_bytes_left
if not buf:
await self._receive_more()
data = bytes(buf[: self.request_bytes_left])
size = len(data)
del buf[:size]
self.request_bytes_left -= size
return data
# Response methods
def respond(self, response):
"""Initiate new streaming response.
Nothing is sent until the first send() call on the returned object, and
calling this function multiple times will just alter the response to be
given."""
if self.stage is not Stage.HANDLER:
self.stage = Stage.FAILED
raise RuntimeError("Response already started")
self.response, response.stream = response, self
return response
@property
def send(self):
return self.response_func

View File

@ -1,4 +1,3 @@
import asyncio
import email.utils import email.utils
from collections import defaultdict, namedtuple from collections import defaultdict, namedtuple
@ -8,6 +7,7 @@ from urllib.parse import parse_qs, parse_qsl, unquote, urlunparse
from httptools import parse_url # type: ignore from httptools import parse_url # type: ignore
from sanic.compat import CancelledErrors
from sanic.exceptions import InvalidUsage from sanic.exceptions import InvalidUsage
from sanic.headers import ( from sanic.headers import (
parse_content_header, parse_content_header,
@ -16,6 +16,7 @@ from sanic.headers import (
parse_xforwarded, parse_xforwarded,
) )
from sanic.log import error_logger, logger from sanic.log import error_logger, logger
from sanic.response import BaseHTTPResponse, HTTPResponse
try: try:
@ -24,7 +25,6 @@ except ImportError:
from json import loads as json_loads # type: ignore from json import loads as json_loads # type: ignore
DEFAULT_HTTP_CONTENT_TYPE = "application/octet-stream" DEFAULT_HTTP_CONTENT_TYPE = "application/octet-stream"
EXPECT_HEADER = "EXPECT"
# HTTP/1.1: https://www.w3.org/Protocols/rfc2616/rfc2616-sec7.html#sec7.2.1 # HTTP/1.1: https://www.w3.org/Protocols/rfc2616/rfc2616-sec7.html#sec7.2.1
# > If the media type remains unknown, the recipient SHOULD treat it # > If the media type remains unknown, the recipient SHOULD treat it
@ -45,35 +45,6 @@ class RequestParameters(dict):
return super().get(name, default) return super().get(name, default)
class StreamBuffer:
def __init__(self, buffer_size=100):
self._queue = asyncio.Queue(buffer_size)
async def read(self):
""" Stop reading when gets None """
payload = await self._queue.get()
self._queue.task_done()
return payload
async def __aiter__(self):
"""Support `async for data in request.stream`"""
while True:
data = await self.read()
if not data:
break
yield data
async def put(self, payload):
await self._queue.put(payload)
def is_full(self):
return self._queue.full()
@property
def buffer_size(self):
return self._queue.maxsize
class Request: class Request:
"""Properties of an HTTP request such as URL, headers, etc.""" """Properties of an HTTP request such as URL, headers, etc."""
@ -92,6 +63,7 @@ class Request:
"endpoint", "endpoint",
"headers", "headers",
"method", "method",
"name",
"parsed_args", "parsed_args",
"parsed_not_grouped_args", "parsed_not_grouped_args",
"parsed_files", "parsed_files",
@ -99,6 +71,7 @@ class Request:
"parsed_json", "parsed_json",
"parsed_forwarded", "parsed_forwarded",
"raw_url", "raw_url",
"request_middleware_started",
"stream", "stream",
"transport", "transport",
"uri_template", "uri_template",
@ -117,9 +90,10 @@ class Request:
self.transport = transport self.transport = transport
# Init but do not inhale # Init but do not inhale
self.body_init() self.body = b""
self.conn_info = None self.conn_info = None
self.ctx = SimpleNamespace() self.ctx = SimpleNamespace()
self.name = None
self.parsed_forwarded = None self.parsed_forwarded = None
self.parsed_json = None self.parsed_json = None
self.parsed_form = None self.parsed_form = None
@ -127,6 +101,7 @@ class Request:
self.parsed_args = defaultdict(RequestParameters) self.parsed_args = defaultdict(RequestParameters)
self.parsed_not_grouped_args = defaultdict(list) self.parsed_not_grouped_args = defaultdict(list)
self.uri_template = None self.uri_template = None
self.request_middleware_started = False
self._cookies = None self._cookies = None
self.stream = None self.stream = None
self.endpoint = None self.endpoint = None
@ -135,35 +110,42 @@ class Request:
class_name = self.__class__.__name__ class_name = self.__class__.__name__
return f"<{class_name}: {self.method} {self.path}>" return f"<{class_name}: {self.method} {self.path}>"
def body_init(self): async def respond(
""".. deprecated:: 20.3 self, response=None, *, status=200, headers=None, content_type=None
To be removed in 21.3""" ):
self.body = [] # This logic of determining which response to use is subject to change
if response is None:
def body_push(self, data): response = self.stream.response or HTTPResponse(
""".. deprecated:: 20.3 status=status,
To be removed in 21.3""" headers=headers,
self.body.append(data) content_type=content_type,
)
def body_finish(self): # Connect the response
""".. deprecated:: 20.3 if isinstance(response, BaseHTTPResponse):
To be removed in 21.3""" response = self.stream.respond(response)
self.body = b"".join(self.body) # Run response middleware
try:
response = await self.app._run_response_middleware(
self, response, request_name=self.name
)
except CancelledErrors:
raise
except Exception:
error_logger.exception(
"Exception occurred in one of response middleware handlers"
)
return response
async def receive_body(self): async def receive_body(self):
"""Receive request.body, if not already received. """Receive request.body, if not already received.
Streaming handlers may call this to receive the full body. Streaming handlers may call this to receive the full body. Sanic calls
this function before running any handlers of non-streaming routes.
This is added as a compatibility shim in Sanic 20.3 because future Custom request classes can override this for custom handling of both
versions of Sanic will make all requests streaming and will use this streaming and non-streaming routes.
function instead of the non-async body_init/push/finish functions.
Please make an issue if your code depends on the old functionality and
cannot be upgraded to the new API.
""" """
if not self.stream: if not self.body:
return
self.body = b"".join([data async for data in self.stream]) self.body = b"".join([data async for data in self.stream])
@property @property

View File

@ -2,10 +2,10 @@ from functools import partial
from mimetypes import guess_type from mimetypes import guess_type
from os import path from os import path
from urllib.parse import quote_plus from urllib.parse import quote_plus
from warnings import warn
from sanic.compat import Header, open_async from sanic.compat import Header, open_async
from sanic.cookies import CookieJar from sanic.cookies import CookieJar
from sanic.headers import format_http1, format_http1_response
from sanic.helpers import has_message_body, remove_entity_headers from sanic.helpers import has_message_body, remove_entity_headers
@ -28,50 +28,51 @@ class BaseHTTPResponse:
return b"" return b""
return data.encode() if hasattr(data, "encode") else data return data.encode() if hasattr(data, "encode") else data
def _parse_headers(self):
return format_http1(self.headers.items())
@property @property
def cookies(self): def cookies(self):
if self._cookies is None: if self._cookies is None:
self._cookies = CookieJar(self.headers) self._cookies = CookieJar(self.headers)
return self._cookies return self._cookies
def get_headers( @property
self, def processed_headers(self):
version="1.1", """Obtain a list of header tuples encoded in bytes for sending.
keep_alive=False,
keep_alive_timeout=None,
body=b"",
):
""".. deprecated:: 20.3:
This function is not public API and will be removed in 21.3."""
# self.headers get priority over content_type Add and remove headers based on status and content_type.
if self.content_type and "Content-Type" not in self.headers: """
self.headers["Content-Type"] = self.content_type # TODO: Make a blacklist set of header names and then filter with that
if self.status in (304, 412): # Not Modified, Precondition Failed
if keep_alive:
self.headers["Connection"] = "keep-alive"
if keep_alive_timeout is not None:
self.headers["Keep-Alive"] = keep_alive_timeout
else:
self.headers["Connection"] = "close"
if self.status in (304, 412):
self.headers = remove_entity_headers(self.headers) self.headers = remove_entity_headers(self.headers)
if has_message_body(self.status):
self.headers.setdefault("content-type", self.content_type)
# Encode headers into bytes
return (
(name.encode("ascii"), f"{value}".encode(errors="surrogateescape"))
for name, value in self.headers.items()
)
return format_http1_response(self.status, self.headers.items(), body) async def send(self, data=None, end_stream=None):
"""Send any pending response headers and the given data as body.
:param data: str or bytes to be written
:end_stream: whether to close the stream after this block
"""
if data is None and end_stream is None:
end_stream = True
if end_stream and not data and self.stream.send is None:
return
data = data.encode() if hasattr(data, "encode") else data or b""
await self.stream.send(data, end_stream=end_stream)
class StreamingHTTPResponse(BaseHTTPResponse): class StreamingHTTPResponse(BaseHTTPResponse):
"""Old style streaming response. Use `request.respond()` instead of this in
new code to avoid the callback."""
__slots__ = ( __slots__ = (
"protocol",
"streaming_fn", "streaming_fn",
"status", "status",
"content_type", "content_type",
"headers", "headers",
"chunked",
"_cookies", "_cookies",
) )
@ -81,63 +82,34 @@ class StreamingHTTPResponse(BaseHTTPResponse):
status=200, status=200,
headers=None, headers=None,
content_type="text/plain; charset=utf-8", content_type="text/plain; charset=utf-8",
chunked=True, chunked="deprecated",
): ):
if chunked != "deprecated":
warn(
"The chunked argument has been deprecated and will be "
"removed in v21.6"
)
super().__init__() super().__init__()
self.content_type = content_type self.content_type = content_type
self.streaming_fn = streaming_fn self.streaming_fn = streaming_fn
self.status = status self.status = status
self.headers = Header(headers or {}) self.headers = Header(headers or {})
self.chunked = chunked
self._cookies = None self._cookies = None
self.protocol = None
async def write(self, data): async def write(self, data):
"""Writes a chunk of data to the streaming response. """Writes a chunk of data to the streaming response.
:param data: str or bytes-ish data to be written. :param data: str or bytes-ish data to be written.
""" """
data = self._encode_body(data) await super().send(self._encode_body(data))
# `chunked` will always be False in ASGI-mode, even if the underlying async def send(self, *args, **kwargs):
# ASGI Transport implements Chunked transport. That does it itself. if self.streaming_fn is not None:
if self.chunked:
await self.protocol.push_data(b"%x\r\n%b\r\n" % (len(data), data))
else:
await self.protocol.push_data(data)
await self.protocol.drain()
async def stream(
self, version="1.1", keep_alive=False, keep_alive_timeout=None
):
"""Streams headers, runs the `streaming_fn` callback that writes
content to the response body, then finalizes the response body.
"""
if version != "1.1":
self.chunked = False
if not getattr(self, "asgi", False):
headers = self.get_headers(
version,
keep_alive=keep_alive,
keep_alive_timeout=keep_alive_timeout,
)
await self.protocol.push_data(headers)
await self.protocol.drain()
await self.streaming_fn(self) await self.streaming_fn(self)
if self.chunked: self.streaming_fn = None
await self.protocol.push_data(b"0\r\n\r\n") await super().send(*args, **kwargs)
# no need to await drain here after this write, because it is the
# very last thing we write and nothing needs to wait for it.
def get_headers(
self, version="1.1", keep_alive=False, keep_alive_timeout=None
):
if self.chunked and version == "1.1":
self.headers["Transfer-Encoding"] = "chunked"
self.headers.pop("Content-Length", None)
return super().get_headers(version, keep_alive, keep_alive_timeout)
class HTTPResponse(BaseHTTPResponse): class HTTPResponse(BaseHTTPResponse):
@ -158,22 +130,6 @@ class HTTPResponse(BaseHTTPResponse):
self.headers = Header(headers or {}) self.headers = Header(headers or {})
self._cookies = None self._cookies = None
def output(self, version="1.1", keep_alive=False, keep_alive_timeout=None):
body = b""
if has_message_body(self.status):
body = self.body
self.headers["Content-Length"] = self.headers.get(
"Content-Length", len(self.body)
)
return self.get_headers(version, keep_alive, keep_alive_timeout, body)
@property
def cookies(self):
if self._cookies is None:
self._cookies = CookieJar(self.headers)
return self._cookies
def empty(status=204, headers=None): def empty(status=204, headers=None):
""" """
@ -319,7 +275,7 @@ async def file_stream(
mime_type=None, mime_type=None,
headers=None, headers=None,
filename=None, filename=None,
chunked=True, chunked="deprecated",
_range=None, _range=None,
): ):
"""Return a streaming response object with file data. """Return a streaming response object with file data.
@ -329,9 +285,15 @@ async def file_stream(
:param mime_type: Specific mime_type. :param mime_type: Specific mime_type.
:param headers: Custom Headers. :param headers: Custom Headers.
:param filename: Override filename. :param filename: Override filename.
:param chunked: Enable or disable chunked transfer-encoding :param chunked: Deprecated
:param _range: :param _range:
""" """
if chunked != "deprecated":
warn(
"The chunked argument has been deprecated and will be "
"removed in v21.6"
)
headers = headers or {} headers = headers or {}
if filename: if filename:
headers.setdefault( headers.setdefault(
@ -370,7 +332,6 @@ async def file_stream(
status=status, status=status,
headers=headers, headers=headers,
content_type=mime_type, content_type=mime_type,
chunked=chunked,
) )
@ -379,7 +340,7 @@ def stream(
status=200, status=200,
headers=None, headers=None,
content_type="text/plain; charset=utf-8", content_type="text/plain; charset=utf-8",
chunked=True, chunked="deprecated",
): ):
"""Accepts an coroutine `streaming_fn` which can be used to """Accepts an coroutine `streaming_fn` which can be used to
write chunks to a streaming response. Returns a `StreamingHTTPResponse`. write chunks to a streaming response. Returns a `StreamingHTTPResponse`.
@ -398,14 +359,19 @@ def stream(
writes content to that response. writes content to that response.
:param mime_type: Specific mime_type. :param mime_type: Specific mime_type.
:param headers: Custom Headers. :param headers: Custom Headers.
:param chunked: Enable or disable chunked transfer-encoding :param chunked: Deprecated
""" """
if chunked != "deprecated":
warn(
"The chunked argument has been deprecated and will be "
"removed in v21.6"
)
return StreamingHTTPResponse( return StreamingHTTPResponse(
streaming_fn, streaming_fn,
headers=headers, headers=headers,
content_type=content_type, content_type=content_type,
status=status, status=status,
chunked=chunked,
) )

View File

@ -20,6 +20,7 @@ Route = namedtuple(
"name", "name",
"uri", "uri",
"endpoint", "endpoint",
"ignore_body",
], ],
) )
Parameter = namedtuple("Parameter", ["name", "cast"]) Parameter = namedtuple("Parameter", ["name", "cast"])
@ -135,6 +136,7 @@ class Router:
handler, handler,
host=None, host=None,
strict_slashes=False, strict_slashes=False,
ignore_body=False,
version=None, version=None,
name=None, name=None,
): ):
@ -146,6 +148,7 @@ class Router:
:param handler: request handler function. :param handler: request handler function.
When executed, it should provide a response object. When executed, it should provide a response object.
:param strict_slashes: strict to trailing slash :param strict_slashes: strict to trailing slash
:param ignore_body: Handler should not read the body, if any
:param version: current version of the route or blueprint. See :param version: current version of the route or blueprint. See
docs for further details. docs for further details.
:return: Nothing :return: Nothing
@ -155,7 +158,9 @@ class Router:
version = re.escape(str(version).strip("/").lstrip("v")) version = re.escape(str(version).strip("/").lstrip("v"))
uri = "/".join([f"/v{version}", uri.lstrip("/")]) uri = "/".join([f"/v{version}", uri.lstrip("/")])
# add regular version # add regular version
routes.append(self._add(uri, methods, handler, host, name)) routes.append(
self._add(uri, methods, handler, host, name, ignore_body)
)
if strict_slashes: if strict_slashes:
return routes return routes
@ -187,14 +192,20 @@ class Router:
) )
# add version with trailing slash # add version with trailing slash
if slash_is_missing: if slash_is_missing:
routes.append(self._add(uri + "/", methods, handler, host, name)) routes.append(
self._add(uri + "/", methods, handler, host, name, ignore_body)
)
# add version without trailing slash # add version without trailing slash
elif without_slash_is_missing: elif without_slash_is_missing:
routes.append(self._add(uri[:-1], methods, handler, host, name)) routes.append(
self._add(uri[:-1], methods, handler, host, name, ignore_body)
)
return routes return routes
def _add(self, uri, methods, handler, host=None, name=None): def _add(
self, uri, methods, handler, host=None, name=None, ignore_body=False
):
"""Add a handler to the route list """Add a handler to the route list
:param uri: path to match :param uri: path to match
@ -326,6 +337,7 @@ class Router:
name=handler_name, name=handler_name,
uri=uri, uri=uri,
endpoint=endpoint, endpoint=endpoint,
ignore_body=ignore_body,
) )
self.routes_all[uri] = route self.routes_all[uri] = route
@ -465,7 +477,15 @@ class Router:
if hasattr(route_handler, "handlers"): if hasattr(route_handler, "handlers"):
route_handler = route_handler.handlers[method] route_handler = route_handler.handlers[method]
return route_handler, [], kwargs, route.uri, route.name, route.endpoint return (
route_handler,
[],
kwargs,
route.uri,
route.name,
route.endpoint,
route.ignore_body,
)
def is_stream_handler(self, request): def is_stream_handler(self, request):
"""Handler for request is stream or not. """Handler for request is stream or not.

View File

@ -5,33 +5,22 @@ import secrets
import socket import socket
import stat import stat
import sys import sys
import traceback
from collections import deque from asyncio import CancelledError
from functools import partial from functools import partial
from inspect import isawaitable from inspect import isawaitable
from ipaddress import ip_address from ipaddress import ip_address
from signal import SIG_IGN, SIGINT, SIGTERM, Signals from signal import SIG_IGN, SIGINT, SIGTERM, Signals
from signal import signal as signal_func from signal import signal as signal_func
from time import time from time import monotonic as current_time
from typing import Dict, Type, Union from typing import Dict, Type, Union
from httptools import HttpRequestParser # type: ignore from sanic.compat import OS_IS_WINDOWS, ctrlc_workaround_for_windows
from httptools.parser.errors import HttpParserError # type: ignore
from sanic.compat import Header, ctrlc_workaround_for_windows
from sanic.config import Config from sanic.config import Config
from sanic.exceptions import ( from sanic.exceptions import RequestTimeout, ServiceUnavailable
HeaderExpectationFailed, from sanic.http import Http, Stage
InvalidUsage, from sanic.log import logger
PayloadTooLarge, from sanic.request import Request
RequestTimeout,
ServerError,
ServiceUnavailable,
)
from sanic.log import access_logger, logger
from sanic.request import EXPECT_HEADER, Request, StreamBuffer
from sanic.response import HTTPResponse
try: try:
@ -42,8 +31,6 @@ try:
except ImportError: except ImportError:
pass pass
OS_IS_WINDOWS = os.name == "nt"
class Signal: class Signal:
stopped = False stopped = False
@ -99,10 +86,7 @@ class HttpProtocol(asyncio.Protocol):
"signal", "signal",
"conn_info", "conn_info",
# request params # request params
"parser",
"request", "request",
"url",
"headers",
# request config # request config
"request_handler", "request_handler",
"request_timeout", "request_timeout",
@ -111,26 +95,21 @@ class HttpProtocol(asyncio.Protocol):
"request_max_size", "request_max_size",
"request_buffer_queue_size", "request_buffer_queue_size",
"request_class", "request_class",
"is_request_stream",
"error_handler", "error_handler",
# enable or disable access log purpose # enable or disable access log purpose
"access_log", "access_log",
# connection management # connection management
"_total_request_size",
"_request_timeout_handler",
"_response_timeout_handler",
"_keep_alive_timeout_handler",
"_last_request_time",
"_last_response_time",
"_is_stream_handler",
"_not_paused",
"_request_handler_task",
"_request_stream_task",
"_keep_alive",
"_header_fragment",
"state", "state",
"url",
"_handler_task",
"_can_write",
"_data_received",
"_time",
"_task",
"_http",
"_exception",
"recv_buffer",
"_unix", "_unix",
"_body_chunks",
) )
def __init__( def __init__(
@ -148,12 +127,10 @@ class HttpProtocol(asyncio.Protocol):
self.loop = loop self.loop = loop
deprecated_loop = self.loop if sys.version_info < (3, 7) else None deprecated_loop = self.loop if sys.version_info < (3, 7) else None
self.app = app self.app = app
self.url = None
self.transport = None self.transport = None
self.conn_info = None self.conn_info = None
self.request = None self.request = None
self.parser = None
self.url = None
self.headers = None
self.signal = signal self.signal = signal
self.access_log = self.app.config.ACCESS_LOG self.access_log = self.app.config.ACCESS_LOG
self.connections = connections if connections is not None else set() self.connections = connections if connections is not None else set()
@ -167,511 +144,99 @@ class HttpProtocol(asyncio.Protocol):
self.keep_alive_timeout = self.app.config.KEEP_ALIVE_TIMEOUT self.keep_alive_timeout = self.app.config.KEEP_ALIVE_TIMEOUT
self.request_max_size = self.app.config.REQUEST_MAX_SIZE self.request_max_size = self.app.config.REQUEST_MAX_SIZE
self.request_class = self.app.request_class or Request self.request_class = self.app.request_class or Request
self.is_request_stream = self.app.is_request_stream
self._is_stream_handler = False
self._not_paused = asyncio.Event(loop=deprecated_loop)
self._total_request_size = 0
self._request_timeout_handler = None
self._response_timeout_handler = None
self._keep_alive_timeout_handler = None
self._last_request_time = None
self._last_response_time = None
self._request_handler_task = None
self._request_stream_task = None
self._keep_alive = self.app.config.KEEP_ALIVE
self._header_fragment = b""
self.state = state if state else {} self.state = state if state else {}
if "requests_count" not in self.state: if "requests_count" not in self.state:
self.state["requests_count"] = 0 self.state["requests_count"] = 0
self._data_received = asyncio.Event(loop=deprecated_loop)
self._can_write = asyncio.Event(loop=deprecated_loop)
self._can_write.set()
self._exception = None
self._unix = unix self._unix = unix
self._not_paused.set()
self._body_chunks = deque()
@property def _setup_connection(self):
def keep_alive(self): self._http = Http(self)
self._time = current_time()
self.check_timeouts()
async def connection_task(self):
"""Run a HTTP connection.
Timeouts and some additional error handling occur here, while most of
everything else happens in class Http or in code called from there.
""" """
Check if the connection needs to be kept alive based on the params try:
attached to the `_keep_alive` attribute, :attr:`Signal.stopped` self._setup_connection()
and :func:`HttpProtocol.parser.should_keep_alive` await self._http.http1()
except CancelledError:
:return: ``True`` if connection is to be kept alive ``False`` else pass
""" except Exception:
return ( logger.exception("protocol.connection_task uncaught")
self._keep_alive finally:
and not self.signal.stopped if self.app.debug and self._http:
and self.parser.should_keep_alive() ip = self.transport.get_extra_info("peername")
logger.error(
"Connection lost before response written"
f" @ {ip} {self._http.request}"
) )
self._http = None
self._task = None
try:
self.close()
except BaseException:
logger.exception("Closing failed")
# -------------------------------------------- # async def receive_more(self):
# Connection """Wait until more data is received into self._buffer."""
# -------------------------------------------- # self.transport.resume_reading()
self._data_received.clear()
await self._data_received.wait()
def connection_made(self, transport): def check_timeouts(self):
self.connections.add(self) """Runs itself periodically to enforce any expired timeouts."""
self._request_timeout_handler = self.loop.call_later( try:
self.request_timeout, self.request_timeout_callback if not self._task:
) return
self.transport = transport duration = current_time() - self._time
self.conn_info = ConnInfo(transport, unix=self._unix) stage = self._http.stage
self._last_request_time = time() if stage is Stage.IDLE and duration > self.keep_alive_timeout:
def connection_lost(self, exc):
self.connections.discard(self)
if self._request_handler_task:
self._request_handler_task.cancel()
if self._request_stream_task:
self._request_stream_task.cancel()
if self._request_timeout_handler:
self._request_timeout_handler.cancel()
if self._response_timeout_handler:
self._response_timeout_handler.cancel()
if self._keep_alive_timeout_handler:
self._keep_alive_timeout_handler.cancel()
def pause_writing(self):
self._not_paused.clear()
def resume_writing(self):
self._not_paused.set()
def request_timeout_callback(self):
# See the docstring in the RequestTimeout exception, to see
# exactly what this timeout is checking for.
# Check if elapsed time since request initiated exceeds our
# configured maximum request timeout value
time_elapsed = time() - self._last_request_time
if time_elapsed < self.request_timeout:
time_left = self.request_timeout - time_elapsed
self._request_timeout_handler = self.loop.call_later(
time_left, self.request_timeout_callback
)
else:
if self._request_stream_task:
self._request_stream_task.cancel()
if self._request_handler_task:
self._request_handler_task.cancel()
self.write_error(RequestTimeout("Request Timeout"))
def response_timeout_callback(self):
# Check if elapsed time since response was initiated exceeds our
# configured maximum request timeout value
time_elapsed = time() - self._last_request_time
if time_elapsed < self.response_timeout:
time_left = self.response_timeout - time_elapsed
self._response_timeout_handler = self.loop.call_later(
time_left, self.response_timeout_callback
)
else:
if self._request_stream_task:
self._request_stream_task.cancel()
if self._request_handler_task:
self._request_handler_task.cancel()
self.write_error(ServiceUnavailable("Response Timeout"))
def keep_alive_timeout_callback(self):
"""
Check if elapsed time since last response exceeds our configured
maximum keep alive timeout value and if so, close the transport
pipe and let the response writer handle the error.
:return: None
"""
time_elapsed = time() - self._last_response_time
if time_elapsed < self.keep_alive_timeout:
time_left = self.keep_alive_timeout - time_elapsed
self._keep_alive_timeout_handler = self.loop.call_later(
time_left, self.keep_alive_timeout_callback
)
else:
logger.debug("KeepAlive Timeout. Closing connection.") logger.debug("KeepAlive Timeout. Closing connection.")
self.transport.close() elif stage is Stage.REQUEST and duration > self.request_timeout:
self.transport = None self._http.exception = RequestTimeout("Request Timeout")
elif (
# -------------------------------------------- # stage in (Stage.HANDLER, Stage.RESPONSE, Stage.FAILED)
# Parsing and duration > self.response_timeout
# -------------------------------------------- #
def data_received(self, data):
# Check for the request itself getting too large and exceeding
# memory limits
self._total_request_size += len(data)
if self._total_request_size > self.request_max_size:
self.write_error(PayloadTooLarge("Payload Too Large"))
# Create parser if this is the first time we're receiving data
if self.parser is None:
assert self.request is None
self.headers = []
self.parser = HttpRequestParser(self)
# requests count
self.state["requests_count"] = self.state["requests_count"] + 1
# Parse request chunk or close connection
try:
self.parser.feed_data(data)
except HttpParserError:
message = "Bad Request"
if self.app.debug:
message += "\n" + traceback.format_exc()
self.write_error(InvalidUsage(message))
def on_url(self, url):
if not self.url:
self.url = url
else:
self.url += url
def on_header(self, name, value):
self._header_fragment += name
if value is not None:
if (
self._header_fragment == b"Content-Length"
and int(value) > self.request_max_size
): ):
self.write_error(PayloadTooLarge("Payload Too Large")) self._http.exception = ServiceUnavailable("Response Timeout")
try:
value = value.decode()
except UnicodeDecodeError:
value = value.decode("latin_1")
self.headers.append(
(self._header_fragment.decode().casefold(), value)
)
self._header_fragment = b""
def on_headers_complete(self):
self.request = self.request_class(
url_bytes=self.url,
headers=Header(self.headers),
version=self.parser.get_http_version(),
method=self.parser.get_method().decode(),
transport=self.transport,
app=self.app,
)
self.request.conn_info = self.conn_info
# Remove any existing KeepAlive handler here,
# It will be recreated if required on the new request.
if self._keep_alive_timeout_handler:
self._keep_alive_timeout_handler.cancel()
self._keep_alive_timeout_handler = None
if self.request.headers.get(EXPECT_HEADER):
self.expect_handler()
if self.is_request_stream:
self._is_stream_handler = self.app.router.is_stream_handler(
self.request
)
if self._is_stream_handler:
self.request.stream = StreamBuffer(
self.request_buffer_queue_size
)
self.execute_request_handler()
def expect_handler(self):
"""
Handler for Expect Header.
"""
expect = self.request.headers.get(EXPECT_HEADER)
if self.request.version == "1.1":
if expect.lower() == "100-continue":
self.transport.write(b"HTTP/1.1 100 Continue\r\n\r\n")
else: else:
self.write_error( interval = (
HeaderExpectationFailed(f"Unknown Expect: {expect}") min(
self.keep_alive_timeout,
self.request_timeout,
self.response_timeout,
) )
/ 2
def on_body(self, body):
if self.is_request_stream and self._is_stream_handler:
# body chunks can be put into asyncio.Queue out of order if
# multiple tasks put concurrently and the queue is full in python
# 3.7. so we should not create more than one task putting into the
# queue simultaneously.
self._body_chunks.append(body)
if (
not self._request_stream_task
or self._request_stream_task.done()
):
self._request_stream_task = self.loop.create_task(
self.stream_append()
) )
else: self.loop.call_later(max(0.1, interval), self.check_timeouts)
self.request.body_push(body)
async def body_append(self, body):
if (
self.request is None
or self._request_stream_task is None
or self._request_stream_task.cancelled()
):
return return
self._task.cancel()
except Exception:
logger.exception("protocol.check_timeouts")
if self.request.stream.is_full(): async def send(self, data):
self.transport.pause_reading() """Writes data with backpressure control."""
await self.request.stream.put(body) await self._can_write.wait()
self.transport.resume_reading() if self.transport.is_closing():
else: raise CancelledError
await self.request.stream.put(body)
async def stream_append(self):
while self._body_chunks:
body = self._body_chunks.popleft()
if self.request:
if self.request.stream.is_full():
self.transport.pause_reading()
await self.request.stream.put(body)
self.transport.resume_reading()
else:
await self.request.stream.put(body)
def on_message_complete(self):
# Entire request (headers and whole body) is received.
# We can cancel and remove the request timeout handler now.
if self._request_timeout_handler:
self._request_timeout_handler.cancel()
self._request_timeout_handler = None
if self.is_request_stream and self._is_stream_handler:
self._body_chunks.append(None)
if (
not self._request_stream_task
or self._request_stream_task.done()
):
self._request_stream_task = self.loop.create_task(
self.stream_append()
)
return
self.request.body_finish()
self.execute_request_handler()
def execute_request_handler(self):
"""
Invoke the request handler defined by the
:func:`sanic.app.Sanic.handle_request` method
:return: None
"""
self._response_timeout_handler = self.loop.call_later(
self.response_timeout, self.response_timeout_callback
)
self._last_request_time = time()
self._request_handler_task = self.loop.create_task(
self.request_handler(
self.request, self.write_response, self.stream_response
)
)
# -------------------------------------------- #
# Responding
# -------------------------------------------- #
def log_response(self, response):
"""
Helper method provided to enable the logging of responses in case if
the :attr:`HttpProtocol.access_log` is enabled.
:param response: Response generated for the current request
:type response: :class:`sanic.response.HTTPResponse` or
:class:`sanic.response.StreamingHTTPResponse`
:return: None
"""
if self.access_log:
extra = {"status": getattr(response, "status", 0)}
if isinstance(response, HTTPResponse):
extra["byte"] = len(response.body)
else:
extra["byte"] = -1
extra["host"] = "UNKNOWN"
if self.request is not None:
if self.request.ip:
extra["host"] = f"{self.request.ip}:{self.request.port}"
extra["request"] = f"{self.request.method} {self.request.url}"
else:
extra["request"] = "nil"
access_logger.info("", extra=extra)
def write_response(self, response):
"""
Writes response content synchronously to the transport.
"""
if self._response_timeout_handler:
self._response_timeout_handler.cancel()
self._response_timeout_handler = None
try:
keep_alive = self.keep_alive
self.transport.write(
response.output(
self.request.version, keep_alive, self.keep_alive_timeout
)
)
self.log_response(response)
except AttributeError:
logger.error(
"Invalid response object for url %s, "
"Expected Type: HTTPResponse, Actual Type: %s",
self.url,
type(response),
)
self.write_error(ServerError("Invalid response type"))
except RuntimeError:
if self.app.debug:
logger.error(
"Connection lost before response written @ %s",
self.request.ip,
)
keep_alive = False
except Exception as e:
self.bail_out(f"Writing response failed, connection closed {e!r}")
finally:
if not keep_alive:
self.transport.close()
self.transport = None
else:
self._keep_alive_timeout_handler = self.loop.call_later(
self.keep_alive_timeout, self.keep_alive_timeout_callback
)
self._last_response_time = time()
self.cleanup()
async def drain(self):
await self._not_paused.wait()
async def push_data(self, data):
self.transport.write(data) self.transport.write(data)
self._time = current_time()
async def stream_response(self, response):
"""
Streams a response to the client asynchronously. Attaches
the transport to the response so the response consumer can
write to the response as needed.
"""
if self._response_timeout_handler:
self._response_timeout_handler.cancel()
self._response_timeout_handler = None
try:
keep_alive = self.keep_alive
response.protocol = self
await response.stream(
self.request.version, keep_alive, self.keep_alive_timeout
)
self.log_response(response)
except AttributeError:
logger.error(
"Invalid response object for url %s, "
"Expected Type: HTTPResponse, Actual Type: %s",
self.url,
type(response),
)
self.write_error(ServerError("Invalid response type"))
except RuntimeError:
if self.app.debug:
logger.error(
"Connection lost before response written @ %s",
self.request.ip,
)
keep_alive = False
except Exception as e:
self.bail_out(f"Writing response failed, connection closed {e!r}")
finally:
if not keep_alive:
self.transport.close()
self.transport = None
else:
self._keep_alive_timeout_handler = self.loop.call_later(
self.keep_alive_timeout, self.keep_alive_timeout_callback
)
self._last_response_time = time()
self.cleanup()
def write_error(self, exception):
# An error _is_ a response.
# Don't throw a response timeout, when a response _is_ given.
if self._response_timeout_handler:
self._response_timeout_handler.cancel()
self._response_timeout_handler = None
response = None
try:
response = self.error_handler.response(self.request, exception)
version = self.request.version if self.request else "1.1"
self.transport.write(response.output(version))
except RuntimeError:
if self.app.debug:
logger.error(
"Connection lost before error written @ %s",
self.request.ip if self.request else "Unknown",
)
except Exception as e:
self.bail_out(
f"Writing error failed, connection closed {e!r}",
from_error=True,
)
finally:
if self.parser and (
self.keep_alive or getattr(response, "status", 0) == 408
):
self.log_response(response)
try:
self.transport.close()
except AttributeError:
logger.debug("Connection lost before server could close it.")
def bail_out(self, message, from_error=False):
"""
In case if the transport pipes are closed and the sanic app encounters
an error while writing data to the transport pipe, we log the error
with proper details.
:param message: Error message to display
:param from_error: If the bail out was invoked while handling an
exception scenario.
:type message: str
:type from_error: bool
:return: None
"""
if from_error or self.transport is None or self.transport.is_closing():
logger.error(
"Transport closed @ %s and exception "
"experienced during error handling",
(
self.transport.get_extra_info("peername")
if self.transport is not None
else "N/A"
),
)
logger.debug("Exception:", exc_info=True)
else:
self.write_error(ServerError(message))
logger.error(message)
def cleanup(self):
"""This is called when KeepAlive feature is used,
it resets the connection in order for it to be able
to handle receiving another request on the same connection."""
self.parser = None
self.request = None
self.url = None
self.headers = None
self._request_handler_task = None
self._request_stream_task = None
self._total_request_size = 0
self._is_stream_handler = False
def close_if_idle(self): def close_if_idle(self):
"""Close the connection if a request is not being sent or received """Close the connection if a request is not being sent or received
:return: boolean - True if closed, false if staying open :return: boolean - True if closed, false if staying open
""" """
if not self.parser and self.transport is not None: if self._http is None or self._http.stage is Stage.IDLE:
self.transport.close() self.close()
return True return True
return False return False
@ -679,10 +244,57 @@ class HttpProtocol(asyncio.Protocol):
""" """
Force close the connection. Force close the connection.
""" """
if self.transport is not None: # Cause a call to connection_lost where further cleanup occurs
if self.transport:
self.transport.close() self.transport.close()
self.transport = None self.transport = None
# -------------------------------------------- #
# Only asyncio.Protocol callbacks below this
# -------------------------------------------- #
def connection_made(self, transport):
try:
# TODO: Benchmark to find suitable write buffer limits
transport.set_write_buffer_limits(low=16384, high=65536)
self.connections.add(self)
self.transport = transport
self._task = self.loop.create_task(self.connection_task())
self.recv_buffer = bytearray()
self.conn_info = ConnInfo(self.transport, unix=self._unix)
except Exception:
logger.exception("protocol.connect_made")
def connection_lost(self, exc):
try:
self.connections.discard(self)
self.resume_writing()
if self._task:
self._task.cancel()
except Exception:
logger.exception("protocol.connection_lost")
def pause_writing(self):
self._can_write.clear()
def resume_writing(self):
self._can_write.set()
def data_received(self, data):
try:
self._time = current_time()
if not data:
return self.close()
self.recv_buffer += data
if len(self.recv_buffer) > self.app.config.REQUEST_BUFFER_SIZE:
self.transport.pause_reading()
if self._data_received:
self._data_received.set()
except Exception:
logger.exception("protocol.data_received")
def trigger_events(events, loop): def trigger_events(events, loop):
"""Trigger event callbacks (functions or async) """Trigger event callbacks (functions or async)

View File

@ -47,11 +47,21 @@ class SanicTestClient:
async with self.get_new_session() as session: async with self.get_new_session() as session:
try: try:
if method == "request":
args = [url] + list(args)
url = kwargs.pop("http_method", "GET").upper()
response = await getattr(session, method.lower())( response = await getattr(session, method.lower())(
url, *args, **kwargs url, *args, **kwargs
) )
except NameError: except httpx.HTTPError as e:
raise Exception(response.status_code) if hasattr(e, "response"):
response = e.response
else:
logger.error(
f"{method.upper()} {url} received no response!",
exc_info=True,
)
return None
response.body = await response.aread() response.body = await response.aread()
response.status = response.status_code response.status = response.status_code
@ -85,7 +95,6 @@ class SanicTestClient:
): ):
results = [None, None] results = [None, None]
exceptions = [] exceptions = []
if gather_request: if gather_request:
def _collect_request(request): def _collect_request(request):
@ -161,6 +170,9 @@ class SanicTestClient:
except BaseException: # noqa except BaseException: # noqa
raise ValueError(f"Request object expected, got ({results})") raise ValueError(f"Request object expected, got ({results})")
def request(self, *args, **kwargs):
return self._sanic_endpoint_test("request", *args, **kwargs)
def get(self, *args, **kwargs): def get(self, *args, **kwargs):
return self._sanic_endpoint_test("get", *args, **kwargs) return self._sanic_endpoint_test("get", *args, **kwargs)

View File

@ -118,7 +118,7 @@ def test_app_route_raise_value_error(app):
def test_app_handle_request_handler_is_none(app, monkeypatch): def test_app_handle_request_handler_is_none(app, monkeypatch):
def mockreturn(*args, **kwargs): def mockreturn(*args, **kwargs):
return None, [], {}, "", "", None return None, [], {}, "", "", None, False
# Not sure how to make app.router.get() return None, so use mock here. # Not sure how to make app.router.get() return None, so use mock here.
monkeypatch.setattr(app.router, "get", mockreturn) monkeypatch.setattr(app.router, "get", mockreturn)

View File

@ -8,7 +8,7 @@ def test_bad_request_response(app):
async def _request(sanic, loop): async def _request(sanic, loop):
connect = asyncio.open_connection("127.0.0.1", 42101) connect = asyncio.open_connection("127.0.0.1", 42101)
reader, writer = await connect reader, writer = await connect
writer.write(b"not http") writer.write(b"not http\r\n\r\n")
while True: while True:
line = await reader.readline() line = await reader.readline()
if not line: if not line:

View File

@ -71,7 +71,6 @@ def test_bp(app):
app.blueprint(bp) app.blueprint(bp)
request, response = app.test_client.get("/") request, response = app.test_client.get("/")
assert app.is_request_stream is False
assert response.text == "Hello" assert response.text == "Hello"
@ -505,48 +504,38 @@ def test_bp_shorthand(app):
@blueprint.get("/get") @blueprint.get("/get")
def handler(request): def handler(request):
assert request.stream is None
return text("OK") return text("OK")
@blueprint.put("/put") @blueprint.put("/put")
def put_handler(request): def put_handler(request):
assert request.stream is None
return text("OK") return text("OK")
@blueprint.post("/post") @blueprint.post("/post")
def post_handler(request): def post_handler(request):
assert request.stream is None
return text("OK") return text("OK")
@blueprint.head("/head") @blueprint.head("/head")
def head_handler(request): def head_handler(request):
assert request.stream is None
return text("OK") return text("OK")
@blueprint.options("/options") @blueprint.options("/options")
def options_handler(request): def options_handler(request):
assert request.stream is None
return text("OK") return text("OK")
@blueprint.patch("/patch") @blueprint.patch("/patch")
def patch_handler(request): def patch_handler(request):
assert request.stream is None
return text("OK") return text("OK")
@blueprint.delete("/delete") @blueprint.delete("/delete")
def delete_handler(request): def delete_handler(request):
assert request.stream is None
return text("OK") return text("OK")
@blueprint.websocket("/ws/", strict_slashes=True) @blueprint.websocket("/ws/", strict_slashes=True)
async def websocket_handler(request, ws): async def websocket_handler(request, ws):
assert request.stream is None
ev.set() ev.set()
app.blueprint(blueprint) app.blueprint(blueprint)
assert app.is_request_stream is False
request, response = app.test_client.get("/get") request, response = app.test_client.get("/get")
assert response.text == "OK" assert response.text == "OK"

View File

@ -6,17 +6,14 @@ from sanic.response import json_dumps, text
class CustomRequest(Request): class CustomRequest(Request):
__slots__ = ("body_buffer",) """Alternative implementation for loading body (non-streaming handlers)"""
def body_init(self): async def receive_body(self):
self.body_buffer = BytesIO() buffer = BytesIO()
async for data in self.stream:
def body_push(self, data): buffer.write(data)
self.body_buffer.write(data) self.body = buffer.getvalue().upper()
buffer.close()
def body_finish(self):
self.body = self.body_buffer.getvalue()
self.body_buffer.close()
def test_custom_request(): def test_custom_request():
@ -37,17 +34,13 @@ def test_custom_request():
"/post", data=json_dumps(payload), headers=headers "/post", data=json_dumps(payload), headers=headers
) )
assert isinstance(request.body_buffer, BytesIO) assert request.body == b'{"TEST":"OK"}'
assert request.body_buffer.closed assert request.json.get("TEST") == "OK"
assert request.body == b'{"test":"OK"}'
assert request.json.get("test") == "OK"
assert response.text == "OK" assert response.text == "OK"
assert response.status == 200 assert response.status == 200
request, response = app.test_client.get("/get") request, response = app.test_client.get("/get")
assert isinstance(request.body_buffer, BytesIO)
assert request.body_buffer.closed
assert request.body == b"" assert request.body == b""
assert response.text == "OK" assert response.text == "OK"
assert response.status == 200 assert response.status == 200

View File

@ -1,14 +1,26 @@
import asyncio
from bs4 import BeautifulSoup from bs4 import BeautifulSoup
from sanic import Sanic from sanic import Sanic
from sanic.exceptions import InvalidUsage, NotFound, ServerError from sanic.exceptions import Forbidden, InvalidUsage, NotFound, ServerError
from sanic.handlers import ErrorHandler from sanic.handlers import ErrorHandler
from sanic.response import text from sanic.response import stream, text
exception_handler_app = Sanic("test_exception_handler") exception_handler_app = Sanic("test_exception_handler")
async def sample_streaming_fn(response):
await response.write("foo,")
await asyncio.sleep(0.001)
await response.write("bar")
class ErrorWithRequestCtx(ServerError):
pass
@exception_handler_app.route("/1") @exception_handler_app.route("/1")
def handler_1(request): def handler_1(request):
raise InvalidUsage("OK") raise InvalidUsage("OK")
@ -47,11 +59,40 @@ def handler_6(request, arg):
return text(foo) return text(foo)
@exception_handler_app.exception(NotFound, ServerError) @exception_handler_app.route("/7")
def handler_7(request):
raise Forbidden("go away!")
@exception_handler_app.route("/8")
def handler_8(request):
raise ErrorWithRequestCtx("OK")
@exception_handler_app.exception(ErrorWithRequestCtx, NotFound)
def handler_exception_with_ctx(request, exception):
return text(request.ctx.middleware_ran)
@exception_handler_app.exception(ServerError)
def handler_exception(request, exception): def handler_exception(request, exception):
return text("OK") return text("OK")
@exception_handler_app.exception(Forbidden)
async def async_handler_exception(request, exception):
return stream(
sample_streaming_fn,
content_type="text/csv",
)
@exception_handler_app.middleware
async def some_request_middleware(request):
request.ctx.middleware_ran = "Done."
def test_invalid_usage_exception_handler(): def test_invalid_usage_exception_handler():
request, response = exception_handler_app.test_client.get("/1") request, response = exception_handler_app.test_client.get("/1")
assert response.status == 400 assert response.status == 400
@ -71,7 +112,13 @@ def test_not_found_exception_handler():
def test_text_exception__handler(): def test_text_exception__handler():
request, response = exception_handler_app.test_client.get("/random") request, response = exception_handler_app.test_client.get("/random")
assert response.status == 200 assert response.status == 200
assert response.text == "OK" assert response.text == "Done."
def test_async_exception_handler():
request, response = exception_handler_app.test_client.get("/7")
assert response.status == 200
assert response.text == "foo,bar"
def test_html_traceback_output_in_debug_mode(): def test_html_traceback_output_in_debug_mode():
@ -156,3 +203,9 @@ def test_exception_handler_lookup():
assert handler.lookup(CustomError()) == custom_error_handler assert handler.lookup(CustomError()) == custom_error_handler
assert handler.lookup(ServerError("Error")) == server_error_handler assert handler.lookup(ServerError("Error")) == server_error_handler
assert handler.lookup(CustomServerError("Error")) == server_error_handler assert handler.lookup(CustomServerError("Error")) == server_error_handler
def test_exception_handler_processed_request_middleware():
request, response = exception_handler_app.test_client.get("/8")
assert response.status == 200
assert response.text == "Done."

View File

@ -1,6 +1,12 @@
from unittest.mock import Mock
import pytest import pytest
from sanic import headers from sanic import Sanic, headers
from sanic.compat import Header
from sanic.exceptions import PayloadTooLarge
from sanic.http import Http
from sanic.request import Request
@pytest.mark.parametrize( @pytest.mark.parametrize(
@ -61,3 +67,21 @@ from sanic import headers
) )
def test_parse_headers(input, expected): def test_parse_headers(input, expected):
assert headers.parse_content_header(input) == expected assert headers.parse_content_header(input) == expected
@pytest.mark.asyncio
async def test_header_size_exceeded():
recv_buffer = bytearray()
async def _receive_more():
nonlocal recv_buffer
recv_buffer += b"123"
protocol = Mock()
http = Http(protocol)
http._receive_more = _receive_more
http.request_max_size = 1
http.recv_buffer = recv_buffer
with pytest.raises(PayloadTooLarge):
await http.http1_request_header()

View File

@ -2,11 +2,14 @@ import asyncio
from asyncio import sleep as aio_sleep from asyncio import sleep as aio_sleep
from json import JSONDecodeError from json import JSONDecodeError
from os import environ
import httpcore import httpcore
import httpx import httpx
import pytest
from sanic import Sanic, server from sanic import Sanic, server
from sanic.compat import OS_IS_WINDOWS
from sanic.response import text from sanic.response import text
from sanic.testing import HOST, SanicTestClient from sanic.testing import HOST, SanicTestClient
@ -202,6 +205,10 @@ async def handler3(request):
return text("OK") return text("OK")
@pytest.mark.skipif(
bool(environ.get("SANIC_NO_UVLOOP")) or OS_IS_WINDOWS,
reason="Not testable with current client",
)
def test_keep_alive_timeout_reuse(): def test_keep_alive_timeout_reuse():
"""If the server keep-alive timeout and client keep-alive timeout are """If the server keep-alive timeout and client keep-alive timeout are
both longer than the delay, the client _and_ server will successfully both longer than the delay, the client _and_ server will successfully
@ -223,6 +230,10 @@ def test_keep_alive_timeout_reuse():
client.kill_server() client.kill_server()
@pytest.mark.skipif(
bool(environ.get("SANIC_NO_UVLOOP")) or OS_IS_WINDOWS,
reason="Not testable with current client",
)
def test_keep_alive_client_timeout(): def test_keep_alive_client_timeout():
"""If the server keep-alive timeout is longer than the client """If the server keep-alive timeout is longer than the client
keep-alive timeout, client will try to create a new connection here.""" keep-alive timeout, client will try to create a new connection here."""
@ -244,6 +255,10 @@ def test_keep_alive_client_timeout():
client.kill_server() client.kill_server()
@pytest.mark.skipif(
bool(environ.get("SANIC_NO_UVLOOP")) or OS_IS_WINDOWS,
reason="Not testable with current client",
)
def test_keep_alive_server_timeout(): def test_keep_alive_server_timeout():
"""If the client keep-alive timeout is longer than the server """If the client keep-alive timeout is longer than the server
keep-alive timeout, the client will either a 'Connection reset' error keep-alive timeout, the client will either a 'Connection reset' error

View File

@ -1,16 +1,17 @@
import logging import logging
import os import os
import sys
import uuid import uuid
from importlib import reload from importlib import reload
from io import StringIO from io import StringIO
from unittest.mock import Mock
import pytest import pytest
import sanic import sanic
from sanic import Sanic from sanic import Sanic
from sanic.compat import OS_IS_WINDOWS
from sanic.log import LOGGING_CONFIG_DEFAULTS, logger from sanic.log import LOGGING_CONFIG_DEFAULTS, logger
from sanic.response import text from sanic.response import text
from sanic.testing import SanicTestClient from sanic.testing import SanicTestClient
@ -111,12 +112,11 @@ def test_log_connection_lost(app, debug, monkeypatch):
@app.route("/conn_lost") @app.route("/conn_lost")
async def conn_lost(request): async def conn_lost(request):
response = text("Ok") response = text("Ok")
response.output = Mock(side_effect=RuntimeError) request.transport.close()
return response return response
with pytest.raises(ValueError): req, res = app.test_client.get("/conn_lost", debug=debug)
# catch ValueError: Exception during request assert res is None
app.test_client.get("/conn_lost", debug=debug)
log = stream.getvalue() log = stream.getvalue()
@ -126,7 +126,8 @@ def test_log_connection_lost(app, debug, monkeypatch):
assert "Connection lost before response written @" not in log assert "Connection lost before response written @" not in log
def test_logger(caplog): @pytest.mark.asyncio
async def test_logger(caplog):
rand_string = str(uuid.uuid4()) rand_string = str(uuid.uuid4())
app = Sanic(name=__name__) app = Sanic(name=__name__)
@ -137,32 +138,16 @@ def test_logger(caplog):
return text("hello") return text("hello")
with caplog.at_level(logging.INFO): with caplog.at_level(logging.INFO):
request, response = app.test_client.get("/") _ = await app.asgi_client.get("/")
port = request.server_port record = ("sanic.root", logging.INFO, rand_string)
assert record in caplog.record_tuples
# Note: testing with random port doesn't show the banner because it doesn't
# define host and port. This test supports both modes.
if caplog.record_tuples[0] == (
"sanic.root",
logging.INFO,
f"Goin' Fast @ http://127.0.0.1:{port}",
):
caplog.record_tuples.pop(0)
assert caplog.record_tuples[0] == (
"sanic.root",
logging.INFO,
f"http://127.0.0.1:{port}/",
)
assert caplog.record_tuples[1] == ("sanic.root", logging.INFO, rand_string)
assert caplog.record_tuples[-1] == (
"sanic.root",
logging.INFO,
"Server Stopped",
)
@pytest.mark.skipif(
OS_IS_WINDOWS and sys.version_info >= (3, 8),
reason="Not testable with current client",
)
def test_logger_static_and_secure(caplog): def test_logger_static_and_secure(caplog):
# Same as test_logger, except for more coverage: # Same as test_logger, except for more coverage:
# - test_client initialised separately for static port # - test_client initialised separately for static port

View File

@ -1,6 +1,7 @@
import logging import logging
from asyncio import CancelledError from asyncio import CancelledError
from itertools import count
from sanic.exceptions import NotFound, SanicException from sanic.exceptions import NotFound, SanicException
from sanic.request import Request from sanic.request import Request
@ -54,7 +55,7 @@ def test_middleware_response(app):
def test_middleware_response_exception(app): def test_middleware_response_exception(app):
result = {"status_code": None} result = {"status_code": "middleware not run"}
@app.middleware("response") @app.middleware("response")
async def process_response(request, response): async def process_response(request, response):
@ -184,3 +185,23 @@ def test_middleware_order(app):
assert response.status == 200 assert response.status == 200
assert order == [1, 2, 3, 4, 5, 6] assert order == [1, 2, 3, 4, 5, 6]
def test_request_middleware_executes_once(app):
i = count()
@app.middleware("request")
async def inc(request):
nonlocal i
next(i)
@app.route("/")
async def handler(request):
await request.app._run_request_middleware(request)
return text("OK")
request, response = app.test_client.get("/")
assert next(i) == 1
request, response = app.test_client.get("/")
assert next(i) == 3

View File

@ -85,7 +85,6 @@ def test_pickle_app_with_bp(app, protocol):
up_p_app = pickle.loads(p_app) up_p_app = pickle.loads(p_app)
assert up_p_app assert up_p_app
request, response = up_p_app.test_client.get("/") request, response = up_p_app.test_client.get("/")
assert up_p_app.is_request_stream is False
assert response.text == "Hello" assert response.text == "Hello"

View File

@ -107,10 +107,8 @@ def test_shorthand_named_routes_post(app):
def test_shorthand_named_routes_put(app): def test_shorthand_named_routes_put(app):
@app.put("/put", name="route_put") @app.put("/put", name="route_put")
def handler(request): def handler(request):
assert request.stream is None
return text("OK") return text("OK")
assert app.is_request_stream is False
assert app.router.routes_all["/put"].name == "route_put" assert app.router.routes_all["/put"].name == "route_put"
assert app.url_for("route_put") == "/put" assert app.url_for("route_put") == "/put"
with pytest.raises(URLBuildError): with pytest.raises(URLBuildError):
@ -120,10 +118,8 @@ def test_shorthand_named_routes_put(app):
def test_shorthand_named_routes_delete(app): def test_shorthand_named_routes_delete(app):
@app.delete("/delete", name="route_delete") @app.delete("/delete", name="route_delete")
def handler(request): def handler(request):
assert request.stream is None
return text("OK") return text("OK")
assert app.is_request_stream is False
assert app.router.routes_all["/delete"].name == "route_delete" assert app.router.routes_all["/delete"].name == "route_delete"
assert app.url_for("route_delete") == "/delete" assert app.url_for("route_delete") == "/delete"
with pytest.raises(URLBuildError): with pytest.raises(URLBuildError):
@ -133,10 +129,8 @@ def test_shorthand_named_routes_delete(app):
def test_shorthand_named_routes_patch(app): def test_shorthand_named_routes_patch(app):
@app.patch("/patch", name="route_patch") @app.patch("/patch", name="route_patch")
def handler(request): def handler(request):
assert request.stream is None
return text("OK") return text("OK")
assert app.is_request_stream is False
assert app.router.routes_all["/patch"].name == "route_patch" assert app.router.routes_all["/patch"].name == "route_patch"
assert app.url_for("route_patch") == "/patch" assert app.url_for("route_patch") == "/patch"
with pytest.raises(URLBuildError): with pytest.raises(URLBuildError):
@ -146,10 +140,8 @@ def test_shorthand_named_routes_patch(app):
def test_shorthand_named_routes_head(app): def test_shorthand_named_routes_head(app):
@app.head("/head", name="route_head") @app.head("/head", name="route_head")
def handler(request): def handler(request):
assert request.stream is None
return text("OK") return text("OK")
assert app.is_request_stream is False
assert app.router.routes_all["/head"].name == "route_head" assert app.router.routes_all["/head"].name == "route_head"
assert app.url_for("route_head") == "/head" assert app.url_for("route_head") == "/head"
with pytest.raises(URLBuildError): with pytest.raises(URLBuildError):
@ -159,10 +151,8 @@ def test_shorthand_named_routes_head(app):
def test_shorthand_named_routes_options(app): def test_shorthand_named_routes_options(app):
@app.options("/options", name="route_options") @app.options("/options", name="route_options")
def handler(request): def handler(request):
assert request.stream is None
return text("OK") return text("OK")
assert app.is_request_stream is False
assert app.router.routes_all["/options"].name == "route_options" assert app.router.routes_all["/options"].name == "route_options"
assert app.url_for("route_options") == "/options" assert app.url_for("route_options") == "/options"
with pytest.raises(URLBuildError): with pytest.raises(URLBuildError):

View File

@ -27,7 +27,7 @@ def test_payload_too_large_at_data_received_default(app):
response = app.test_client.get("/1", gather_request=False) response = app.test_client.get("/1", gather_request=False)
assert response.status == 413 assert response.status == 413
assert "Payload Too Large" in response.text assert "Request header" in response.text
def test_payload_too_large_at_on_header_default(app): def test_payload_too_large_at_on_header_default(app):
@ -40,4 +40,4 @@ def test_payload_too_large_at_on_header_default(app):
data = "a" * 1000 data = "a" * 1000
response = app.test_client.post("/1", gather_request=False, data=data) response = app.test_client.post("/1", gather_request=False, data=data)
assert response.status == 413 assert response.status == 413
assert "Payload Too Large" in response.text assert "Request body" in response.text

View File

@ -1,37 +0,0 @@
import io
from sanic.response import text
data = "abc" * 10_000_000
def test_request_buffer_queue_size(app):
default_buf_qsz = app.config.get("REQUEST_BUFFER_QUEUE_SIZE")
qsz = 1
while qsz == default_buf_qsz:
qsz += 1
app.config.REQUEST_BUFFER_QUEUE_SIZE = qsz
@app.post("/post", stream=True)
async def post(request):
assert request.stream.buffer_size == qsz
print("request.stream.buffer_size =", request.stream.buffer_size)
bio = io.BytesIO()
while True:
bdata = await request.stream.read()
if not bdata:
break
bio.write(bdata)
head = bdata[:3].decode("utf-8")
tail = bdata[3:][-3:].decode("utf-8")
print(head, "...", tail)
bio.seek(0)
return text(bio.read().decode("utf-8"))
request, response = app.test_client.post("/post", data=data)
assert response.status == 200
assert response.text == data

View File

@ -1,11 +1,13 @@
import asyncio import asyncio
from contextlib import closing
from socket import socket
import pytest import pytest
from sanic import Sanic
from sanic.blueprints import Blueprint from sanic.blueprints import Blueprint
from sanic.exceptions import HeaderExpectationFailed from sanic.response import json, text
from sanic.request import StreamBuffer
from sanic.response import json, stream, text
from sanic.server import HttpProtocol from sanic.server import HttpProtocol
from sanic.views import CompositionView, HTTPMethodView from sanic.views import CompositionView, HTTPMethodView
from sanic.views import stream as stream_decorator from sanic.views import stream as stream_decorator
@ -15,28 +17,22 @@ data = "abc" * 1_000_000
def test_request_stream_method_view(app): def test_request_stream_method_view(app):
"""for self.is_request_stream = True"""
class SimpleView(HTTPMethodView): class SimpleView(HTTPMethodView):
def get(self, request): def get(self, request):
assert request.stream is None
return text("OK") return text("OK")
@stream_decorator @stream_decorator
async def post(self, request): async def post(self, request):
assert isinstance(request.stream, StreamBuffer) result = b""
result = ""
while True: while True:
body = await request.stream.read() body = await request.stream.read()
if body is None: if body is None:
break break
result += body.decode("utf-8") result += body
return text(result) return text(result.decode())
app.add_route(SimpleView.as_view(), "/method_view") app.add_route(SimpleView.as_view(), "/method_view")
assert app.is_request_stream is True
request, response = app.test_client.get("/method_view") request, response = app.test_client.get("/method_view")
assert response.status == 200 assert response.status == 200
assert response.text == "OK" assert response.text == "OK"
@ -50,14 +46,16 @@ def test_request_stream_method_view(app):
"headers, expect_raise_exception", "headers, expect_raise_exception",
[ [
({"EXPECT": "100-continue"}, False), ({"EXPECT": "100-continue"}, False),
({"EXPECT": "100-continue-extra"}, True), # The below test SHOULD work, and it does produce a 417
# However, httpx now intercepts this and raises an exception,
# so we will need a new method for testing this
# ({"EXPECT": "100-continue-extra"}, True),
], ],
) )
def test_request_stream_100_continue(app, headers, expect_raise_exception): def test_request_stream_100_continue(app, headers, expect_raise_exception):
class SimpleView(HTTPMethodView): class SimpleView(HTTPMethodView):
@stream_decorator @stream_decorator
async def post(self, request): async def post(self, request):
assert isinstance(request.stream, StreamBuffer)
result = "" result = ""
while True: while True:
body = await request.stream.read() body = await request.stream.read()
@ -68,55 +66,42 @@ def test_request_stream_100_continue(app, headers, expect_raise_exception):
app.add_route(SimpleView.as_view(), "/method_view") app.add_route(SimpleView.as_view(), "/method_view")
assert app.is_request_stream is True
if not expect_raise_exception: if not expect_raise_exception:
request, response = app.test_client.post( request, response = app.test_client.post(
"/method_view", data=data, headers={"EXPECT": "100-continue"} "/method_view", data=data, headers=headers
) )
assert response.status == 200 assert response.status == 200
assert response.text == data assert response.text == data
else: else:
with pytest.raises(ValueError) as e: request, response = app.test_client.post(
app.test_client.post( "/method_view", data=data, headers=headers
"/method_view",
data=data,
headers={"EXPECT": "100-continue-extra"},
) )
assert "Unknown Expect: 100-continue-extra" in str(e) assert response.status == 417
def test_request_stream_app(app): def test_request_stream_app(app):
"""for self.is_request_stream = True and decorators"""
@app.get("/get") @app.get("/get")
async def get(request): async def get(request):
assert request.stream is None
return text("GET") return text("GET")
@app.head("/head") @app.head("/head")
async def head(request): async def head(request):
assert request.stream is None
return text("HEAD") return text("HEAD")
@app.delete("/delete") @app.delete("/delete")
async def delete(request): async def delete(request):
assert request.stream is None
return text("DELETE") return text("DELETE")
@app.options("/options") @app.options("/options")
async def options(request): async def options(request):
assert request.stream is None
return text("OPTIONS") return text("OPTIONS")
@app.post("/_post/<id>") @app.post("/_post/<id>")
async def _post(request, id): async def _post(request, id):
assert request.stream is None
return text("_POST") return text("_POST")
@app.post("/post/<id>", stream=True) @app.post("/post/<id>", stream=True)
async def post(request, id): async def post(request, id):
assert isinstance(request.stream, StreamBuffer)
result = "" result = ""
while True: while True:
body = await request.stream.read() body = await request.stream.read()
@ -127,12 +112,10 @@ def test_request_stream_app(app):
@app.put("/_put") @app.put("/_put")
async def _put(request): async def _put(request):
assert request.stream is None
return text("_PUT") return text("_PUT")
@app.put("/put", stream=True) @app.put("/put", stream=True)
async def put(request): async def put(request):
assert isinstance(request.stream, StreamBuffer)
result = "" result = ""
while True: while True:
body = await request.stream.read() body = await request.stream.read()
@ -143,12 +126,10 @@ def test_request_stream_app(app):
@app.patch("/_patch") @app.patch("/_patch")
async def _patch(request): async def _patch(request):
assert request.stream is None
return text("_PATCH") return text("_PATCH")
@app.patch("/patch", stream=True) @app.patch("/patch", stream=True)
async def patch(request): async def patch(request):
assert isinstance(request.stream, StreamBuffer)
result = "" result = ""
while True: while True:
body = await request.stream.read() body = await request.stream.read()
@ -157,8 +138,6 @@ def test_request_stream_app(app):
result += body.decode("utf-8") result += body.decode("utf-8")
return text(result) return text(result)
assert app.is_request_stream is True
request, response = app.test_client.get("/get") request, response = app.test_client.get("/get")
assert response.status == 200 assert response.status == 200
assert response.text == "GET" assert response.text == "GET"
@ -202,36 +181,28 @@ def test_request_stream_app(app):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_request_stream_app_asgi(app): async def test_request_stream_app_asgi(app):
"""for self.is_request_stream = True and decorators"""
@app.get("/get") @app.get("/get")
async def get(request): async def get(request):
assert request.stream is None
return text("GET") return text("GET")
@app.head("/head") @app.head("/head")
async def head(request): async def head(request):
assert request.stream is None
return text("HEAD") return text("HEAD")
@app.delete("/delete") @app.delete("/delete")
async def delete(request): async def delete(request):
assert request.stream is None
return text("DELETE") return text("DELETE")
@app.options("/options") @app.options("/options")
async def options(request): async def options(request):
assert request.stream is None
return text("OPTIONS") return text("OPTIONS")
@app.post("/_post/<id>") @app.post("/_post/<id>")
async def _post(request, id): async def _post(request, id):
assert request.stream is None
return text("_POST") return text("_POST")
@app.post("/post/<id>", stream=True) @app.post("/post/<id>", stream=True)
async def post(request, id): async def post(request, id):
assert isinstance(request.stream, StreamBuffer)
result = "" result = ""
while True: while True:
body = await request.stream.read() body = await request.stream.read()
@ -242,12 +213,10 @@ async def test_request_stream_app_asgi(app):
@app.put("/_put") @app.put("/_put")
async def _put(request): async def _put(request):
assert request.stream is None
return text("_PUT") return text("_PUT")
@app.put("/put", stream=True) @app.put("/put", stream=True)
async def put(request): async def put(request):
assert isinstance(request.stream, StreamBuffer)
result = "" result = ""
while True: while True:
body = await request.stream.read() body = await request.stream.read()
@ -258,12 +227,10 @@ async def test_request_stream_app_asgi(app):
@app.patch("/_patch") @app.patch("/_patch")
async def _patch(request): async def _patch(request):
assert request.stream is None
return text("_PATCH") return text("_PATCH")
@app.patch("/patch", stream=True) @app.patch("/patch", stream=True)
async def patch(request): async def patch(request):
assert isinstance(request.stream, StreamBuffer)
result = "" result = ""
while True: while True:
body = await request.stream.read() body = await request.stream.read()
@ -272,8 +239,6 @@ async def test_request_stream_app_asgi(app):
result += body.decode("utf-8") result += body.decode("utf-8")
return text(result) return text(result)
assert app.is_request_stream is True
request, response = await app.asgi_client.get("/get") request, response = await app.asgi_client.get("/get")
assert response.status == 200 assert response.status == 200
assert response.text == "GET" assert response.text == "GET"
@ -320,14 +285,13 @@ def test_request_stream_handle_exception(app):
@app.post("/post/<id>", stream=True) @app.post("/post/<id>", stream=True)
async def post(request, id): async def post(request, id):
assert isinstance(request.stream, StreamBuffer) result = b""
result = ""
while True: while True:
body = await request.stream.read() body = await request.stream.read()
if body is None: if body is None:
break break
result += body.decode("utf-8") result += body
return text(result) return text(result.decode())
# 404 # 404
request, response = app.test_client.post("/in_valid_post", data=data) request, response = app.test_client.post("/in_valid_post", data=data)
@ -340,54 +304,31 @@ def test_request_stream_handle_exception(app):
assert "Method GET not allowed for URL /post/random_id" in response.text assert "Method GET not allowed for URL /post/random_id" in response.text
@pytest.mark.asyncio
async def test_request_stream_unread(app):
"""ensure no error is raised when leaving unread bytes in byte-buffer"""
err = None
protocol = HttpProtocol(loop=asyncio.get_event_loop(), app=app)
try:
protocol.request = None
protocol._body_chunks.append("this is a test")
await protocol.stream_append()
except AttributeError as e:
err = e
assert err is None and not protocol._body_chunks
def test_request_stream_blueprint(app): def test_request_stream_blueprint(app):
"""for self.is_request_stream = True"""
bp = Blueprint("test_blueprint_request_stream_blueprint") bp = Blueprint("test_blueprint_request_stream_blueprint")
@app.get("/get") @app.get("/get")
async def get(request): async def get(request):
assert request.stream is None
return text("GET") return text("GET")
@bp.head("/head") @bp.head("/head")
async def head(request): async def head(request):
assert request.stream is None
return text("HEAD") return text("HEAD")
@bp.delete("/delete") @bp.delete("/delete")
async def delete(request): async def delete(request):
assert request.stream is None
return text("DELETE") return text("DELETE")
@bp.options("/options") @bp.options("/options")
async def options(request): async def options(request):
assert request.stream is None
return text("OPTIONS") return text("OPTIONS")
@bp.post("/_post/<id>") @bp.post("/_post/<id>")
async def _post(request, id): async def _post(request, id):
assert request.stream is None
return text("_POST") return text("_POST")
@bp.post("/post/<id>", stream=True) @bp.post("/post/<id>", stream=True)
async def post(request, id): async def post(request, id):
assert isinstance(request.stream, StreamBuffer)
result = "" result = ""
while True: while True:
body = await request.stream.read() body = await request.stream.read()
@ -398,12 +339,10 @@ def test_request_stream_blueprint(app):
@bp.put("/_put") @bp.put("/_put")
async def _put(request): async def _put(request):
assert request.stream is None
return text("_PUT") return text("_PUT")
@bp.put("/put", stream=True) @bp.put("/put", stream=True)
async def put(request): async def put(request):
assert isinstance(request.stream, StreamBuffer)
result = "" result = ""
while True: while True:
body = await request.stream.read() body = await request.stream.read()
@ -414,12 +353,10 @@ def test_request_stream_blueprint(app):
@bp.patch("/_patch") @bp.patch("/_patch")
async def _patch(request): async def _patch(request):
assert request.stream is None
return text("_PATCH") return text("_PATCH")
@bp.patch("/patch", stream=True) @bp.patch("/patch", stream=True)
async def patch(request): async def patch(request):
assert isinstance(request.stream, StreamBuffer)
result = "" result = ""
while True: while True:
body = await request.stream.read() body = await request.stream.read()
@ -429,7 +366,6 @@ def test_request_stream_blueprint(app):
return text(result) return text(result)
async def post_add_route(request): async def post_add_route(request):
assert isinstance(request.stream, StreamBuffer)
result = "" result = ""
while True: while True:
body = await request.stream.read() body = await request.stream.read()
@ -443,8 +379,6 @@ def test_request_stream_blueprint(app):
) )
app.blueprint(bp) app.blueprint(bp)
assert app.is_request_stream is True
request, response = app.test_client.get("/get") request, response = app.test_client.get("/get")
assert response.status == 200 assert response.status == 200
assert response.text == "GET" assert response.text == "GET"
@ -491,14 +425,10 @@ def test_request_stream_blueprint(app):
def test_request_stream_composition_view(app): def test_request_stream_composition_view(app):
"""for self.is_request_stream = True"""
def get_handler(request): def get_handler(request):
assert request.stream is None
return text("OK") return text("OK")
async def post_handler(request): async def post_handler(request):
assert isinstance(request.stream, StreamBuffer)
result = "" result = ""
while True: while True:
body = await request.stream.read() body = await request.stream.read()
@ -512,8 +442,6 @@ def test_request_stream_composition_view(app):
view.add(["POST"], post_handler, stream=True) view.add(["POST"], post_handler, stream=True)
app.add_route(view, "/composition_view") app.add_route(view, "/composition_view")
assert app.is_request_stream is True
request, response = app.test_client.get("/composition_view") request, response = app.test_client.get("/composition_view")
assert response.status == 200 assert response.status == 200
assert response.text == "OK" assert response.text == "OK"
@ -529,12 +457,10 @@ def test_request_stream(app):
class SimpleView(HTTPMethodView): class SimpleView(HTTPMethodView):
def get(self, request): def get(self, request):
assert request.stream is None
return text("OK") return text("OK")
@stream_decorator @stream_decorator
async def post(self, request): async def post(self, request):
assert isinstance(request.stream, StreamBuffer)
result = "" result = ""
while True: while True:
body = await request.stream.read() body = await request.stream.read()
@ -545,7 +471,6 @@ def test_request_stream(app):
@app.post("/stream", stream=True) @app.post("/stream", stream=True)
async def handler(request): async def handler(request):
assert isinstance(request.stream, StreamBuffer)
result = "" result = ""
while True: while True:
body = await request.stream.read() body = await request.stream.read()
@ -556,12 +481,10 @@ def test_request_stream(app):
@app.get("/get") @app.get("/get")
async def get(request): async def get(request):
assert request.stream is None
return text("OK") return text("OK")
@bp.post("/bp_stream", stream=True) @bp.post("/bp_stream", stream=True)
async def bp_stream(request): async def bp_stream(request):
assert isinstance(request.stream, StreamBuffer)
result = "" result = ""
while True: while True:
body = await request.stream.read() body = await request.stream.read()
@ -572,15 +495,12 @@ def test_request_stream(app):
@bp.get("/bp_get") @bp.get("/bp_get")
async def bp_get(request): async def bp_get(request):
assert request.stream is None
return text("OK") return text("OK")
def get_handler(request): def get_handler(request):
assert request.stream is None
return text("OK") return text("OK")
async def post_handler(request): async def post_handler(request):
assert isinstance(request.stream, StreamBuffer)
result = "" result = ""
while True: while True:
body = await request.stream.read() body = await request.stream.read()
@ -599,8 +519,6 @@ def test_request_stream(app):
app.add_route(view, "/composition_view") app.add_route(view, "/composition_view")
assert app.is_request_stream is True
request, response = app.test_client.get("/method_view") request, response = app.test_client.get("/method_view")
assert response.status == 200 assert response.status == 200
assert response.text == "OK" assert response.text == "OK"
@ -636,14 +554,14 @@ def test_request_stream(app):
def test_streaming_new_api(app): def test_streaming_new_api(app):
@app.post("/non-stream") @app.post("/non-stream")
async def handler(request): async def handler1(request):
assert request.body == b"x" assert request.body == b"x"
await request.receive_body() # This should do nothing await request.receive_body() # This should do nothing
assert request.body == b"x" assert request.body == b"x"
return text("OK") return text("OK")
@app.post("/1", stream=True) @app.post("/1", stream=True)
async def handler(request): async def handler2(request):
assert request.stream assert request.stream
assert not request.body assert not request.body
await request.receive_body() await request.receive_body()
@ -671,5 +589,85 @@ def test_streaming_new_api(app):
assert response.status == 200 assert response.status == 200
res = response.json res = response.json
assert isinstance(res, list) assert isinstance(res, list)
assert len(res) > 1
assert "".join(res) == data assert "".join(res) == data
def test_streaming_echo():
"""2-way streaming chat between server and client."""
app = Sanic(name=__name__)
@app.post("/echo", stream=True)
async def handler(request):
res = await request.respond(content_type="text/plain; charset=utf-8")
# Send headers
await res.send(end_stream=False)
# Echo back data (case swapped)
async for data in request.stream:
await res.send(data.swapcase())
# Add EOF marker after successful operation
await res.send(b"-", end_stream=True)
@app.listener("after_server_start")
async def client_task(app, loop):
try:
reader, writer = await asyncio.open_connection(*addr)
await client(app, reader, writer)
finally:
writer.close()
app.stop()
async def client(app, reader, writer):
# Unfortunately httpx does not support 2-way streaming, so do it by hand.
host = f"host: {addr[0]}:{addr[1]}\r\n".encode()
writer.write(
b"POST /echo HTTP/1.1\r\n" + host + b"content-length: 2\r\n"
b"content-type: text/plain; charset=utf-8\r\n"
b"\r\n"
)
# Read response
res = b""
while not b"\r\n\r\n" in res:
res += await reader.read(4096)
assert res.startswith(b"HTTP/1.1 200 OK\r\n")
assert res.endswith(b"\r\n\r\n")
buffer = b""
async def read_chunk():
nonlocal buffer
while not b"\r\n" in buffer:
data = await reader.read(4096)
assert data
buffer += data
size, buffer = buffer.split(b"\r\n", 1)
size = int(size, 16)
if size == 0:
return None
while len(buffer) < size + 2:
data = await reader.read(4096)
assert data
buffer += data
print(res)
assert buffer[size : size + 2] == b"\r\n"
ret, buffer = buffer[:size], buffer[size + 2 :]
return ret
# Chat with server
writer.write(b"a")
res = await read_chunk()
assert res == b"A"
writer.write(b"b")
res = await read_chunk()
assert res == b"B"
res = await read_chunk()
assert res == b"-"
res = await read_chunk()
assert res == None
# Use random port for tests
with closing(socket()) as sock:
sock.bind(("127.0.0.1", 0))
addr = sock.getsockname()
app.run(sock=sock, access_log=False)

View File

@ -59,7 +59,7 @@ def test_ip(app):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_ip_asgi(app): async def test_url_asgi(app):
@app.route("/") @app.route("/")
def handler(request): def handler(request):
return text(f"{request.url}") return text(f"{request.url}")
@ -2119,3 +2119,37 @@ def test_url_for_without_server_name(app):
response.json["url"] response.json["url"]
== f"http://127.0.0.1:{request.server_port}/url-for" == f"http://127.0.0.1:{request.server_port}/url-for"
) )
def test_safe_method_with_body_ignored(app):
@app.get("/")
async def handler(request):
return text("OK")
payload = {"test": "OK"}
headers = {"content-type": "application/json"}
request, response = app.test_client.request(
"/", http_method="get", data=json_dumps(payload), headers=headers
)
assert request.body == b""
assert request.json == None
assert response.text == "OK"
def test_safe_method_with_body(app):
@app.get("/", ignore_body=False)
async def handler(request):
return text("OK")
payload = {"test": "OK"}
headers = {"content-type": "application/json"}
data = json_dumps(payload)
request, response = app.test_client.request(
"/", http_method="get", data=data, headers=headers
)
assert request.body == data.encode("utf-8")
assert request.json.get("test") == "OK"
assert response.text == "OK"

View File

@ -85,7 +85,6 @@ def test_response_header(app):
request, response = app.test_client.get("/") request, response = app.test_client.get("/")
assert dict(response.headers) == { assert dict(response.headers) == {
"connection": "keep-alive", "connection": "keep-alive",
"keep-alive": str(app.config.KEEP_ALIVE_TIMEOUT),
"content-length": "11", "content-length": "11",
"content-type": "application/json", "content-type": "application/json",
} }
@ -201,7 +200,6 @@ def streaming_app(app):
async def test(request): async def test(request):
return stream( return stream(
sample_streaming_fn, sample_streaming_fn,
headers={"Content-Length": "7"},
content_type="text/csv", content_type="text/csv",
) )
@ -243,7 +241,12 @@ async def test_chunked_streaming_returns_correct_content_asgi(streaming_app):
def test_non_chunked_streaming_adds_correct_headers(non_chunked_streaming_app): def test_non_chunked_streaming_adds_correct_headers(non_chunked_streaming_app):
with pytest.warns(UserWarning) as record:
request, response = non_chunked_streaming_app.test_client.get("/") request, response = non_chunked_streaming_app.test_client.get("/")
assert len(record) == 1
assert "removed in v21.6" in record[0].message.args[0]
assert "Transfer-Encoding" not in response.headers assert "Transfer-Encoding" not in response.headers
assert response.headers["Content-Type"] == "text/csv" assert response.headers["Content-Type"] == "text/csv"
assert response.headers["Content-Length"] == "7" assert response.headers["Content-Length"] == "7"
@ -266,115 +269,6 @@ def test_non_chunked_streaming_returns_correct_content(
assert response.text == "foo,bar" assert response.text == "foo,bar"
@pytest.mark.parametrize("status", [200, 201, 400, 401])
def test_stream_response_status_returns_correct_headers(status):
response = StreamingHTTPResponse(sample_streaming_fn, status=status)
headers = response.get_headers()
assert b"HTTP/1.1 %s" % str(status).encode() in headers
@pytest.mark.parametrize("keep_alive_timeout", [10, 20, 30])
def test_stream_response_keep_alive_returns_correct_headers(
keep_alive_timeout,
):
response = StreamingHTTPResponse(sample_streaming_fn)
headers = response.get_headers(
keep_alive=True, keep_alive_timeout=keep_alive_timeout
)
assert b"Keep-Alive: %s\r\n" % str(keep_alive_timeout).encode() in headers
def test_stream_response_includes_chunked_header_http11():
response = StreamingHTTPResponse(sample_streaming_fn)
headers = response.get_headers(version="1.1")
assert b"Transfer-Encoding: chunked\r\n" in headers
def test_stream_response_does_not_include_chunked_header_http10():
response = StreamingHTTPResponse(sample_streaming_fn)
headers = response.get_headers(version="1.0")
assert b"Transfer-Encoding: chunked\r\n" not in headers
def test_stream_response_does_not_include_chunked_header_if_disabled():
response = StreamingHTTPResponse(sample_streaming_fn, chunked=False)
headers = response.get_headers(version="1.1")
assert b"Transfer-Encoding: chunked\r\n" not in headers
def test_stream_response_writes_correct_content_to_transport_when_chunked(
streaming_app,
):
response = StreamingHTTPResponse(sample_streaming_fn)
response.protocol = MagicMock(HttpProtocol)
response.protocol.transport = MagicMock(asyncio.Transport)
async def mock_drain():
pass
async def mock_push_data(data):
response.protocol.transport.write(data)
response.protocol.push_data = mock_push_data
response.protocol.drain = mock_drain
@streaming_app.listener("after_server_start")
async def run_stream(app, loop):
await response.stream()
assert response.protocol.transport.write.call_args_list[1][0][0] == (
b"4\r\nfoo,\r\n"
)
assert response.protocol.transport.write.call_args_list[2][0][0] == (
b"3\r\nbar\r\n"
)
assert response.protocol.transport.write.call_args_list[3][0][0] == (
b"0\r\n\r\n"
)
assert len(response.protocol.transport.write.call_args_list) == 4
app.stop()
streaming_app.run(host=HOST, port=PORT)
def test_stream_response_writes_correct_content_to_transport_when_not_chunked(
streaming_app,
):
response = StreamingHTTPResponse(sample_streaming_fn)
response.protocol = MagicMock(HttpProtocol)
response.protocol.transport = MagicMock(asyncio.Transport)
async def mock_drain():
pass
async def mock_push_data(data):
response.protocol.transport.write(data)
response.protocol.push_data = mock_push_data
response.protocol.drain = mock_drain
@streaming_app.listener("after_server_start")
async def run_stream(app, loop):
await response.stream(version="1.0")
assert response.protocol.transport.write.call_args_list[1][0][0] == (
b"foo,"
)
assert response.protocol.transport.write.call_args_list[2][0][0] == (
b"bar"
)
assert len(response.protocol.transport.write.call_args_list) == 3
app.stop()
streaming_app.run(host=HOST, port=PORT)
def test_stream_response_with_cookies(app): def test_stream_response_with_cookies(app):
@app.route("/") @app.route("/")
async def test(request): async def test(request):

View File

@ -67,16 +67,12 @@ def test_shorthand_routes_multiple(app):
def test_route_strict_slash(app): def test_route_strict_slash(app):
@app.get("/get", strict_slashes=True) @app.get("/get", strict_slashes=True)
def handler1(request): def handler1(request):
assert request.stream is None
return text("OK") return text("OK")
@app.post("/post/", strict_slashes=True) @app.post("/post/", strict_slashes=True)
def handler2(request): def handler2(request):
assert request.stream is None
return text("OK") return text("OK")
assert app.is_request_stream is False
request, response = app.test_client.get("/get") request, response = app.test_client.get("/get")
assert response.text == "OK" assert response.text == "OK"
@ -216,11 +212,8 @@ def test_shorthand_routes_post(app):
def test_shorthand_routes_put(app): def test_shorthand_routes_put(app):
@app.put("/put") @app.put("/put")
def handler(request): def handler(request):
assert request.stream is None
return text("OK") return text("OK")
assert app.is_request_stream is False
request, response = app.test_client.put("/put") request, response = app.test_client.put("/put")
assert response.text == "OK" assert response.text == "OK"
@ -231,11 +224,8 @@ def test_shorthand_routes_put(app):
def test_shorthand_routes_delete(app): def test_shorthand_routes_delete(app):
@app.delete("/delete") @app.delete("/delete")
def handler(request): def handler(request):
assert request.stream is None
return text("OK") return text("OK")
assert app.is_request_stream is False
request, response = app.test_client.delete("/delete") request, response = app.test_client.delete("/delete")
assert response.text == "OK" assert response.text == "OK"
@ -246,11 +236,8 @@ def test_shorthand_routes_delete(app):
def test_shorthand_routes_patch(app): def test_shorthand_routes_patch(app):
@app.patch("/patch") @app.patch("/patch")
def handler(request): def handler(request):
assert request.stream is None
return text("OK") return text("OK")
assert app.is_request_stream is False
request, response = app.test_client.patch("/patch") request, response = app.test_client.patch("/patch")
assert response.text == "OK" assert response.text == "OK"
@ -261,11 +248,8 @@ def test_shorthand_routes_patch(app):
def test_shorthand_routes_head(app): def test_shorthand_routes_head(app):
@app.head("/head") @app.head("/head")
def handler(request): def handler(request):
assert request.stream is None
return text("OK") return text("OK")
assert app.is_request_stream is False
request, response = app.test_client.head("/head") request, response = app.test_client.head("/head")
assert response.status == 200 assert response.status == 200
@ -276,11 +260,8 @@ def test_shorthand_routes_head(app):
def test_shorthand_routes_options(app): def test_shorthand_routes_options(app):
@app.options("/options") @app.options("/options")
def handler(request): def handler(request):
assert request.stream is None
return text("OK") return text("OK")
assert app.is_request_stream is False
request, response = app.test_client.options("/options") request, response = app.test_client.options("/options")
assert response.status == 200 assert response.status == 200

View File

@ -0,0 +1,74 @@
import asyncio
from time import monotonic as current_time
from unittest.mock import Mock
import pytest
from sanic import Sanic
from sanic.exceptions import RequestTimeout, ServiceUnavailable
from sanic.http import Stage
from sanic.server import HttpProtocol
@pytest.fixture
def app():
return Sanic("test")
@pytest.fixture
def mock_transport():
return Mock()
@pytest.fixture
def protocol(app, mock_transport):
loop = asyncio.new_event_loop()
protocol = HttpProtocol(loop=loop, app=app)
protocol.connection_made(mock_transport)
protocol._setup_connection()
protocol._task = Mock(spec=asyncio.Task)
protocol._task.cancel = Mock()
return protocol
def test_setup(protocol: HttpProtocol):
assert protocol._task is not None
assert protocol._http is not None
assert protocol._time is not None
def test_check_timeouts_no_timeout(protocol: HttpProtocol):
protocol.keep_alive_timeout = 1
protocol.loop.call_later = Mock()
protocol.check_timeouts()
protocol._task.cancel.assert_not_called()
assert protocol._http.stage is Stage.IDLE
assert protocol._http.exception is None
protocol.loop.call_later.assert_called_with(
protocol.keep_alive_timeout / 2, protocol.check_timeouts
)
def test_check_timeouts_keep_alive_timeout(protocol: HttpProtocol):
protocol._http.stage = Stage.IDLE
protocol._time = 0
protocol.check_timeouts()
protocol._task.cancel.assert_called_once()
assert protocol._http.exception is None
def test_check_timeouts_request_timeout(protocol: HttpProtocol):
protocol._http.stage = Stage.REQUEST
protocol._time = 0
protocol.check_timeouts()
protocol._task.cancel.assert_called_once()
assert isinstance(protocol._http.exception, RequestTimeout)
def test_check_timeouts_response_timeout(protocol: HttpProtocol):
protocol._http.stage = Stage.RESPONSE
protocol._time = 0
protocol.check_timeouts()
protocol._task.cancel.assert_called_once()
assert isinstance(protocol._http.exception, ServiceUnavailable)

View File

@ -12,7 +12,6 @@ from sanic.views import CompositionView, HTTPMethodView
def test_methods(app, method): def test_methods(app, method):
class DummyView(HTTPMethodView): class DummyView(HTTPMethodView):
async def get(self, request): async def get(self, request):
assert request.stream is None
return text("", headers={"method": "GET"}) return text("", headers={"method": "GET"})
def post(self, request): def post(self, request):
@ -34,7 +33,6 @@ def test_methods(app, method):
return text("", headers={"method": "DELETE"}) return text("", headers={"method": "DELETE"})
app.add_route(DummyView.as_view(), "/") app.add_route(DummyView.as_view(), "/")
assert app.is_request_stream is False
request, response = getattr(app.test_client, method.lower())("/") request, response = getattr(app.test_client, method.lower())("/")
assert response.headers["method"] == method assert response.headers["method"] == method
@ -69,7 +67,6 @@ def test_with_bp(app):
class DummyView(HTTPMethodView): class DummyView(HTTPMethodView):
def get(self, request): def get(self, request):
assert request.stream is None
return text("I am get method") return text("I am get method")
bp.add_route(DummyView.as_view(), "/") bp.add_route(DummyView.as_view(), "/")
@ -77,7 +74,6 @@ def test_with_bp(app):
app.blueprint(bp) app.blueprint(bp)
request, response = app.test_client.get("/") request, response = app.test_client.get("/")
assert app.is_request_stream is False
assert response.text == "I am get method" assert response.text == "I am get method"
@ -210,14 +206,12 @@ def test_composition_view_runs_methods_as_expected(app, method):
view = CompositionView() view = CompositionView()
def first(request): def first(request):
assert request.stream is None
return text("first method") return text("first method")
view.add(["GET", "POST", "PUT"], first) view.add(["GET", "POST", "PUT"], first)
view.add(["DELETE", "PATCH"], lambda x: text("second method")) view.add(["DELETE", "PATCH"], lambda x: text("second method"))
app.add_route(view, "/") app.add_route(view, "/")
assert app.is_request_stream is False
if method in ["GET", "POST", "PUT"]: if method in ["GET", "POST", "PUT"]:
request, response = getattr(app.test_client, method.lower())("/") request, response = getattr(app.test_client, method.lower())("/")