Fix #1587: add support for handling Expect Header (#1600)

Fix #1587: add support for handling Expect Header
This commit is contained in:
Eli Uriegas 2019-06-10 14:45:37 -07:00 committed by GitHub
commit 072fcfe03e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 61 additions and 2 deletions

View File

@ -218,6 +218,11 @@ class ContentRangeError(SanicException):
} }
@add_status_code(417)
class HeaderExpectationFailed(SanicException):
pass
@add_status_code(403) @add_status_code(403)
class Forbidden(SanicException): class Forbidden(SanicException):
pass pass

View File

@ -29,7 +29,7 @@ except ImportError:
DEFAULT_HTTP_CONTENT_TYPE = "application/octet-stream" DEFAULT_HTTP_CONTENT_TYPE = "application/octet-stream"
EXPECT_HEADER = "EXPECT"
# HTTP/1.1: https://www.w3.org/Protocols/rfc2616/rfc2616-sec7.html#sec7.2.1 # HTTP/1.1: https://www.w3.org/Protocols/rfc2616/rfc2616-sec7.html#sec7.2.1
# > If the media type remains unknown, the recipient SHOULD treat it # > If the media type remains unknown, the recipient SHOULD treat it

View File

@ -15,6 +15,7 @@ from httptools.parser.errors import HttpParserError
from multidict import CIMultiDict from multidict import CIMultiDict
from sanic.exceptions import ( from sanic.exceptions import (
HeaderExpectationFailed,
InvalidUsage, InvalidUsage,
PayloadTooLarge, PayloadTooLarge,
RequestTimeout, RequestTimeout,
@ -22,7 +23,7 @@ from sanic.exceptions import (
ServiceUnavailable, ServiceUnavailable,
) )
from sanic.log import access_logger, logger from sanic.log import access_logger, logger
from sanic.request import Request, StreamBuffer from sanic.request import EXPECT_HEADER, Request, StreamBuffer
from sanic.response import HTTPResponse from sanic.response import HTTPResponse
@ -314,6 +315,10 @@ class HttpProtocol(asyncio.Protocol):
if self._keep_alive_timeout_handler: if self._keep_alive_timeout_handler:
self._keep_alive_timeout_handler.cancel() self._keep_alive_timeout_handler.cancel()
self._keep_alive_timeout_handler = None self._keep_alive_timeout_handler = None
if self.request.headers.get(EXPECT_HEADER):
self.expect_handler()
if self.is_request_stream: if self.is_request_stream:
self._is_stream_handler = self.router.is_stream_handler( self._is_stream_handler = self.router.is_stream_handler(
self.request self.request
@ -324,6 +329,21 @@ class HttpProtocol(asyncio.Protocol):
) )
self.execute_request_handler() self.execute_request_handler()
def expect_handler(self):
"""
Handler for Expect Header.
"""
expect = self.request.headers.get(EXPECT_HEADER)
if self.request.version == "1.1":
if expect.lower() == "100-continue":
self.transport.write(b"HTTP/1.1 100 Continue\r\n\r\n")
else:
self.write_error(
HeaderExpectationFailed(
"Unknown Expect: {expect}".format(expect=expect)
)
)
def on_body(self, body): def on_body(self, body):
if self.is_request_stream and self._is_stream_handler: if self.is_request_stream and self._is_stream_handler:
self._request_stream_task = self.loop.create_task( self._request_stream_task = self.loop.create_task(

View File

@ -1,4 +1,6 @@
import pytest
from sanic.blueprints import Blueprint from sanic.blueprints import Blueprint
from sanic.exceptions import HeaderExpectationFailed
from sanic.request import StreamBuffer from sanic.request import StreamBuffer
from sanic.response import stream, text from sanic.response import stream, text
from sanic.views import CompositionView, HTTPMethodView from sanic.views import CompositionView, HTTPMethodView
@ -40,6 +42,38 @@ def test_request_stream_method_view(app):
assert response.text == data assert response.text == data
@pytest.mark.parametrize("headers, expect_raise_exception", [
({"EXPECT": "100-continue"}, False),
({"EXPECT": "100-continue-extra"}, True),
])
def test_request_stream_100_continue(app, headers, expect_raise_exception):
class SimpleView(HTTPMethodView):
@stream_decorator
async def post(self, request):
assert isinstance(request.stream, StreamBuffer)
result = ""
while True:
body = await request.stream.read()
if body is None:
break
result += body.decode("utf-8")
return text(result)
app.add_route(SimpleView.as_view(), "/method_view")
assert app.is_request_stream is True
if not expect_raise_exception:
request, response = app.test_client.post("/method_view", data=data, headers={"EXPECT": "100-continue"})
assert response.status == 200
assert response.text == data
else:
with pytest.raises(ValueError) as e:
app.test_client.post("/method_view", data=data, headers={"EXPECT": "100-continue-extra"})
assert "Unknown Expect: 100-continue-extra" in str(e)
def test_request_stream_app(app): def test_request_stream_app(app):
"""for self.is_request_stream = True and decorators""" """for self.is_request_stream = True and decorators"""