From fe64a2764d6429511a41fd396c9e1df4a8d3ba7b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?L=2E=20K=C3=A4rkk=C3=A4inen?= Date: Fri, 21 Feb 2020 13:28:50 +0200 Subject: [PATCH] Make all requests streaming and preload body for non-streaming handlers. --- sanic/app.py | 10 +++--- sanic/asgi.py | 21 +++--------- sanic/request.py | 12 +------ sanic/server.py | 62 ++++++++++++++--------------------- tests/test_blueprints.py | 11 ------- tests/test_custom_request.py | 4 --- tests/test_multiprocessing.py | 1 - tests/test_named_routes.py | 10 ------ tests/test_request_stream.py | 62 ++++------------------------------- tests/test_routes.py | 19 ----------- tests/test_views.py | 6 ---- 11 files changed, 41 insertions(+), 177 deletions(-) diff --git a/sanic/app.py b/sanic/app.py index abdd36fb..5da7c925 100644 --- a/sanic/app.py +++ b/sanic/app.py @@ -82,7 +82,6 @@ class Sanic: self.strict_slashes = strict_slashes self.listeners = defaultdict(list) self.is_running = False - self.is_request_stream = False self.websocket_enabled = False self.websocket_tasks = set() self.named_request_middleware = {} @@ -187,9 +186,6 @@ class Sanic: if not uri.startswith("/"): uri = "/" + uri - if stream: - self.is_request_stream = True - if strict_slashes is None: strict_slashes = self.strict_slashes @@ -956,6 +952,10 @@ class Sanic: # Fetch handler from router handler, args, kwargs, uri, name = self.router.get(request) + # Non-streaming handlers have their body preloaded + if not self.router.is_stream_handler(request): + await request.receive_body() + # -------------------------------------------- # # Request Middleware # -------------------------------------------- # @@ -1381,7 +1381,7 @@ class Sanic: server_settings = { "protocol": protocol, "request_class": self.request_class, - "is_request_stream": self.is_request_stream, + "is_request_stream": True, "router": self.router, "host": host, "port": port, diff --git a/sanic/asgi.py b/sanic/asgi.py index f08cc454..b1ebc9c3 100644 --- a/sanic/asgi.py +++ b/sanic/asgi.py @@ -190,7 +190,6 @@ class ASGIApp: sanic_app: "sanic.app.Sanic" request: Request transport: MockTransport - do_stream: bool lifespan: Lifespan ws: Optional[WebSocketConnection] @@ -213,9 +212,6 @@ class ASGIApp: for key, value in scope.get("headers", []) ] ) - instance.do_stream = ( - True if headers.get("expect") == "100-continue" else False - ) instance.lifespan = Lifespan(instance) if scope["type"] == "lifespan": @@ -256,15 +252,9 @@ class ASGIApp: sanic_app, ) - if sanic_app.is_request_stream: - is_stream_handler = sanic_app.router.is_stream_handler( - instance.request - ) - if is_stream_handler: - instance.request.stream = StreamBuffer( - sanic_app.config.REQUEST_BUFFER_QUEUE_SIZE - ) - instance.do_stream = True + instance.request.stream = StreamBuffer( + sanic_app.config.REQUEST_BUFFER_QUEUE_SIZE + ) return instance @@ -300,10 +290,7 @@ class ASGIApp: """ Handle the incoming request. """ - if not self.do_stream: - self.request.body = await self.read_body() - else: - self.sanic_app.loop.create_task(self.stream_body()) + self.sanic_app.loop.create_task(self.stream_body()) handler = self.sanic_app.handle_request callback = None if self.ws else self.stream_callback diff --git a/sanic/request.py b/sanic/request.py index 4087dc33..6cb1b73f 100644 --- a/sanic/request.py +++ b/sanic/request.py @@ -116,7 +116,7 @@ class Request: self.transport = transport # Init but do not inhale - self.body_init() + self.body = None self.ctx = SimpleNamespace() self.parsed_forwarded = None self.parsed_json = None @@ -159,17 +159,7 @@ class Request: Custom context is now stored in `request.custom_context.yourkey`""" setattr(self.ctx, key, value) - def body_init(self): - self.body = [] - - def body_push(self, data): - self.body.append(data) - - def body_finish(self): - self.body = b"".join(self.body) - async def receive_body(self): - assert self.body == [] self.body = b"".join([data async for data in self.stream]) @property diff --git a/sanic/server.py b/sanic/server.py index b9e72197..97242555 100644 --- a/sanic/server.py +++ b/sanic/server.py @@ -328,14 +328,8 @@ class HttpProtocol(asyncio.Protocol): self.expect_handler() if self.is_request_stream: - self._is_stream_handler = self.router.is_stream_handler( - self.request - ) - if self._is_stream_handler: - self.request.stream = StreamBuffer( - self.request_buffer_queue_size - ) - self.execute_request_handler() + self.request.stream = StreamBuffer(self.request_buffer_queue_size) + self.execute_request_handler() def expect_handler(self): """ @@ -353,21 +347,18 @@ class HttpProtocol(asyncio.Protocol): ) def on_body(self, body): - if self.is_request_stream and self._is_stream_handler: - # body chunks can be put into asyncio.Queue out of order if - # multiple tasks put concurrently and the queue is full in python - # 3.7. so we should not create more than one task putting into the - # queue simultaneously. - self._body_chunks.append(body) - if ( - not self._request_stream_task - or self._request_stream_task.done() - ): - self._request_stream_task = self.loop.create_task( - self.stream_append() - ) - else: - self.request.body_push(body) + # body chunks can be put into asyncio.Queue out of order if + # multiple tasks put concurrently and the queue is full in python + # 3.7. so we should not create more than one task putting into the + # queue simultaneously. + self._body_chunks.append(body) + if ( + not self._request_stream_task + or self._request_stream_task.done() + ): + self._request_stream_task = self.loop.create_task( + self.stream_append() + ) async def body_append(self, body): if ( @@ -385,7 +376,7 @@ class HttpProtocol(asyncio.Protocol): await self.request.stream.put(body) async def stream_append(self): - while self._body_chunks: + while self._body_chunks and self.request: body = self._body_chunks.popleft() if self.request.stream.is_full(): self.transport.pause_reading() @@ -393,6 +384,7 @@ class HttpProtocol(asyncio.Protocol): self.transport.resume_reading() else: await self.request.stream.put(body) + self._body_chunks.clear() def on_message_complete(self): # Entire request (headers and whole body) is received. @@ -400,18 +392,15 @@ class HttpProtocol(asyncio.Protocol): if self._request_timeout_handler: self._request_timeout_handler.cancel() self._request_timeout_handler = None - if self.is_request_stream and self._is_stream_handler: - self._body_chunks.append(None) - if ( - not self._request_stream_task - or self._request_stream_task.done() - ): - self._request_stream_task = self.loop.create_task( - self.stream_append() - ) - return - self.request.body_finish() - self.execute_request_handler() + + self._body_chunks.append(None) + if ( + not self._request_stream_task + or self._request_stream_task.done() + ): + self._request_stream_task = self.loop.create_task( + self.stream_append() + ) def execute_request_handler(self): """ @@ -639,7 +628,6 @@ class HttpProtocol(asyncio.Protocol): self._request_handler_task = None self._request_stream_task = None self._total_request_size = 0 - self._is_stream_handler = False def close_if_idle(self): """Close the connection if a request is not being sent or received diff --git a/tests/test_blueprints.py b/tests/test_blueprints.py index c21f3807..55353efa 100644 --- a/tests/test_blueprints.py +++ b/tests/test_blueprints.py @@ -71,7 +71,6 @@ def test_bp(app): app.blueprint(bp) request, response = app.test_client.get("/") - assert app.is_request_stream is False assert response.text == "Hello" @@ -383,48 +382,38 @@ def test_bp_shorthand(app): @blueprint.get("/get") def handler(request): - assert request.stream is None return text("OK") @blueprint.put("/put") def put_handler(request): - assert request.stream is None return text("OK") @blueprint.post("/post") def post_handler(request): - assert request.stream is None return text("OK") @blueprint.head("/head") def head_handler(request): - assert request.stream is None return text("OK") @blueprint.options("/options") def options_handler(request): - assert request.stream is None return text("OK") @blueprint.patch("/patch") def patch_handler(request): - assert request.stream is None return text("OK") @blueprint.delete("/delete") def delete_handler(request): - assert request.stream is None return text("OK") @blueprint.websocket("/ws/", strict_slashes=True) async def websocket_handler(request, ws): - assert request.stream is None ev.set() app.blueprint(blueprint) - assert app.is_request_stream is False - request, response = app.test_client.get("/get") assert response.text == "OK" diff --git a/tests/test_custom_request.py b/tests/test_custom_request.py index d0ae48e7..81f20c8e 100644 --- a/tests/test_custom_request.py +++ b/tests/test_custom_request.py @@ -37,8 +37,6 @@ def test_custom_request(): "/post", data=json_dumps(payload), headers=headers ) - assert isinstance(request.body_buffer, BytesIO) - assert request.body_buffer.closed assert request.body == b'{"test":"OK"}' assert request.json.get("test") == "OK" assert response.text == "OK" @@ -46,8 +44,6 @@ def test_custom_request(): request, response = app.test_client.get("/get") - assert isinstance(request.body_buffer, BytesIO) - assert request.body_buffer.closed assert request.body == b"" assert response.text == "OK" assert response.status == 200 diff --git a/tests/test_multiprocessing.py b/tests/test_multiprocessing.py index 3cd60e56..f1e53933 100644 --- a/tests/test_multiprocessing.py +++ b/tests/test_multiprocessing.py @@ -85,5 +85,4 @@ def test_pickle_app_with_bp(app, protocol): up_p_app = pickle.loads(p_app) assert up_p_app request, response = up_p_app.test_client.get("/") - assert up_p_app.is_request_stream is False assert response.text == "Hello" diff --git a/tests/test_named_routes.py b/tests/test_named_routes.py index 7783e454..29b8e40d 100644 --- a/tests/test_named_routes.py +++ b/tests/test_named_routes.py @@ -107,10 +107,8 @@ def test_shorthand_named_routes_post(app): def test_shorthand_named_routes_put(app): @app.put("/put", name="route_put") def handler(request): - assert request.stream is None return text("OK") - assert app.is_request_stream is False assert app.router.routes_all["/put"].name == "route_put" assert app.url_for("route_put") == "/put" with pytest.raises(URLBuildError): @@ -120,10 +118,8 @@ def test_shorthand_named_routes_put(app): def test_shorthand_named_routes_delete(app): @app.delete("/delete", name="route_delete") def handler(request): - assert request.stream is None return text("OK") - assert app.is_request_stream is False assert app.router.routes_all["/delete"].name == "route_delete" assert app.url_for("route_delete") == "/delete" with pytest.raises(URLBuildError): @@ -133,10 +129,8 @@ def test_shorthand_named_routes_delete(app): def test_shorthand_named_routes_patch(app): @app.patch("/patch", name="route_patch") def handler(request): - assert request.stream is None return text("OK") - assert app.is_request_stream is False assert app.router.routes_all["/patch"].name == "route_patch" assert app.url_for("route_patch") == "/patch" with pytest.raises(URLBuildError): @@ -146,10 +140,8 @@ def test_shorthand_named_routes_patch(app): def test_shorthand_named_routes_head(app): @app.head("/head", name="route_head") def handler(request): - assert request.stream is None return text("OK") - assert app.is_request_stream is False assert app.router.routes_all["/head"].name == "route_head" assert app.url_for("route_head") == "/head" with pytest.raises(URLBuildError): @@ -159,10 +151,8 @@ def test_shorthand_named_routes_head(app): def test_shorthand_named_routes_options(app): @app.options("/options", name="route_options") def handler(request): - assert request.stream is None return text("OK") - assert app.is_request_stream is False assert app.router.routes_all["/options"].name == "route_options" assert app.url_for("route_options") == "/options" with pytest.raises(URLBuildError): diff --git a/tests/test_request_stream.py b/tests/test_request_stream.py index b76e7248..756956c9 100644 --- a/tests/test_request_stream.py +++ b/tests/test_request_stream.py @@ -12,28 +12,23 @@ data = "abc" * 10000000 def test_request_stream_method_view(app): - """for self.is_request_stream = True""" - class SimpleView(HTTPMethodView): def get(self, request): - assert request.stream is None return text("OK") @stream_decorator async def post(self, request): assert isinstance(request.stream, StreamBuffer) - result = "" + result = b"" while True: body = await request.stream.read() if body is None: break - result += body.decode("utf-8") - return text(result) + result += body + return text(result.decode()) app.add_route(SimpleView.as_view(), "/method_view") - assert app.is_request_stream is True - request, response = app.test_client.get("/method_view") assert response.status == 200 assert response.text == "OK" @@ -65,8 +60,6 @@ def test_request_stream_100_continue(app, headers, expect_raise_exception): app.add_route(SimpleView.as_view(), "/method_view") - assert app.is_request_stream is True - if not expect_raise_exception: request, response = app.test_client.post( "/method_view", data=data, headers={"EXPECT": "100-continue"} @@ -84,31 +77,24 @@ def test_request_stream_100_continue(app, headers, expect_raise_exception): def test_request_stream_app(app): - """for self.is_request_stream = True and decorators""" - @app.get("/get") async def get(request): - assert request.stream is None return text("GET") @app.head("/head") async def head(request): - assert request.stream is None return text("HEAD") @app.delete("/delete") async def delete(request): - assert request.stream is None return text("DELETE") @app.options("/options") async def options(request): - assert request.stream is None return text("OPTIONS") @app.post("/_post/") async def _post(request, id): - assert request.stream is None return text("_POST") @app.post("/post/", stream=True) @@ -124,7 +110,6 @@ def test_request_stream_app(app): @app.put("/_put") async def _put(request): - assert request.stream is None return text("_PUT") @app.put("/put", stream=True) @@ -140,7 +125,6 @@ def test_request_stream_app(app): @app.patch("/_patch") async def _patch(request): - assert request.stream is None return text("_PATCH") @app.patch("/patch", stream=True) @@ -154,8 +138,6 @@ def test_request_stream_app(app): result += body.decode("utf-8") return text(result) - assert app.is_request_stream is True - request, response = app.test_client.get("/get") assert response.status == 200 assert response.text == "GET" @@ -199,31 +181,24 @@ def test_request_stream_app(app): @pytest.mark.asyncio async def test_request_stream_app_asgi(app): - """for self.is_request_stream = True and decorators""" - @app.get("/get") async def get(request): - assert request.stream is None return text("GET") @app.head("/head") async def head(request): - assert request.stream is None return text("HEAD") @app.delete("/delete") async def delete(request): - assert request.stream is None return text("DELETE") @app.options("/options") async def options(request): - assert request.stream is None return text("OPTIONS") @app.post("/_post/") async def _post(request, id): - assert request.stream is None return text("_POST") @app.post("/post/", stream=True) @@ -239,7 +214,6 @@ async def test_request_stream_app_asgi(app): @app.put("/_put") async def _put(request): - assert request.stream is None return text("_PUT") @app.put("/put", stream=True) @@ -255,7 +229,6 @@ async def test_request_stream_app_asgi(app): @app.patch("/_patch") async def _patch(request): - assert request.stream is None return text("_PATCH") @app.patch("/patch", stream=True) @@ -269,8 +242,6 @@ async def test_request_stream_app_asgi(app): result += body.decode("utf-8") return text(result) - assert app.is_request_stream is True - request, response = await app.asgi_client.get("/get") assert response.status == 200 assert response.text == "GET" @@ -318,13 +289,13 @@ def test_request_stream_handle_exception(app): @app.post("/post/", stream=True) async def post(request, id): assert isinstance(request.stream, StreamBuffer) - result = "" + result = b"" while True: body = await request.stream.read() if body is None: break - result += body.decode("utf-8") - return text(result) + result += body + return text(result.decode()) # 404 request, response = app.test_client.post("/in_valid_post", data=data) @@ -338,32 +309,26 @@ def test_request_stream_handle_exception(app): def test_request_stream_blueprint(app): - """for self.is_request_stream = True""" bp = Blueprint("test_blueprint_request_stream_blueprint") @app.get("/get") async def get(request): - assert request.stream is None return text("GET") @bp.head("/head") async def head(request): - assert request.stream is None return text("HEAD") @bp.delete("/delete") async def delete(request): - assert request.stream is None return text("DELETE") @bp.options("/options") async def options(request): - assert request.stream is None return text("OPTIONS") @bp.post("/_post/") async def _post(request, id): - assert request.stream is None return text("_POST") @bp.post("/post/", stream=True) @@ -379,7 +344,6 @@ def test_request_stream_blueprint(app): @bp.put("/_put") async def _put(request): - assert request.stream is None return text("_PUT") @bp.put("/put", stream=True) @@ -395,7 +359,6 @@ def test_request_stream_blueprint(app): @bp.patch("/_patch") async def _patch(request): - assert request.stream is None return text("_PATCH") @bp.patch("/patch", stream=True) @@ -424,8 +387,6 @@ def test_request_stream_blueprint(app): ) app.blueprint(bp) - assert app.is_request_stream is True - request, response = app.test_client.get("/get") assert response.status == 200 assert response.text == "GET" @@ -472,10 +433,7 @@ def test_request_stream_blueprint(app): def test_request_stream_composition_view(app): - """for self.is_request_stream = True""" - def get_handler(request): - assert request.stream is None return text("OK") async def post_handler(request): @@ -493,8 +451,6 @@ def test_request_stream_composition_view(app): view.add(["POST"], post_handler, stream=True) app.add_route(view, "/composition_view") - assert app.is_request_stream is True - request, response = app.test_client.get("/composition_view") assert response.status == 200 assert response.text == "OK" @@ -510,7 +466,6 @@ def test_request_stream(app): class SimpleView(HTTPMethodView): def get(self, request): - assert request.stream is None return text("OK") @stream_decorator @@ -537,7 +492,6 @@ def test_request_stream(app): @app.get("/get") async def get(request): - assert request.stream is None return text("OK") @bp.post("/bp_stream", stream=True) @@ -553,11 +507,9 @@ def test_request_stream(app): @bp.get("/bp_get") async def bp_get(request): - assert request.stream is None return text("OK") def get_handler(request): - assert request.stream is None return text("OK") async def post_handler(request): @@ -580,8 +532,6 @@ def test_request_stream(app): app.add_route(view, "/composition_view") - assert app.is_request_stream is True - request, response = app.test_client.get("/method_view") assert response.status == 200 assert response.text == "OK" diff --git a/tests/test_routes.py b/tests/test_routes.py index 31fa1a56..9c4d36a6 100644 --- a/tests/test_routes.py +++ b/tests/test_routes.py @@ -66,16 +66,12 @@ def test_shorthand_routes_multiple(app): def test_route_strict_slash(app): @app.get("/get", strict_slashes=True) def handler1(request): - assert request.stream is None return text("OK") @app.post("/post/", strict_slashes=True) def handler2(request): - assert request.stream is None return text("OK") - assert app.is_request_stream is False - request, response = app.test_client.get("/get") assert response.text == "OK" @@ -214,11 +210,8 @@ def test_shorthand_routes_post(app): def test_shorthand_routes_put(app): @app.put("/put") def handler(request): - assert request.stream is None return text("OK") - assert app.is_request_stream is False - request, response = app.test_client.put("/put") assert response.text == "OK" @@ -229,11 +222,8 @@ def test_shorthand_routes_put(app): def test_shorthand_routes_delete(app): @app.delete("/delete") def handler(request): - assert request.stream is None return text("OK") - assert app.is_request_stream is False - request, response = app.test_client.delete("/delete") assert response.text == "OK" @@ -244,11 +234,8 @@ def test_shorthand_routes_delete(app): def test_shorthand_routes_patch(app): @app.patch("/patch") def handler(request): - assert request.stream is None return text("OK") - assert app.is_request_stream is False - request, response = app.test_client.patch("/patch") assert response.text == "OK" @@ -259,11 +246,8 @@ def test_shorthand_routes_patch(app): def test_shorthand_routes_head(app): @app.head("/head") def handler(request): - assert request.stream is None return text("OK") - assert app.is_request_stream is False - request, response = app.test_client.head("/head") assert response.status == 200 @@ -274,11 +258,8 @@ def test_shorthand_routes_head(app): def test_shorthand_routes_options(app): @app.options("/options") def handler(request): - assert request.stream is None return text("OK") - assert app.is_request_stream is False - request, response = app.test_client.options("/options") assert response.status == 200 diff --git a/tests/test_views.py b/tests/test_views.py index d2844659..c7e3cd5a 100644 --- a/tests/test_views.py +++ b/tests/test_views.py @@ -12,7 +12,6 @@ from sanic.views import CompositionView, HTTPMethodView def test_methods(app, method): class DummyView(HTTPMethodView): async def get(self, request): - assert request.stream is None return text("", headers={"method": "GET"}) def post(self, request): @@ -34,7 +33,6 @@ def test_methods(app, method): return text("", headers={"method": "DELETE"}) app.add_route(DummyView.as_view(), "/") - assert app.is_request_stream is False request, response = getattr(app.test_client, method.lower())("/") assert response.headers["method"] == method @@ -69,7 +67,6 @@ def test_with_bp(app): class DummyView(HTTPMethodView): def get(self, request): - assert request.stream is None return text("I am get method") bp.add_route(DummyView.as_view(), "/") @@ -77,7 +74,6 @@ def test_with_bp(app): app.blueprint(bp) request, response = app.test_client.get("/") - assert app.is_request_stream is False assert response.text == "I am get method" @@ -211,14 +207,12 @@ def test_composition_view_runs_methods_as_expected(app, method): view = CompositionView() def first(request): - assert request.stream is None return text("first method") view.add(["GET", "POST", "PUT"], first) view.add(["DELETE", "PATCH"], lambda x: text("second method")) app.add_route(view, "/") - assert app.is_request_stream is False if method in ["GET", "POST", "PUT"]: request, response = getattr(app.test_client, method.lower())("/")