From f21db608598793cb3cf24bd838ab382a79b6b30a Mon Sep 17 00:00:00 2001 From: Yun Xu Date: Mon, 3 Jun 2019 22:08:24 -0700 Subject: [PATCH 1/4] fix: handle expect header --- sanic/exceptions.py | 5 ++++ sanic/request.py | 2 +- sanic/server.py | 18 ++++++++++++- tests/test_request_stream.py | 49 ++++++++++++++++++++++++++++++++++++ 4 files changed, 72 insertions(+), 2 deletions(-) diff --git a/sanic/exceptions.py b/sanic/exceptions.py index b06c76d1..2c4ab2c0 100644 --- a/sanic/exceptions.py +++ b/sanic/exceptions.py @@ -218,6 +218,11 @@ class ContentRangeError(SanicException): } +@add_status_code(417) +class HeaderExpectationFailed(SanicException): + pass + + @add_status_code(403) class Forbidden(SanicException): pass diff --git a/sanic/request.py b/sanic/request.py index dfb3d1ff..15c2d5c4 100644 --- a/sanic/request.py +++ b/sanic/request.py @@ -29,7 +29,7 @@ except ImportError: DEFAULT_HTTP_CONTENT_TYPE = "application/octet-stream" - +EXPECT_HEADER = "EXPECT" # HTTP/1.1: https://www.w3.org/Protocols/rfc2616/rfc2616-sec7.html#sec7.2.1 # > If the media type remains unknown, the recipient SHOULD treat it diff --git a/sanic/server.py b/sanic/server.py index a2038e3c..b4673cac 100644 --- a/sanic/server.py +++ b/sanic/server.py @@ -15,6 +15,7 @@ from httptools.parser.errors import HttpParserError from multidict import CIMultiDict from sanic.exceptions import ( + HeaderExpectationFailed, InvalidUsage, PayloadTooLarge, RequestTimeout, @@ -22,7 +23,7 @@ from sanic.exceptions import ( ServiceUnavailable, ) from sanic.log import access_logger, logger -from sanic.request import Request, StreamBuffer +from sanic.request import Request, StreamBuffer, EXPECT_HEADER from sanic.response import HTTPResponse @@ -314,6 +315,10 @@ class HttpProtocol(asyncio.Protocol): if self._keep_alive_timeout_handler: self._keep_alive_timeout_handler.cancel() self._keep_alive_timeout_handler = None + + if self.request.headers.get(EXPECT_HEADER): + self.expect_handler() + if self.is_request_stream: self._is_stream_handler = self.router.is_stream_handler( self.request @@ -324,6 +329,17 @@ class HttpProtocol(asyncio.Protocol): ) self.execute_request_handler() + def expect_handler(self): + """ + Handler for Expect Header. + """ + expect = self.request.headers.get(EXPECT_HEADER) + if self.request.version == "1.1": + if expect.lower() == "100-continue": + self.transport.write(b"HTTP/1.1 100 Continue\r\n\r\n") + else: + self.write_error(HeaderExpectationFailed("Unknow Expect: {expect}".format(expect=expect))) + def on_body(self, body): if self.is_request_stream and self._is_stream_handler: self._request_stream_task = self.loop.create_task( diff --git a/tests/test_request_stream.py b/tests/test_request_stream.py index d845dc85..823bd4e4 100644 --- a/tests/test_request_stream.py +++ b/tests/test_request_stream.py @@ -1,4 +1,6 @@ +import pytest from sanic.blueprints import Blueprint +from sanic.exceptions import HeaderExpectationFailed from sanic.request import StreamBuffer from sanic.response import stream, text from sanic.views import CompositionView, HTTPMethodView @@ -40,6 +42,53 @@ def test_request_stream_method_view(app): assert response.text == data +def test_request_stream_100_continue(app): + + class SimpleView(HTTPMethodView): + + @stream_decorator + async def post(self, request): + assert isinstance(request.stream, StreamBuffer) + result = "" + while True: + body = await request.stream.read() + if body is None: + break + result += body.decode("utf-8") + return text(result) + + app.add_route(SimpleView.as_view(), "/method_view") + + assert app.is_request_stream is True + + request, response = app.test_client.post("/method_view", data=data, headers={"EXPECT": "100-continue"}) + assert response.status == 200 + assert response.text == data + + +def test_request_stream_100_continue_raise_HeaderExpectationFailed(app): + + class SimpleView(HTTPMethodView): + + @stream_decorator + async def post(self, request): + assert isinstance(request.stream, StreamBuffer) + result = "" + while True: + body = await request.stream.read() + if body is None: + break + result += body.decode("utf-8") + return text(result) + + app.add_route(SimpleView.as_view(), "/method_view") + + assert app.is_request_stream is True + with pytest.raises(ValueError) as e: + app.test_client.post("/method_view", data=data, headers={"EXPECT": "100-continue-extra"}) + assert "Unknow Expect: 100-continue-extra" in str(e) + + def test_request_stream_app(app): """for self.is_request_stream = True and decorators""" From 2631f10c5e3531d921a3446c0bc7b2fa43db3a17 Mon Sep 17 00:00:00 2001 From: Yun Xu Date: Mon, 3 Jun 2019 22:12:10 -0700 Subject: [PATCH 2/4] lint: fix isort and flake8 complains --- sanic/server.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/sanic/server.py b/sanic/server.py index b4673cac..460436a6 100644 --- a/sanic/server.py +++ b/sanic/server.py @@ -23,7 +23,7 @@ from sanic.exceptions import ( ServiceUnavailable, ) from sanic.log import access_logger, logger -from sanic.request import Request, StreamBuffer, EXPECT_HEADER +from sanic.request import EXPECT_HEADER, Request, StreamBuffer from sanic.response import HTTPResponse @@ -338,7 +338,11 @@ class HttpProtocol(asyncio.Protocol): if expect.lower() == "100-continue": self.transport.write(b"HTTP/1.1 100 Continue\r\n\r\n") else: - self.write_error(HeaderExpectationFailed("Unknow Expect: {expect}".format(expect=expect))) + self.write_error( + HeaderExpectationFailed( + "Unknow Expect: {expect}".format(expect=expect) + ) + ) def on_body(self, body): if self.is_request_stream and self._is_stream_handler: From 39d134994d7652b5d380c6ee35eb7264677fa4c2 Mon Sep 17 00:00:00 2001 From: Yun Xu Date: Tue, 4 Jun 2019 10:25:32 -0700 Subject: [PATCH 3/4] minor: address pr feedbacks, small refactoring and fix --- sanic/server.py | 2 +- tests/test_request_stream.py | 41 ++++++++++++------------------------ 2 files changed, 14 insertions(+), 29 deletions(-) diff --git a/sanic/server.py b/sanic/server.py index 460436a6..f8a9b203 100644 --- a/sanic/server.py +++ b/sanic/server.py @@ -340,7 +340,7 @@ class HttpProtocol(asyncio.Protocol): else: self.write_error( HeaderExpectationFailed( - "Unknow Expect: {expect}".format(expect=expect) + "Unknown Expect: {expect}".format(expect=expect) ) ) diff --git a/tests/test_request_stream.py b/tests/test_request_stream.py index 823bd4e4..7e406ac0 100644 --- a/tests/test_request_stream.py +++ b/tests/test_request_stream.py @@ -42,8 +42,11 @@ def test_request_stream_method_view(app): assert response.text == data -def test_request_stream_100_continue(app): - +@pytest.mark.parametrize("headers, expect_raise_exception", [ +({"EXPECT": "100-continue"}, False), +({"EXPECT": "100-continue-extra"}, True), +]) +def test_request_stream_100_continue(app, headers, expect_raise_exception): class SimpleView(HTTPMethodView): @stream_decorator @@ -61,34 +64,16 @@ def test_request_stream_100_continue(app): assert app.is_request_stream is True - request, response = app.test_client.post("/method_view", data=data, headers={"EXPECT": "100-continue"}) - assert response.status == 200 - assert response.text == data + if not expect_raise_exception: + request, response = app.test_client.post("/method_view", data=data, headers={"EXPECT": "100-continue"}) + assert response.status == 200 + assert response.text == data + else: + with pytest.raises(ValueError) as e: + app.test_client.post("/method_view", data=data, headers={"EXPECT": "100-continue-extra"}) + assert "Unknow Expect: 100-continue-extra" in str(e) -def test_request_stream_100_continue_raise_HeaderExpectationFailed(app): - - class SimpleView(HTTPMethodView): - - @stream_decorator - async def post(self, request): - assert isinstance(request.stream, StreamBuffer) - result = "" - while True: - body = await request.stream.read() - if body is None: - break - result += body.decode("utf-8") - return text(result) - - app.add_route(SimpleView.as_view(), "/method_view") - - assert app.is_request_stream is True - with pytest.raises(ValueError) as e: - app.test_client.post("/method_view", data=data, headers={"EXPECT": "100-continue-extra"}) - assert "Unknow Expect: 100-continue-extra" in str(e) - - def test_request_stream_app(app): """for self.is_request_stream = True and decorators""" From 1b1a51c1bbda2ad9d6de268ea6b4e82be4f63e25 Mon Sep 17 00:00:00 2001 From: Yun Xu Date: Tue, 4 Jun 2019 10:37:03 -0700 Subject: [PATCH 4/4] minor: fix typo in error msg --- tests/test_request_stream.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_request_stream.py b/tests/test_request_stream.py index 7e406ac0..70fa621b 100644 --- a/tests/test_request_stream.py +++ b/tests/test_request_stream.py @@ -71,7 +71,7 @@ def test_request_stream_100_continue(app, headers, expect_raise_exception): else: with pytest.raises(ValueError) as e: app.test_client.post("/method_view", data=data, headers={"EXPECT": "100-continue-extra"}) - assert "Unknow Expect: 100-continue-extra" in str(e) + assert "Unknown Expect: 100-continue-extra" in str(e) def test_request_stream_app(app):