fix: handle expect header
This commit is contained in:
parent
c15158224b
commit
f21db60859
|
@ -218,6 +218,11 @@ class ContentRangeError(SanicException):
|
|||
}
|
||||
|
||||
|
||||
@add_status_code(417)
|
||||
class HeaderExpectationFailed(SanicException):
|
||||
pass
|
||||
|
||||
|
||||
@add_status_code(403)
|
||||
class Forbidden(SanicException):
|
||||
pass
|
||||
|
|
|
@ -29,7 +29,7 @@ except ImportError:
|
|||
|
||||
|
||||
DEFAULT_HTTP_CONTENT_TYPE = "application/octet-stream"
|
||||
|
||||
EXPECT_HEADER = "EXPECT"
|
||||
|
||||
# HTTP/1.1: https://www.w3.org/Protocols/rfc2616/rfc2616-sec7.html#sec7.2.1
|
||||
# > If the media type remains unknown, the recipient SHOULD treat it
|
||||
|
|
|
@ -15,6 +15,7 @@ from httptools.parser.errors import HttpParserError
|
|||
from multidict import CIMultiDict
|
||||
|
||||
from sanic.exceptions import (
|
||||
HeaderExpectationFailed,
|
||||
InvalidUsage,
|
||||
PayloadTooLarge,
|
||||
RequestTimeout,
|
||||
|
@ -22,7 +23,7 @@ from sanic.exceptions import (
|
|||
ServiceUnavailable,
|
||||
)
|
||||
from sanic.log import access_logger, logger
|
||||
from sanic.request import Request, StreamBuffer
|
||||
from sanic.request import Request, StreamBuffer, EXPECT_HEADER
|
||||
from sanic.response import HTTPResponse
|
||||
|
||||
|
||||
|
@ -314,6 +315,10 @@ class HttpProtocol(asyncio.Protocol):
|
|||
if self._keep_alive_timeout_handler:
|
||||
self._keep_alive_timeout_handler.cancel()
|
||||
self._keep_alive_timeout_handler = None
|
||||
|
||||
if self.request.headers.get(EXPECT_HEADER):
|
||||
self.expect_handler()
|
||||
|
||||
if self.is_request_stream:
|
||||
self._is_stream_handler = self.router.is_stream_handler(
|
||||
self.request
|
||||
|
@ -324,6 +329,17 @@ class HttpProtocol(asyncio.Protocol):
|
|||
)
|
||||
self.execute_request_handler()
|
||||
|
||||
def expect_handler(self):
|
||||
"""
|
||||
Handler for Expect Header.
|
||||
"""
|
||||
expect = self.request.headers.get(EXPECT_HEADER)
|
||||
if self.request.version == "1.1":
|
||||
if expect.lower() == "100-continue":
|
||||
self.transport.write(b"HTTP/1.1 100 Continue\r\n\r\n")
|
||||
else:
|
||||
self.write_error(HeaderExpectationFailed("Unknow Expect: {expect}".format(expect=expect)))
|
||||
|
||||
def on_body(self, body):
|
||||
if self.is_request_stream and self._is_stream_handler:
|
||||
self._request_stream_task = self.loop.create_task(
|
||||
|
|
|
@ -1,4 +1,6 @@
|
|||
import pytest
|
||||
from sanic.blueprints import Blueprint
|
||||
from sanic.exceptions import HeaderExpectationFailed
|
||||
from sanic.request import StreamBuffer
|
||||
from sanic.response import stream, text
|
||||
from sanic.views import CompositionView, HTTPMethodView
|
||||
|
@ -40,6 +42,53 @@ def test_request_stream_method_view(app):
|
|||
assert response.text == data
|
||||
|
||||
|
||||
def test_request_stream_100_continue(app):
|
||||
|
||||
class SimpleView(HTTPMethodView):
|
||||
|
||||
@stream_decorator
|
||||
async def post(self, request):
|
||||
assert isinstance(request.stream, StreamBuffer)
|
||||
result = ""
|
||||
while True:
|
||||
body = await request.stream.read()
|
||||
if body is None:
|
||||
break
|
||||
result += body.decode("utf-8")
|
||||
return text(result)
|
||||
|
||||
app.add_route(SimpleView.as_view(), "/method_view")
|
||||
|
||||
assert app.is_request_stream is True
|
||||
|
||||
request, response = app.test_client.post("/method_view", data=data, headers={"EXPECT": "100-continue"})
|
||||
assert response.status == 200
|
||||
assert response.text == data
|
||||
|
||||
|
||||
def test_request_stream_100_continue_raise_HeaderExpectationFailed(app):
|
||||
|
||||
class SimpleView(HTTPMethodView):
|
||||
|
||||
@stream_decorator
|
||||
async def post(self, request):
|
||||
assert isinstance(request.stream, StreamBuffer)
|
||||
result = ""
|
||||
while True:
|
||||
body = await request.stream.read()
|
||||
if body is None:
|
||||
break
|
||||
result += body.decode("utf-8")
|
||||
return text(result)
|
||||
|
||||
app.add_route(SimpleView.as_view(), "/method_view")
|
||||
|
||||
assert app.is_request_stream is True
|
||||
with pytest.raises(ValueError) as e:
|
||||
app.test_client.post("/method_view", data=data, headers={"EXPECT": "100-continue-extra"})
|
||||
assert "Unknow Expect: 100-continue-extra" in str(e)
|
||||
|
||||
|
||||
def test_request_stream_app(app):
|
||||
"""for self.is_request_stream = True and decorators"""
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user