Streaming Server (#1876)

* Streaming request by async for.

* Make all requests streaming and preload body for non-streaming handlers.

* Cleanup of code and avoid mixing streaming responses.

* Async http protocol loop.

* Change of test: don't require early bad request error but only after CRLF-CRLF.

* Add back streaming requests.

* Rewritten request body parser.

* Misc. cleanup, down to 4 failing tests.

* All tests OK.

* Entirely remove request body queue.

* Let black f*ckup the layout

* Better testing error messages on protocol errors.

* Remove StreamBuffer tests because the type is about to be removed.

* Remove tests using the deprecated get_headers function that can no longer be supported. Chunked mode is now autodetected, so do not put content-length header if chunked mode is preferred.

* Major refactoring of HTTP protocol handling (new module http.py added), all requests made streaming. A few compatibility issues and a lot of cleanup to be done remain, 16 tests failing.

* Terminate check_timeouts once connection_task finishes.

* Code cleanup, 14 tests failing.

* Much cleanup, 12 failing...

* Even more cleanup and error checking, 8 failing tests.

* Remove keep-alive header from responses. First of all, it should say timeout=<value> which wasn't the case with existing implementation, and secondly none of the other web servers I tried include this header.

* Everything but CustomServer OK.

* Linter

* Disable custom protocol test

* Remove unnecessary variables, optimise performance.

* A test was missing that body_init/body_push/body_finish are never called. Rewritten using receive_body and case switching to make it fail if bypassed.

* Minor fixes.

* Remove unused code.

* Py 3.8 check for deprecated loop argument.

* Fix a middleware cancellation handling test with py38.

* Linter 'n fixes

* Typing

* Stricter handling of request header size

* More specific error messages on Payload Too Large.

* Init http.response = None

* Messages further tuned.

* Always try to consume request body, plus minor cleanup.

* Add a missing check in case of close_if_idle on a dead connection.

* Avoid error messages on PayloadTooLarge.

* Add test for new API.

* json takes str, not bytes

* Default to no maximum request size for streaming handlers.

* Fix chunked mode crash.

* Header values should be strictly ASCII but both UTF-8 and Latin-1 exist. Use UTF-8B to
cope with all.

* Refactoring and cleanup.

* Unify response header processing of ASGI and asyncio modes.

* Avoid special handling of StreamingHTTPResponse.

* 35 % speedup in HTTP/1.1 response formatting (not so much overall effect).

* Duplicate set-cookie headers were being produced.

* Cleanup processed_headers some more.

* Linting

* Import ordering

* Response middleware ran by async request.respond().

* Need to check if transport is closing to avoid getting stuck in sending loops after peer has disconnected.

* Middleware and error handling refactoring.

* Linter

* Fix tracking of HTTP stage when writing to transport fails.

* Add clarifying comment

* Add a check for request body functions and a test for NotImplementedError.

* Linter and typing

* These must be tuples + hack mypy warnings away.

* New streaming test and minor fixes.

* Constant receive buffer size.

* 256 KiB send and receive buffers.

* Revert "256 KiB send and receive buffers."

This reverts commit abc1e3edb2.

* app.handle_exception already sends the response.

* Improved handling of errors during request.

* An odd hack to avoid an httpx limitation that causes test failures.

* Limit request header size to 8 KiB at most.

* Remove unnecessary use of format string.

* Cleanup tests

* Remove artifact

* Fix type checking

* Mark test for skipping

* Cleanup some edge cases

* Add ignore_body flag to safe methods

* Add unit tests for timeout logic

* Add unit tests for timeout logic

* Fix Mock usage in timeout test

* Change logging test to only logger in handler

* Windows py3.8 logging issue with current testing client

* Add test_header_size_exceeded

* Resolve merge conflicts

* Add request middleware to hard exception handling

* Add request middleware to hard exception handling

* Request middleware on exception handlers

* Linting

* Cleanup deprecations

Co-authored-by: L. Kärkkäinen <tronic@users.noreply.github.com>
Co-authored-by: Adam Hopkins <admhpkns@gmail.com>
This commit is contained in:
L. Kärkkäinen
2021-01-11 00:45:36 +02:00
committed by GitHub
parent 574a9c27a6
commit 7028eae083
35 changed files with 1372 additions and 1348 deletions

View File

@@ -118,7 +118,7 @@ def test_app_route_raise_value_error(app):
def test_app_handle_request_handler_is_none(app, monkeypatch):
def mockreturn(*args, **kwargs):
return None, [], {}, "", "", None
return None, [], {}, "", "", None, False
# Not sure how to make app.router.get() return None, so use mock here.
monkeypatch.setattr(app.router, "get", mockreturn)

View File

@@ -8,7 +8,7 @@ def test_bad_request_response(app):
async def _request(sanic, loop):
connect = asyncio.open_connection("127.0.0.1", 42101)
reader, writer = await connect
writer.write(b"not http")
writer.write(b"not http\r\n\r\n")
while True:
line = await reader.readline()
if not line:

View File

@@ -71,7 +71,6 @@ def test_bp(app):
app.blueprint(bp)
request, response = app.test_client.get("/")
assert app.is_request_stream is False
assert response.text == "Hello"
@@ -505,48 +504,38 @@ def test_bp_shorthand(app):
@blueprint.get("/get")
def handler(request):
assert request.stream is None
return text("OK")
@blueprint.put("/put")
def put_handler(request):
assert request.stream is None
return text("OK")
@blueprint.post("/post")
def post_handler(request):
assert request.stream is None
return text("OK")
@blueprint.head("/head")
def head_handler(request):
assert request.stream is None
return text("OK")
@blueprint.options("/options")
def options_handler(request):
assert request.stream is None
return text("OK")
@blueprint.patch("/patch")
def patch_handler(request):
assert request.stream is None
return text("OK")
@blueprint.delete("/delete")
def delete_handler(request):
assert request.stream is None
return text("OK")
@blueprint.websocket("/ws/", strict_slashes=True)
async def websocket_handler(request, ws):
assert request.stream is None
ev.set()
app.blueprint(blueprint)
assert app.is_request_stream is False
request, response = app.test_client.get("/get")
assert response.text == "OK"

View File

@@ -6,17 +6,14 @@ from sanic.response import json_dumps, text
class CustomRequest(Request):
__slots__ = ("body_buffer",)
"""Alternative implementation for loading body (non-streaming handlers)"""
def body_init(self):
self.body_buffer = BytesIO()
def body_push(self, data):
self.body_buffer.write(data)
def body_finish(self):
self.body = self.body_buffer.getvalue()
self.body_buffer.close()
async def receive_body(self):
buffer = BytesIO()
async for data in self.stream:
buffer.write(data)
self.body = buffer.getvalue().upper()
buffer.close()
def test_custom_request():
@@ -37,17 +34,13 @@ def test_custom_request():
"/post", data=json_dumps(payload), headers=headers
)
assert isinstance(request.body_buffer, BytesIO)
assert request.body_buffer.closed
assert request.body == b'{"test":"OK"}'
assert request.json.get("test") == "OK"
assert request.body == b'{"TEST":"OK"}'
assert request.json.get("TEST") == "OK"
assert response.text == "OK"
assert response.status == 200
request, response = app.test_client.get("/get")
assert isinstance(request.body_buffer, BytesIO)
assert request.body_buffer.closed
assert request.body == b""
assert response.text == "OK"
assert response.status == 200

View File

@@ -1,14 +1,26 @@
import asyncio
from bs4 import BeautifulSoup
from sanic import Sanic
from sanic.exceptions import InvalidUsage, NotFound, ServerError
from sanic.exceptions import Forbidden, InvalidUsage, NotFound, ServerError
from sanic.handlers import ErrorHandler
from sanic.response import text
from sanic.response import stream, text
exception_handler_app = Sanic("test_exception_handler")
async def sample_streaming_fn(response):
await response.write("foo,")
await asyncio.sleep(0.001)
await response.write("bar")
class ErrorWithRequestCtx(ServerError):
pass
@exception_handler_app.route("/1")
def handler_1(request):
raise InvalidUsage("OK")
@@ -47,11 +59,40 @@ def handler_6(request, arg):
return text(foo)
@exception_handler_app.exception(NotFound, ServerError)
@exception_handler_app.route("/7")
def handler_7(request):
raise Forbidden("go away!")
@exception_handler_app.route("/8")
def handler_8(request):
raise ErrorWithRequestCtx("OK")
@exception_handler_app.exception(ErrorWithRequestCtx, NotFound)
def handler_exception_with_ctx(request, exception):
return text(request.ctx.middleware_ran)
@exception_handler_app.exception(ServerError)
def handler_exception(request, exception):
return text("OK")
@exception_handler_app.exception(Forbidden)
async def async_handler_exception(request, exception):
return stream(
sample_streaming_fn,
content_type="text/csv",
)
@exception_handler_app.middleware
async def some_request_middleware(request):
request.ctx.middleware_ran = "Done."
def test_invalid_usage_exception_handler():
request, response = exception_handler_app.test_client.get("/1")
assert response.status == 400
@@ -71,7 +112,13 @@ def test_not_found_exception_handler():
def test_text_exception__handler():
request, response = exception_handler_app.test_client.get("/random")
assert response.status == 200
assert response.text == "OK"
assert response.text == "Done."
def test_async_exception_handler():
request, response = exception_handler_app.test_client.get("/7")
assert response.status == 200
assert response.text == "foo,bar"
def test_html_traceback_output_in_debug_mode():
@@ -156,3 +203,9 @@ def test_exception_handler_lookup():
assert handler.lookup(CustomError()) == custom_error_handler
assert handler.lookup(ServerError("Error")) == server_error_handler
assert handler.lookup(CustomServerError("Error")) == server_error_handler
def test_exception_handler_processed_request_middleware():
request, response = exception_handler_app.test_client.get("/8")
assert response.status == 200
assert response.text == "Done."

View File

@@ -1,6 +1,12 @@
from unittest.mock import Mock
import pytest
from sanic import headers
from sanic import Sanic, headers
from sanic.compat import Header
from sanic.exceptions import PayloadTooLarge
from sanic.http import Http
from sanic.request import Request
@pytest.mark.parametrize(
@@ -61,3 +67,21 @@ from sanic import headers
)
def test_parse_headers(input, expected):
assert headers.parse_content_header(input) == expected
@pytest.mark.asyncio
async def test_header_size_exceeded():
recv_buffer = bytearray()
async def _receive_more():
nonlocal recv_buffer
recv_buffer += b"123"
protocol = Mock()
http = Http(protocol)
http._receive_more = _receive_more
http.request_max_size = 1
http.recv_buffer = recv_buffer
with pytest.raises(PayloadTooLarge):
await http.http1_request_header()

View File

@@ -2,11 +2,14 @@ import asyncio
from asyncio import sleep as aio_sleep
from json import JSONDecodeError
from os import environ
import httpcore
import httpx
import pytest
from sanic import Sanic, server
from sanic.compat import OS_IS_WINDOWS
from sanic.response import text
from sanic.testing import HOST, SanicTestClient
@@ -202,6 +205,10 @@ async def handler3(request):
return text("OK")
@pytest.mark.skipif(
bool(environ.get("SANIC_NO_UVLOOP")) or OS_IS_WINDOWS,
reason="Not testable with current client",
)
def test_keep_alive_timeout_reuse():
"""If the server keep-alive timeout and client keep-alive timeout are
both longer than the delay, the client _and_ server will successfully
@@ -223,6 +230,10 @@ def test_keep_alive_timeout_reuse():
client.kill_server()
@pytest.mark.skipif(
bool(environ.get("SANIC_NO_UVLOOP")) or OS_IS_WINDOWS,
reason="Not testable with current client",
)
def test_keep_alive_client_timeout():
"""If the server keep-alive timeout is longer than the client
keep-alive timeout, client will try to create a new connection here."""
@@ -244,6 +255,10 @@ def test_keep_alive_client_timeout():
client.kill_server()
@pytest.mark.skipif(
bool(environ.get("SANIC_NO_UVLOOP")) or OS_IS_WINDOWS,
reason="Not testable with current client",
)
def test_keep_alive_server_timeout():
"""If the client keep-alive timeout is longer than the server
keep-alive timeout, the client will either a 'Connection reset' error

View File

@@ -1,16 +1,17 @@
import logging
import os
import sys
import uuid
from importlib import reload
from io import StringIO
from unittest.mock import Mock
import pytest
import sanic
from sanic import Sanic
from sanic.compat import OS_IS_WINDOWS
from sanic.log import LOGGING_CONFIG_DEFAULTS, logger
from sanic.response import text
from sanic.testing import SanicTestClient
@@ -111,12 +112,11 @@ def test_log_connection_lost(app, debug, monkeypatch):
@app.route("/conn_lost")
async def conn_lost(request):
response = text("Ok")
response.output = Mock(side_effect=RuntimeError)
request.transport.close()
return response
with pytest.raises(ValueError):
# catch ValueError: Exception during request
app.test_client.get("/conn_lost", debug=debug)
req, res = app.test_client.get("/conn_lost", debug=debug)
assert res is None
log = stream.getvalue()
@@ -126,7 +126,8 @@ def test_log_connection_lost(app, debug, monkeypatch):
assert "Connection lost before response written @" not in log
def test_logger(caplog):
@pytest.mark.asyncio
async def test_logger(caplog):
rand_string = str(uuid.uuid4())
app = Sanic(name=__name__)
@@ -137,32 +138,16 @@ def test_logger(caplog):
return text("hello")
with caplog.at_level(logging.INFO):
request, response = app.test_client.get("/")
_ = await app.asgi_client.get("/")
port = request.server_port
# Note: testing with random port doesn't show the banner because it doesn't
# define host and port. This test supports both modes.
if caplog.record_tuples[0] == (
"sanic.root",
logging.INFO,
f"Goin' Fast @ http://127.0.0.1:{port}",
):
caplog.record_tuples.pop(0)
assert caplog.record_tuples[0] == (
"sanic.root",
logging.INFO,
f"http://127.0.0.1:{port}/",
)
assert caplog.record_tuples[1] == ("sanic.root", logging.INFO, rand_string)
assert caplog.record_tuples[-1] == (
"sanic.root",
logging.INFO,
"Server Stopped",
)
record = ("sanic.root", logging.INFO, rand_string)
assert record in caplog.record_tuples
@pytest.mark.skipif(
OS_IS_WINDOWS and sys.version_info >= (3, 8),
reason="Not testable with current client",
)
def test_logger_static_and_secure(caplog):
# Same as test_logger, except for more coverage:
# - test_client initialised separately for static port

View File

@@ -1,6 +1,7 @@
import logging
from asyncio import CancelledError
from itertools import count
from sanic.exceptions import NotFound, SanicException
from sanic.request import Request
@@ -54,7 +55,7 @@ def test_middleware_response(app):
def test_middleware_response_exception(app):
result = {"status_code": None}
result = {"status_code": "middleware not run"}
@app.middleware("response")
async def process_response(request, response):
@@ -184,3 +185,23 @@ def test_middleware_order(app):
assert response.status == 200
assert order == [1, 2, 3, 4, 5, 6]
def test_request_middleware_executes_once(app):
i = count()
@app.middleware("request")
async def inc(request):
nonlocal i
next(i)
@app.route("/")
async def handler(request):
await request.app._run_request_middleware(request)
return text("OK")
request, response = app.test_client.get("/")
assert next(i) == 1
request, response = app.test_client.get("/")
assert next(i) == 3

View File

@@ -85,7 +85,6 @@ def test_pickle_app_with_bp(app, protocol):
up_p_app = pickle.loads(p_app)
assert up_p_app
request, response = up_p_app.test_client.get("/")
assert up_p_app.is_request_stream is False
assert response.text == "Hello"

View File

@@ -107,10 +107,8 @@ def test_shorthand_named_routes_post(app):
def test_shorthand_named_routes_put(app):
@app.put("/put", name="route_put")
def handler(request):
assert request.stream is None
return text("OK")
assert app.is_request_stream is False
assert app.router.routes_all["/put"].name == "route_put"
assert app.url_for("route_put") == "/put"
with pytest.raises(URLBuildError):
@@ -120,10 +118,8 @@ def test_shorthand_named_routes_put(app):
def test_shorthand_named_routes_delete(app):
@app.delete("/delete", name="route_delete")
def handler(request):
assert request.stream is None
return text("OK")
assert app.is_request_stream is False
assert app.router.routes_all["/delete"].name == "route_delete"
assert app.url_for("route_delete") == "/delete"
with pytest.raises(URLBuildError):
@@ -133,10 +129,8 @@ def test_shorthand_named_routes_delete(app):
def test_shorthand_named_routes_patch(app):
@app.patch("/patch", name="route_patch")
def handler(request):
assert request.stream is None
return text("OK")
assert app.is_request_stream is False
assert app.router.routes_all["/patch"].name == "route_patch"
assert app.url_for("route_patch") == "/patch"
with pytest.raises(URLBuildError):
@@ -146,10 +140,8 @@ def test_shorthand_named_routes_patch(app):
def test_shorthand_named_routes_head(app):
@app.head("/head", name="route_head")
def handler(request):
assert request.stream is None
return text("OK")
assert app.is_request_stream is False
assert app.router.routes_all["/head"].name == "route_head"
assert app.url_for("route_head") == "/head"
with pytest.raises(URLBuildError):
@@ -159,10 +151,8 @@ def test_shorthand_named_routes_head(app):
def test_shorthand_named_routes_options(app):
@app.options("/options", name="route_options")
def handler(request):
assert request.stream is None
return text("OK")
assert app.is_request_stream is False
assert app.router.routes_all["/options"].name == "route_options"
assert app.url_for("route_options") == "/options"
with pytest.raises(URLBuildError):

View File

@@ -27,7 +27,7 @@ def test_payload_too_large_at_data_received_default(app):
response = app.test_client.get("/1", gather_request=False)
assert response.status == 413
assert "Payload Too Large" in response.text
assert "Request header" in response.text
def test_payload_too_large_at_on_header_default(app):
@@ -40,4 +40,4 @@ def test_payload_too_large_at_on_header_default(app):
data = "a" * 1000
response = app.test_client.post("/1", gather_request=False, data=data)
assert response.status == 413
assert "Payload Too Large" in response.text
assert "Request body" in response.text

View File

@@ -1,37 +0,0 @@
import io
from sanic.response import text
data = "abc" * 10_000_000
def test_request_buffer_queue_size(app):
default_buf_qsz = app.config.get("REQUEST_BUFFER_QUEUE_SIZE")
qsz = 1
while qsz == default_buf_qsz:
qsz += 1
app.config.REQUEST_BUFFER_QUEUE_SIZE = qsz
@app.post("/post", stream=True)
async def post(request):
assert request.stream.buffer_size == qsz
print("request.stream.buffer_size =", request.stream.buffer_size)
bio = io.BytesIO()
while True:
bdata = await request.stream.read()
if not bdata:
break
bio.write(bdata)
head = bdata[:3].decode("utf-8")
tail = bdata[3:][-3:].decode("utf-8")
print(head, "...", tail)
bio.seek(0)
return text(bio.read().decode("utf-8"))
request, response = app.test_client.post("/post", data=data)
assert response.status == 200
assert response.text == data

View File

@@ -1,11 +1,13 @@
import asyncio
from contextlib import closing
from socket import socket
import pytest
from sanic import Sanic
from sanic.blueprints import Blueprint
from sanic.exceptions import HeaderExpectationFailed
from sanic.request import StreamBuffer
from sanic.response import json, stream, text
from sanic.response import json, text
from sanic.server import HttpProtocol
from sanic.views import CompositionView, HTTPMethodView
from sanic.views import stream as stream_decorator
@@ -15,28 +17,22 @@ data = "abc" * 1_000_000
def test_request_stream_method_view(app):
"""for self.is_request_stream = True"""
class SimpleView(HTTPMethodView):
def get(self, request):
assert request.stream is None
return text("OK")
@stream_decorator
async def post(self, request):
assert isinstance(request.stream, StreamBuffer)
result = ""
result = b""
while True:
body = await request.stream.read()
if body is None:
break
result += body.decode("utf-8")
return text(result)
result += body
return text(result.decode())
app.add_route(SimpleView.as_view(), "/method_view")
assert app.is_request_stream is True
request, response = app.test_client.get("/method_view")
assert response.status == 200
assert response.text == "OK"
@@ -50,14 +46,16 @@ def test_request_stream_method_view(app):
"headers, expect_raise_exception",
[
({"EXPECT": "100-continue"}, False),
({"EXPECT": "100-continue-extra"}, True),
# The below test SHOULD work, and it does produce a 417
# However, httpx now intercepts this and raises an exception,
# so we will need a new method for testing this
# ({"EXPECT": "100-continue-extra"}, True),
],
)
def test_request_stream_100_continue(app, headers, expect_raise_exception):
class SimpleView(HTTPMethodView):
@stream_decorator
async def post(self, request):
assert isinstance(request.stream, StreamBuffer)
result = ""
while True:
body = await request.stream.read()
@@ -68,55 +66,42 @@ def test_request_stream_100_continue(app, headers, expect_raise_exception):
app.add_route(SimpleView.as_view(), "/method_view")
assert app.is_request_stream is True
if not expect_raise_exception:
request, response = app.test_client.post(
"/method_view", data=data, headers={"EXPECT": "100-continue"}
"/method_view", data=data, headers=headers
)
assert response.status == 200
assert response.text == data
else:
with pytest.raises(ValueError) as e:
app.test_client.post(
"/method_view",
data=data,
headers={"EXPECT": "100-continue-extra"},
)
assert "Unknown Expect: 100-continue-extra" in str(e)
request, response = app.test_client.post(
"/method_view", data=data, headers=headers
)
assert response.status == 417
def test_request_stream_app(app):
"""for self.is_request_stream = True and decorators"""
@app.get("/get")
async def get(request):
assert request.stream is None
return text("GET")
@app.head("/head")
async def head(request):
assert request.stream is None
return text("HEAD")
@app.delete("/delete")
async def delete(request):
assert request.stream is None
return text("DELETE")
@app.options("/options")
async def options(request):
assert request.stream is None
return text("OPTIONS")
@app.post("/_post/<id>")
async def _post(request, id):
assert request.stream is None
return text("_POST")
@app.post("/post/<id>", stream=True)
async def post(request, id):
assert isinstance(request.stream, StreamBuffer)
result = ""
while True:
body = await request.stream.read()
@@ -127,12 +112,10 @@ def test_request_stream_app(app):
@app.put("/_put")
async def _put(request):
assert request.stream is None
return text("_PUT")
@app.put("/put", stream=True)
async def put(request):
assert isinstance(request.stream, StreamBuffer)
result = ""
while True:
body = await request.stream.read()
@@ -143,12 +126,10 @@ def test_request_stream_app(app):
@app.patch("/_patch")
async def _patch(request):
assert request.stream is None
return text("_PATCH")
@app.patch("/patch", stream=True)
async def patch(request):
assert isinstance(request.stream, StreamBuffer)
result = ""
while True:
body = await request.stream.read()
@@ -157,8 +138,6 @@ def test_request_stream_app(app):
result += body.decode("utf-8")
return text(result)
assert app.is_request_stream is True
request, response = app.test_client.get("/get")
assert response.status == 200
assert response.text == "GET"
@@ -202,36 +181,28 @@ def test_request_stream_app(app):
@pytest.mark.asyncio
async def test_request_stream_app_asgi(app):
"""for self.is_request_stream = True and decorators"""
@app.get("/get")
async def get(request):
assert request.stream is None
return text("GET")
@app.head("/head")
async def head(request):
assert request.stream is None
return text("HEAD")
@app.delete("/delete")
async def delete(request):
assert request.stream is None
return text("DELETE")
@app.options("/options")
async def options(request):
assert request.stream is None
return text("OPTIONS")
@app.post("/_post/<id>")
async def _post(request, id):
assert request.stream is None
return text("_POST")
@app.post("/post/<id>", stream=True)
async def post(request, id):
assert isinstance(request.stream, StreamBuffer)
result = ""
while True:
body = await request.stream.read()
@@ -242,12 +213,10 @@ async def test_request_stream_app_asgi(app):
@app.put("/_put")
async def _put(request):
assert request.stream is None
return text("_PUT")
@app.put("/put", stream=True)
async def put(request):
assert isinstance(request.stream, StreamBuffer)
result = ""
while True:
body = await request.stream.read()
@@ -258,12 +227,10 @@ async def test_request_stream_app_asgi(app):
@app.patch("/_patch")
async def _patch(request):
assert request.stream is None
return text("_PATCH")
@app.patch("/patch", stream=True)
async def patch(request):
assert isinstance(request.stream, StreamBuffer)
result = ""
while True:
body = await request.stream.read()
@@ -272,8 +239,6 @@ async def test_request_stream_app_asgi(app):
result += body.decode("utf-8")
return text(result)
assert app.is_request_stream is True
request, response = await app.asgi_client.get("/get")
assert response.status == 200
assert response.text == "GET"
@@ -320,14 +285,13 @@ def test_request_stream_handle_exception(app):
@app.post("/post/<id>", stream=True)
async def post(request, id):
assert isinstance(request.stream, StreamBuffer)
result = ""
result = b""
while True:
body = await request.stream.read()
if body is None:
break
result += body.decode("utf-8")
return text(result)
result += body
return text(result.decode())
# 404
request, response = app.test_client.post("/in_valid_post", data=data)
@@ -340,54 +304,31 @@ def test_request_stream_handle_exception(app):
assert "Method GET not allowed for URL /post/random_id" in response.text
@pytest.mark.asyncio
async def test_request_stream_unread(app):
"""ensure no error is raised when leaving unread bytes in byte-buffer"""
err = None
protocol = HttpProtocol(loop=asyncio.get_event_loop(), app=app)
try:
protocol.request = None
protocol._body_chunks.append("this is a test")
await protocol.stream_append()
except AttributeError as e:
err = e
assert err is None and not protocol._body_chunks
def test_request_stream_blueprint(app):
"""for self.is_request_stream = True"""
bp = Blueprint("test_blueprint_request_stream_blueprint")
@app.get("/get")
async def get(request):
assert request.stream is None
return text("GET")
@bp.head("/head")
async def head(request):
assert request.stream is None
return text("HEAD")
@bp.delete("/delete")
async def delete(request):
assert request.stream is None
return text("DELETE")
@bp.options("/options")
async def options(request):
assert request.stream is None
return text("OPTIONS")
@bp.post("/_post/<id>")
async def _post(request, id):
assert request.stream is None
return text("_POST")
@bp.post("/post/<id>", stream=True)
async def post(request, id):
assert isinstance(request.stream, StreamBuffer)
result = ""
while True:
body = await request.stream.read()
@@ -398,12 +339,10 @@ def test_request_stream_blueprint(app):
@bp.put("/_put")
async def _put(request):
assert request.stream is None
return text("_PUT")
@bp.put("/put", stream=True)
async def put(request):
assert isinstance(request.stream, StreamBuffer)
result = ""
while True:
body = await request.stream.read()
@@ -414,12 +353,10 @@ def test_request_stream_blueprint(app):
@bp.patch("/_patch")
async def _patch(request):
assert request.stream is None
return text("_PATCH")
@bp.patch("/patch", stream=True)
async def patch(request):
assert isinstance(request.stream, StreamBuffer)
result = ""
while True:
body = await request.stream.read()
@@ -429,7 +366,6 @@ def test_request_stream_blueprint(app):
return text(result)
async def post_add_route(request):
assert isinstance(request.stream, StreamBuffer)
result = ""
while True:
body = await request.stream.read()
@@ -443,8 +379,6 @@ def test_request_stream_blueprint(app):
)
app.blueprint(bp)
assert app.is_request_stream is True
request, response = app.test_client.get("/get")
assert response.status == 200
assert response.text == "GET"
@@ -491,14 +425,10 @@ def test_request_stream_blueprint(app):
def test_request_stream_composition_view(app):
"""for self.is_request_stream = True"""
def get_handler(request):
assert request.stream is None
return text("OK")
async def post_handler(request):
assert isinstance(request.stream, StreamBuffer)
result = ""
while True:
body = await request.stream.read()
@@ -512,8 +442,6 @@ def test_request_stream_composition_view(app):
view.add(["POST"], post_handler, stream=True)
app.add_route(view, "/composition_view")
assert app.is_request_stream is True
request, response = app.test_client.get("/composition_view")
assert response.status == 200
assert response.text == "OK"
@@ -529,12 +457,10 @@ def test_request_stream(app):
class SimpleView(HTTPMethodView):
def get(self, request):
assert request.stream is None
return text("OK")
@stream_decorator
async def post(self, request):
assert isinstance(request.stream, StreamBuffer)
result = ""
while True:
body = await request.stream.read()
@@ -545,7 +471,6 @@ def test_request_stream(app):
@app.post("/stream", stream=True)
async def handler(request):
assert isinstance(request.stream, StreamBuffer)
result = ""
while True:
body = await request.stream.read()
@@ -556,12 +481,10 @@ def test_request_stream(app):
@app.get("/get")
async def get(request):
assert request.stream is None
return text("OK")
@bp.post("/bp_stream", stream=True)
async def bp_stream(request):
assert isinstance(request.stream, StreamBuffer)
result = ""
while True:
body = await request.stream.read()
@@ -572,15 +495,12 @@ def test_request_stream(app):
@bp.get("/bp_get")
async def bp_get(request):
assert request.stream is None
return text("OK")
def get_handler(request):
assert request.stream is None
return text("OK")
async def post_handler(request):
assert isinstance(request.stream, StreamBuffer)
result = ""
while True:
body = await request.stream.read()
@@ -599,8 +519,6 @@ def test_request_stream(app):
app.add_route(view, "/composition_view")
assert app.is_request_stream is True
request, response = app.test_client.get("/method_view")
assert response.status == 200
assert response.text == "OK"
@@ -636,14 +554,14 @@ def test_request_stream(app):
def test_streaming_new_api(app):
@app.post("/non-stream")
async def handler(request):
async def handler1(request):
assert request.body == b"x"
await request.receive_body() # This should do nothing
assert request.body == b"x"
return text("OK")
@app.post("/1", stream=True)
async def handler(request):
async def handler2(request):
assert request.stream
assert not request.body
await request.receive_body()
@@ -671,5 +589,85 @@ def test_streaming_new_api(app):
assert response.status == 200
res = response.json
assert isinstance(res, list)
assert len(res) > 1
assert "".join(res) == data
def test_streaming_echo():
"""2-way streaming chat between server and client."""
app = Sanic(name=__name__)
@app.post("/echo", stream=True)
async def handler(request):
res = await request.respond(content_type="text/plain; charset=utf-8")
# Send headers
await res.send(end_stream=False)
# Echo back data (case swapped)
async for data in request.stream:
await res.send(data.swapcase())
# Add EOF marker after successful operation
await res.send(b"-", end_stream=True)
@app.listener("after_server_start")
async def client_task(app, loop):
try:
reader, writer = await asyncio.open_connection(*addr)
await client(app, reader, writer)
finally:
writer.close()
app.stop()
async def client(app, reader, writer):
# Unfortunately httpx does not support 2-way streaming, so do it by hand.
host = f"host: {addr[0]}:{addr[1]}\r\n".encode()
writer.write(
b"POST /echo HTTP/1.1\r\n" + host + b"content-length: 2\r\n"
b"content-type: text/plain; charset=utf-8\r\n"
b"\r\n"
)
# Read response
res = b""
while not b"\r\n\r\n" in res:
res += await reader.read(4096)
assert res.startswith(b"HTTP/1.1 200 OK\r\n")
assert res.endswith(b"\r\n\r\n")
buffer = b""
async def read_chunk():
nonlocal buffer
while not b"\r\n" in buffer:
data = await reader.read(4096)
assert data
buffer += data
size, buffer = buffer.split(b"\r\n", 1)
size = int(size, 16)
if size == 0:
return None
while len(buffer) < size + 2:
data = await reader.read(4096)
assert data
buffer += data
print(res)
assert buffer[size : size + 2] == b"\r\n"
ret, buffer = buffer[:size], buffer[size + 2 :]
return ret
# Chat with server
writer.write(b"a")
res = await read_chunk()
assert res == b"A"
writer.write(b"b")
res = await read_chunk()
assert res == b"B"
res = await read_chunk()
assert res == b"-"
res = await read_chunk()
assert res == None
# Use random port for tests
with closing(socket()) as sock:
sock.bind(("127.0.0.1", 0))
addr = sock.getsockname()
app.run(sock=sock, access_log=False)

View File

@@ -59,7 +59,7 @@ def test_ip(app):
@pytest.mark.asyncio
async def test_ip_asgi(app):
async def test_url_asgi(app):
@app.route("/")
def handler(request):
return text(f"{request.url}")
@@ -2119,3 +2119,37 @@ def test_url_for_without_server_name(app):
response.json["url"]
== f"http://127.0.0.1:{request.server_port}/url-for"
)
def test_safe_method_with_body_ignored(app):
@app.get("/")
async def handler(request):
return text("OK")
payload = {"test": "OK"}
headers = {"content-type": "application/json"}
request, response = app.test_client.request(
"/", http_method="get", data=json_dumps(payload), headers=headers
)
assert request.body == b""
assert request.json == None
assert response.text == "OK"
def test_safe_method_with_body(app):
@app.get("/", ignore_body=False)
async def handler(request):
return text("OK")
payload = {"test": "OK"}
headers = {"content-type": "application/json"}
data = json_dumps(payload)
request, response = app.test_client.request(
"/", http_method="get", data=data, headers=headers
)
assert request.body == data.encode("utf-8")
assert request.json.get("test") == "OK"
assert response.text == "OK"

View File

@@ -85,7 +85,6 @@ def test_response_header(app):
request, response = app.test_client.get("/")
assert dict(response.headers) == {
"connection": "keep-alive",
"keep-alive": str(app.config.KEEP_ALIVE_TIMEOUT),
"content-length": "11",
"content-type": "application/json",
}
@@ -201,7 +200,6 @@ def streaming_app(app):
async def test(request):
return stream(
sample_streaming_fn,
headers={"Content-Length": "7"},
content_type="text/csv",
)
@@ -243,7 +241,12 @@ async def test_chunked_streaming_returns_correct_content_asgi(streaming_app):
def test_non_chunked_streaming_adds_correct_headers(non_chunked_streaming_app):
request, response = non_chunked_streaming_app.test_client.get("/")
with pytest.warns(UserWarning) as record:
request, response = non_chunked_streaming_app.test_client.get("/")
assert len(record) == 1
assert "removed in v21.6" in record[0].message.args[0]
assert "Transfer-Encoding" not in response.headers
assert response.headers["Content-Type"] == "text/csv"
assert response.headers["Content-Length"] == "7"
@@ -266,115 +269,6 @@ def test_non_chunked_streaming_returns_correct_content(
assert response.text == "foo,bar"
@pytest.mark.parametrize("status", [200, 201, 400, 401])
def test_stream_response_status_returns_correct_headers(status):
response = StreamingHTTPResponse(sample_streaming_fn, status=status)
headers = response.get_headers()
assert b"HTTP/1.1 %s" % str(status).encode() in headers
@pytest.mark.parametrize("keep_alive_timeout", [10, 20, 30])
def test_stream_response_keep_alive_returns_correct_headers(
keep_alive_timeout,
):
response = StreamingHTTPResponse(sample_streaming_fn)
headers = response.get_headers(
keep_alive=True, keep_alive_timeout=keep_alive_timeout
)
assert b"Keep-Alive: %s\r\n" % str(keep_alive_timeout).encode() in headers
def test_stream_response_includes_chunked_header_http11():
response = StreamingHTTPResponse(sample_streaming_fn)
headers = response.get_headers(version="1.1")
assert b"Transfer-Encoding: chunked\r\n" in headers
def test_stream_response_does_not_include_chunked_header_http10():
response = StreamingHTTPResponse(sample_streaming_fn)
headers = response.get_headers(version="1.0")
assert b"Transfer-Encoding: chunked\r\n" not in headers
def test_stream_response_does_not_include_chunked_header_if_disabled():
response = StreamingHTTPResponse(sample_streaming_fn, chunked=False)
headers = response.get_headers(version="1.1")
assert b"Transfer-Encoding: chunked\r\n" not in headers
def test_stream_response_writes_correct_content_to_transport_when_chunked(
streaming_app,
):
response = StreamingHTTPResponse(sample_streaming_fn)
response.protocol = MagicMock(HttpProtocol)
response.protocol.transport = MagicMock(asyncio.Transport)
async def mock_drain():
pass
async def mock_push_data(data):
response.protocol.transport.write(data)
response.protocol.push_data = mock_push_data
response.protocol.drain = mock_drain
@streaming_app.listener("after_server_start")
async def run_stream(app, loop):
await response.stream()
assert response.protocol.transport.write.call_args_list[1][0][0] == (
b"4\r\nfoo,\r\n"
)
assert response.protocol.transport.write.call_args_list[2][0][0] == (
b"3\r\nbar\r\n"
)
assert response.protocol.transport.write.call_args_list[3][0][0] == (
b"0\r\n\r\n"
)
assert len(response.protocol.transport.write.call_args_list) == 4
app.stop()
streaming_app.run(host=HOST, port=PORT)
def test_stream_response_writes_correct_content_to_transport_when_not_chunked(
streaming_app,
):
response = StreamingHTTPResponse(sample_streaming_fn)
response.protocol = MagicMock(HttpProtocol)
response.protocol.transport = MagicMock(asyncio.Transport)
async def mock_drain():
pass
async def mock_push_data(data):
response.protocol.transport.write(data)
response.protocol.push_data = mock_push_data
response.protocol.drain = mock_drain
@streaming_app.listener("after_server_start")
async def run_stream(app, loop):
await response.stream(version="1.0")
assert response.protocol.transport.write.call_args_list[1][0][0] == (
b"foo,"
)
assert response.protocol.transport.write.call_args_list[2][0][0] == (
b"bar"
)
assert len(response.protocol.transport.write.call_args_list) == 3
app.stop()
streaming_app.run(host=HOST, port=PORT)
def test_stream_response_with_cookies(app):
@app.route("/")
async def test(request):

View File

@@ -67,16 +67,12 @@ def test_shorthand_routes_multiple(app):
def test_route_strict_slash(app):
@app.get("/get", strict_slashes=True)
def handler1(request):
assert request.stream is None
return text("OK")
@app.post("/post/", strict_slashes=True)
def handler2(request):
assert request.stream is None
return text("OK")
assert app.is_request_stream is False
request, response = app.test_client.get("/get")
assert response.text == "OK"
@@ -216,11 +212,8 @@ def test_shorthand_routes_post(app):
def test_shorthand_routes_put(app):
@app.put("/put")
def handler(request):
assert request.stream is None
return text("OK")
assert app.is_request_stream is False
request, response = app.test_client.put("/put")
assert response.text == "OK"
@@ -231,11 +224,8 @@ def test_shorthand_routes_put(app):
def test_shorthand_routes_delete(app):
@app.delete("/delete")
def handler(request):
assert request.stream is None
return text("OK")
assert app.is_request_stream is False
request, response = app.test_client.delete("/delete")
assert response.text == "OK"
@@ -246,11 +236,8 @@ def test_shorthand_routes_delete(app):
def test_shorthand_routes_patch(app):
@app.patch("/patch")
def handler(request):
assert request.stream is None
return text("OK")
assert app.is_request_stream is False
request, response = app.test_client.patch("/patch")
assert response.text == "OK"
@@ -261,11 +248,8 @@ def test_shorthand_routes_patch(app):
def test_shorthand_routes_head(app):
@app.head("/head")
def handler(request):
assert request.stream is None
return text("OK")
assert app.is_request_stream is False
request, response = app.test_client.head("/head")
assert response.status == 200
@@ -276,11 +260,8 @@ def test_shorthand_routes_head(app):
def test_shorthand_routes_options(app):
@app.options("/options")
def handler(request):
assert request.stream is None
return text("OK")
assert app.is_request_stream is False
request, response = app.test_client.options("/options")
assert response.status == 200

View File

@@ -0,0 +1,74 @@
import asyncio
from time import monotonic as current_time
from unittest.mock import Mock
import pytest
from sanic import Sanic
from sanic.exceptions import RequestTimeout, ServiceUnavailable
from sanic.http import Stage
from sanic.server import HttpProtocol
@pytest.fixture
def app():
return Sanic("test")
@pytest.fixture
def mock_transport():
return Mock()
@pytest.fixture
def protocol(app, mock_transport):
loop = asyncio.new_event_loop()
protocol = HttpProtocol(loop=loop, app=app)
protocol.connection_made(mock_transport)
protocol._setup_connection()
protocol._task = Mock(spec=asyncio.Task)
protocol._task.cancel = Mock()
return protocol
def test_setup(protocol: HttpProtocol):
assert protocol._task is not None
assert protocol._http is not None
assert protocol._time is not None
def test_check_timeouts_no_timeout(protocol: HttpProtocol):
protocol.keep_alive_timeout = 1
protocol.loop.call_later = Mock()
protocol.check_timeouts()
protocol._task.cancel.assert_not_called()
assert protocol._http.stage is Stage.IDLE
assert protocol._http.exception is None
protocol.loop.call_later.assert_called_with(
protocol.keep_alive_timeout / 2, protocol.check_timeouts
)
def test_check_timeouts_keep_alive_timeout(protocol: HttpProtocol):
protocol._http.stage = Stage.IDLE
protocol._time = 0
protocol.check_timeouts()
protocol._task.cancel.assert_called_once()
assert protocol._http.exception is None
def test_check_timeouts_request_timeout(protocol: HttpProtocol):
protocol._http.stage = Stage.REQUEST
protocol._time = 0
protocol.check_timeouts()
protocol._task.cancel.assert_called_once()
assert isinstance(protocol._http.exception, RequestTimeout)
def test_check_timeouts_response_timeout(protocol: HttpProtocol):
protocol._http.stage = Stage.RESPONSE
protocol._time = 0
protocol.check_timeouts()
protocol._task.cancel.assert_called_once()
assert isinstance(protocol._http.exception, ServiceUnavailable)

View File

@@ -12,7 +12,6 @@ from sanic.views import CompositionView, HTTPMethodView
def test_methods(app, method):
class DummyView(HTTPMethodView):
async def get(self, request):
assert request.stream is None
return text("", headers={"method": "GET"})
def post(self, request):
@@ -34,7 +33,6 @@ def test_methods(app, method):
return text("", headers={"method": "DELETE"})
app.add_route(DummyView.as_view(), "/")
assert app.is_request_stream is False
request, response = getattr(app.test_client, method.lower())("/")
assert response.headers["method"] == method
@@ -69,7 +67,6 @@ def test_with_bp(app):
class DummyView(HTTPMethodView):
def get(self, request):
assert request.stream is None
return text("I am get method")
bp.add_route(DummyView.as_view(), "/")
@@ -77,7 +74,6 @@ def test_with_bp(app):
app.blueprint(bp)
request, response = app.test_client.get("/")
assert app.is_request_stream is False
assert response.text == "I am get method"
@@ -210,14 +206,12 @@ def test_composition_view_runs_methods_as_expected(app, method):
view = CompositionView()
def first(request):
assert request.stream is None
return text("first method")
view.add(["GET", "POST", "PUT"], first)
view.add(["DELETE", "PATCH"], lambda x: text("second method"))
app.add_route(view, "/")
assert app.is_request_stream is False
if method in ["GET", "POST", "PUT"]:
request, response = getattr(app.test_client, method.lower())("/")